Skip to main content

nodedb_query/
expr_parse.rs

1//! SQL expression text → SqlExpr AST parser.
2//!
3//! Parses the subset of SQL expressions used in `GENERATED ALWAYS AS (expr)`
4//! column definitions. Supports:
5//! - Column references: `price`, `tax_rate`
6//! - Numeric literals: `42`, `3.14`, `-1`
7//! - String literals: `'hello'`, `''escaped''`
8//! - Binary operators: `+`, `-`, `*`, `/`, `%`
9//! - Comparison: `=`, `!=`, `<>`, `<`, `>`, `<=`, `>=`
10//! - Logical: `AND`, `OR`, `NOT`
11//! - Parenthesized sub-expressions: `(a + b) * c`
12//! - Function calls: `ROUND(price * 1.08, 2)`, `CONCAT(a, ' ', b)`
13//! - COALESCE: `COALESCE(a, b, '')`
14//! - CASE WHEN: `CASE WHEN x > 0 THEN 'positive' ELSE 'non-positive' END`
15//! - NULL literal
16//!
17//! Determinism validation: rejects `NOW()`, `RANDOM()`, `NEXTVAL()`, `UUID()`.
18
19use super::expr::{BinaryOp, SqlExpr};
20use nodedb_types::Value;
21
22/// Parse a SQL expression string into an SqlExpr AST.
23///
24/// Returns the parsed expression and a list of column names it references
25/// (the `depends_on` set for generated columns).
26pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec<String>), String> {
27    let tokens = tokenize(text)?;
28    let mut pos = 0;
29    let expr = parse_expr(&tokens, &mut pos)?;
30    if pos < tokens.len() {
31        return Err(format!(
32            "unexpected token after expression: '{}'",
33            tokens[pos].text
34        ));
35    }
36
37    // Validate determinism.
38    validate_deterministic(&expr)?;
39
40    // Collect column references.
41    let mut deps = Vec::new();
42    collect_columns(&expr, &mut deps);
43    deps.sort();
44    deps.dedup();
45
46    Ok((expr, deps))
47}
48
49// ── Tokenizer ─────────────────────────────────────────────────────────
50
51#[derive(Debug, Clone)]
52struct Token {
53    text: String,
54    kind: TokenKind,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq)]
58enum TokenKind {
59    Ident,
60    Number,
61    StringLit,
62    LParen,
63    RParen,
64    Comma,
65    Op,
66}
67
68fn tokenize(input: &str) -> Result<Vec<Token>, String> {
69    let bytes = input.as_bytes();
70    let mut tokens = Vec::new();
71    let mut i = 0;
72
73    while i < bytes.len() {
74        let b = bytes[i];
75
76        // Skip whitespace.
77        if b.is_ascii_whitespace() {
78            i += 1;
79            continue;
80        }
81
82        // Single-char tokens.
83        if b == b'(' {
84            tokens.push(Token {
85                text: "(".into(),
86                kind: TokenKind::LParen,
87            });
88            i += 1;
89            continue;
90        }
91        if b == b')' {
92            tokens.push(Token {
93                text: ")".into(),
94                kind: TokenKind::RParen,
95            });
96            i += 1;
97            continue;
98        }
99        if b == b',' {
100            tokens.push(Token {
101                text: ",".into(),
102                kind: TokenKind::Comma,
103            });
104            i += 1;
105            continue;
106        }
107
108        // Two-char operators.
109        if i + 1 < bytes.len() {
110            let two = &input[i..i + 2];
111            if matches!(two, "<=" | ">=" | "!=" | "<>") {
112                tokens.push(Token {
113                    text: two.into(),
114                    kind: TokenKind::Op,
115                });
116                i += 2;
117                continue;
118            }
119            if two == "||" {
120                tokens.push(Token {
121                    text: "||".into(),
122                    kind: TokenKind::Op,
123                });
124                i += 2;
125                continue;
126            }
127        }
128
129        // Single-char operators.
130        if matches!(b, b'+' | b'-' | b'*' | b'/' | b'%' | b'=' | b'<' | b'>') {
131            tokens.push(Token {
132                text: (b as char).to_string(),
133                kind: TokenKind::Op,
134            });
135            i += 1;
136            continue;
137        }
138
139        // String literal.
140        if b == b'\'' {
141            let mut s = String::new();
142            i += 1;
143            while i < bytes.len() {
144                if bytes[i] == b'\'' {
145                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
146                        s.push('\'');
147                        i += 2;
148                        continue;
149                    }
150                    i += 1;
151                    break;
152                }
153                s.push(bytes[i] as char);
154                i += 1;
155            }
156            tokens.push(Token {
157                text: s,
158                kind: TokenKind::StringLit,
159            });
160            continue;
161        }
162
163        // Number.
164        if b.is_ascii_digit() || (b == b'.' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit())
165        {
166            let start = i;
167            while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') {
168                i += 1;
169            }
170            tokens.push(Token {
171                text: input[start..i].to_string(),
172                kind: TokenKind::Number,
173            });
174            continue;
175        }
176
177        // Identifier or keyword.
178        if b.is_ascii_alphabetic() || b == b'_' {
179            let start = i;
180            while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
181                i += 1;
182            }
183            tokens.push(Token {
184                text: input[start..i].to_string(),
185                kind: TokenKind::Ident,
186            });
187            continue;
188        }
189
190        return Err(format!("unexpected character: '{}'", b as char));
191    }
192
193    Ok(tokens)
194}
195
196// ── Recursive descent parser ──────────────────────────────────────────
197
198/// Parse an expression (lowest precedence: OR).
199fn parse_expr(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
200    parse_or(tokens, pos)
201}
202
203fn parse_or(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
204    let mut left = parse_and(tokens, pos)?;
205    while peek_keyword(tokens, *pos, "OR") {
206        *pos += 1;
207        let right = parse_and(tokens, pos)?;
208        left = SqlExpr::BinaryOp {
209            left: Box::new(left),
210            op: BinaryOp::Or,
211            right: Box::new(right),
212        };
213    }
214    Ok(left)
215}
216
217fn parse_and(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
218    let mut left = parse_comparison(tokens, pos)?;
219    while peek_keyword(tokens, *pos, "AND") {
220        *pos += 1;
221        let right = parse_comparison(tokens, pos)?;
222        left = SqlExpr::BinaryOp {
223            left: Box::new(left),
224            op: BinaryOp::And,
225            right: Box::new(right),
226        };
227    }
228    Ok(left)
229}
230
231fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
232    let left = parse_additive(tokens, pos)?;
233    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
234        let op = match tokens[*pos].text.as_str() {
235            "=" => BinaryOp::Eq,
236            "!=" | "<>" => BinaryOp::NotEq,
237            "<" => BinaryOp::Lt,
238            "<=" => BinaryOp::LtEq,
239            ">" => BinaryOp::Gt,
240            ">=" => BinaryOp::GtEq,
241            _ => return Ok(left),
242        };
243        *pos += 1;
244        let right = parse_additive(tokens, pos)?;
245        return Ok(SqlExpr::BinaryOp {
246            left: Box::new(left),
247            op,
248            right: Box::new(right),
249        });
250    }
251    Ok(left)
252}
253
254fn parse_additive(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
255    let mut left = parse_multiplicative(tokens, pos)?;
256    while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
257        let op = match tokens[*pos].text.as_str() {
258            "+" => BinaryOp::Add,
259            "-" => BinaryOp::Sub,
260            "||" => BinaryOp::Concat,
261            _ => break,
262        };
263        *pos += 1;
264        let right = parse_multiplicative(tokens, pos)?;
265        left = SqlExpr::BinaryOp {
266            left: Box::new(left),
267            op,
268            right: Box::new(right),
269        };
270    }
271    Ok(left)
272}
273
274fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
275    let mut left = parse_unary(tokens, pos)?;
276    while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
277        let op = match tokens[*pos].text.as_str() {
278            "*" => BinaryOp::Mul,
279            "/" => BinaryOp::Div,
280            "%" => BinaryOp::Mod,
281            _ => break,
282        };
283        *pos += 1;
284        let right = parse_unary(tokens, pos)?;
285        left = SqlExpr::BinaryOp {
286            left: Box::new(left),
287            op,
288            right: Box::new(right),
289        };
290    }
291    Ok(left)
292}
293
294fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
295    // Unary minus.
296    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" {
297        *pos += 1;
298        let expr = parse_primary(tokens, pos)?;
299        return Ok(SqlExpr::Negate(Box::new(expr)));
300    }
301    // NOT
302    if peek_keyword(tokens, *pos, "NOT") {
303        *pos += 1;
304        let expr = parse_primary(tokens, pos)?;
305        return Ok(SqlExpr::Negate(Box::new(expr)));
306    }
307    parse_primary(tokens, pos)
308}
309
310fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
311    if *pos >= tokens.len() {
312        return Err("unexpected end of expression".into());
313    }
314
315    let token = &tokens[*pos];
316
317    match token.kind {
318        // Parenthesized expression.
319        TokenKind::LParen => {
320            *pos += 1;
321            let expr = parse_expr(tokens, pos)?;
322            expect_token(tokens, pos, TokenKind::RParen, ")")?;
323            Ok(expr)
324        }
325
326        // Number literal.
327        TokenKind::Number => {
328            *pos += 1;
329            if let Ok(i) = token.text.parse::<i64>() {
330                Ok(SqlExpr::Literal(Value::Integer(i)))
331            } else if let Ok(f) = token.text.parse::<f64>() {
332                Ok(SqlExpr::Literal(Value::Float(f)))
333            } else {
334                Err(format!("invalid number: '{}'", token.text))
335            }
336        }
337
338        // String literal.
339        TokenKind::StringLit => {
340            *pos += 1;
341            Ok(SqlExpr::Literal(Value::String(token.text.clone())))
342        }
343
344        // Identifier: column ref, function call, keyword (NULL, TRUE, FALSE, CASE, COALESCE).
345        TokenKind::Ident => {
346            let name = token.text.clone();
347            let upper = name.to_uppercase();
348            *pos += 1;
349
350            match upper.as_str() {
351                "NULL" => Ok(SqlExpr::Literal(Value::Null)),
352                "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))),
353                "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))),
354                "CASE" => parse_case(tokens, pos),
355                "COALESCE" => {
356                    let args = parse_arg_list(tokens, pos)?;
357                    Ok(SqlExpr::Coalesce(args))
358                }
359                _ => {
360                    // Function call: IDENT(args).
361                    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen {
362                        let args = parse_arg_list(tokens, pos)?;
363                        Ok(SqlExpr::Function {
364                            name: name.to_lowercase(),
365                            args,
366                        })
367                    } else {
368                        // Column reference.
369                        Ok(SqlExpr::Column(name.to_lowercase()))
370                    }
371                }
372            }
373        }
374
375        _ => Err(format!("unexpected token: '{}'", token.text)),
376    }
377}
378
379/// Parse `CASE WHEN cond THEN result [WHEN ... THEN ...] [ELSE result] END`.
380fn parse_case(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
381    let mut when_thens = Vec::new();
382    let mut else_expr = None;
383
384    loop {
385        if peek_keyword(tokens, *pos, "WHEN") {
386            *pos += 1;
387            let cond = parse_expr(tokens, pos)?;
388            expect_keyword(tokens, pos, "THEN")?;
389            let then = parse_expr(tokens, pos)?;
390            when_thens.push((cond, then));
391        } else if peek_keyword(tokens, *pos, "ELSE") {
392            *pos += 1;
393            else_expr = Some(Box::new(parse_expr(tokens, pos)?));
394        } else if peek_keyword(tokens, *pos, "END") {
395            *pos += 1;
396            break;
397        } else {
398            return Err("expected WHEN, ELSE, or END in CASE expression".into());
399        }
400    }
401
402    if when_thens.is_empty() {
403        return Err("CASE requires at least one WHEN clause".into());
404    }
405
406    Ok(SqlExpr::Case {
407        operand: None,
408        when_thens,
409        else_expr,
410    })
411}
412
413/// Parse a parenthesized, comma-separated argument list: `(expr, expr, ...)`.
414fn parse_arg_list(tokens: &[Token], pos: &mut usize) -> Result<Vec<SqlExpr>, String> {
415    expect_token(tokens, pos, TokenKind::LParen, "(")?;
416    let mut args = Vec::new();
417    if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen {
418        *pos += 1;
419        return Ok(args);
420    }
421    loop {
422        args.push(parse_expr(tokens, pos)?);
423        if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma {
424            *pos += 1;
425        } else {
426            break;
427        }
428    }
429    expect_token(tokens, pos, TokenKind::RParen, ")")?;
430    Ok(args)
431}
432
433// ── Helpers ───────────────────────────────────────────────────────────
434
435fn peek_keyword(tokens: &[Token], pos: usize, keyword: &str) -> bool {
436    pos < tokens.len()
437        && tokens[pos].kind == TokenKind::Ident
438        && tokens[pos].text.eq_ignore_ascii_case(keyword)
439}
440
441fn expect_keyword(tokens: &[Token], pos: &mut usize, keyword: &str) -> Result<(), String> {
442    if peek_keyword(tokens, *pos, keyword) {
443        *pos += 1;
444        Ok(())
445    } else {
446        let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
447        Err(format!("expected '{keyword}', got '{got}'"))
448    }
449}
450
451fn expect_token(
452    tokens: &[Token],
453    pos: &mut usize,
454    kind: TokenKind,
455    expected: &str,
456) -> Result<(), String> {
457    if *pos < tokens.len() && tokens[*pos].kind == kind {
458        *pos += 1;
459        Ok(())
460    } else {
461        let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
462        Err(format!("expected '{expected}', got '{got}'"))
463    }
464}
465
466// ── Validation ────────────────────────────────────────────────────────
467
468/// Non-deterministic functions that are rejected in GENERATED ALWAYS AS.
469const NON_DETERMINISTIC: &[&str] = &[
470    "now",
471    "current_timestamp",
472    "random",
473    "nextval",
474    "uuid",
475    "uuid_v4",
476    "uuid_v7",
477    "gen_random_uuid",
478    "ulid",
479    "cuid2",
480    "nanoid",
481];
482
483fn validate_deterministic(expr: &SqlExpr) -> Result<(), String> {
484    match expr {
485        SqlExpr::Function { name, args } => {
486            if NON_DETERMINISTIC.contains(&name.as_str()) {
487                return Err(format!(
488                    "non-deterministic function '{name}()' not allowed in GENERATED ALWAYS AS"
489                ));
490            }
491            for arg in args {
492                validate_deterministic(arg)?;
493            }
494            Ok(())
495        }
496        SqlExpr::BinaryOp { left, right, .. } => {
497            validate_deterministic(left)?;
498            validate_deterministic(right)
499        }
500        SqlExpr::Negate(inner) => validate_deterministic(inner),
501        SqlExpr::Coalesce(args) => {
502            for arg in args {
503                validate_deterministic(arg)?;
504            }
505            Ok(())
506        }
507        SqlExpr::Case {
508            operand,
509            when_thens,
510            else_expr,
511        } => {
512            if let Some(op) = operand {
513                validate_deterministic(op)?;
514            }
515            for (cond, then) in when_thens {
516                validate_deterministic(cond)?;
517                validate_deterministic(then)?;
518            }
519            if let Some(e) = else_expr {
520                validate_deterministic(e)?;
521            }
522            Ok(())
523        }
524        SqlExpr::Cast { expr, .. } => validate_deterministic(expr),
525        SqlExpr::NullIf(a, b) => {
526            validate_deterministic(a)?;
527            validate_deterministic(b)
528        }
529        SqlExpr::IsNull { expr, .. } => validate_deterministic(expr),
530        SqlExpr::Column(_) | SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => Ok(()),
531    }
532}
533
534fn collect_columns(expr: &SqlExpr, deps: &mut Vec<String>) {
535    match expr {
536        SqlExpr::Column(name) => deps.push(name.clone()),
537        SqlExpr::BinaryOp { left, right, .. } => {
538            collect_columns(left, deps);
539            collect_columns(right, deps);
540        }
541        SqlExpr::Negate(inner) => collect_columns(inner, deps),
542        SqlExpr::Function { args, .. } => {
543            for arg in args {
544                collect_columns(arg, deps);
545            }
546        }
547        SqlExpr::Coalesce(args) => {
548            for arg in args {
549                collect_columns(arg, deps);
550            }
551        }
552        SqlExpr::Case {
553            operand,
554            when_thens,
555            else_expr,
556        } => {
557            if let Some(op) = operand {
558                collect_columns(op, deps);
559            }
560            for (cond, then) in when_thens {
561                collect_columns(cond, deps);
562                collect_columns(then, deps);
563            }
564            if let Some(e) = else_expr {
565                collect_columns(e, deps);
566            }
567        }
568        SqlExpr::Cast { expr, .. } => collect_columns(expr, deps),
569        SqlExpr::NullIf(a, b) => {
570            collect_columns(a, deps);
571            collect_columns(b, deps);
572        }
573        SqlExpr::IsNull { expr, .. } => collect_columns(expr, deps),
574        SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => {}
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use nodedb_types::Value;
582
583    fn parse_ok(text: &str) -> (SqlExpr, Vec<String>) {
584        parse_generated_expr(text).unwrap()
585    }
586
587    #[test]
588    fn simple_arithmetic() {
589        let (expr, deps) = parse_ok("price * (1 + tax_rate)");
590        assert_eq!(deps, vec!["price", "tax_rate"]);
591        let doc = Value::from(serde_json::json!({"price": 100.0, "tax_rate": 0.08}));
592        let result = expr.eval(&doc);
593        // eval returns integer when result is whole number.
594        assert_eq!(result.as_f64(), Some(108.0));
595    }
596
597    #[test]
598    fn round_function() {
599        let (expr, deps) = parse_ok("ROUND(price * (1 + tax_rate), 2)");
600        assert_eq!(deps, vec!["price", "tax_rate"]);
601        let doc = Value::from(serde_json::json!({"price": 99.99, "tax_rate": 0.08}));
602        let result = expr.eval(&doc);
603        assert_eq!(result, Value::Float(107.99));
604    }
605
606    #[test]
607    fn concat_function() {
608        let (expr, deps) = parse_ok("CONCAT(name, ' ', brand)");
609        assert_eq!(deps, vec!["brand", "name"]);
610        let doc = Value::from(serde_json::json!({"name": "Shoe", "brand": "Nike"}));
611        assert_eq!(expr.eval(&doc), Value::String("Shoe Nike".into()));
612    }
613
614    #[test]
615    fn coalesce() {
616        let (expr, _) = parse_ok("COALESCE(description, '')");
617        let doc = Value::from(serde_json::json!({"description": null}));
618        assert_eq!(expr.eval(&doc), Value::String("".into()));
619    }
620
621    #[test]
622    fn case_when() {
623        let (expr, deps) =
624            parse_ok("CASE WHEN discount > 0 THEN price * (1 - discount) ELSE price END");
625        assert!(deps.contains(&"discount".to_string()));
626        assert!(deps.contains(&"price".to_string()));
627
628        let doc = Value::from(serde_json::json!({"price": 100.0, "discount": 0.2}));
629        assert_eq!(expr.eval(&doc).as_f64(), Some(80.0));
630
631        let doc2 = Value::from(serde_json::json!({"price": 100.0, "discount": 0}));
632        assert_eq!(expr.eval(&doc2).as_f64(), Some(100.0));
633    }
634
635    #[test]
636    fn rejects_now() {
637        assert!(parse_generated_expr("NOW()").is_err());
638    }
639
640    #[test]
641    fn rejects_random() {
642        assert!(parse_generated_expr("RANDOM()").is_err());
643    }
644
645    #[test]
646    fn rejects_uuid() {
647        assert!(parse_generated_expr("UUID()").is_err());
648    }
649
650    #[test]
651    fn string_literal() {
652        let (expr, _) = parse_ok("CONCAT(name, ' - ', 'default')");
653        let doc = Value::from(serde_json::json!({"name": "Product"}));
654        assert_eq!(expr.eval(&doc), Value::String("Product - default".into()));
655    }
656
657    #[test]
658    fn null_literal() {
659        let (expr, _) = parse_ok("COALESCE(x, NULL, 0)");
660        let doc = Value::from(serde_json::json!({"x": null}));
661        assert_eq!(expr.eval(&doc), Value::Integer(0));
662    }
663
664    #[test]
665    fn nested_functions() {
666        let (expr, _) = parse_ok("ROUND(price * (1 - COALESCE(discount, 0)), 2)");
667        let doc = Value::from(serde_json::json!({"price": 49.99}));
668        assert_eq!(expr.eval(&doc), Value::Float(49.99));
669    }
670}