Skip to main content

nautilus_schema/
bool_expr.rs

1//! Boolean expression parser for `@check` and `@@check` constraint attributes.
2//!
3//! The parser operates on tokens already produced by the schema [`Lexer`] and
4//! builds a small AST ([`BoolExpr`]) via recursive descent with operator
5//! precedence climbing.  The AST implements [`Display`] so it can be
6//! round-tripped back to SQL-compatible text.
7
8use std::fmt;
9
10use crate::error::{Result, SchemaError};
11use crate::span::Span;
12use crate::token::{Token, TokenKind};
13
14/// A comparison operator in a boolean expression.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CmpOp {
17    /// `=`
18    Eq,
19    /// `!=` / `<>`
20    Ne,
21    /// `<`
22    Lt,
23    /// `>`
24    Gt,
25    /// `<=`
26    Le,
27    /// `>=`
28    Ge,
29}
30
31impl fmt::Display for CmpOp {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.write_str(match self {
34            CmpOp::Eq => "=",
35            CmpOp::Ne => "<>",
36            CmpOp::Lt => "<",
37            CmpOp::Gt => ">",
38            CmpOp::Le => "<=",
39            CmpOp::Ge => ">=",
40        })
41    }
42}
43
44/// An operand (leaf value) in a boolean expression.
45#[derive(Debug, Clone, PartialEq)]
46pub enum Operand {
47    /// A field reference (e.g. `age`, `status`).
48    Field(String),
49    /// A numeric literal (e.g. `18`, `3.14`).
50    Number(String),
51    /// A string literal (e.g. `"hello"`).
52    StringLit(String),
53    /// A boolean literal (`true` / `false`).
54    Bool(bool),
55    /// A bare identifier inside an `IN [...]` list — treated as an enum variant.
56    EnumVariant(String),
57}
58
59impl fmt::Display for Operand {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            Operand::Field(name) => write!(f, "{}", name),
63            Operand::Number(n) => write!(f, "{}", n),
64            Operand::StringLit(s) => write!(f, "'{}'", s),
65            Operand::Bool(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
66            Operand::EnumVariant(v) => write!(f, "'{}'", v),
67        }
68    }
69}
70
71/// A parsed boolean expression node.
72#[derive(Debug, Clone, PartialEq)]
73pub enum BoolExpr {
74    /// A comparison (e.g. `age > 18`).
75    Comparison {
76        /// Left-hand operand.
77        left: Operand,
78        /// Comparison operator.
79        op: CmpOp,
80        /// Right-hand operand.
81        right: Operand,
82    },
83    /// Logical AND (e.g. `a AND b`).
84    And(Box<BoolExpr>, Box<BoolExpr>),
85    /// Logical OR (e.g. `a OR b`).
86    Or(Box<BoolExpr>, Box<BoolExpr>),
87    /// Logical NOT (e.g. `NOT a`).
88    Not(Box<BoolExpr>),
89    /// IN list (e.g. `status IN [ACTIVE, PENDING]`).
90    In {
91        /// Field being tested.
92        field: String,
93        /// List of values to test against.
94        values: Vec<Operand>,
95    },
96    /// Parenthesised sub-expression.
97    Paren(Box<BoolExpr>),
98}
99
100impl fmt::Display for BoolExpr {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self {
103            BoolExpr::Comparison { left, op, right } => write!(f, "{} {} {}", left, op, right),
104            BoolExpr::And(left, right) => write!(f, "{} AND {}", left, right),
105            BoolExpr::Or(left, right) => write!(f, "{} OR {}", left, right),
106            BoolExpr::Not(inner) => write!(f, "NOT {}", inner),
107            BoolExpr::In { field, values } => {
108                write!(f, "{} IN [", field)?;
109                for (i, val) in values.iter().enumerate() {
110                    if i > 0 {
111                        write!(f, ", ")?;
112                    }
113                    // In schema text, enum variants are bare identifiers; to_sql() quotes them.
114                    match val {
115                        Operand::EnumVariant(v) => write!(f, "{}", v)?,
116                        other => write!(f, "{}", other)?,
117                    }
118                }
119                write!(f, "]")
120            }
121            BoolExpr::Paren(inner) => write!(f, "({})", inner),
122        }
123    }
124}
125
126impl BoolExpr {
127    /// Collect all field references in this expression.
128    pub fn field_references(&self) -> Vec<&str> {
129        let mut refs = Vec::new();
130        self.collect_field_refs(&mut refs);
131        refs
132    }
133
134    fn collect_field_refs<'a>(&'a self, refs: &mut Vec<&'a str>) {
135        match self {
136            BoolExpr::Comparison { left, right, .. } => {
137                if let Operand::Field(name) = left {
138                    refs.push(name);
139                }
140                if let Operand::Field(name) = right {
141                    refs.push(name);
142                }
143            }
144            BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
145                l.collect_field_refs(refs);
146                r.collect_field_refs(refs);
147            }
148            BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
149                inner.collect_field_refs(refs);
150            }
151            BoolExpr::In { field, .. } => {
152                refs.push(field);
153            }
154        }
155    }
156
157    /// Collect `(field_name, [variant_names])` pairs from `IN` nodes where values
158    /// are enum variants. Used by the validator for enum checking.
159    pub fn enum_in_lists(&self) -> Vec<(&str, Vec<&str>)> {
160        let mut result = Vec::new();
161        self.collect_enum_in_lists(&mut result);
162        result
163    }
164
165    fn collect_enum_in_lists<'a>(&'a self, result: &mut Vec<(&'a str, Vec<&'a str>)>) {
166        match self {
167            BoolExpr::In { field, values } => {
168                let variants: Vec<&str> = values
169                    .iter()
170                    .filter_map(|v| match v {
171                        Operand::EnumVariant(name) => Some(name.as_str()),
172                        _ => None,
173                    })
174                    .collect();
175                if !variants.is_empty() {
176                    result.push((field.as_str(), variants));
177                }
178            }
179            BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
180                l.collect_enum_in_lists(result);
181                r.collect_enum_in_lists(result);
182            }
183            BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
184                inner.collect_enum_in_lists(result);
185            }
186            BoolExpr::Comparison { .. } => {}
187        }
188    }
189
190    /// Render this expression as valid SQL (for DDL `CHECK (...)` clauses).
191    ///
192    /// This differs from `Display` in that `IN` lists use SQL `IN (...)` syntax
193    /// and enum variants are rendered as quoted string literals.
194    pub fn to_sql(&self) -> String {
195        match self {
196            BoolExpr::Comparison { left, op, right } => {
197                format!("{} {} {}", left, op, right)
198            }
199            BoolExpr::And(left, right) => format!("{} AND {}", left.to_sql(), right.to_sql()),
200            BoolExpr::Or(left, right) => format!("{} OR {}", left.to_sql(), right.to_sql()),
201            BoolExpr::Not(inner) => format!("NOT {}", inner.to_sql()),
202            BoolExpr::In { field, values } => {
203                let vals: Vec<String> = values.iter().map(|v| v.to_string()).collect();
204                format!("{} IN ({})", field, vals.join(", "))
205            }
206            BoolExpr::Paren(inner) => format!("({})", inner.to_sql()),
207        }
208    }
209}
210
211/// Recursive-descent parser that converts a slice of schema tokens into a
212/// [`BoolExpr`] tree.
213struct BoolExprParser<'a> {
214    tokens: &'a [Token],
215    pos: usize,
216    /// Span used for error reporting when there are no more tokens.
217    fallback_span: Span,
218}
219
220impl<'a> BoolExprParser<'a> {
221    fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
222        Self {
223            tokens,
224            pos: 0,
225            fallback_span,
226        }
227    }
228
229    fn peek(&self) -> Option<&TokenKind> {
230        self.tokens.get(self.pos).map(|t| &t.kind)
231    }
232
233    fn span(&self) -> Span {
234        self.tokens
235            .get(self.pos)
236            .map(|t| t.span)
237            .unwrap_or(self.fallback_span)
238    }
239
240    fn advance(&mut self) -> &Token {
241        let tok = &self.tokens[self.pos];
242        self.pos += 1;
243        tok
244    }
245
246    fn at_end(&self) -> bool {
247        self.pos >= self.tokens.len()
248    }
249
250    fn is_keyword(&self, kw: &str) -> bool {
251        matches!(self.peek(), Some(TokenKind::Ident(s)) if s.eq_ignore_ascii_case(kw))
252    }
253
254    fn parse_expr(&mut self) -> Result<BoolExpr> {
255        self.parse_or()
256    }
257
258    /// OR has the lowest precedence.
259    fn parse_or(&mut self) -> Result<BoolExpr> {
260        let mut left = self.parse_and()?;
261        while self.is_keyword("OR") {
262            self.advance();
263            let right = self.parse_and()?;
264            left = BoolExpr::Or(Box::new(left), Box::new(right));
265        }
266        Ok(left)
267    }
268
269    fn parse_and(&mut self) -> Result<BoolExpr> {
270        let mut left = self.parse_not()?;
271        while self.is_keyword("AND") {
272            self.advance();
273            let right = self.parse_not()?;
274            left = BoolExpr::And(Box::new(left), Box::new(right));
275        }
276        Ok(left)
277    }
278
279    fn parse_not(&mut self) -> Result<BoolExpr> {
280        if self.is_keyword("NOT") {
281            self.advance();
282            let inner = self.parse_not()?;
283            return Ok(BoolExpr::Not(Box::new(inner)));
284        }
285        self.parse_primary()
286    }
287
288    fn parse_primary(&mut self) -> Result<BoolExpr> {
289        if self.at_end() {
290            return Err(SchemaError::Parse(
291                "Unexpected end of check expression".to_string(),
292                self.span(),
293            ));
294        }
295
296        if matches!(self.peek(), Some(TokenKind::LParen)) {
297            self.advance();
298            let inner = self.parse_expr()?;
299            match self.peek() {
300                Some(TokenKind::RParen) => {
301                    self.advance();
302                    return Ok(BoolExpr::Paren(Box::new(inner)));
303                }
304                _ => {
305                    return Err(SchemaError::Parse(
306                        "Expected ')' after parenthesised expression".to_string(),
307                        self.span(),
308                    ));
309                }
310            }
311        }
312
313        if matches!(self.peek(), Some(TokenKind::True)) {
314            self.advance();
315            return Ok(BoolExpr::Comparison {
316                left: Operand::Bool(true),
317                op: CmpOp::Eq,
318                right: Operand::Bool(true),
319            });
320        }
321        if matches!(self.peek(), Some(TokenKind::False)) {
322            self.advance();
323            return Ok(BoolExpr::Comparison {
324                left: Operand::Bool(false),
325                op: CmpOp::Eq,
326                right: Operand::Bool(true),
327            });
328        }
329
330        let left = self.parse_operand(false)?;
331
332        if self.is_keyword("IN") {
333            let field_name = match &left {
334                Operand::Field(name) => name.clone(),
335                _ => {
336                    return Err(SchemaError::Parse(
337                        "Left side of IN must be a field reference".to_string(),
338                        self.span(),
339                    ));
340                }
341            };
342            self.advance();
343            let values = self.parse_in_list()?;
344            return Ok(BoolExpr::In {
345                field: field_name,
346                values,
347            });
348        }
349
350        let op = self.parse_cmp_op()?;
351        let right = self.parse_operand(false)?;
352
353        Ok(BoolExpr::Comparison { left, op, right })
354    }
355
356    /// Parse a single operand (field reference, literal, etc.).
357    /// When `in_list` is true, bare identifiers are treated as enum variants.
358    fn parse_operand(&mut self, in_list: bool) -> Result<Operand> {
359        if self.at_end() {
360            return Err(SchemaError::Parse(
361                "Expected operand in check expression".to_string(),
362                self.span(),
363            ));
364        }
365
366        match self.peek().cloned() {
367            Some(TokenKind::Number(n)) => {
368                self.advance();
369                Ok(Operand::Number(n))
370            }
371            Some(TokenKind::String(s)) => {
372                self.advance();
373                Ok(Operand::StringLit(s))
374            }
375            Some(TokenKind::True) => {
376                self.advance();
377                Ok(Operand::Bool(true))
378            }
379            Some(TokenKind::False) => {
380                self.advance();
381                Ok(Operand::Bool(false))
382            }
383            Some(TokenKind::Ident(name)) => {
384                self.advance();
385                if in_list {
386                    Ok(Operand::EnumVariant(name))
387                } else {
388                    Ok(Operand::Field(name))
389                }
390            }
391            // Schema keywords (e.g. `String`, `Int`) are valid field/enum names inside expressions.
392            Some(k) if k.is_keyword() => {
393                let tok = self.advance();
394                let name = tok.kind.to_string();
395                if in_list {
396                    Ok(Operand::EnumVariant(name))
397                } else {
398                    Ok(Operand::Field(name))
399                }
400            }
401            Some(other) => Err(SchemaError::Parse(
402                format!("Unexpected token '{}' in check expression", other),
403                self.span(),
404            )),
405            None => Err(SchemaError::Parse(
406                "Unexpected end of check expression".to_string(),
407                self.span(),
408            )),
409        }
410    }
411
412    fn parse_cmp_op(&mut self) -> Result<CmpOp> {
413        if self.at_end() {
414            return Err(SchemaError::Parse(
415                "Expected comparison operator".to_string(),
416                self.span(),
417            ));
418        }
419
420        match self.peek() {
421            Some(TokenKind::Equal) => {
422                self.advance();
423                Ok(CmpOp::Eq)
424            }
425            Some(TokenKind::BangEqual) => {
426                self.advance();
427                Ok(CmpOp::Ne)
428            }
429            Some(TokenKind::LAngle) => {
430                self.advance();
431                // `<>` is accepted as an alternative to `!=`.
432                if matches!(self.peek(), Some(TokenKind::RAngle)) {
433                    self.advance();
434                    Ok(CmpOp::Ne)
435                } else {
436                    Ok(CmpOp::Lt)
437                }
438            }
439            Some(TokenKind::RAngle) => {
440                self.advance();
441                Ok(CmpOp::Gt)
442            }
443            Some(TokenKind::LessEqual) => {
444                self.advance();
445                Ok(CmpOp::Le)
446            }
447            Some(TokenKind::GreaterEqual) => {
448                self.advance();
449                Ok(CmpOp::Ge)
450            }
451            Some(other) => Err(SchemaError::Parse(
452                format!(
453                    "Expected comparison operator (=, !=, <, >, <=, >=), got '{}'",
454                    other
455                ),
456                self.span(),
457            )),
458            None => Err(SchemaError::Parse(
459                "Expected comparison operator".to_string(),
460                self.span(),
461            )),
462        }
463    }
464
465    fn parse_in_list(&mut self) -> Result<Vec<Operand>> {
466        match self.peek() {
467            Some(TokenKind::LBracket) => {
468                self.advance();
469            }
470            _ => {
471                return Err(SchemaError::Parse(
472                    "Expected '[' after IN".to_string(),
473                    self.span(),
474                ));
475            }
476        }
477
478        let mut values = Vec::new();
479
480        if !matches!(self.peek(), Some(TokenKind::RBracket)) {
481            values.push(self.parse_operand(true)?);
482            while matches!(self.peek(), Some(TokenKind::Comma)) {
483                self.advance();
484                values.push(self.parse_operand(true)?);
485            }
486        }
487
488        match self.peek() {
489            Some(TokenKind::RBracket) => {
490                self.advance();
491                Ok(values)
492            }
493            _ => Err(SchemaError::Parse(
494                "Expected ']' to close IN list".to_string(),
495                self.span(),
496            )),
497        }
498    }
499}
500
501/// Parse a slice of schema tokens into a validated [`BoolExpr`] tree.
502///
503/// The token slice should contain **only** the expression tokens (i.e. without
504/// the surrounding `@check(` ... `)` scaffolding).
505///
506/// `fallback_span` is used for error reporting when the slice is empty.
507pub fn parse_bool_expr(tokens: &[Token], fallback_span: Span) -> Result<BoolExpr> {
508    if tokens.is_empty() {
509        return Err(SchemaError::Parse(
510            "@check expression is empty".to_string(),
511            fallback_span,
512        ));
513    }
514
515    let mut parser = BoolExprParser::new(tokens, fallback_span);
516    let expr = parser.parse_expr()?;
517
518    if !parser.at_end() {
519        return Err(SchemaError::Parse(
520            format!(
521                "Unexpected token '{}' after check expression",
522                parser.tokens[parser.pos].kind
523            ),
524            parser.span(),
525        ));
526    }
527
528    Ok(expr)
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534    use crate::lexer::Lexer;
535
536    fn tokenize(src: &str) -> Vec<Token> {
537        let mut lexer = Lexer::new(src);
538        let mut tokens = Vec::new();
539        loop {
540            let tok = lexer.next_token().expect("lex error");
541            match tok.kind {
542                TokenKind::Eof => break,
543                TokenKind::Newline => continue,
544                _ => tokens.push(tok),
545            }
546        }
547        tokens
548    }
549
550    fn parse(src: &str) -> BoolExpr {
551        let tokens = tokenize(src);
552        parse_bool_expr(&tokens, Span::new(0, 0)).expect("parse error")
553    }
554
555    fn parse_err(src: &str) -> String {
556        let tokens = tokenize(src);
557        match parse_bool_expr(&tokens, Span::new(0, 0)) {
558            Err(e) => format!("{}", e),
559            Ok(expr) => panic!("Expected error, got: {:?}", expr),
560        }
561    }
562
563    #[test]
564    fn simple_comparison() {
565        let expr = parse("age > 18");
566        assert_eq!(expr.to_string(), "age > 18");
567    }
568
569    #[test]
570    fn less_equal() {
571        let expr = parse("age <= 150");
572        assert_eq!(expr.to_string(), "age <= 150");
573    }
574
575    #[test]
576    fn greater_equal() {
577        let expr = parse("score >= 0");
578        assert_eq!(expr.to_string(), "score >= 0");
579    }
580
581    #[test]
582    fn not_equal() {
583        let expr = parse("status != 0");
584        assert_eq!(expr.to_string(), "status <> 0");
585    }
586
587    #[test]
588    fn equality() {
589        let expr = parse("active = true");
590        assert_eq!(expr.to_string(), "active = TRUE");
591    }
592
593    #[test]
594    fn and_expression() {
595        let expr = parse("age > 18 AND age <= 150");
596        assert_eq!(expr.to_string(), "age > 18 AND age <= 150");
597    }
598
599    #[test]
600    fn or_expression() {
601        let expr = parse("age < 18 OR age > 65");
602        assert_eq!(expr.to_string(), "age < 18 OR age > 65");
603    }
604
605    #[test]
606    fn not_expression() {
607        let expr = parse("NOT age < 0");
608        assert_eq!(expr.to_string(), "NOT age < 0");
609    }
610
611    #[test]
612    fn in_with_enum_variants() {
613        let expr = parse("status IN [ACTIVE, PENDING]");
614        assert_eq!(expr.to_string(), "status IN [ACTIVE, PENDING]");
615    }
616
617    #[test]
618    fn in_with_numbers() {
619        let expr = parse("priority IN [1, 2, 3]");
620        assert_eq!(expr.to_string(), "priority IN [1, 2, 3]");
621    }
622
623    #[test]
624    fn in_with_strings() {
625        let expr = parse("role IN [\"admin\", \"moderator\"]");
626        assert_eq!(expr.to_string(), "role IN ['admin', 'moderator']");
627    }
628
629    #[test]
630    fn complex_and_or() {
631        let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
632        assert_eq!(expr.to_string(), "age > 18 AND status IN [ACTIVE, PENDING]");
633    }
634
635    #[test]
636    fn parenthesised() {
637        let expr = parse("(age > 18 OR admin = true) AND active = true");
638        assert_eq!(
639            expr.to_string(),
640            "(age > 18 OR admin = TRUE) AND active = TRUE"
641        );
642    }
643
644    #[test]
645    fn sql_output() {
646        let expr = parse("status IN [ACTIVE, PENDING]");
647        assert_eq!(expr.to_sql(), "status IN ('ACTIVE', 'PENDING')");
648    }
649
650    #[test]
651    fn sql_output_complex() {
652        let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
653        assert_eq!(
654            expr.to_sql(),
655            "age > 18 AND status IN ('ACTIVE', 'PENDING')"
656        );
657    }
658
659    #[test]
660    fn field_references() {
661        let expr = parse("age > 18 AND status IN [ACTIVE]");
662        let refs = expr.field_references();
663        assert_eq!(refs, vec!["age", "status"]);
664    }
665
666    #[test]
667    fn enum_in_lists() {
668        let expr = parse("status IN [ACTIVE, PENDING] AND role IN [ADMIN]");
669        let lists = expr.enum_in_lists();
670        assert_eq!(lists.len(), 2);
671        assert_eq!(lists[0], ("status", vec!["ACTIVE", "PENDING"]));
672        assert_eq!(lists[1], ("role", vec!["ADMIN"]));
673    }
674
675    #[test]
676    fn empty_is_error() {
677        let tokens: Vec<Token> = vec![];
678        assert!(parse_bool_expr(&tokens, Span::new(0, 0)).is_err());
679    }
680
681    #[test]
682    fn missing_operator_is_error() {
683        let err = parse_err("age 18");
684        assert!(err.contains("Expected comparison operator"));
685    }
686
687    #[test]
688    fn unclosed_in_list_is_error() {
689        let err = parse_err("status IN [ACTIVE, PENDING");
690        assert!(err.contains("Expected ']'"));
691    }
692
693    #[test]
694    fn missing_in_bracket_is_error() {
695        let err = parse_err("status IN ACTIVE");
696        assert!(err.contains("Expected '['"));
697    }
698}