Skip to main content

nautilus_schema/
bool_expr.rs

1//! Boolean expression parser for `@check` and `@@check` constraint attributes.
2//!
3//! The parser operates on tokens already produced by the schema [`Lexer`] and
4//! builds a small AST ([`BoolExpr`]) via recursive descent with operator
5//! precedence climbing.  The AST implements [`Display`] so it can be
6//! round-tripped back to SQL-compatible text.
7
8use std::fmt;
9
10use crate::error::{Result, SchemaError};
11use crate::span::Span;
12use crate::token::{Token, TokenKind};
13
14/// A comparison operator in a boolean expression.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum CmpOp {
17    /// `=`
18    Eq,
19    /// `!=` / `<>`
20    Ne,
21    /// `<`
22    Lt,
23    /// `>`
24    Gt,
25    /// `<=`
26    Le,
27    /// `>=`
28    Ge,
29}
30
31impl fmt::Display for CmpOp {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.write_str(match self {
34            CmpOp::Eq => "=",
35            CmpOp::Ne => "<>",
36            CmpOp::Lt => "<",
37            CmpOp::Gt => ">",
38            CmpOp::Le => "<=",
39            CmpOp::Ge => ">=",
40        })
41    }
42}
43
44/// An operand (leaf value) in a boolean expression.
45#[derive(Debug, Clone, PartialEq)]
46pub enum Operand {
47    /// A field reference (e.g. `age`, `status`).
48    Field(String),
49    /// A numeric literal (e.g. `18`, `3.14`).
50    Number(String),
51    /// A string literal (e.g. `"hello"`).
52    StringLit(String),
53    /// A boolean literal (`true` / `false`).
54    Bool(bool),
55    /// A bare identifier inside an `IN [...]` list — treated as an enum variant.
56    EnumVariant(String),
57}
58
59impl fmt::Display for Operand {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        match self {
62            Operand::Field(name) => write!(f, "{}", name),
63            Operand::Number(n) => write!(f, "{}", n),
64            Operand::StringLit(s) => write!(f, "'{}'", s),
65            Operand::Bool(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
66            Operand::EnumVariant(v) => write!(f, "'{}'", v),
67        }
68    }
69}
70
71/// A parsed boolean expression node.
72#[derive(Debug, Clone, PartialEq)]
73pub enum BoolExpr {
74    /// A comparison (e.g. `age > 18`).
75    Comparison {
76        /// Left-hand operand.
77        left: Operand,
78        /// Comparison operator.
79        op: CmpOp,
80        /// Right-hand operand.
81        right: Operand,
82    },
83    /// Logical AND (e.g. `a AND b`).
84    And(Box<BoolExpr>, Box<BoolExpr>),
85    /// Logical OR (e.g. `a OR b`).
86    Or(Box<BoolExpr>, Box<BoolExpr>),
87    /// Logical NOT (e.g. `NOT a`).
88    Not(Box<BoolExpr>),
89    /// IN list (e.g. `status IN [ACTIVE, PENDING]`).
90    In {
91        /// Field being tested.
92        field: String,
93        /// List of values to test against.
94        values: Vec<Operand>,
95    },
96    /// Parenthesised sub-expression.
97    Paren(Box<BoolExpr>),
98}
99
100impl fmt::Display for BoolExpr {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self {
103            BoolExpr::Comparison { left, op, right } => write!(f, "{} {} {}", left, op, right),
104            BoolExpr::And(left, right) => write!(f, "{} AND {}", left, right),
105            BoolExpr::Or(left, right) => write!(f, "{} OR {}", left, right),
106            BoolExpr::Not(inner) => write!(f, "NOT {}", inner),
107            BoolExpr::In { field, values } => {
108                write!(f, "{} IN [", field)?;
109                for (i, val) in values.iter().enumerate() {
110                    if i > 0 {
111                        write!(f, ", ")?;
112                    }
113                    // In schema text, enum variants are bare identifiers; to_sql() quotes them.
114                    match val {
115                        Operand::EnumVariant(v) => write!(f, "{}", v)?,
116                        other => write!(f, "{}", other)?,
117                    }
118                }
119                write!(f, "]")
120            }
121            BoolExpr::Paren(inner) => write!(f, "({})", inner),
122        }
123    }
124}
125
126impl BoolExpr {
127    /// Collect all field references in this expression.
128    pub fn field_references(&self) -> Vec<&str> {
129        let mut refs = Vec::new();
130        self.collect_field_refs(&mut refs);
131        refs
132    }
133
134    fn collect_field_refs<'a>(&'a self, refs: &mut Vec<&'a str>) {
135        match self {
136            BoolExpr::Comparison { left, right, .. } => {
137                if let Operand::Field(name) = left {
138                    refs.push(name);
139                }
140                if let Operand::Field(name) = right {
141                    refs.push(name);
142                }
143            }
144            BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
145                l.collect_field_refs(refs);
146                r.collect_field_refs(refs);
147            }
148            BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
149                inner.collect_field_refs(refs);
150            }
151            BoolExpr::In { field, .. } => {
152                refs.push(field);
153            }
154        }
155    }
156
157    /// Collect `(field_name, [variant_names])` pairs from `IN` nodes where values
158    /// are enum variants. Used by the validator for enum checking.
159    pub fn enum_in_lists(&self) -> Vec<(&str, Vec<&str>)> {
160        let mut result = Vec::new();
161        self.collect_enum_in_lists(&mut result);
162        result
163    }
164
165    fn collect_enum_in_lists<'a>(&'a self, result: &mut Vec<(&'a str, Vec<&'a str>)>) {
166        match self {
167            BoolExpr::In { field, values } => {
168                let variants: Vec<&str> = values
169                    .iter()
170                    .filter_map(|v| match v {
171                        Operand::EnumVariant(name) => Some(name.as_str()),
172                        _ => None,
173                    })
174                    .collect();
175                if !variants.is_empty() {
176                    result.push((field.as_str(), variants));
177                }
178            }
179            BoolExpr::And(l, r) | BoolExpr::Or(l, r) => {
180                l.collect_enum_in_lists(result);
181                r.collect_enum_in_lists(result);
182            }
183            BoolExpr::Not(inner) | BoolExpr::Paren(inner) => {
184                inner.collect_enum_in_lists(result);
185            }
186            BoolExpr::Comparison { .. } => {}
187        }
188    }
189
190    /// Render this expression as valid SQL (for DDL `CHECK (...)` clauses).
191    ///
192    /// This differs from `Display` in that `IN` lists use SQL `IN (...)` syntax
193    /// and enum variants are rendered as quoted string literals.
194    pub fn to_sql(&self) -> String {
195        self.to_sql_mapped(&|name| name.to_string())
196    }
197
198    /// Render this expression as valid SQL, mapping logical field names to their
199    /// physical database column names using the provided function.
200    pub fn to_sql_mapped<F>(&self, map_field: &F) -> String
201    where
202        F: Fn(&str) -> String,
203    {
204        match self {
205            BoolExpr::Comparison { left, op, right } => {
206                let left_s = match left {
207                    Operand::Field(name) => map_field(name),
208                    other => other.to_string(),
209                };
210                let right_s = match right {
211                    Operand::Field(name) => map_field(name),
212                    other => other.to_string(),
213                };
214                format!("{} {} {}", left_s, op, right_s)
215            }
216            BoolExpr::And(left, right) => format!(
217                "{} AND {}",
218                left.to_sql_mapped(map_field),
219                right.to_sql_mapped(map_field)
220            ),
221            BoolExpr::Or(left, right) => format!(
222                "{} OR {}",
223                left.to_sql_mapped(map_field),
224                right.to_sql_mapped(map_field)
225            ),
226            BoolExpr::Not(inner) => format!("NOT {}", inner.to_sql_mapped(map_field)),
227            BoolExpr::In { field, values } => {
228                let vals: Vec<String> = values.iter().map(|v| v.to_string()).collect();
229                format!("{} IN ({})", map_field(field), vals.join(", "))
230            }
231            BoolExpr::Paren(inner) => format!("({})", inner.to_sql_mapped(map_field)),
232        }
233    }
234}
235
236/// Recursive-descent parser that converts a slice of schema tokens into a
237/// [`BoolExpr`] tree.
238struct BoolExprParser<'a> {
239    tokens: &'a [Token],
240    pos: usize,
241    /// Span used for error reporting when there are no more tokens.
242    fallback_span: Span,
243}
244
245impl<'a> BoolExprParser<'a> {
246    fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
247        Self {
248            tokens,
249            pos: 0,
250            fallback_span,
251        }
252    }
253
254    fn peek(&self) -> Option<&TokenKind> {
255        self.tokens.get(self.pos).map(|t| &t.kind)
256    }
257
258    fn span(&self) -> Span {
259        self.tokens
260            .get(self.pos)
261            .map(|t| t.span)
262            .unwrap_or(self.fallback_span)
263    }
264
265    fn advance(&mut self) -> &Token {
266        let tok = &self.tokens[self.pos];
267        self.pos += 1;
268        tok
269    }
270
271    fn at_end(&self) -> bool {
272        self.pos >= self.tokens.len()
273    }
274
275    fn is_keyword(&self, kw: &str) -> bool {
276        matches!(self.peek(), Some(TokenKind::Ident(s)) if s.eq_ignore_ascii_case(kw))
277    }
278
279    fn parse_expr(&mut self) -> Result<BoolExpr> {
280        self.parse_or()
281    }
282
283    /// OR has the lowest precedence.
284    fn parse_or(&mut self) -> Result<BoolExpr> {
285        let mut left = self.parse_and()?;
286        while self.is_keyword("OR") {
287            self.advance();
288            let right = self.parse_and()?;
289            left = BoolExpr::Or(Box::new(left), Box::new(right));
290        }
291        Ok(left)
292    }
293
294    fn parse_and(&mut self) -> Result<BoolExpr> {
295        let mut left = self.parse_not()?;
296        while self.is_keyword("AND") {
297            self.advance();
298            let right = self.parse_not()?;
299            left = BoolExpr::And(Box::new(left), Box::new(right));
300        }
301        Ok(left)
302    }
303
304    fn parse_not(&mut self) -> Result<BoolExpr> {
305        if self.is_keyword("NOT") {
306            self.advance();
307            let inner = self.parse_not()?;
308            return Ok(BoolExpr::Not(Box::new(inner)));
309        }
310        self.parse_primary()
311    }
312
313    fn parse_primary(&mut self) -> Result<BoolExpr> {
314        if self.at_end() {
315            return Err(SchemaError::Parse(
316                "Unexpected end of check expression".to_string(),
317                self.span(),
318            ));
319        }
320
321        if matches!(self.peek(), Some(TokenKind::LParen)) {
322            self.advance();
323            let inner = self.parse_expr()?;
324            match self.peek() {
325                Some(TokenKind::RParen) => {
326                    self.advance();
327                    return Ok(BoolExpr::Paren(Box::new(inner)));
328                }
329                _ => {
330                    return Err(SchemaError::Parse(
331                        "Expected ')' after parenthesised expression".to_string(),
332                        self.span(),
333                    ));
334                }
335            }
336        }
337
338        if matches!(self.peek(), Some(TokenKind::True)) {
339            self.advance();
340            return Ok(BoolExpr::Comparison {
341                left: Operand::Bool(true),
342                op: CmpOp::Eq,
343                right: Operand::Bool(true),
344            });
345        }
346        if matches!(self.peek(), Some(TokenKind::False)) {
347            self.advance();
348            return Ok(BoolExpr::Comparison {
349                left: Operand::Bool(false),
350                op: CmpOp::Eq,
351                right: Operand::Bool(true),
352            });
353        }
354
355        let left = self.parse_operand(false)?;
356
357        if self.is_keyword("IN") {
358            let field_name = match &left {
359                Operand::Field(name) => name.clone(),
360                _ => {
361                    return Err(SchemaError::Parse(
362                        "Left side of IN must be a field reference".to_string(),
363                        self.span(),
364                    ));
365                }
366            };
367            self.advance();
368            let values = self.parse_in_list()?;
369            return Ok(BoolExpr::In {
370                field: field_name,
371                values,
372            });
373        }
374
375        let op = self.parse_cmp_op()?;
376        let right = self.parse_operand(false)?;
377
378        Ok(BoolExpr::Comparison { left, op, right })
379    }
380
381    /// Parse a single operand (field reference, literal, etc.).
382    /// When `in_list` is true, bare identifiers are treated as enum variants.
383    fn parse_operand(&mut self, in_list: bool) -> Result<Operand> {
384        if self.at_end() {
385            return Err(SchemaError::Parse(
386                "Expected operand in check expression".to_string(),
387                self.span(),
388            ));
389        }
390
391        match self.peek().cloned() {
392            Some(TokenKind::Number(n)) => {
393                self.advance();
394                Ok(Operand::Number(n))
395            }
396            Some(TokenKind::String(s)) => {
397                self.advance();
398                Ok(Operand::StringLit(s))
399            }
400            Some(TokenKind::True) => {
401                self.advance();
402                Ok(Operand::Bool(true))
403            }
404            Some(TokenKind::False) => {
405                self.advance();
406                Ok(Operand::Bool(false))
407            }
408            Some(TokenKind::Ident(name)) => {
409                self.advance();
410                if in_list {
411                    Ok(Operand::EnumVariant(name))
412                } else {
413                    Ok(Operand::Field(name))
414                }
415            }
416            // Schema keywords (e.g. `String`, `Int`) are valid field/enum names inside expressions.
417            Some(k) if k.is_keyword() => {
418                let tok = self.advance();
419                let name = tok.kind.to_string();
420                if in_list {
421                    Ok(Operand::EnumVariant(name))
422                } else {
423                    Ok(Operand::Field(name))
424                }
425            }
426            Some(other) => Err(SchemaError::Parse(
427                format!("Unexpected token '{}' in check expression", other),
428                self.span(),
429            )),
430            None => Err(SchemaError::Parse(
431                "Unexpected end of check expression".to_string(),
432                self.span(),
433            )),
434        }
435    }
436
437    fn parse_cmp_op(&mut self) -> Result<CmpOp> {
438        if self.at_end() {
439            return Err(SchemaError::Parse(
440                "Expected comparison operator".to_string(),
441                self.span(),
442            ));
443        }
444
445        match self.peek() {
446            Some(TokenKind::Equal) => {
447                self.advance();
448                Ok(CmpOp::Eq)
449            }
450            Some(TokenKind::BangEqual) => {
451                self.advance();
452                Ok(CmpOp::Ne)
453            }
454            Some(TokenKind::LAngle) => {
455                self.advance();
456                // `<>` is accepted as an alternative to `!=`.
457                if matches!(self.peek(), Some(TokenKind::RAngle)) {
458                    self.advance();
459                    Ok(CmpOp::Ne)
460                } else {
461                    Ok(CmpOp::Lt)
462                }
463            }
464            Some(TokenKind::RAngle) => {
465                self.advance();
466                Ok(CmpOp::Gt)
467            }
468            Some(TokenKind::LessEqual) => {
469                self.advance();
470                Ok(CmpOp::Le)
471            }
472            Some(TokenKind::GreaterEqual) => {
473                self.advance();
474                Ok(CmpOp::Ge)
475            }
476            Some(other) => Err(SchemaError::Parse(
477                format!(
478                    "Expected comparison operator (=, !=, <, >, <=, >=), got '{}'",
479                    other
480                ),
481                self.span(),
482            )),
483            None => Err(SchemaError::Parse(
484                "Expected comparison operator".to_string(),
485                self.span(),
486            )),
487        }
488    }
489
490    fn parse_in_list(&mut self) -> Result<Vec<Operand>> {
491        match self.peek() {
492            Some(TokenKind::LBracket) => {
493                self.advance();
494            }
495            _ => {
496                return Err(SchemaError::Parse(
497                    "Expected '[' after IN".to_string(),
498                    self.span(),
499                ));
500            }
501        }
502
503        let mut values = Vec::new();
504
505        if !matches!(self.peek(), Some(TokenKind::RBracket)) {
506            values.push(self.parse_operand(true)?);
507            while matches!(self.peek(), Some(TokenKind::Comma)) {
508                self.advance();
509                values.push(self.parse_operand(true)?);
510            }
511        }
512
513        match self.peek() {
514            Some(TokenKind::RBracket) => {
515                self.advance();
516                Ok(values)
517            }
518            _ => Err(SchemaError::Parse(
519                "Expected ']' to close IN list".to_string(),
520                self.span(),
521            )),
522        }
523    }
524}
525
526/// Parse a slice of schema tokens into a validated [`BoolExpr`] tree.
527///
528/// The token slice should contain **only** the expression tokens (i.e. without
529/// the surrounding `@check(` ... `)` scaffolding).
530///
531/// `fallback_span` is used for error reporting when the slice is empty.
532pub fn parse_bool_expr(tokens: &[Token], fallback_span: Span) -> Result<BoolExpr> {
533    if tokens.is_empty() {
534        return Err(SchemaError::Parse(
535            "@check expression is empty".to_string(),
536            fallback_span,
537        ));
538    }
539
540    let mut parser = BoolExprParser::new(tokens, fallback_span);
541    let expr = parser.parse_expr()?;
542
543    if !parser.at_end() {
544        return Err(SchemaError::Parse(
545            format!(
546                "Unexpected token '{}' after check expression",
547                parser.tokens[parser.pos].kind
548            ),
549            parser.span(),
550        ));
551    }
552
553    Ok(expr)
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use crate::lexer::Lexer;
560
561    fn tokenize(src: &str) -> Vec<Token> {
562        let mut lexer = Lexer::new(src);
563        let mut tokens = Vec::new();
564        loop {
565            let tok = lexer.next_token().expect("lex error");
566            match tok.kind {
567                TokenKind::Eof => break,
568                TokenKind::Newline => continue,
569                _ => tokens.push(tok),
570            }
571        }
572        tokens
573    }
574
575    fn parse(src: &str) -> BoolExpr {
576        let tokens = tokenize(src);
577        parse_bool_expr(&tokens, Span::new(0, 0)).expect("parse error")
578    }
579
580    fn parse_err(src: &str) -> String {
581        let tokens = tokenize(src);
582        match parse_bool_expr(&tokens, Span::new(0, 0)) {
583            Err(e) => format!("{}", e),
584            Ok(expr) => panic!("Expected error, got: {:?}", expr),
585        }
586    }
587
588    #[test]
589    fn simple_comparison() {
590        let expr = parse("age > 18");
591        assert_eq!(expr.to_string(), "age > 18");
592    }
593
594    #[test]
595    fn less_equal() {
596        let expr = parse("age <= 150");
597        assert_eq!(expr.to_string(), "age <= 150");
598    }
599
600    #[test]
601    fn greater_equal() {
602        let expr = parse("score >= 0");
603        assert_eq!(expr.to_string(), "score >= 0");
604    }
605
606    #[test]
607    fn not_equal() {
608        let expr = parse("status != 0");
609        assert_eq!(expr.to_string(), "status <> 0");
610    }
611
612    #[test]
613    fn equality() {
614        let expr = parse("active = true");
615        assert_eq!(expr.to_string(), "active = TRUE");
616    }
617
618    #[test]
619    fn and_expression() {
620        let expr = parse("age > 18 AND age <= 150");
621        assert_eq!(expr.to_string(), "age > 18 AND age <= 150");
622    }
623
624    #[test]
625    fn or_expression() {
626        let expr = parse("age < 18 OR age > 65");
627        assert_eq!(expr.to_string(), "age < 18 OR age > 65");
628    }
629
630    #[test]
631    fn not_expression() {
632        let expr = parse("NOT age < 0");
633        assert_eq!(expr.to_string(), "NOT age < 0");
634    }
635
636    #[test]
637    fn in_with_enum_variants() {
638        let expr = parse("status IN [ACTIVE, PENDING]");
639        assert_eq!(expr.to_string(), "status IN [ACTIVE, PENDING]");
640    }
641
642    #[test]
643    fn in_with_numbers() {
644        let expr = parse("priority IN [1, 2, 3]");
645        assert_eq!(expr.to_string(), "priority IN [1, 2, 3]");
646    }
647
648    #[test]
649    fn in_with_strings() {
650        let expr = parse("role IN [\"admin\", \"moderator\"]");
651        assert_eq!(expr.to_string(), "role IN ['admin', 'moderator']");
652    }
653
654    #[test]
655    fn complex_and_or() {
656        let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
657        assert_eq!(expr.to_string(), "age > 18 AND status IN [ACTIVE, PENDING]");
658    }
659
660    #[test]
661    fn parenthesised() {
662        let expr = parse("(age > 18 OR admin = true) AND active = true");
663        assert_eq!(
664            expr.to_string(),
665            "(age > 18 OR admin = TRUE) AND active = TRUE"
666        );
667    }
668
669    #[test]
670    fn sql_output() {
671        let expr = parse("status IN [ACTIVE, PENDING]");
672        assert_eq!(expr.to_sql(), "status IN ('ACTIVE', 'PENDING')");
673    }
674
675    #[test]
676    fn sql_output_complex() {
677        let expr = parse("age > 18 AND status IN [ACTIVE, PENDING]");
678        assert_eq!(
679            expr.to_sql(),
680            "age > 18 AND status IN ('ACTIVE', 'PENDING')"
681        );
682    }
683
684    #[test]
685    fn field_references() {
686        let expr = parse("age > 18 AND status IN [ACTIVE]");
687        let refs = expr.field_references();
688        assert_eq!(refs, vec!["age", "status"]);
689    }
690
691    #[test]
692    fn enum_in_lists() {
693        let expr = parse("status IN [ACTIVE, PENDING] AND role IN [ADMIN]");
694        let lists = expr.enum_in_lists();
695        assert_eq!(lists.len(), 2);
696        assert_eq!(lists[0], ("status", vec!["ACTIVE", "PENDING"]));
697        assert_eq!(lists[1], ("role", vec!["ADMIN"]));
698    }
699
700    #[test]
701    fn empty_is_error() {
702        let tokens: Vec<Token> = vec![];
703        assert!(parse_bool_expr(&tokens, Span::new(0, 0)).is_err());
704    }
705
706    #[test]
707    fn missing_operator_is_error() {
708        let err = parse_err("age 18");
709        assert!(err.contains("Expected comparison operator"));
710    }
711
712    #[test]
713    fn unclosed_in_list_is_error() {
714        let err = parse_err("status IN [ACTIVE, PENDING");
715        assert!(err.contains("Expected ']'"));
716    }
717
718    #[test]
719    fn missing_in_bracket_is_error() {
720        let err = parse_err("status IN ACTIVE");
721        assert!(err.contains("Expected '['"));
722    }
723}