Skip to main content

nodedb_query/expr_parse/
mod.rs

1//! SQL expression text → SqlExpr AST parser.
2//!
3//! Parses the subset of SQL expressions used in `GENERATED ALWAYS AS (expr)`
4//! column definitions. Supports:
5//! - Column references: `price`, `tax_rate`
6//! - Numeric literals: `42`, `3.14`, `-1`
7//! - String literals: `'hello'`, `''escaped''`
8//! - Binary operators: `+`, `-`, `*`, `/`, `%`
9//! - Comparison: `=`, `!=`, `<>`, `<`, `>`, `<=`, `>=`
10//! - Logical: `AND`, `OR`, `NOT`
11//! - Parenthesized sub-expressions: `(a + b) * c`
12//! - Function calls: `ROUND(price * 1.08, 2)`, `CONCAT(a, ' ', b)`
13//! - COALESCE: `COALESCE(a, b, '')`
14//! - CASE WHEN: `CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END`
15//! - NULL literal
16//!
17//! Determinism validation: rejects `NOW()`, `RANDOM()`, `NEXTVAL()`, `UUID()`.
18
19mod tokenizer;
20
21use super::expr::{BinaryOp, SqlExpr};
22use nodedb_types::Value;
23use tokenizer::{Token, TokenKind, tokenize};
24
25/// Parse a SQL expression string into an SqlExpr AST.
26///
27/// Returns the parsed expression and a list of column names it references
28/// (the `depends_on` set for generated columns).
29pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec<String>), String> {
30    let tokens = tokenize(text)?;
31    let mut pos = 0;
32    let expr = parse_expr(&tokens, &mut pos, &mut 0)?;
33    if pos < tokens.len() {
34        return Err(format!(
35            "unexpected token after expression: '{}'",
36            tokens[pos].text
37        ));
38    }
39
40    // Validate determinism.
41    validate_deterministic(&expr)?;
42
43    // Collect column references.
44    let mut deps = Vec::new();
45    collect_columns(&expr, &mut deps);
46    deps.sort();
47    deps.dedup();
48
49    Ok((expr, deps))
50}
51
52// ── Recursive descent parser ──────────────────────────────────────────
53
54/// Maximum recursion depth for nested parentheses / sub-expressions.
55const MAX_EXPR_DEPTH: usize = 128;
56
57fn parse_expr(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
58    parse_or(tokens, pos, depth)
59}
60
61fn parse_or(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
62    let mut left = parse_and(tokens, pos, depth)?;
63    while peek_keyword(tokens, *pos, "OR") {
64        *pos += 1;
65        let right = parse_and(tokens, pos, depth)?;
66        left = SqlExpr::BinaryOp {
67            left: Box::new(left),
68            op: BinaryOp::Or,
69            right: Box::new(right),
70        };
71    }
72    Ok(left)
73}
74
75fn parse_and(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
76    let mut left = parse_comparison(tokens, pos, depth)?;
77    while peek_keyword(tokens, *pos, "AND") {
78        *pos += 1;
79        let right = parse_comparison(tokens, pos, depth)?;
80        left = SqlExpr::BinaryOp {
81            left: Box::new(left),
82            op: BinaryOp::And,
83            right: Box::new(right),
84        };
85    }
86    Ok(left)
87}
88
89fn parse_comparison(
90    tokens: &[Token],
91    pos: &mut usize,
92    depth: &mut usize,
93) -> Result<SqlExpr, String> {
94    let left = parse_additive(tokens, pos, depth)?;
95    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
96        let op = match tokens[*pos].text.as_str() {
97            "=" => BinaryOp::Eq,
98            "!=" | "<>" => BinaryOp::NotEq,
99            "<" => BinaryOp::Lt,
100            "<=" => BinaryOp::LtEq,
101            ">" => BinaryOp::Gt,
102            ">=" => BinaryOp::GtEq,
103            _ => return Ok(left),
104        };
105        *pos += 1;
106        let right = parse_additive(tokens, pos, depth)?;
107        return Ok(SqlExpr::BinaryOp {
108            left: Box::new(left),
109            op,
110            right: Box::new(right),
111        });
112    }
113    Ok(left)
114}
115
116fn parse_additive(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
117    let mut left = parse_multiplicative(tokens, pos, depth)?;
118    while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
119        let op = match tokens[*pos].text.as_str() {
120            "+" => BinaryOp::Add,
121            "-" => BinaryOp::Sub,
122            "||" => BinaryOp::Concat,
123            _ => break,
124        };
125        *pos += 1;
126        let right = parse_multiplicative(tokens, pos, depth)?;
127        left = SqlExpr::BinaryOp {
128            left: Box::new(left),
129            op,
130            right: Box::new(right),
131        };
132    }
133    Ok(left)
134}
135
136fn parse_multiplicative(
137    tokens: &[Token],
138    pos: &mut usize,
139    depth: &mut usize,
140) -> Result<SqlExpr, String> {
141    let mut left = parse_unary(tokens, pos, depth)?;
142    while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
143        let op = match tokens[*pos].text.as_str() {
144            "*" => BinaryOp::Mul,
145            "/" => BinaryOp::Div,
146            "%" => BinaryOp::Mod,
147            _ => break,
148        };
149        *pos += 1;
150        let right = parse_unary(tokens, pos, depth)?;
151        left = SqlExpr::BinaryOp {
152            left: Box::new(left),
153            op,
154            right: Box::new(right),
155        };
156    }
157    Ok(left)
158}
159
160fn parse_unary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
161    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" {
162        *pos += 1;
163        let expr = parse_primary(tokens, pos, depth)?;
164        return Ok(SqlExpr::Negate(Box::new(expr)));
165    }
166    if peek_keyword(tokens, *pos, "NOT") {
167        *pos += 1;
168        let expr = parse_primary(tokens, pos, depth)?;
169        return Ok(SqlExpr::Negate(Box::new(expr)));
170    }
171    parse_primary(tokens, pos, depth)
172}
173
174fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
175    if *pos >= tokens.len() {
176        return Err("unexpected end of expression".into());
177    }
178
179    let token = &tokens[*pos];
180
181    match token.kind {
182        TokenKind::LParen => {
183            *depth += 1;
184            if *depth > MAX_EXPR_DEPTH {
185                return Err(format!(
186                    "expression nesting depth exceeds maximum of {MAX_EXPR_DEPTH}"
187                ));
188            }
189            *pos += 1;
190            let expr = parse_expr(tokens, pos, depth)?;
191            *depth -= 1;
192            expect_token(tokens, pos, TokenKind::RParen, ")")?;
193            Ok(expr)
194        }
195
196        TokenKind::Number => {
197            *pos += 1;
198            if let Ok(i) = token.text.parse::<i64>() {
199                Ok(SqlExpr::Literal(Value::Integer(i)))
200            } else if let Ok(f) = token.text.parse::<f64>() {
201                Ok(SqlExpr::Literal(Value::Float(f)))
202            } else {
203                Err(format!("invalid number: '{}'", token.text))
204            }
205        }
206
207        TokenKind::StringLit => {
208            *pos += 1;
209            Ok(SqlExpr::Literal(Value::String(token.text.clone())))
210        }
211
212        TokenKind::Ident => {
213            let name = token.text.clone();
214            let upper = name.to_uppercase();
215            *pos += 1;
216
217            match upper.as_str() {
218                "NULL" => Ok(SqlExpr::Literal(Value::Null)),
219                "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))),
220                "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))),
221                "CASE" => parse_case(tokens, pos, depth),
222                "COALESCE" => {
223                    let args = parse_arg_list(tokens, pos, depth)?;
224                    Ok(SqlExpr::Coalesce(args))
225                }
226                _ => {
227                    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen {
228                        let args = parse_arg_list(tokens, pos, depth)?;
229                        Ok(SqlExpr::Function {
230                            name: name.to_lowercase(),
231                            args,
232                        })
233                    } else {
234                        Ok(SqlExpr::Column(name.to_lowercase()))
235                    }
236                }
237            }
238        }
239
240        _ => Err(format!("unexpected token: '{}'", token.text)),
241    }
242}
243
244fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
245    let mut when_thens = Vec::new();
246    let mut else_expr = None;
247
248    loop {
249        if peek_keyword(tokens, *pos, "WHEN") {
250            *pos += 1;
251            let cond = parse_expr(tokens, pos, depth)?;
252            expect_keyword(tokens, pos, "THEN")?;
253            let then = parse_expr(tokens, pos, depth)?;
254            when_thens.push((cond, then));
255        } else if peek_keyword(tokens, *pos, "ELSE") {
256            *pos += 1;
257            else_expr = Some(Box::new(parse_expr(tokens, pos, depth)?));
258        } else if peek_keyword(tokens, *pos, "END") {
259            *pos += 1;
260            break;
261        } else {
262            return Err("expected WHEN, ELSE, or END in CASE expression".into());
263        }
264    }
265
266    if when_thens.is_empty() {
267        return Err("CASE requires at least one WHEN clause".into());
268    }
269
270    Ok(SqlExpr::Case {
271        operand: None,
272        when_thens,
273        else_expr,
274    })
275}
276
277fn parse_arg_list(
278    tokens: &[Token],
279    pos: &mut usize,
280    depth: &mut usize,
281) -> Result<Vec<SqlExpr>, String> {
282    expect_token(tokens, pos, TokenKind::LParen, "(")?;
283    let mut args = Vec::new();
284    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen {
285        *pos += 1;
286        return Ok(args);
287    }
288    loop {
289        args.push(parse_expr(tokens, pos, depth)?);
290        if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma {
291            *pos += 1;
292        } else {
293            break;
294        }
295    }
296    expect_token(tokens, pos, TokenKind::RParen, ")")?;
297    Ok(args)
298}
299
300// ── Helpers ───────────────────────────────────────────────────────────
301
302fn peek_keyword(tokens: &[Token], pos: usize, keyword: &str) -> bool {
303    pos < tokens.len()
304        && tokens[pos].kind == TokenKind::Ident
305        && tokens[pos].text.eq_ignore_ascii_case(keyword)
306}
307
308fn expect_keyword(tokens: &[Token], pos: &mut usize, keyword: &str) -> Result<(), String> {
309    if peek_keyword(tokens, *pos, keyword) {
310        *pos += 1;
311        Ok(())
312    } else {
313        let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
314        Err(format!("expected '{keyword}', got '{got}'"))
315    }
316}
317
318fn expect_token(
319    tokens: &[Token],
320    pos: &mut usize,
321    kind: TokenKind,
322    expected: &str,
323) -> Result<(), String> {
324    if *pos < tokens.len() && tokens[*pos].kind == kind {
325        *pos += 1;
326        Ok(())
327    } else {
328        let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
329        Err(format!("expected '{expected}', got '{got}'"))
330    }
331}
332
333// ── Validation ────────────────────────────────────────────────────────
334
335const NON_DETERMINISTIC: &[&str] = &[
336    "now",
337    "current_timestamp",
338    "random",
339    "nextval",
340    "uuid",
341    "uuid_v4",
342    "uuid_v7",
343    "gen_random_uuid",
344    "ulid",
345    "cuid2",
346    "nanoid",
347];
348
349fn validate_deterministic(expr: &SqlExpr) -> Result<(), String> {
350    match expr {
351        SqlExpr::Function { name, args } => {
352            if NON_DETERMINISTIC.contains(&name.as_str()) {
353                return Err(format!(
354                    "non-deterministic function '{name}()' not allowed in GENERATED ALWAYS AS"
355                ));
356            }
357            for arg in args {
358                validate_deterministic(arg)?;
359            }
360            Ok(())
361        }
362        SqlExpr::BinaryOp { left, right, .. } => {
363            validate_deterministic(left)?;
364            validate_deterministic(right)
365        }
366        SqlExpr::Negate(inner) => validate_deterministic(inner),
367        SqlExpr::Coalesce(args) => {
368            for arg in args {
369                validate_deterministic(arg)?;
370            }
371            Ok(())
372        }
373        SqlExpr::Case {
374            operand,
375            when_thens,
376            else_expr,
377        } => {
378            if let Some(op) = operand {
379                validate_deterministic(op)?;
380            }
381            for (cond, then) in when_thens {
382                validate_deterministic(cond)?;
383                validate_deterministic(then)?;
384            }
385            if let Some(e) = else_expr {
386                validate_deterministic(e)?;
387            }
388            Ok(())
389        }
390        SqlExpr::Cast { expr, .. } => validate_deterministic(expr),
391        SqlExpr::NullIf(a, b) => {
392            validate_deterministic(a)?;
393            validate_deterministic(b)
394        }
395        SqlExpr::IsNull { expr, .. } => validate_deterministic(expr),
396        SqlExpr::Column(_) | SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => Ok(()),
397    }
398}
399
400fn collect_columns(expr: &SqlExpr, deps: &mut Vec<String>) {
401    match expr {
402        SqlExpr::Column(name) => deps.push(name.clone()),
403        SqlExpr::BinaryOp { left, right, .. } => {
404            collect_columns(left, deps);
405            collect_columns(right, deps);
406        }
407        SqlExpr::Negate(inner) => collect_columns(inner, deps),
408        SqlExpr::Function { args, .. } => {
409            for arg in args {
410                collect_columns(arg, deps);
411            }
412        }
413        SqlExpr::Coalesce(args) => {
414            for arg in args {
415                collect_columns(arg, deps);
416            }
417        }
418        SqlExpr::Case {
419            operand,
420            when_thens,
421            else_expr,
422        } => {
423            if let Some(op) = operand {
424                collect_columns(op, deps);
425            }
426            for (cond, then) in when_thens {
427                collect_columns(cond, deps);
428                collect_columns(then, deps);
429            }
430            if let Some(e) = else_expr {
431                collect_columns(e, deps);
432            }
433        }
434        SqlExpr::Cast { expr, .. } => collect_columns(expr, deps),
435        SqlExpr::NullIf(a, b) => {
436            collect_columns(a, deps);
437            collect_columns(b, deps);
438        }
439        SqlExpr::IsNull { expr, .. } => collect_columns(expr, deps),
440        SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => {}
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use nodedb_types::Value;
448
449    fn parse_ok(text: &str) -> (SqlExpr, Vec<String>) {
450        parse_generated_expr(text).unwrap()
451    }
452
453    #[test]
454    fn simple_arithmetic() {
455        let (expr, deps) = parse_ok("price * (1 + tax_rate)");
456        assert_eq!(deps, vec!["price", "tax_rate"]);
457        let doc = Value::from(serde_json::json!({"price": 100.0, "tax_rate": 0.08}));
458        let result = expr.eval(&doc);
459        assert_eq!(result.as_f64(), Some(108.0));
460    }
461
462    #[test]
463    fn round_function() {
464        let (expr, deps) = parse_ok("ROUND(price * (1 + tax_rate), 2)");
465        assert_eq!(deps, vec!["price", "tax_rate"]);
466        let doc = Value::from(serde_json::json!({"price": 99.99, "tax_rate": 0.08}));
467        let result = expr.eval(&doc);
468        assert_eq!(result, Value::Float(107.99));
469    }
470
471    #[test]
472    fn concat_function() {
473        let (expr, deps) = parse_ok("CONCAT(name, ' ', brand)");
474        assert_eq!(deps, vec!["brand", "name"]);
475        let doc = Value::from(serde_json::json!({"name": "Shoe", "brand": "Nike"}));
476        assert_eq!(expr.eval(&doc), Value::String("Shoe Nike".into()));
477    }
478
479    #[test]
480    fn coalesce() {
481        let (expr, _) = parse_ok("COALESCE(description, '')");
482        let doc = Value::from(serde_json::json!({"description": null}));
483        assert_eq!(expr.eval(&doc), Value::String("".into()));
484    }
485
486    #[test]
487    fn case_when() {
488        let (expr, deps) =
489            parse_ok("CASE WHEN discount > 0 THEN price * (1 - discount) ELSE price END");
490        assert!(deps.contains(&"discount".to_string()));
491        assert!(deps.contains(&"price".to_string()));
492
493        let doc = Value::from(serde_json::json!({"price": 100.0, "discount": 0.2}));
494        assert_eq!(expr.eval(&doc).as_f64(), Some(80.0));
495
496        let doc2 = Value::from(serde_json::json!({"price": 100.0, "discount": 0}));
497        assert_eq!(expr.eval(&doc2).as_f64(), Some(100.0));
498    }
499
500    #[test]
501    fn rejects_now() {
502        assert!(parse_generated_expr("NOW()").is_err());
503    }
504
505    #[test]
506    fn rejects_random() {
507        assert!(parse_generated_expr("RANDOM()").is_err());
508    }
509
510    #[test]
511    fn rejects_uuid() {
512        assert!(parse_generated_expr("UUID()").is_err());
513    }
514
515    #[test]
516    fn string_literal() {
517        let (expr, _) = parse_ok("CONCAT(name, ' - ', 'default')");
518        let doc = Value::from(serde_json::json!({"name": "Product"}));
519        assert_eq!(expr.eval(&doc), Value::String("Product - default".into()));
520    }
521
522    #[test]
523    fn null_literal() {
524        let (expr, _) = parse_ok("COALESCE(x, NULL, 0)");
525        let doc = Value::from(serde_json::json!({"x": null}));
526        assert_eq!(expr.eval(&doc), Value::Integer(0));
527    }
528
529    #[test]
530    fn nested_functions() {
531        let (expr, _) = parse_ok("ROUND(price * (1 - COALESCE(discount, 0)), 2)");
532        let doc = Value::from(serde_json::json!({"price": 49.99}));
533        assert_eq!(expr.eval(&doc), Value::Float(49.99));
534    }
535
536    #[test]
537    fn deeply_nested_parentheses_return_error_not_stack_overflow() {
538        let depth = 10_000;
539        let input = format!("{}x{}", "(".repeat(depth), ")".repeat(depth));
540        let result = parse_generated_expr(&input);
541        assert!(
542            result.is_err(),
543            "parse_generated_expr must return Err for {depth}-deep nesting"
544        );
545    }
546
547    #[test]
548    fn cjk_string_in_concat() {
549        let (expr, _) = parse_ok("CONCAT('你好', name)");
550        let doc = Value::from(serde_json::json!({"name": "world"}));
551        assert_eq!(expr.eval(&doc), Value::String("你好world".into()));
552    }
553
554    #[test]
555    fn comparison_with_utf8_literal() {
556        let (expr, deps) = parse_ok("name != '禁止'");
557        assert_eq!(deps, vec!["name"]);
558        let doc = Value::from(serde_json::json!({"name": "allowed"}));
559        assert_eq!(expr.eval(&doc), Value::Bool(true));
560    }
561}