Skip to main content

nautilus_schema/
sql_expr.rs

1//! Lightweight SQL expression parser for `@computed` and future attribute
2//! expressions.
3//!
4//! The parser operates on tokens already produced by the schema [`Lexer`] and
5//! builds a small AST ([`SqlExpr`]) via recursive descent with operator
6//! precedence climbing.  The AST implements [`Display`] so it can be
7//! round-tripped back to SQL text.
8
9use std::fmt;
10
11use crate::error::{Result, SchemaError};
12use crate::span::Span;
13use crate::token::{Token, TokenKind};
14
15// ---------------------------------------------------------------------------
16// AST
17// ---------------------------------------------------------------------------
18
19/// A binary operator in a SQL expression.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum BinOp {
22    /// `+`
23    Add,
24    /// `-`
25    Sub,
26    /// `*`
27    Mul,
28    /// `/`
29    Div,
30    /// `%`
31    Mod,
32    /// `||` (string concatenation)
33    Concat,
34    /// `<`
35    Lt,
36    /// `>`
37    Gt,
38}
39
40impl fmt::Display for BinOp {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.write_str(match self {
43            BinOp::Add => "+",
44            BinOp::Sub => "-",
45            BinOp::Mul => "*",
46            BinOp::Div => "/",
47            BinOp::Mod => "%",
48            BinOp::Concat => "||",
49            BinOp::Lt => "<",
50            BinOp::Gt => ">",
51        })
52    }
53}
54
55/// A unary operator in a SQL expression.
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum UnaryOp {
58    /// `-` (negation)
59    Neg,
60    /// `+` (no-op, explicit positive)
61    Pos,
62}
63
64impl fmt::Display for UnaryOp {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.write_str(match self {
67            UnaryOp::Neg => "-",
68            UnaryOp::Pos => "+",
69        })
70    }
71}
72
73/// A parsed SQL expression node.
74#[derive(Debug, Clone, PartialEq)]
75pub enum SqlExpr {
76    /// Column reference or SQL keyword (e.g. `price`, `COALESCE`).
77    Ident(String),
78    /// Numeric literal (e.g. `42`, `3.14`).
79    Number(String),
80    /// String literal (e.g. `"hello"`).
81    StringLit(String),
82    /// Boolean literal (`true` / `false`).
83    Bool(bool),
84    /// Binary operation (e.g. `price * quantity`).
85    BinaryOp {
86        /// Left-hand side.
87        left: Box<SqlExpr>,
88        /// Operator.
89        op: BinOp,
90        /// Right-hand side.
91        right: Box<SqlExpr>,
92    },
93    /// Unary operation (e.g. `-amount`).
94    UnaryOp {
95        /// Operator.
96        op: UnaryOp,
97        /// Operand.
98        operand: Box<SqlExpr>,
99    },
100    /// Function call (e.g. `COALESCE(a, b)`).
101    FnCall {
102        /// Function name.
103        name: String,
104        /// Argument list.
105        args: Vec<SqlExpr>,
106    },
107    /// Parenthesised sub-expression.
108    Paren(Box<SqlExpr>),
109}
110
111// ---------------------------------------------------------------------------
112// Display – reconstruct SQL text
113// ---------------------------------------------------------------------------
114
115impl fmt::Display for SqlExpr {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        match self {
118            SqlExpr::Ident(name) => write!(f, "{}", name),
119            SqlExpr::Number(n) => write!(f, "{}", n),
120            SqlExpr::StringLit(s) => write!(f, "\"{}\"", s),
121            SqlExpr::Bool(b) => write!(f, "{}", b),
122            SqlExpr::BinaryOp { left, op, right } => {
123                write!(f, "{} {} {}", left, op, right)
124            }
125            SqlExpr::UnaryOp { op, operand } => write!(f, "{}{}", op, operand),
126            SqlExpr::FnCall { name, args } => {
127                write!(f, "{}(", name)?;
128                for (i, arg) in args.iter().enumerate() {
129                    if i > 0 {
130                        write!(f, ", ")?;
131                    }
132                    write!(f, "{}", arg)?;
133                }
134                write!(f, ")")
135            }
136            SqlExpr::Paren(inner) => write!(f, "({})", inner),
137        }
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Parser
143// ---------------------------------------------------------------------------
144
145/// Recursive-descent parser that converts a slice of schema tokens into a
146/// [`SqlExpr`] tree.
147struct SqlExprParser<'a> {
148    tokens: &'a [Token],
149    pos: usize,
150    /// Span used for error reporting when there are no more tokens.
151    fallback_span: Span,
152}
153
154impl<'a> SqlExprParser<'a> {
155    fn new(tokens: &'a [Token], fallback_span: Span) -> Self {
156        Self {
157            tokens,
158            pos: 0,
159            fallback_span,
160        }
161    }
162
163    fn peek(&self) -> Option<&TokenKind> {
164        self.tokens.get(self.pos).map(|t| &t.kind)
165    }
166
167    fn span(&self) -> Span {
168        self.tokens
169            .get(self.pos)
170            .map(|t| t.span)
171            .unwrap_or(self.fallback_span)
172    }
173
174    fn advance(&mut self) -> &Token {
175        let tok = &self.tokens[self.pos];
176        self.pos += 1;
177        tok
178    }
179
180    fn at_end(&self) -> bool {
181        self.pos >= self.tokens.len()
182    }
183
184    /// Operator precedence (higher = tighter binding).
185    fn precedence(op: &BinOp) -> u8 {
186        match op {
187            BinOp::Concat => 1,
188            BinOp::Lt | BinOp::Gt => 2,
189            BinOp::Add | BinOp::Sub => 3,
190            BinOp::Mul | BinOp::Div | BinOp::Mod => 4,
191        }
192    }
193
194    fn token_to_binop(kind: &TokenKind) -> Option<BinOp> {
195        match kind {
196            TokenKind::Plus => Some(BinOp::Add),
197            TokenKind::Minus => Some(BinOp::Sub),
198            TokenKind::Star => Some(BinOp::Mul),
199            TokenKind::Slash => Some(BinOp::Div),
200            TokenKind::Percent => Some(BinOp::Mod),
201            TokenKind::DoublePipe => Some(BinOp::Concat),
202            TokenKind::LAngle => Some(BinOp::Lt),
203            TokenKind::RAngle => Some(BinOp::Gt),
204            _ => None,
205        }
206    }
207
208    fn parse_expr(&mut self) -> Result<SqlExpr> {
209        self.parse_binary(0)
210    }
211
212    fn parse_binary(&mut self, min_prec: u8) -> Result<SqlExpr> {
213        let mut left = self.parse_unary()?;
214
215        while let Some(kind) = self.peek().cloned() {
216            let Some(op) = Self::token_to_binop(&kind) else {
217                break;
218            };
219            let prec = Self::precedence(&op);
220            if prec < min_prec {
221                break;
222            }
223            self.advance();
224            let right = self.parse_binary(prec + 1)?;
225            left = SqlExpr::BinaryOp {
226                left: Box::new(left),
227                op,
228                right: Box::new(right),
229            };
230        }
231
232        Ok(left)
233    }
234
235    fn parse_unary(&mut self) -> Result<SqlExpr> {
236        match self.peek() {
237            Some(TokenKind::Minus) => {
238                self.advance();
239                let operand = self.parse_unary()?;
240                Ok(SqlExpr::UnaryOp {
241                    op: UnaryOp::Neg,
242                    operand: Box::new(operand),
243                })
244            }
245            Some(TokenKind::Plus) => {
246                self.advance();
247                let operand = self.parse_unary()?;
248                Ok(SqlExpr::UnaryOp {
249                    op: UnaryOp::Pos,
250                    operand: Box::new(operand),
251                })
252            }
253            _ => self.parse_primary(),
254        }
255    }
256
257    fn parse_primary(&mut self) -> Result<SqlExpr> {
258        if self.at_end() {
259            return Err(SchemaError::Parse(
260                "Unexpected end of SQL expression".to_string(),
261                self.span(),
262            ));
263        }
264
265        match self.peek().cloned() {
266            Some(TokenKind::Number(n)) => {
267                self.advance();
268                Ok(SqlExpr::Number(n))
269            }
270            Some(TokenKind::String(s)) => {
271                self.advance();
272                Ok(SqlExpr::StringLit(s))
273            }
274            Some(TokenKind::True) => {
275                self.advance();
276                Ok(SqlExpr::Bool(true))
277            }
278            Some(TokenKind::False) => {
279                self.advance();
280                Ok(SqlExpr::Bool(false))
281            }
282            Some(TokenKind::Ident(_)) => self.parse_ident_or_call(),
283            // Keywords used as identifiers inside SQL expressions
284            Some(k) if k.is_keyword() => self.parse_ident_or_call(),
285            Some(TokenKind::LParen) => {
286                self.advance();
287                let inner = self.parse_expr()?;
288                match self.peek() {
289                    Some(TokenKind::RParen) => {
290                        self.advance();
291                        Ok(SqlExpr::Paren(Box::new(inner)))
292                    }
293                    _ => Err(SchemaError::Parse(
294                        "Expected ')' after parenthesised expression".to_string(),
295                        self.span(),
296                    )),
297                }
298            }
299            Some(other) => Err(SchemaError::Parse(
300                format!("Unexpected token '{}' in SQL expression", other),
301                self.span(),
302            )),
303            None => Err(SchemaError::Parse(
304                "Unexpected end of SQL expression".to_string(),
305                self.span(),
306            )),
307        }
308    }
309
310    fn parse_ident_or_call(&mut self) -> Result<SqlExpr> {
311        let tok = self.advance();
312        let name = match &tok.kind {
313            TokenKind::Ident(s) => s.clone(),
314            // Allow schema keywords as SQL identifiers (e.g. `model`, `enum`)
315            other => other.to_string(),
316        };
317
318        if self.peek() == Some(&TokenKind::LParen) {
319            self.advance();
320            let mut args = Vec::new();
321            if self.peek() != Some(&TokenKind::RParen) {
322                args.push(self.parse_expr()?);
323                while self.peek() == Some(&TokenKind::Comma) {
324                    self.advance();
325                    args.push(self.parse_expr()?);
326                }
327            }
328            match self.peek() {
329                Some(TokenKind::RParen) => {
330                    self.advance();
331                    Ok(SqlExpr::FnCall { name, args })
332                }
333                _ => Err(SchemaError::Parse(
334                    format!("Expected ')' after arguments of function '{}'", name),
335                    self.span(),
336                )),
337            }
338        } else {
339            Ok(SqlExpr::Ident(name))
340        }
341    }
342}
343
344// ---------------------------------------------------------------------------
345// Public API
346// ---------------------------------------------------------------------------
347
348/// Parse a slice of schema tokens into a validated [`SqlExpr`] tree.
349///
350/// The token slice should contain **only** the expression tokens (i.e. without
351/// the surrounding `@computed(` ... `, Stored)` scaffolding).
352///
353/// `fallback_span` is used for error reporting when the slice is empty.
354pub fn parse_sql_expr(tokens: &[Token], fallback_span: Span) -> Result<SqlExpr> {
355    if tokens.is_empty() {
356        return Err(SchemaError::Parse(
357            "@computed expression is empty".to_string(),
358            fallback_span,
359        ));
360    }
361
362    let mut parser = SqlExprParser::new(tokens, fallback_span);
363    let expr = parser.parse_expr()?;
364
365    if !parser.at_end() {
366        return Err(SchemaError::Parse(
367            format!(
368                "Unexpected token '{}' after SQL expression",
369                parser.tokens[parser.pos].kind
370            ),
371            parser.span(),
372        ));
373    }
374
375    Ok(expr)
376}
377
378// ---------------------------------------------------------------------------
379// Tests
380// ---------------------------------------------------------------------------
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use crate::lexer::Lexer;
386
387    /// Tokenise a raw string (skipping newlines) for expression parsing.
388    fn tokenize(src: &str) -> Vec<Token> {
389        let mut lexer = Lexer::new(src);
390        let mut tokens = Vec::new();
391        loop {
392            let tok = lexer.next_token().expect("lex error");
393            match tok.kind {
394                TokenKind::Eof => break,
395                TokenKind::Newline => continue,
396                _ => tokens.push(tok),
397            }
398        }
399        tokens
400    }
401
402    fn parse(src: &str) -> SqlExpr {
403        let tokens = tokenize(src);
404        parse_sql_expr(&tokens, Span::new(0, 0)).expect("parse error")
405    }
406
407    fn parse_err(src: &str) -> String {
408        let tokens = tokenize(src);
409        match parse_sql_expr(&tokens, Span::new(0, 0)) {
410            Err(e) => format!("{}", e),
411            Ok(expr) => panic!("Expected error, got: {:?}", expr),
412        }
413    }
414
415    // ---- valid expressions ----
416
417    #[test]
418    fn simple_ident() {
419        assert_eq!(parse("price").to_string(), "price");
420    }
421
422    #[test]
423    fn binary_mul() {
424        let expr = parse("price * quantity");
425        assert_eq!(expr.to_string(), "price * quantity");
426    }
427
428    #[test]
429    fn precedence_add_mul() {
430        // a + b * c  =>  a + (b * c)
431        let expr = parse("a + b * c");
432        assert!(matches!(expr, SqlExpr::BinaryOp { op: BinOp::Add, .. }));
433    }
434
435    #[test]
436    fn concat_operator() {
437        let expr = parse("first_name || \" \" || last_name");
438        assert_eq!(expr.to_string(), "first_name || \" \" || last_name");
439    }
440
441    #[test]
442    fn function_call() {
443        let expr = parse("COALESCE(a, b)");
444        assert!(matches!(expr, SqlExpr::FnCall { .. }));
445        assert_eq!(expr.to_string(), "COALESCE(a, b)");
446    }
447
448    #[test]
449    fn nested_function() {
450        let expr = parse("UPPER(TRIM(name))");
451        assert_eq!(expr.to_string(), "UPPER(TRIM(name))");
452    }
453
454    #[test]
455    fn paren_expr() {
456        let expr = parse("(a + b) * c");
457        assert_eq!(expr.to_string(), "(a + b) * c");
458    }
459
460    #[test]
461    fn unary_neg() {
462        let expr = parse("-amount");
463        assert_eq!(expr.to_string(), "-amount");
464    }
465
466    #[test]
467    fn number_literal() {
468        let expr = parse("score * 10");
469        assert_eq!(expr.to_string(), "score * 10");
470    }
471
472    #[test]
473    fn boolean_literal() {
474        let expr = parse("true");
475        assert_eq!(expr.to_string(), "true");
476    }
477
478    #[test]
479    fn complex_expr() {
480        let expr = parse("(price * quantity) - COALESCE(discount, 0)");
481        assert_eq!(
482            expr.to_string(),
483            "(price * quantity) - COALESCE(discount, 0)"
484        );
485    }
486
487    // ---- invalid expressions ----
488
489    #[test]
490    fn empty_is_error() {
491        let tokens: Vec<Token> = vec![];
492        assert!(parse_sql_expr(&tokens, Span::new(0, 0)).is_err());
493    }
494
495    #[test]
496    fn only_operators_is_error() {
497        let err = parse_err("* * *");
498        assert!(err.contains("Unexpected token"));
499    }
500
501    #[test]
502    fn trailing_operator_is_error() {
503        let err = parse_err("a +");
504        assert!(err.contains("Unexpected end"));
505    }
506
507    #[test]
508    fn unclosed_paren_is_error() {
509        let err = parse_err("(a + b");
510        assert!(err.contains("Expected ')'"));
511    }
512
513    #[test]
514    fn double_operator_is_error() {
515        let err = parse_err("a + * b");
516        assert!(err.contains("Unexpected token"));
517    }
518}