Skip to main content

nautilus_schema/
sql_expr.rs

1//! Lightweight SQL expression parser for `@computed` and future attribute
2//! expressions.
3//!
4//! The parser operates on tokens already produced by the schema [`Lexer`] and
5//! builds a small AST ([`SqlExpr`]) via recursive descent with operator
6//! precedence climbing.  The AST implements [`Display`] so it can be
7//! round-tripped back to SQL text.
8
9use std::fmt;
10
11use crate::error::{Result, SchemaError};
12use crate::span::Span;
13use crate::token::{Token, TokenKind};
14/// A binary operator in a SQL expression.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum BinOp {
17    /// `+`
18    Add,
19    /// `-`
20    Sub,
21    /// `*`
22    Mul,
23    /// `/`
24    Div,
25    /// `%`
26    Mod,
27    /// `||` (string concatenation)
28    Concat,
29    /// `<`
30    Lt,
31    /// `>`
32    Gt,
33}
34
35impl fmt::Display for BinOp {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.write_str(match self {
38            BinOp::Add => "+",
39            BinOp::Sub => "-",
40            BinOp::Mul => "*",
41            BinOp::Div => "/",
42            BinOp::Mod => "%",
43            BinOp::Concat => "||",
44            BinOp::Lt => "<",
45            BinOp::Gt => ">",
46        })
47    }
48}
49
50/// A unary operator in a SQL expression.
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum UnaryOp {
53    /// `-` (negation)
54    Neg,
55    /// `+` (no-op, explicit positive)
56    Pos,
57}
58
59impl fmt::Display for UnaryOp {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        f.write_str(match self {
62            UnaryOp::Neg => "-",
63            UnaryOp::Pos => "+",
64        })
65    }
66}
67
68/// A parsed SQL expression node.
69#[derive(Debug, Clone, PartialEq)]
70pub enum SqlExpr {
71    /// Column reference or SQL keyword (e.g. `price`, `COALESCE`).
72    Ident(String),
73    /// Numeric literal (e.g. `42`, `3.14`).
74    Number(String),
75    /// String literal (e.g. `"hello"`).
76    StringLit(String),
77    /// Boolean literal (`true` / `false`).
78    Bool(bool),
79    /// Binary operation (e.g. `price * quantity`).
80    BinaryOp {
81        /// Left-hand side.
82        left: Box<SqlExpr>,
83        /// Operator.
84        op: BinOp,
85        /// Right-hand side.
86        right: Box<SqlExpr>,
87    },
88    /// Unary operation (e.g. `-amount`).
89    UnaryOp {
90        /// Operator.
91        op: UnaryOp,
92        /// Operand.
93        operand: Box<SqlExpr>,
94    },
95    /// Function call (e.g. `COALESCE(a, b)`).
96    FnCall {
97        /// Function name.
98        name: String,
99        /// Argument list.
100        args: Vec<SqlExpr>,
101    },
102    /// Parenthesised sub-expression.
103    Paren(Box<SqlExpr>),
104}
105impl fmt::Display for SqlExpr {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            SqlExpr::Ident(name) => write!(f, "{}", name),
109            SqlExpr::Number(n) => write!(f, "{}", n),
110            SqlExpr::StringLit(s) => write!(f, "\"{}\"", s),
111            SqlExpr::Bool(b) => write!(f, "{}", b),
112            SqlExpr::BinaryOp { left, op, right } => {
113                write!(f, "{} {} {}", left, op, right)
114            }
115            SqlExpr::UnaryOp { op, operand } => write!(f, "{}{}", op, operand),
116            SqlExpr::FnCall { name, args } => {
117                write!(f, "{}(", name)?;
118                for (i, arg) in args.iter().enumerate() {
119                    if i > 0 {
120                        write!(f, ", ")?;
121                    }
122                    write!(f, "{}", arg)?;
123                }
124                write!(f, ")")
125            }
126            SqlExpr::Paren(inner) => write!(f, "({})", inner),
127        }
128    }
129}
130
131impl SqlExpr {
132    /// Render this expression, mapping logical field identifiers to their
133    /// physical database column names using the provided function.
134    pub fn to_sql_mapped<F>(&self, map_field: &F) -> String
135    where
136        F: Fn(&str) -> String,
137    {
138        match self {
139            SqlExpr::Ident(name) => map_field(name),
140            SqlExpr::Number(n) => n.clone(),
141            SqlExpr::StringLit(s) => format!("\"{}\"", s),
142            SqlExpr::Bool(b) => b.to_string(),
143            SqlExpr::BinaryOp { left, op, right } => format!(
144                "{} {} {}",
145                left.to_sql_mapped(map_field),
146                op,
147                right.to_sql_mapped(map_field)
148            ),
149            SqlExpr::UnaryOp { op, operand } => {
150                format!("{}{}", op, operand.to_sql_mapped(map_field))
151            }
152            SqlExpr::FnCall { name, args } => {
153                let args_s: Vec<String> = args.iter().map(|a| a.to_sql_mapped(map_field)).collect();
154                format!("{}({})", name, args_s.join(", "))
155            }
156            SqlExpr::Paren(inner) => format!("({})", inner.to_sql_mapped(map_field)),
157        }
158    }
159}
160/// Recursive-descent parser that converts a slice of schema tokens into a
161/// [`SqlExpr`] tree.
162struct SqlExprParser<'a> {
163    tokens: &'a [Token],
164    pos: usize,
165    /// Span used for error reporting when there are no more tokens.
166    fallback_span: Span,
167}
168
169impl<'a> SqlExprParser<'a> {
170    fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
171        Self {
172            tokens,
173            pos: 0,
174            fallback_span,
175        }
176    }
177
178    fn peek(&self) -> Option<&TokenKind> {
179        self.tokens.get(self.pos).map(|t| &t.kind)
180    }
181
182    fn span(&self) -> Span {
183        self.tokens
184            .get(self.pos)
185            .map(|t| t.span)
186            .unwrap_or(self.fallback_span)
187    }
188
189    fn advance(&mut self) -> &Token {
190        let tok = &self.tokens[self.pos];
191        self.pos += 1;
192        tok
193    }
194
195    fn at_end(&self) -> bool {
196        self.pos >= self.tokens.len()
197    }
198
199    /// Operator precedence (higher = tighter binding).
200    fn precedence(op: &BinOp) -> u8 {
201        match op {
202            BinOp::Concat => 1,
203            BinOp::Lt | BinOp::Gt => 2,
204            BinOp::Add | BinOp::Sub => 3,
205            BinOp::Mul | BinOp::Div | BinOp::Mod => 4,
206        }
207    }
208
209    fn token_to_binop(kind: &TokenKind) -> Option<BinOp> {
210        match kind {
211            TokenKind::Plus => Some(BinOp::Add),
212            TokenKind::Minus => Some(BinOp::Sub),
213            TokenKind::Star => Some(BinOp::Mul),
214            TokenKind::Slash => Some(BinOp::Div),
215            TokenKind::Percent => Some(BinOp::Mod),
216            TokenKind::DoublePipe => Some(BinOp::Concat),
217            TokenKind::LAngle => Some(BinOp::Lt),
218            TokenKind::RAngle => Some(BinOp::Gt),
219            _ => None,
220        }
221    }
222
223    fn parse_expr(&mut self) -> Result<SqlExpr> {
224        self.parse_binary(0)
225    }
226
227    fn parse_binary(&mut self, min_prec: u8) -> Result<SqlExpr> {
228        let mut left = self.parse_unary()?;
229
230        while let Some(kind) = self.peek().cloned() {
231            let Some(op) = Self::token_to_binop(&kind) else {
232                break;
233            };
234            let prec = Self::precedence(&op);
235            if prec < min_prec {
236                break;
237            }
238            self.advance();
239            let right = self.parse_binary(prec + 1)?;
240            left = SqlExpr::BinaryOp {
241                left: Box::new(left),
242                op,
243                right: Box::new(right),
244            };
245        }
246
247        Ok(left)
248    }
249
250    fn parse_unary(&mut self) -> Result<SqlExpr> {
251        match self.peek() {
252            Some(TokenKind::Minus) => {
253                self.advance();
254                let operand = self.parse_unary()?;
255                Ok(SqlExpr::UnaryOp {
256                    op: UnaryOp::Neg,
257                    operand: Box::new(operand),
258                })
259            }
260            Some(TokenKind::Plus) => {
261                self.advance();
262                let operand = self.parse_unary()?;
263                Ok(SqlExpr::UnaryOp {
264                    op: UnaryOp::Pos,
265                    operand: Box::new(operand),
266                })
267            }
268            _ => self.parse_primary(),
269        }
270    }
271
272    fn parse_primary(&mut self) -> Result<SqlExpr> {
273        if self.at_end() {
274            return Err(SchemaError::Parse(
275                "Unexpected end of SQL expression".to_string(),
276                self.span(),
277            ));
278        }
279
280        match self.peek().cloned() {
281            Some(TokenKind::Number(n)) => {
282                self.advance();
283                Ok(SqlExpr::Number(n))
284            }
285            Some(TokenKind::String(s)) => {
286                self.advance();
287                Ok(SqlExpr::StringLit(s))
288            }
289            Some(TokenKind::True) => {
290                self.advance();
291                Ok(SqlExpr::Bool(true))
292            }
293            Some(TokenKind::False) => {
294                self.advance();
295                Ok(SqlExpr::Bool(false))
296            }
297            Some(TokenKind::Ident(_)) => self.parse_ident_or_call(),
298            // Keywords used as identifiers inside SQL expressions
299            Some(k) if k.is_keyword() => self.parse_ident_or_call(),
300            Some(TokenKind::LParen) => {
301                self.advance();
302                let inner = self.parse_expr()?;
303                match self.peek() {
304                    Some(TokenKind::RParen) => {
305                        self.advance();
306                        Ok(SqlExpr::Paren(Box::new(inner)))
307                    }
308                    _ => Err(SchemaError::Parse(
309                        "Expected ')' after parenthesised expression".to_string(),
310                        self.span(),
311                    )),
312                }
313            }
314            Some(other) => Err(SchemaError::Parse(
315                format!("Unexpected token '{}' in SQL expression", other),
316                self.span(),
317            )),
318            None => Err(SchemaError::Parse(
319                "Unexpected end of SQL expression".to_string(),
320                self.span(),
321            )),
322        }
323    }
324
325    fn parse_ident_or_call(&mut self) -> Result<SqlExpr> {
326        let tok = self.advance();
327        let name = match &tok.kind {
328            TokenKind::Ident(s) => s.clone(),
329            // Allow schema keywords as SQL identifiers (e.g. `model`, `enum`)
330            other => other.to_string(),
331        };
332
333        if self.peek() == Some(&TokenKind::LParen) {
334            self.advance();
335            let mut args = Vec::new();
336            if self.peek() != Some(&TokenKind::RParen) {
337                args.push(self.parse_expr()?);
338                while self.peek() == Some(&TokenKind::Comma) {
339                    self.advance();
340                    args.push(self.parse_expr()?);
341                }
342            }
343            match self.peek() {
344                Some(TokenKind::RParen) => {
345                    self.advance();
346                    Ok(SqlExpr::FnCall { name, args })
347                }
348                _ => Err(SchemaError::Parse(
349                    format!("Expected ')' after arguments of function '{}'", name),
350                    self.span(),
351                )),
352            }
353        } else {
354            Ok(SqlExpr::Ident(name))
355        }
356    }
357}
358/// Parse a slice of schema tokens into a validated [`SqlExpr`] tree.
359///
360/// The token slice should contain **only** the expression tokens (i.e. without
361/// the surrounding `@computed(` ... `, Stored)` scaffolding).
362///
363/// `fallback_span` is used for error reporting when the slice is empty.
364pub fn parse_sql_expr(tokens: &[Token], fallback_span: Span) -> Result<SqlExpr> {
365    if tokens.is_empty() {
366        return Err(SchemaError::Parse(
367            "@computed expression is empty".to_string(),
368            fallback_span,
369        ));
370    }
371
372    let mut parser = SqlExprParser::new(tokens, fallback_span);
373    let expr = parser.parse_expr()?;
374
375    if !parser.at_end() {
376        return Err(SchemaError::Parse(
377            format!(
378                "Unexpected token '{}' after SQL expression",
379                parser.tokens[parser.pos].kind
380            ),
381            parser.span(),
382        ));
383    }
384
385    Ok(expr)
386}
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::lexer::Lexer;
391
392    /// Tokenise a raw string (skipping newlines) for expression parsing.
393    fn tokenize(src: &str) -> Vec<Token> {
394        let mut lexer = Lexer::new(src);
395        let mut tokens = Vec::new();
396        loop {
397            let tok = lexer.next_token().expect("lex error");
398            match tok.kind {
399                TokenKind::Eof => break,
400                TokenKind::Newline => continue,
401                _ => tokens.push(tok),
402            }
403        }
404        tokens
405    }
406
407    fn parse(src: &str) -> SqlExpr {
408        let tokens = tokenize(src);
409        parse_sql_expr(&tokens, Span::new(0, 0)).expect("parse error")
410    }
411
412    fn parse_err(src: &str) -> String {
413        let tokens = tokenize(src);
414        match parse_sql_expr(&tokens, Span::new(0, 0)) {
415            Err(e) => format!("{}", e),
416            Ok(expr) => panic!("Expected error, got: {:?}", expr),
417        }
418    }
419
420    #[test]
421    fn simple_ident() {
422        assert_eq!(parse("price").to_string(), "price");
423    }
424
425    #[test]
426    fn binary_mul() {
427        let expr = parse("price * quantity");
428        assert_eq!(expr.to_string(), "price * quantity");
429    }
430
431    #[test]
432    fn precedence_add_mul() {
433        let expr = parse("a + b * c");
434        assert!(matches!(expr, SqlExpr::BinaryOp { op: BinOp::Add, .. }));
435    }
436
437    #[test]
438    fn concat_operator() {
439        let expr = parse("first_name || \" \" || last_name");
440        assert_eq!(expr.to_string(), "first_name || \" \" || last_name");
441    }
442
443    #[test]
444    fn function_call() {
445        let expr = parse("COALESCE(a, b)");
446        assert!(matches!(expr, SqlExpr::FnCall { .. }));
447        assert_eq!(expr.to_string(), "COALESCE(a, b)");
448    }
449
450    #[test]
451    fn nested_function() {
452        let expr = parse("UPPER(TRIM(name))");
453        assert_eq!(expr.to_string(), "UPPER(TRIM(name))");
454    }
455
456    #[test]
457    fn paren_expr() {
458        let expr = parse("(a + b) * c");
459        assert_eq!(expr.to_string(), "(a + b) * c");
460    }
461
462    #[test]
463    fn unary_neg() {
464        let expr = parse("-amount");
465        assert_eq!(expr.to_string(), "-amount");
466    }
467
468    #[test]
469    fn number_literal() {
470        let expr = parse("score * 10");
471        assert_eq!(expr.to_string(), "score * 10");
472    }
473
474    #[test]
475    fn boolean_literal() {
476        let expr = parse("true");
477        assert_eq!(expr.to_string(), "true");
478    }
479
480    #[test]
481    fn complex_expr() {
482        let expr = parse("(price * quantity) - COALESCE(discount, 0)");
483        assert_eq!(
484            expr.to_string(),
485            "(price * quantity) - COALESCE(discount, 0)"
486        );
487    }
488    #[test]
489    fn empty_is_error() {
490        let tokens: Vec<Token> = vec![];
491        assert!(parse_sql_expr(&tokens, Span::new(0, 0)).is_err());
492    }
493
494    #[test]
495    fn only_operators_is_error() {
496        let err = parse_err("* * *");
497        assert!(err.contains("Unexpected token"));
498    }
499
500    #[test]
501    fn trailing_operator_is_error() {
502        let err = parse_err("a +");
503        assert!(err.contains("Unexpected end"));
504    }
505
506    #[test]
507    fn unclosed_paren_is_error() {
508        let err = parse_err("(a + b");
509        assert!(err.contains("Expected ')'"));
510    }
511
512    #[test]
513    fn double_operator_is_error() {
514        let err = parse_err("a + * b");
515        assert!(err.contains("Unexpected token"));
516    }
517}