1use std::fmt;
9
10use crate::error::{Result, SchemaError};
11use crate::span::Span;
12use crate::token::{Token, TokenKind};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CmpOp {
17 Eq,
19 Ne,
21 Lt,
23 Gt,
25 Le,
27 Ge,
29}
30
31impl fmt::Display for CmpOp {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.write_str(match self {
34 CmpOp::Eq => "=",
35 CmpOp::Ne => "<>",
36 CmpOp::Lt => "<",
37 CmpOp::Gt => ">",
38 CmpOp::Le => "<=",
39 CmpOp::Ge => ">=",
40 })
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum Operand {
47 Field(String),
49 Number(String),
51 StringLit(String),
53 Bool(bool),
55 EnumVariant(String),
57}
58
59impl fmt::Display for Operand {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 Operand::Field(name) => write!(f, "{}", name),
63 Operand::Number(n) => write!(f, "{}", n),
64 Operand::StringLit(s) => write!(f, "'{}'", s),
65 Operand::Bool(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
66 Operand::EnumVariant(v) => write!(f, "'{}'", v),
67 }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum BoolExpr {
74 Comparison {
76 left: Operand,
78 op: CmpOp,
80 right: Operand,
82 },
83 And(Box<BoolExpr>, Box<BoolExpr>),
85 Or(Box<BoolExpr>, Box<BoolExpr>),
87 Not(Box<BoolExpr>),
89 In {
91 field: String,
93 values: Vec<Operand>,
95 },
96 Paren(Box<BoolExpr>),
98}
99
100impl fmt::Display for BoolExpr {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self {
103 BoolExpr::Comparison { left, op, right } => write!(f, "{} {} {}", left, op, right),
104 BoolExpr::And(left, right) => write!(f, "{} AND {}", left, right),
105 BoolExpr::Or(left, right) => write!(f, "{} OR {}", left, right),
106 BoolExpr::Not(inner) => write!(f, "NOT {}", inner),
107 BoolExpr::In { field, values } => {
108 write!(f, "{} IN [", field)?;
109 for (i, val) in values.iter().enumerate() {
110 if i > 0 {
111 write!(f, ", ")?;
112 }
113 match val {
115 Operand::EnumVariant(v) => write!(f, "{}", v)?,
116 other => write!(f, "{}", other)?,
117 }
118 }
119 write!(f, "]")
120 }
121 BoolExpr::Paren(inner) => write!(f, "({})", inner),
122 }
123 }
124}
125
126impl BoolExpr {
127 pub fn field_references(&self) -> Vec<&str> {
129 let mut refs = Vec::new();
130 self.collect_field_refs(&mut refs);
131 refs
132 }
133
134 fn collect_field_refs<'a>(&'a self, refs: &mut Vec<&'a str>) {
135 match self {
136 BoolExpr::Comparison { left, right, .. } => {
137 if let Operand::Field(name) = left {
138 refs.push(name);
139 }
140 if let Operand::Field(name) = right {
141 refs.push(name);
142 }
143 }
144 BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
145 l.collect_field_refs(refs);
146 r.collect_field_refs(refs);
147 }
148 BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
149 inner.collect_field_refs(refs);
150 }
151 BoolExpr::In { field, .. } => {
152 refs.push(field);
153 }
154 }
155 }
156
157 pub fn enum_in_lists(&self) -> Vec<(&str, Vec<&str>)> {
160 let mut result = Vec::new();
161 self.collect_enum_in_lists(&mut result);
162 result
163 }
164
165 fn collect_enum_in_lists<'a>(&'a self, result: &mut Vec<(&'a str, Vec<&'a str>)>) {
166 match self {
167 BoolExpr::In { field, values } => {
168 let variants: Vec<&str> = values
169 .iter()
170 .filter_map(|v| match v {
171 Operand::EnumVariant(name) => Some(name.as_str()),
172 _ => None,
173 })
174 .collect();
175 if !variants.is_empty() {
176 result.push((field.as_str(), variants));
177 }
178 }
179 BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
180 l.collect_enum_in_lists(result);
181 r.collect_enum_in_lists(result);
182 }
183 BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
184 inner.collect_enum_in_lists(result);
185 }
186 BoolExpr::Comparison { .. } => {}
187 }
188 }
189
190 pub fn to_sql(&self) -> String {
195 self.to_sql_mapped(&|name| name.to_string())
196 }
197
198 pub fn to_sql_mapped<F>(&self, map_field: &F) -> String
201 where
202 F: Fn(&str) -> String,
203 {
204 match self {
205 BoolExpr::Comparison { left, op, right } => {
206 let left_s = match left {
207 Operand::Field(name) => map_field(name),
208 other => other.to_string(),
209 };
210 let right_s = match right {
211 Operand::Field(name) => map_field(name),
212 other => other.to_string(),
213 };
214 format!("{} {} {}", left_s, op, right_s)
215 }
216 BoolExpr::And(left, right) => format!(
217 "{} AND {}",
218 left.to_sql_mapped(map_field),
219 right.to_sql_mapped(map_field)
220 ),
221 BoolExpr::Or(left, right) => format!(
222 "{} OR {}",
223 left.to_sql_mapped(map_field),
224 right.to_sql_mapped(map_field)
225 ),
226 BoolExpr::Not(inner) => format!("NOT {}", inner.to_sql_mapped(map_field)),
227 BoolExpr::In { field, values } => {
228 let vals: Vec<String> = values.iter().map(|v| v.to_string()).collect();
229 format!("{} IN ({})", map_field(field), vals.join(", "))
230 }
231 BoolExpr::Paren(inner) => format!("({})", inner.to_sql_mapped(map_field)),
232 }
233 }
234}
235
236struct BoolExprParser<'a> {
239 tokens: &'a [Token],
240 pos: usize,
241 fallback_span: Span,
243}
244
245impl<'a> BoolExprParser<'a> {
246 fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
247 Self {
248 tokens,
249 pos: 0,
250 fallback_span,
251 }
252 }
253
254 fn peek(&self) -> Option<&TokenKind> {
255 self.tokens.get(self.pos).map(|t| &t.kind)
256 }
257
258 fn span(&self) -> Span {
259 self.tokens
260 .get(self.pos)
261 .map(|t| t.span)
262 .unwrap_or(self.fallback_span)
263 }
264
265 fn advance(&mut self) -> &Token {
266 let tok = &self.tokens[self.pos];
267 self.pos += 1;
268 tok
269 }
270
271 fn at_end(&self) -> bool {
272 self.pos >= self.tokens.len()
273 }
274
275 fn is_keyword(&self, kw: &str) -> bool {
276 matches!(self.peek(), Some(TokenKind::Ident(s)) if s.eq_ignore_ascii_case(kw))
277 }
278
279 fn parse_expr(&mut self) -> Result<BoolExpr> {
280 self.parse_or()
281 }
282
283 fn parse_or(&mut self) -> Result<BoolExpr> {
285 let mut left = self.parse_and()?;
286 while self.is_keyword("OR") {
287 self.advance();
288 let right = self.parse_and()?;
289 left = BoolExpr::Or(Box::new(left), Box::new(right));
290 }
291 Ok(left)
292 }
293
294 fn parse_and(&mut self) -> Result<BoolExpr> {
295 let mut left = self.parse_not()?;
296 while self.is_keyword("AND") {
297 self.advance();
298 let right = self.parse_not()?;
299 left = BoolExpr::And(Box::new(left), Box::new(right));
300 }
301 Ok(left)
302 }
303
304 fn parse_not(&mut self) -> Result<BoolExpr> {
305 if self.is_keyword("NOT") {
306 self.advance();
307 let inner = self.parse_not()?;
308 return Ok(BoolExpr::Not(Box::new(inner)));
309 }
310 self.parse_primary()
311 }
312
313 fn parse_primary(&mut self) -> Result<BoolExpr> {
314 if self.at_end() {
315 return Err(SchemaError::Parse(
316 "Unexpected end of check expression".to_string(),
317 self.span(),
318 ));
319 }
320
321 if matches!(self.peek(), Some(TokenKind::LParen)) {
322 self.advance();
323 let inner = self.parse_expr()?;
324 match self.peek() {
325 Some(TokenKind::RParen) => {
326 self.advance();
327 return Ok(BoolExpr::Paren(Box::new(inner)));
328 }
329 _ => {
330 return Err(SchemaError::Parse(
331 "Expected ')' after parenthesised expression".to_string(),
332 self.span(),
333 ));
334 }
335 }
336 }
337
338 if matches!(self.peek(), Some(TokenKind::True)) {
339 self.advance();
340 return Ok(BoolExpr::Comparison {
341 left: Operand::Bool(true),
342 op: CmpOp::Eq,
343 right: Operand::Bool(true),
344 });
345 }
346 if matches!(self.peek(), Some(TokenKind::False)) {
347 self.advance();
348 return Ok(BoolExpr::Comparison {
349 left: Operand::Bool(false),
350 op: CmpOp::Eq,
351 right: Operand::Bool(true),
352 });
353 }
354
355 let left = self.parse_operand(false)?;
356
357 if self.is_keyword("IN") {
358 let field_name = match &left {
359 Operand::Field(name) => name.clone(),
360 _ => {
361 return Err(SchemaError::Parse(
362 "Left side of IN must be a field reference".to_string(),
363 self.span(),
364 ));
365 }
366 };
367 self.advance();
368 let values = self.parse_in_list()?;
369 return Ok(BoolExpr::In {
370 field: field_name,
371 values,
372 });
373 }
374
375 let op = self.parse_cmp_op()?;
376 let right = self.parse_operand(false)?;
377
378 Ok(BoolExpr::Comparison { left, op, right })
379 }
380
381 fn parse_operand(&mut self, in_list: bool) -> Result<Operand> {
384 if self.at_end() {
385 return Err(SchemaError::Parse(
386 "Expected operand in check expression".to_string(),
387 self.span(),
388 ));
389 }
390
391 match self.peek().cloned() {
392 Some(TokenKind::Number(n)) => {
393 self.advance();
394 Ok(Operand::Number(n))
395 }
396 Some(TokenKind::String(s)) => {
397 self.advance();
398 Ok(Operand::StringLit(s))
399 }
400 Some(TokenKind::True) => {
401 self.advance();
402 Ok(Operand::Bool(true))
403 }
404 Some(TokenKind::False) => {
405 self.advance();
406 Ok(Operand::Bool(false))
407 }
408 Some(TokenKind::Ident(name)) => {
409 self.advance();
410 if in_list {
411 Ok(Operand::EnumVariant(name))
412 } else {
413 Ok(Operand::Field(name))
414 }
415 }
416 Some(k) if k.is_keyword() => {
418 let tok = self.advance();
419 let name = tok.kind.to_string();
420 if in_list {
421 Ok(Operand::EnumVariant(name))
422 } else {
423 Ok(Operand::Field(name))
424 }
425 }
426 Some(other) => Err(SchemaError::Parse(
427 format!("Unexpected token '{}' in check expression", other),
428 self.span(),
429 )),
430 None => Err(SchemaError::Parse(
431 "Unexpected end of check expression".to_string(),
432 self.span(),
433 )),
434 }
435 }
436
437 fn parse_cmp_op(&mut self) -> Result<CmpOp> {
438 if self.at_end() {
439 return Err(SchemaError::Parse(
440 "Expected comparison operator".to_string(),
441 self.span(),
442 ));
443 }
444
445 match self.peek() {
446 Some(TokenKind::Equal) => {
447 self.advance();
448 Ok(CmpOp::Eq)
449 }
450 Some(TokenKind::BangEqual) => {
451 self.advance();
452 Ok(CmpOp::Ne)
453 }
454 Some(TokenKind::LAngle) => {
455 self.advance();
456 if matches!(self.peek(), Some(TokenKind::RAngle)) {
458 self.advance();
459 Ok(CmpOp::Ne)
460 } else {
461 Ok(CmpOp::Lt)
462 }
463 }
464 Some(TokenKind::RAngle) => {
465 self.advance();
466 Ok(CmpOp::Gt)
467 }
468 Some(TokenKind::LessEqual) => {
469 self.advance();
470 Ok(CmpOp::Le)
471 }
472 Some(TokenKind::GreaterEqual) => {
473 self.advance();
474 Ok(CmpOp::Ge)
475 }
476 Some(other) => Err(SchemaError::Parse(
477 format!(
478 "Expected comparison operator (=, !=, <, >, <=, >=), got '{}'",
479 other
480 ),
481 self.span(),
482 )),
483 None => Err(SchemaError::Parse(
484 "Expected comparison operator".to_string(),
485 self.span(),
486 )),
487 }
488 }
489
490 fn parse_in_list(&mut self) -> Result<Vec<Operand>> {
491 match self.peek() {
492 Some(TokenKind::LBracket) => {
493 self.advance();
494 }
495 _ => {
496 return Err(SchemaError::Parse(
497 "Expected '[' after IN".to_string(),
498 self.span(),
499 ));
500 }
501 }
502
503 let mut values = Vec::new();
504
505 if !matches!(self.peek(), Some(TokenKind::RBracket)) {
506 values.push(self.parse_operand(true)?);
507 while matches!(self.peek(), Some(TokenKind::Comma)) {
508 self.advance();
509 values.push(self.parse_operand(true)?);
510 }
511 }
512
513 match self.peek() {
514 Some(TokenKind::RBracket) => {
515 self.advance();
516 Ok(values)
517 }
518 _ => Err(SchemaError::Parse(
519 "Expected ']' to close IN list".to_string(),
520 self.span(),
521 )),
522 }
523 }
524}
525
526pub fn parse_bool_expr(tokens: &[Token], fallback_span: Span) -> Result<BoolExpr> {
533 if tokens.is_empty() {
534 return Err(SchemaError::Parse(
535 "@check expression is empty".to_string(),
536 fallback_span,
537 ));
538 }
539
540 let mut parser = BoolExprParser::new(tokens, fallback_span);
541 let expr = parser.parse_expr()?;
542
543 if !parser.at_end() {
544 return Err(SchemaError::Parse(
545 format!(
546 "Unexpected token '{}' after check expression",
547 parser.tokens[parser.pos].kind
548 ),
549 parser.span(),
550 ));
551 }
552
553 Ok(expr)
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::lexer::Lexer;
560
561 fn tokenize(src: &str) -> Vec<Token> {
562 let mut lexer = Lexer::new(src);
563 let mut tokens = Vec::new();
564 loop {
565 let tok = lexer.next_token().expect("lex error");
566 match tok.kind {
567 TokenKind::Eof => break,
568 TokenKind::Newline => continue,
569 _ => tokens.push(tok),
570 }
571 }
572 tokens
573 }
574
575 fn parse(src: &str) -> BoolExpr {
576 let tokens = tokenize(src);
577 parse_bool_expr(&tokens, Span::new(0, 0)).expect("parse error")
578 }
579
580 fn parse_err(src: &str) -> String {
581 let tokens = tokenize(src);
582 match parse_bool_expr(&tokens, Span::new(0, 0)) {
583 Err(e) => format!("{}", e),
584 Ok(expr) => panic!("Expected error, got: {:?}", expr),
585 }
586 }
587
588 #[test]
589 fn simple_comparison() {
590 let expr = parse("age > 18");
591 assert_eq!(expr.to_string(), "age > 18");
592 }
593
594 #[test]
595 fn less_equal() {
596 let expr = parse("age <= 150");
597 assert_eq!(expr.to_string(), "age <= 150");
598 }
599
600 #[test]
601 fn greater_equal() {
602 let expr = parse("score >= 0");
603 assert_eq!(expr.to_string(), "score >= 0");
604 }
605
606 #[test]
607 fn not_equal() {
608 let expr = parse("status != 0");
609 assert_eq!(expr.to_string(), "status <> 0");
610 }
611
612 #[test]
613 fn equality() {
614 let expr = parse("active = true");
615 assert_eq!(expr.to_string(), "active = TRUE");
616 }
617
618 #[test]
619 fn and_expression() {
620 let expr = parse("age > 18 AND age <= 150");
621 assert_eq!(expr.to_string(), "age > 18 AND age <= 150");
622 }
623
624 #[test]
625 fn or_expression() {
626 let expr = parse("age < 18 OR age > 65");
627 assert_eq!(expr.to_string(), "age < 18 OR age > 65");
628 }
629
630 #[test]
631 fn not_expression() {
632 let expr = parse("NOT age < 0");
633 assert_eq!(expr.to_string(), "NOT age < 0");
634 }
635
636 #[test]
637 fn in_with_enum_variants() {
638 let expr = parse("status IN [ACTIVE, PENDING]");
639 assert_eq!(expr.to_string(), "status IN [ACTIVE, PENDING]");
640 }
641
642 #[test]
643 fn in_with_numbers() {
644 let expr = parse("priority IN [1, 2, 3]");
645 assert_eq!(expr.to_string(), "priority IN [1, 2, 3]");
646 }
647
648 #[test]
649 fn in_with_strings() {
650 let expr = parse("role IN [\"admin\", \"moderator\"]");
651 assert_eq!(expr.to_string(), "role IN ['admin', 'moderator']");
652 }
653
654 #[test]
655 fn complex_and_or() {
656 let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
657 assert_eq!(expr.to_string(), "age > 18 AND status IN [ACTIVE, PENDING]");
658 }
659
660 #[test]
661 fn parenthesised() {
662 let expr = parse("(age > 18 OR admin = true) AND active = true");
663 assert_eq!(
664 expr.to_string(),
665 "(age > 18 OR admin = TRUE) AND active = TRUE"
666 );
667 }
668
669 #[test]
670 fn sql_output() {
671 let expr = parse("status IN [ACTIVE, PENDING]");
672 assert_eq!(expr.to_sql(), "status IN ('ACTIVE', 'PENDING')");
673 }
674
675 #[test]
676 fn sql_output_complex() {
677 let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
678 assert_eq!(
679 expr.to_sql(),
680 "age > 18 AND status IN ('ACTIVE', 'PENDING')"
681 );
682 }
683
684 #[test]
685 fn field_references() {
686 let expr = parse("age > 18 AND status IN [ACTIVE]");
687 let refs = expr.field_references();
688 assert_eq!(refs, vec!["age", "status"]);
689 }
690
691 #[test]
692 fn enum_in_lists() {
693 let expr = parse("status IN [ACTIVE, PENDING] AND role IN [ADMIN]");
694 let lists = expr.enum_in_lists();
695 assert_eq!(lists.len(), 2);
696 assert_eq!(lists[0], ("status", vec!["ACTIVE", "PENDING"]));
697 assert_eq!(lists[1], ("role", vec!["ADMIN"]));
698 }
699
700 #[test]
701 fn empty_is_error() {
702 let tokens: Vec<Token> = vec![];
703 assert!(parse_bool_expr(&tokens, Span::new(0, 0)).is_err());
704 }
705
706 #[test]
707 fn missing_operator_is_error() {
708 let err = parse_err("age 18");
709 assert!(err.contains("Expected comparison operator"));
710 }
711
712 #[test]
713 fn unclosed_in_list_is_error() {
714 let err = parse_err("status IN [ACTIVE, PENDING");
715 assert!(err.contains("Expected ']'"));
716 }
717
718 #[test]
719 fn missing_in_bracket_is_error() {
720 let err = parse_err("status IN ACTIVE");
721 assert!(err.contains("Expected '['"));
722 }
723}