1use std::fmt;
9
10use crate::error::{Result, SchemaError};
11use crate::span::Span;
12use crate::token::{Token, TokenKind};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CmpOp {
17 Eq,
19 Ne,
21 Lt,
23 Gt,
25 Le,
27 Ge,
29}
30
31impl fmt::Display for CmpOp {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.write_str(match self {
34 CmpOp::Eq => "=",
35 CmpOp::Ne => "<>",
36 CmpOp::Lt => "<",
37 CmpOp::Gt => ">",
38 CmpOp::Le => "<=",
39 CmpOp::Ge => ">=",
40 })
41 }
42}
43
44#[derive(Debug, Clone, PartialEq)]
46pub enum Operand {
47 Field(String),
49 Number(String),
51 StringLit(String),
53 Bool(bool),
55 EnumVariant(String),
57}
58
59impl fmt::Display for Operand {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 match self {
62 Operand::Field(name) => write!(f, "{}", name),
63 Operand::Number(n) => write!(f, "{}", n),
64 Operand::StringLit(s) => write!(f, "'{}'", s),
65 Operand::Bool(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
66 Operand::EnumVariant(v) => write!(f, "'{}'", v),
67 }
68 }
69}
70
71#[derive(Debug, Clone, PartialEq)]
73pub enum BoolExpr {
74 Comparison {
76 left: Operand,
78 op: CmpOp,
80 right: Operand,
82 },
83 And(Box<BoolExpr>, Box<BoolExpr>),
85 Or(Box<BoolExpr>, Box<BoolExpr>),
87 Not(Box<BoolExpr>),
89 In {
91 field: String,
93 values: Vec<Operand>,
95 },
96 Paren(Box<BoolExpr>),
98}
99
100impl fmt::Display for BoolExpr {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 match self {
103 BoolExpr::Comparison { left, op, right } => write!(f, "{} {} {}", left, op, right),
104 BoolExpr::And(left, right) => write!(f, "{} AND {}", left, right),
105 BoolExpr::Or(left, right) => write!(f, "{} OR {}", left, right),
106 BoolExpr::Not(inner) => write!(f, "NOT {}", inner),
107 BoolExpr::In { field, values } => {
108 write!(f, "{} IN [", field)?;
109 for (i, val) in values.iter().enumerate() {
110 if i > 0 {
111 write!(f, ", ")?;
112 }
113 match val {
115 Operand::EnumVariant(v) => write!(f, "{}", v)?,
116 other => write!(f, "{}", other)?,
117 }
118 }
119 write!(f, "]")
120 }
121 BoolExpr::Paren(inner) => write!(f, "({})", inner),
122 }
123 }
124}
125
126impl BoolExpr {
127 pub fn field_references(&self) -> Vec<&str> {
129 let mut refs = Vec::new();
130 self.collect_field_refs(&mut refs);
131 refs
132 }
133
134 fn collect_field_refs<'a>(&'a self, refs: &mut Vec<&'a str>) {
135 match self {
136 BoolExpr::Comparison { left, right, .. } => {
137 if let Operand::Field(name) = left {
138 refs.push(name);
139 }
140 if let Operand::Field(name) = right {
141 refs.push(name);
142 }
143 }
144 BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
145 l.collect_field_refs(refs);
146 r.collect_field_refs(refs);
147 }
148 BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
149 inner.collect_field_refs(refs);
150 }
151 BoolExpr::In { field, .. } => {
152 refs.push(field);
153 }
154 }
155 }
156
157 pub fn enum_in_lists(&self) -> Vec<(&str, Vec<&str>)> {
160 let mut result = Vec::new();
161 self.collect_enum_in_lists(&mut result);
162 result
163 }
164
165 fn collect_enum_in_lists<'a>(&'a self, result: &mut Vec<(&'a str, Vec<&'a str>)>) {
166 match self {
167 BoolExpr::In { field, values } => {
168 let variants: Vec<&str> = values
169 .iter()
170 .filter_map(|v| match v {
171 Operand::EnumVariant(name) => Some(name.as_str()),
172 _ => None,
173 })
174 .collect();
175 if !variants.is_empty() {
176 result.push((field.as_str(), variants));
177 }
178 }
179 BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
180 l.collect_enum_in_lists(result);
181 r.collect_enum_in_lists(result);
182 }
183 BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
184 inner.collect_enum_in_lists(result);
185 }
186 BoolExpr::Comparison { .. } => {}
187 }
188 }
189
190 pub fn to_sql(&self) -> String {
195 match self {
196 BoolExpr::Comparison { left, op, right } => {
197 format!("{} {} {}", left, op, right)
198 }
199 BoolExpr::And(left, right) => format!("{} AND {}", left.to_sql(), right.to_sql()),
200 BoolExpr::Or(left, right) => format!("{} OR {}", left.to_sql(), right.to_sql()),
201 BoolExpr::Not(inner) => format!("NOT {}", inner.to_sql()),
202 BoolExpr::In { field, values } => {
203 let vals: Vec<String> = values.iter().map(|v| v.to_string()).collect();
204 format!("{} IN ({})", field, vals.join(", "))
205 }
206 BoolExpr::Paren(inner) => format!("({})", inner.to_sql()),
207 }
208 }
209}
210
211struct BoolExprParser<'a> {
214 tokens: &'a [Token],
215 pos: usize,
216 fallback_span: Span,
218}
219
220impl<'a> BoolExprParser<'a> {
221 fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
222 Self {
223 tokens,
224 pos: 0,
225 fallback_span,
226 }
227 }
228
229 fn peek(&self) -> Option<&TokenKind> {
230 self.tokens.get(self.pos).map(|t| &t.kind)
231 }
232
233 fn span(&self) -> Span {
234 self.tokens
235 .get(self.pos)
236 .map(|t| t.span)
237 .unwrap_or(self.fallback_span)
238 }
239
240 fn advance(&mut self) -> &Token {
241 let tok = &self.tokens[self.pos];
242 self.pos += 1;
243 tok
244 }
245
246 fn at_end(&self) -> bool {
247 self.pos >= self.tokens.len()
248 }
249
250 fn is_keyword(&self, kw: &str) -> bool {
251 matches!(self.peek(), Some(TokenKind::Ident(s)) if s.eq_ignore_ascii_case(kw))
252 }
253
254 fn parse_expr(&mut self) -> Result<BoolExpr> {
255 self.parse_or()
256 }
257
258 fn parse_or(&mut self) -> Result<BoolExpr> {
260 let mut left = self.parse_and()?;
261 while self.is_keyword("OR") {
262 self.advance();
263 let right = self.parse_and()?;
264 left = BoolExpr::Or(Box::new(left), Box::new(right));
265 }
266 Ok(left)
267 }
268
269 fn parse_and(&mut self) -> Result<BoolExpr> {
270 let mut left = self.parse_not()?;
271 while self.is_keyword("AND") {
272 self.advance();
273 let right = self.parse_not()?;
274 left = BoolExpr::And(Box::new(left), Box::new(right));
275 }
276 Ok(left)
277 }
278
279 fn parse_not(&mut self) -> Result<BoolExpr> {
280 if self.is_keyword("NOT") {
281 self.advance();
282 let inner = self.parse_not()?;
283 return Ok(BoolExpr::Not(Box::new(inner)));
284 }
285 self.parse_primary()
286 }
287
288 fn parse_primary(&mut self) -> Result<BoolExpr> {
289 if self.at_end() {
290 return Err(SchemaError::Parse(
291 "Unexpected end of check expression".to_string(),
292 self.span(),
293 ));
294 }
295
296 if matches!(self.peek(), Some(TokenKind::LParen)) {
297 self.advance();
298 let inner = self.parse_expr()?;
299 match self.peek() {
300 Some(TokenKind::RParen) => {
301 self.advance();
302 return Ok(BoolExpr::Paren(Box::new(inner)));
303 }
304 _ => {
305 return Err(SchemaError::Parse(
306 "Expected ')' after parenthesised expression".to_string(),
307 self.span(),
308 ));
309 }
310 }
311 }
312
313 if matches!(self.peek(), Some(TokenKind::True)) {
314 self.advance();
315 return Ok(BoolExpr::Comparison {
316 left: Operand::Bool(true),
317 op: CmpOp::Eq,
318 right: Operand::Bool(true),
319 });
320 }
321 if matches!(self.peek(), Some(TokenKind::False)) {
322 self.advance();
323 return Ok(BoolExpr::Comparison {
324 left: Operand::Bool(false),
325 op: CmpOp::Eq,
326 right: Operand::Bool(true),
327 });
328 }
329
330 let left = self.parse_operand(false)?;
331
332 if self.is_keyword("IN") {
333 let field_name = match &left {
334 Operand::Field(name) => name.clone(),
335 _ => {
336 return Err(SchemaError::Parse(
337 "Left side of IN must be a field reference".to_string(),
338 self.span(),
339 ));
340 }
341 };
342 self.advance();
343 let values = self.parse_in_list()?;
344 return Ok(BoolExpr::In {
345 field: field_name,
346 values,
347 });
348 }
349
350 let op = self.parse_cmp_op()?;
351 let right = self.parse_operand(false)?;
352
353 Ok(BoolExpr::Comparison { left, op, right })
354 }
355
356 fn parse_operand(&mut self, in_list: bool) -> Result<Operand> {
359 if self.at_end() {
360 return Err(SchemaError::Parse(
361 "Expected operand in check expression".to_string(),
362 self.span(),
363 ));
364 }
365
366 match self.peek().cloned() {
367 Some(TokenKind::Number(n)) => {
368 self.advance();
369 Ok(Operand::Number(n))
370 }
371 Some(TokenKind::String(s)) => {
372 self.advance();
373 Ok(Operand::StringLit(s))
374 }
375 Some(TokenKind::True) => {
376 self.advance();
377 Ok(Operand::Bool(true))
378 }
379 Some(TokenKind::False) => {
380 self.advance();
381 Ok(Operand::Bool(false))
382 }
383 Some(TokenKind::Ident(name)) => {
384 self.advance();
385 if in_list {
386 Ok(Operand::EnumVariant(name))
387 } else {
388 Ok(Operand::Field(name))
389 }
390 }
391 Some(k) if k.is_keyword() => {
393 let tok = self.advance();
394 let name = tok.kind.to_string();
395 if in_list {
396 Ok(Operand::EnumVariant(name))
397 } else {
398 Ok(Operand::Field(name))
399 }
400 }
401 Some(other) => Err(SchemaError::Parse(
402 format!("Unexpected token '{}' in check expression", other),
403 self.span(),
404 )),
405 None => Err(SchemaError::Parse(
406 "Unexpected end of check expression".to_string(),
407 self.span(),
408 )),
409 }
410 }
411
412 fn parse_cmp_op(&mut self) -> Result<CmpOp> {
413 if self.at_end() {
414 return Err(SchemaError::Parse(
415 "Expected comparison operator".to_string(),
416 self.span(),
417 ));
418 }
419
420 match self.peek() {
421 Some(TokenKind::Equal) => {
422 self.advance();
423 Ok(CmpOp::Eq)
424 }
425 Some(TokenKind::BangEqual) => {
426 self.advance();
427 Ok(CmpOp::Ne)
428 }
429 Some(TokenKind::LAngle) => {
430 self.advance();
431 if matches!(self.peek(), Some(TokenKind::RAngle)) {
433 self.advance();
434 Ok(CmpOp::Ne)
435 } else {
436 Ok(CmpOp::Lt)
437 }
438 }
439 Some(TokenKind::RAngle) => {
440 self.advance();
441 Ok(CmpOp::Gt)
442 }
443 Some(TokenKind::LessEqual) => {
444 self.advance();
445 Ok(CmpOp::Le)
446 }
447 Some(TokenKind::GreaterEqual) => {
448 self.advance();
449 Ok(CmpOp::Ge)
450 }
451 Some(other) => Err(SchemaError::Parse(
452 format!(
453 "Expected comparison operator (=, !=, <, >, <=, >=), got '{}'",
454 other
455 ),
456 self.span(),
457 )),
458 None => Err(SchemaError::Parse(
459 "Expected comparison operator".to_string(),
460 self.span(),
461 )),
462 }
463 }
464
465 fn parse_in_list(&mut self) -> Result<Vec<Operand>> {
466 match self.peek() {
467 Some(TokenKind::LBracket) => {
468 self.advance();
469 }
470 _ => {
471 return Err(SchemaError::Parse(
472 "Expected '[' after IN".to_string(),
473 self.span(),
474 ));
475 }
476 }
477
478 let mut values = Vec::new();
479
480 if !matches!(self.peek(), Some(TokenKind::RBracket)) {
481 values.push(self.parse_operand(true)?);
482 while matches!(self.peek(), Some(TokenKind::Comma)) {
483 self.advance();
484 values.push(self.parse_operand(true)?);
485 }
486 }
487
488 match self.peek() {
489 Some(TokenKind::RBracket) => {
490 self.advance();
491 Ok(values)
492 }
493 _ => Err(SchemaError::Parse(
494 "Expected ']' to close IN list".to_string(),
495 self.span(),
496 )),
497 }
498 }
499}
500
501pub fn parse_bool_expr(tokens: &[Token], fallback_span: Span) -> Result<BoolExpr> {
508 if tokens.is_empty() {
509 return Err(SchemaError::Parse(
510 "@check expression is empty".to_string(),
511 fallback_span,
512 ));
513 }
514
515 let mut parser = BoolExprParser::new(tokens, fallback_span);
516 let expr = parser.parse_expr()?;
517
518 if !parser.at_end() {
519 return Err(SchemaError::Parse(
520 format!(
521 "Unexpected token '{}' after check expression",
522 parser.tokens[parser.pos].kind
523 ),
524 parser.span(),
525 ));
526 }
527
528 Ok(expr)
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::lexer::Lexer;
535
536 fn tokenize(src: &str) -> Vec<Token> {
537 let mut lexer = Lexer::new(src);
538 let mut tokens = Vec::new();
539 loop {
540 let tok = lexer.next_token().expect("lex error");
541 match tok.kind {
542 TokenKind::Eof => break,
543 TokenKind::Newline => continue,
544 _ => tokens.push(tok),
545 }
546 }
547 tokens
548 }
549
550 fn parse(src: &str) -> BoolExpr {
551 let tokens = tokenize(src);
552 parse_bool_expr(&tokens, Span::new(0, 0)).expect("parse error")
553 }
554
555 fn parse_err(src: &str) -> String {
556 let tokens = tokenize(src);
557 match parse_bool_expr(&tokens, Span::new(0, 0)) {
558 Err(e) => format!("{}", e),
559 Ok(expr) => panic!("Expected error, got: {:?}", expr),
560 }
561 }
562
563 #[test]
564 fn simple_comparison() {
565 let expr = parse("age > 18");
566 assert_eq!(expr.to_string(), "age > 18");
567 }
568
569 #[test]
570 fn less_equal() {
571 let expr = parse("age <= 150");
572 assert_eq!(expr.to_string(), "age <= 150");
573 }
574
575 #[test]
576 fn greater_equal() {
577 let expr = parse("score >= 0");
578 assert_eq!(expr.to_string(), "score >= 0");
579 }
580
581 #[test]
582 fn not_equal() {
583 let expr = parse("status != 0");
584 assert_eq!(expr.to_string(), "status <> 0");
585 }
586
587 #[test]
588 fn equality() {
589 let expr = parse("active = true");
590 assert_eq!(expr.to_string(), "active = TRUE");
591 }
592
593 #[test]
594 fn and_expression() {
595 let expr = parse("age > 18 AND age <= 150");
596 assert_eq!(expr.to_string(), "age > 18 AND age <= 150");
597 }
598
599 #[test]
600 fn or_expression() {
601 let expr = parse("age < 18 OR age > 65");
602 assert_eq!(expr.to_string(), "age < 18 OR age > 65");
603 }
604
605 #[test]
606 fn not_expression() {
607 let expr = parse("NOT age < 0");
608 assert_eq!(expr.to_string(), "NOT age < 0");
609 }
610
611 #[test]
612 fn in_with_enum_variants() {
613 let expr = parse("status IN [ACTIVE, PENDING]");
614 assert_eq!(expr.to_string(), "status IN [ACTIVE, PENDING]");
615 }
616
617 #[test]
618 fn in_with_numbers() {
619 let expr = parse("priority IN [1, 2, 3]");
620 assert_eq!(expr.to_string(), "priority IN [1, 2, 3]");
621 }
622
623 #[test]
624 fn in_with_strings() {
625 let expr = parse("role IN [\"admin\", \"moderator\"]");
626 assert_eq!(expr.to_string(), "role IN ['admin', 'moderator']");
627 }
628
629 #[test]
630 fn complex_and_or() {
631 let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
632 assert_eq!(expr.to_string(), "age > 18 AND status IN [ACTIVE, PENDING]");
633 }
634
635 #[test]
636 fn parenthesised() {
637 let expr = parse("(age > 18 OR admin = true) AND active = true");
638 assert_eq!(
639 expr.to_string(),
640 "(age > 18 OR admin = TRUE) AND active = TRUE"
641 );
642 }
643
644 #[test]
645 fn sql_output() {
646 let expr = parse("status IN [ACTIVE, PENDING]");
647 assert_eq!(expr.to_sql(), "status IN ('ACTIVE', 'PENDING')");
648 }
649
650 #[test]
651 fn sql_output_complex() {
652 let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
653 assert_eq!(
654 expr.to_sql(),
655 "age > 18 AND status IN ('ACTIVE', 'PENDING')"
656 );
657 }
658
659 #[test]
660 fn field_references() {
661 let expr = parse("age > 18 AND status IN [ACTIVE]");
662 let refs = expr.field_references();
663 assert_eq!(refs, vec!["age", "status"]);
664 }
665
666 #[test]
667 fn enum_in_lists() {
668 let expr = parse("status IN [ACTIVE, PENDING] AND role IN [ADMIN]");
669 let lists = expr.enum_in_lists();
670 assert_eq!(lists.len(), 2);
671 assert_eq!(lists[0], ("status", vec!["ACTIVE", "PENDING"]));
672 assert_eq!(lists[1], ("role", vec!["ADMIN"]));
673 }
674
675 #[test]
676 fn empty_is_error() {
677 let tokens: Vec<Token> = vec![];
678 assert!(parse_bool_expr(&tokens, Span::new(0, 0)).is_err());
679 }
680
681 #[test]
682 fn missing_operator_is_error() {
683 let err = parse_err("age 18");
684 assert!(err.contains("Expected comparison operator"));
685 }
686
687 #[test]
688 fn unclosed_in_list_is_error() {
689 let err = parse_err("status IN [ACTIVE, PENDING");
690 assert!(err.contains("Expected ']'"));
691 }
692
693 #[test]
694 fn missing_in_bracket_is_error() {
695 let err = parse_err("status IN ACTIVE");
696 assert!(err.contains("Expected '['"));
697 }
698}