Skip to main content

nodedb_query/expr_parse/
parser.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! SQL expression text → SqlExpr AST parser.
4//!
5//! Parses the subset of SQL expressions used in `GENERATED ALWAYS AS (expr)`
6//! column definitions. Supports:
7//! - Column references: `price`, `tax_rate`
8//! - Numeric literals: `42`, `3.14`, `-1`
9//! - String literals: `'hello'`, `''escaped''`
10//! - Binary operators: `+`, `-`, `*`, `/`, `%`
11//! - Comparison: `=`, `!=`, `<>`, `<`, `>`, `<=`, `>=`
12//! - Logical: `AND`, `OR`, `NOT`
13//! - Parenthesized sub-expressions: `(a + b) * c`
14//! - Function calls: `ROUND(price * 1.08, 2)`, `CONCAT(a, ' ', b)`
15//! - COALESCE: `COALESCE(a, b, '')`
16//! - CASE WHEN: `CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END`
17//! - NULL literal
18//!
19//! Determinism validation: rejects `NOW()`, `RANDOM()`, `NEXTVAL()`, `UUID()`.
20
21use super::error::ExprParseError;
22use super::tokenizer::{Token, TokenKind, tokenize};
23use crate::expr::{BinaryOp, SqlExpr};
24use nodedb_types::Value;
25
26/// Parse a SQL expression string into an SqlExpr AST.
27///
28/// Returns the parsed expression and a list of column names it references
29/// (the `depends_on` set for generated columns).
30pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec<String>), ExprParseError> {
31    let tokens = tokenize(text)?;
32    let mut pos = 0;
33    let expr = parse_expr(&tokens, &mut pos, &mut 0)?;
34    if pos < tokens.len() {
35        return Err(ExprParseError::UnexpectedToken {
36            found: tokens[pos].text.clone(),
37            pos,
38        });
39    }
40
41    // Validate determinism.
42    validate_deterministic(&expr)?;
43
44    // Collect column references.
45    let mut deps = Vec::new();
46    collect_columns(&expr, &mut deps);
47    deps.sort();
48    deps.dedup();
49
50    Ok((expr, deps))
51}
52
53// ── Recursive descent parser ──────────────────────────────────────────
54
55/// Maximum recursion depth for nested parentheses / sub-expressions.
56const MAX_EXPR_DEPTH: usize = 128;
57
58fn parse_expr(
59    tokens: &[Token],
60    pos: &mut usize,
61    depth: &mut usize,
62) -> Result<SqlExpr, ExprParseError> {
63    parse_or(tokens, pos, depth)
64}
65
66fn parse_or(
67    tokens: &[Token],
68    pos: &mut usize,
69    depth: &mut usize,
70) -> Result<SqlExpr, ExprParseError> {
71    let mut left = parse_and(tokens, pos, depth)?;
72    while peek_keyword(tokens, *pos, "OR") {
73        *pos += 1;
74        let right = parse_and(tokens, pos, depth)?;
75        left = SqlExpr::BinaryOp {
76            left: Box::new(left),
77            op: BinaryOp::Or,
78            right: Box::new(right),
79        };
80    }
81    Ok(left)
82}
83
84fn parse_and(
85    tokens: &[Token],
86    pos: &mut usize,
87    depth: &mut usize,
88) -> Result<SqlExpr, ExprParseError> {
89    let mut left = parse_comparison(tokens, pos, depth)?;
90    while peek_keyword(tokens, *pos, "AND") {
91        *pos += 1;
92        let right = parse_comparison(tokens, pos, depth)?;
93        left = SqlExpr::BinaryOp {
94            left: Box::new(left),
95            op: BinaryOp::And,
96            right: Box::new(right),
97        };
98    }
99    Ok(left)
100}
101
102fn parse_comparison(
103    tokens: &[Token],
104    pos: &mut usize,
105    depth: &mut usize,
106) -> Result<SqlExpr, ExprParseError> {
107    let left = parse_additive(tokens, pos, depth)?;
108    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
109        let op = match tokens[*pos].text.as_str() {
110            "=" => BinaryOp::Eq,
111            "!=" | "<>" => BinaryOp::NotEq,
112            "<" => BinaryOp::Lt,
113            "<=" => BinaryOp::LtEq,
114            ">" => BinaryOp::Gt,
115            ">=" => BinaryOp::GtEq,
116            _ => return Ok(left),
117        };
118        *pos += 1;
119        let right = parse_additive(tokens, pos, depth)?;
120        return Ok(SqlExpr::BinaryOp {
121            left: Box::new(left),
122            op,
123            right: Box::new(right),
124        });
125    }
126    Ok(left)
127}
128
129fn parse_additive(
130    tokens: &[Token],
131    pos: &mut usize,
132    depth: &mut usize,
133) -> Result<SqlExpr, ExprParseError> {
134    let mut left = parse_multiplicative(tokens, pos, depth)?;
135    while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
136        let op = match tokens[*pos].text.as_str() {
137            "+" => BinaryOp::Add,
138            "-" => BinaryOp::Sub,
139            "||" => BinaryOp::Concat,
140            _ => break,
141        };
142        *pos += 1;
143        let right = parse_multiplicative(tokens, pos, depth)?;
144        left = SqlExpr::BinaryOp {
145            left: Box::new(left),
146            op,
147            right: Box::new(right),
148        };
149    }
150    Ok(left)
151}
152
153fn parse_multiplicative(
154    tokens: &[Token],
155    pos: &mut usize,
156    depth: &mut usize,
157) -> Result<SqlExpr, ExprParseError> {
158    let mut left = parse_unary(tokens, pos, depth)?;
159    while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
160        let op = match tokens[*pos].text.as_str() {
161            "*" => BinaryOp::Mul,
162            "/" => BinaryOp::Div,
163            "%" => BinaryOp::Mod,
164            _ => break,
165        };
166        *pos += 1;
167        let right = parse_unary(tokens, pos, depth)?;
168        left = SqlExpr::BinaryOp {
169            left: Box::new(left),
170            op,
171            right: Box::new(right),
172        };
173    }
174    Ok(left)
175}
176
177fn parse_unary(
178    tokens: &[Token],
179    pos: &mut usize,
180    depth: &mut usize,
181) -> Result<SqlExpr, ExprParseError> {
182    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" {
183        *pos += 1;
184        let expr = parse_primary(tokens, pos, depth)?;
185        return Ok(SqlExpr::Negate(Box::new(expr)));
186    }
187    if peek_keyword(tokens, *pos, "NOT") {
188        *pos += 1;
189        let expr = parse_primary(tokens, pos, depth)?;
190        return Ok(SqlExpr::Negate(Box::new(expr)));
191    }
192    parse_primary(tokens, pos, depth)
193}
194
195fn parse_primary(
196    tokens: &[Token],
197    pos: &mut usize,
198    depth: &mut usize,
199) -> Result<SqlExpr, ExprParseError> {
200    if *pos >= tokens.len() {
201        return Err(ExprParseError::UnexpectedEof);
202    }
203
204    let token = &tokens[*pos];
205
206    match token.kind {
207        TokenKind::LParen => {
208            *depth += 1;
209            if *depth > MAX_EXPR_DEPTH {
210                return Err(ExprParseError::DepthLimitExceeded {
211                    max: MAX_EXPR_DEPTH,
212                });
213            }
214            *pos += 1;
215            let expr = parse_expr(tokens, pos, depth)?;
216            *depth -= 1;
217            expect_token(tokens, pos, TokenKind::RParen, ")")?;
218            Ok(expr)
219        }
220
221        TokenKind::Number => {
222            *pos += 1;
223            if let Ok(i) = token.text.parse::<i64>() {
224                Ok(SqlExpr::Literal(Value::Integer(i)))
225            } else if let Ok(f) = token.text.parse::<f64>() {
226                Ok(SqlExpr::Literal(Value::Float(f)))
227            } else {
228                Err(ExprParseError::InvalidLiteral {
229                    detail: format!("invalid number: '{}'", token.text),
230                })
231            }
232        }
233
234        TokenKind::StringLit => {
235            *pos += 1;
236            Ok(SqlExpr::Literal(Value::String(token.text.clone())))
237        }
238
239        TokenKind::Ident => {
240            let name = token.text.clone();
241            let upper = name.to_uppercase();
242            *pos += 1;
243
244            match upper.as_str() {
245                "NULL" => Ok(SqlExpr::Literal(Value::Null)),
246                "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))),
247                "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))),
248                "CASE" => parse_case(tokens, pos, depth),
249                "COALESCE" => {
250                    let args = parse_arg_list(tokens, pos, depth)?;
251                    Ok(SqlExpr::Coalesce(args))
252                }
253                _ => {
254                    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen {
255                        let args = parse_arg_list(tokens, pos, depth)?;
256                        Ok(SqlExpr::Function {
257                            name: name.to_lowercase(),
258                            args,
259                        })
260                    } else {
261                        Ok(SqlExpr::Column(name.to_lowercase()))
262                    }
263                }
264            }
265        }
266
267        _ => Err(ExprParseError::UnexpectedToken {
268            found: token.text.clone(),
269            pos: *pos,
270        }),
271    }
272}
273
274fn parse_case(
275    tokens: &[Token],
276    pos: &mut usize,
277    depth: &mut usize,
278) -> Result<SqlExpr, ExprParseError> {
279    let mut when_thens = Vec::new();
280    let mut else_expr = None;
281
282    loop {
283        if peek_keyword(tokens, *pos, "WHEN") {
284            *pos += 1;
285            let cond = parse_expr(tokens, pos, depth)?;
286            expect_keyword(tokens, pos, "THEN")?;
287            let then = parse_expr(tokens, pos, depth)?;
288            when_thens.push((cond, then));
289        } else if peek_keyword(tokens, *pos, "ELSE") {
290            *pos += 1;
291            else_expr = Some(Box::new(parse_expr(tokens, pos, depth)?));
292        } else if peek_keyword(tokens, *pos, "END") {
293            *pos += 1;
294            break;
295        } else {
296            return Err(ExprParseError::UnexpectedToken {
297                found: tokens
298                    .get(*pos)
299                    .map_or("EOF".to_string(), |t| t.text.clone()),
300                pos: *pos,
301            });
302        }
303    }
304
305    if when_thens.is_empty() {
306        return Err(ExprParseError::InvalidLiteral {
307            detail: "CASE requires at least one WHEN clause".to_string(),
308        });
309    }
310
311    Ok(SqlExpr::Case {
312        operand: None,
313        when_thens,
314        else_expr,
315    })
316}
317
318fn parse_arg_list(
319    tokens: &[Token],
320    pos: &mut usize,
321    depth: &mut usize,
322) -> Result<Vec<SqlExpr>, ExprParseError> {
323    expect_token(tokens, pos, TokenKind::LParen, "(")?;
324    let mut args = Vec::new();
325    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen {
326        *pos += 1;
327        return Ok(args);
328    }
329    loop {
330        args.push(parse_expr(tokens, pos, depth)?);
331        if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma {
332            *pos += 1;
333        } else {
334            break;
335        }
336    }
337    expect_token(tokens, pos, TokenKind::RParen, ")")?;
338    Ok(args)
339}
340
341// ── Helpers ───────────────────────────────────────────────────────────
342
343fn peek_keyword(tokens: &[Token], pos: usize, keyword: &str) -> bool {
344    pos < tokens.len()
345        && tokens[pos].kind == TokenKind::Ident
346        && tokens[pos].text.eq_ignore_ascii_case(keyword)
347}
348
349fn expect_keyword(tokens: &[Token], pos: &mut usize, keyword: &str) -> Result<(), ExprParseError> {
350    if peek_keyword(tokens, *pos, keyword) {
351        *pos += 1;
352        Ok(())
353    } else {
354        let got = tokens
355            .get(*pos)
356            .map_or_else(|| "EOF".to_string(), |t| t.text.clone());
357        Err(ExprParseError::UnexpectedToken {
358            found: got,
359            pos: *pos,
360        })
361    }
362}
363
364fn expect_token(
365    tokens: &[Token],
366    pos: &mut usize,
367    kind: TokenKind,
368    _expected: &str,
369) -> Result<(), ExprParseError> {
370    if *pos < tokens.len() && tokens[*pos].kind == kind {
371        *pos += 1;
372        Ok(())
373    } else {
374        let got = tokens
375            .get(*pos)
376            .map_or_else(|| "EOF".to_string(), |t| t.text.clone());
377        Err(ExprParseError::UnexpectedToken {
378            found: got,
379            pos: *pos,
380        })
381    }
382}
383
384// ── Validation ────────────────────────────────────────────────────────
385
386const NON_DETERMINISTIC: &[&str] = &[
387    "now",
388    "current_timestamp",
389    "random",
390    "nextval",
391    "uuid",
392    "uuid_v4",
393    "uuid_v7",
394    "gen_random_uuid",
395    "ulid",
396    "cuid2",
397    "nanoid",
398];
399
400fn validate_deterministic(expr: &SqlExpr) -> Result<(), ExprParseError> {
401    match expr {
402        SqlExpr::Function { name, args } => {
403            if NON_DETERMINISTIC.contains(&name.as_str()) {
404                return Err(ExprParseError::InvalidLiteral {
405                    detail: format!(
406                        "non-deterministic function '{name}()' not allowed in GENERATED ALWAYS AS"
407                    ),
408                });
409            }
410            for arg in args {
411                validate_deterministic(arg)?;
412            }
413            Ok(())
414        }
415        SqlExpr::BinaryOp { left, right, .. } => {
416            validate_deterministic(left)?;
417            validate_deterministic(right)
418        }
419        SqlExpr::Negate(inner) => validate_deterministic(inner),
420        SqlExpr::Coalesce(args) => {
421            for arg in args {
422                validate_deterministic(arg)?;
423            }
424            Ok(())
425        }
426        SqlExpr::Case {
427            operand,
428            when_thens,
429            else_expr,
430        } => {
431            if let Some(op) = operand {
432                validate_deterministic(op)?;
433            }
434            for (cond, then) in when_thens {
435                validate_deterministic(cond)?;
436                validate_deterministic(then)?;
437            }
438            if let Some(e) = else_expr {
439                validate_deterministic(e)?;
440            }
441            Ok(())
442        }
443        SqlExpr::Cast { expr, .. } => validate_deterministic(expr),
444        SqlExpr::NullIf(a, b) => {
445            validate_deterministic(a)?;
446            validate_deterministic(b)
447        }
448        SqlExpr::IsNull { expr, .. } => validate_deterministic(expr),
449        SqlExpr::Column(_)
450        | SqlExpr::Literal(_)
451        | SqlExpr::OldColumn(_)
452        | SqlExpr::ExcludedColumn(_) => Ok(()),
453    }
454}
455
456pub(crate) fn collect_columns(expr: &SqlExpr, deps: &mut Vec<String>) {
457    match expr {
458        SqlExpr::Column(name) => deps.push(name.clone()),
459        SqlExpr::BinaryOp { left, right, .. } => {
460            collect_columns(left, deps);
461            collect_columns(right, deps);
462        }
463        SqlExpr::Negate(inner) => collect_columns(inner, deps),
464        SqlExpr::Function { args, .. } => {
465            for arg in args {
466                collect_columns(arg, deps);
467            }
468        }
469        SqlExpr::Coalesce(args) => {
470            for arg in args {
471                collect_columns(arg, deps);
472            }
473        }
474        SqlExpr::Case {
475            operand,
476            when_thens,
477            else_expr,
478        } => {
479            if let Some(op) = operand {
480                collect_columns(op, deps);
481            }
482            for (cond, then) in when_thens {
483                collect_columns(cond, deps);
484                collect_columns(then, deps);
485            }
486            if let Some(e) = else_expr {
487                collect_columns(e, deps);
488            }
489        }
490        SqlExpr::Cast { expr, .. } => collect_columns(expr, deps),
491        SqlExpr::NullIf(a, b) => {
492            collect_columns(a, deps);
493            collect_columns(b, deps);
494        }
495        SqlExpr::IsNull { expr, .. } => collect_columns(expr, deps),
496        SqlExpr::Literal(_) | SqlExpr::OldColumn(_) | SqlExpr::ExcludedColumn(_) => {}
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use nodedb_types::Value;
504
505    fn parse_ok(text: &str) -> (SqlExpr, Vec<String>) {
506        parse_generated_expr(text).unwrap()
507    }
508
509    #[test]
510    fn simple_arithmetic() {
511        let (expr, deps) = parse_ok("price * (1 + tax_rate)");
512        assert_eq!(deps, vec!["price", "tax_rate"]);
513        let doc = Value::from(serde_json::json!({"price": 100.0, "tax_rate": 0.08}));
514        let result = expr.eval(&doc);
515        assert_eq!(result.as_f64(), Some(108.0));
516    }
517
518    #[test]
519    fn round_function() {
520        let (expr, deps) = parse_ok("ROUND(price * (1 + tax_rate), 2)");
521        assert_eq!(deps, vec!["price", "tax_rate"]);
522        let doc = Value::from(serde_json::json!({"price": 99.99, "tax_rate": 0.08}));
523        let result = expr.eval(&doc);
524        assert_eq!(result, Value::Float(107.99));
525    }
526
527    #[test]
528    fn concat_function() {
529        let (expr, deps) = parse_ok("CONCAT(name, ' ', brand)");
530        assert_eq!(deps, vec!["brand", "name"]);
531        let doc = Value::from(serde_json::json!({"name": "Shoe", "brand": "Nike"}));
532        assert_eq!(expr.eval(&doc), Value::String("Shoe Nike".into()));
533    }
534
535    #[test]
536    fn coalesce() {
537        let (expr, _) = parse_ok("COALESCE(description, '')");
538        let doc = Value::from(serde_json::json!({"description": null}));
539        assert_eq!(expr.eval(&doc), Value::String("".into()));
540    }
541
542    #[test]
543    fn case_when() {
544        let (expr, deps) =
545            parse_ok("CASE WHEN discount > 0 THEN price * (1 - discount) ELSE price END");
546        assert!(deps.contains(&"discount".to_string()));
547        assert!(deps.contains(&"price".to_string()));
548
549        let doc = Value::from(serde_json::json!({"price": 100.0, "discount": 0.2}));
550        assert_eq!(expr.eval(&doc).as_f64(), Some(80.0));
551
552        let doc2 = Value::from(serde_json::json!({"price": 100.0, "discount": 0}));
553        assert_eq!(expr.eval(&doc2).as_f64(), Some(100.0));
554    }
555
556    #[test]
557    fn rejects_now() {
558        assert!(parse_generated_expr("NOW()").is_err());
559    }
560
561    #[test]
562    fn rejects_random() {
563        assert!(parse_generated_expr("RANDOM()").is_err());
564    }
565
566    #[test]
567    fn rejects_uuid() {
568        assert!(parse_generated_expr("UUID()").is_err());
569    }
570
571    #[test]
572    fn string_literal() {
573        let (expr, _) = parse_ok("CONCAT(name, ' - ', 'default')");
574        let doc = Value::from(serde_json::json!({"name": "Product"}));
575        assert_eq!(expr.eval(&doc), Value::String("Product - default".into()));
576    }
577
578    #[test]
579    fn null_literal() {
580        let (expr, _) = parse_ok("COALESCE(x, NULL, 0)");
581        let doc = Value::from(serde_json::json!({"x": null}));
582        assert_eq!(expr.eval(&doc), Value::Integer(0));
583    }
584
585    #[test]
586    fn nested_functions() {
587        let (expr, _) = parse_ok("ROUND(price * (1 - COALESCE(discount, 0)), 2)");
588        let doc = Value::from(serde_json::json!({"price": 49.99}));
589        assert_eq!(expr.eval(&doc), Value::Float(49.99));
590    }
591
592    #[test]
593    fn deeply_nested_parentheses_return_error_not_stack_overflow() {
594        let depth = 10_000;
595        let input = format!("{}x{}", "(".repeat(depth), ")".repeat(depth));
596        let result = parse_generated_expr(&input);
597        assert!(
598            result.is_err(),
599            "parse_generated_expr must return Err for {depth}-deep nesting"
600        );
601    }
602
603    #[test]
604    fn cjk_string_in_concat() {
605        let (expr, _) = parse_ok("CONCAT('你好', name)");
606        let doc = Value::from(serde_json::json!({"name": "world"}));
607        assert_eq!(expr.eval(&doc), Value::String("你好world".into()));
608    }
609
610    #[test]
611    fn comparison_with_utf8_literal() {
612        let (expr, deps) = parse_ok("name != '禁止'");
613        assert_eq!(deps, vec!["name"]);
614        let doc = Value::from(serde_json::json!({"name": "allowed"}));
615        assert_eq!(expr.eval(&doc), Value::Bool(true));
616    }
617}