1mod tokenizer;
20
21use super::expr::{BinaryOp, SqlExpr};
22use nodedb_types::Value;
23use tokenizer::{Token, TokenKind, tokenize};
24
25pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec<String>), String> {
30 let tokens = tokenize(text)?;
31 let mut pos = 0;
32 let expr = parse_expr(&tokens, &mut pos, &mut 0)?;
33 if pos < tokens.len() {
34 return Err(format!(
35 "unexpected token after expression: '{}'",
36 tokens[pos].text
37 ));
38 }
39
40 validate_deterministic(&expr)?;
42
43 let mut deps = Vec::new();
45 collect_columns(&expr, &mut deps);
46 deps.sort();
47 deps.dedup();
48
49 Ok((expr, deps))
50}
51
52const MAX_EXPR_DEPTH: usize = 128;
56
57fn parse_expr(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
58 parse_or(tokens, pos, depth)
59}
60
61fn parse_or(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
62 let mut left = parse_and(tokens, pos, depth)?;
63 while peek_keyword(tokens, *pos, "OR") {
64 *pos += 1;
65 let right = parse_and(tokens, pos, depth)?;
66 left = SqlExpr::BinaryOp {
67 left: Box::new(left),
68 op: BinaryOp::Or,
69 right: Box::new(right),
70 };
71 }
72 Ok(left)
73}
74
75fn parse_and(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
76 let mut left = parse_comparison(tokens, pos, depth)?;
77 while peek_keyword(tokens, *pos, "AND") {
78 *pos += 1;
79 let right = parse_comparison(tokens, pos, depth)?;
80 left = SqlExpr::BinaryOp {
81 left: Box::new(left),
82 op: BinaryOp::And,
83 right: Box::new(right),
84 };
85 }
86 Ok(left)
87}
88
89fn parse_comparison(
90 tokens: &[Token],
91 pos: &mut usize,
92 depth: &mut usize,
93) -> Result<SqlExpr, String> {
94 let left = parse_additive(tokens, pos, depth)?;
95 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
96 let op = match tokens[*pos].text.as_str() {
97 "=" => BinaryOp::Eq,
98 "!=" | "<>" => BinaryOp::NotEq,
99 "<" => BinaryOp::Lt,
100 "<=" => BinaryOp::LtEq,
101 ">" => BinaryOp::Gt,
102 ">=" => BinaryOp::GtEq,
103 _ => return Ok(left),
104 };
105 *pos += 1;
106 let right = parse_additive(tokens, pos, depth)?;
107 return Ok(SqlExpr::BinaryOp {
108 left: Box::new(left),
109 op,
110 right: Box::new(right),
111 });
112 }
113 Ok(left)
114}
115
116fn parse_additive(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
117 let mut left = parse_multiplicative(tokens, pos, depth)?;
118 while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
119 let op = match tokens[*pos].text.as_str() {
120 "+" => BinaryOp::Add,
121 "-" => BinaryOp::Sub,
122 "||" => BinaryOp::Concat,
123 _ => break,
124 };
125 *pos += 1;
126 let right = parse_multiplicative(tokens, pos, depth)?;
127 left = SqlExpr::BinaryOp {
128 left: Box::new(left),
129 op,
130 right: Box::new(right),
131 };
132 }
133 Ok(left)
134}
135
136fn parse_multiplicative(
137 tokens: &[Token],
138 pos: &mut usize,
139 depth: &mut usize,
140) -> Result<SqlExpr, String> {
141 let mut left = parse_unary(tokens, pos, depth)?;
142 while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
143 let op = match tokens[*pos].text.as_str() {
144 "*" => BinaryOp::Mul,
145 "/" => BinaryOp::Div,
146 "%" => BinaryOp::Mod,
147 _ => break,
148 };
149 *pos += 1;
150 let right = parse_unary(tokens, pos, depth)?;
151 left = SqlExpr::BinaryOp {
152 left: Box::new(left),
153 op,
154 right: Box::new(right),
155 };
156 }
157 Ok(left)
158}
159
160fn parse_unary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
161 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" {
162 *pos += 1;
163 let expr = parse_primary(tokens, pos, depth)?;
164 return Ok(SqlExpr::Negate(Box::new(expr)));
165 }
166 if peek_keyword(tokens, *pos, "NOT") {
167 *pos += 1;
168 let expr = parse_primary(tokens, pos, depth)?;
169 return Ok(SqlExpr::Negate(Box::new(expr)));
170 }
171 parse_primary(tokens, pos, depth)
172}
173
174fn parse_primary(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
175 if *pos >= tokens.len() {
176 return Err("unexpected end of expression".into());
177 }
178
179 let token = &tokens[*pos];
180
181 match token.kind {
182 TokenKind::LParen => {
183 *depth += 1;
184 if *depth > MAX_EXPR_DEPTH {
185 return Err(format!(
186 "expression nesting depth exceeds maximum of {MAX_EXPR_DEPTH}"
187 ));
188 }
189 *pos += 1;
190 let expr = parse_expr(tokens, pos, depth)?;
191 *depth -= 1;
192 expect_token(tokens, pos, TokenKind::RParen, ")")?;
193 Ok(expr)
194 }
195
196 TokenKind::Number => {
197 *pos += 1;
198 if let Ok(i) = token.text.parse::<i64>() {
199 Ok(SqlExpr::Literal(Value::Integer(i)))
200 } else if let Ok(f) = token.text.parse::<f64>() {
201 Ok(SqlExpr::Literal(Value::Float(f)))
202 } else {
203 Err(format!("invalid number: '{}'", token.text))
204 }
205 }
206
207 TokenKind::StringLit => {
208 *pos += 1;
209 Ok(SqlExpr::Literal(Value::String(token.text.clone())))
210 }
211
212 TokenKind::Ident => {
213 let name = token.text.clone();
214 let upper = name.to_uppercase();
215 *pos += 1;
216
217 match upper.as_str() {
218 "NULL" => Ok(SqlExpr::Literal(Value::Null)),
219 "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))),
220 "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))),
221 "CASE" => parse_case(tokens, pos, depth),
222 "COALESCE" => {
223 let args = parse_arg_list(tokens, pos, depth)?;
224 Ok(SqlExpr::Coalesce(args))
225 }
226 _ => {
227 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen {
228 let args = parse_arg_list(tokens, pos, depth)?;
229 Ok(SqlExpr::Function {
230 name: name.to_lowercase(),
231 args,
232 })
233 } else {
234 Ok(SqlExpr::Column(name.to_lowercase()))
235 }
236 }
237 }
238 }
239
240 _ => Err(format!("unexpected token: '{}'", token.text)),
241 }
242}
243
244fn parse_case(tokens: &[Token], pos: &mut usize, depth: &mut usize) -> Result<SqlExpr, String> {
245 let mut when_thens = Vec::new();
246 let mut else_expr = None;
247
248 loop {
249 if peek_keyword(tokens, *pos, "WHEN") {
250 *pos += 1;
251 let cond = parse_expr(tokens, pos, depth)?;
252 expect_keyword(tokens, pos, "THEN")?;
253 let then = parse_expr(tokens, pos, depth)?;
254 when_thens.push((cond, then));
255 } else if peek_keyword(tokens, *pos, "ELSE") {
256 *pos += 1;
257 else_expr = Some(Box::new(parse_expr(tokens, pos, depth)?));
258 } else if peek_keyword(tokens, *pos, "END") {
259 *pos += 1;
260 break;
261 } else {
262 return Err("expected WHEN, ELSE, or END in CASE expression".into());
263 }
264 }
265
266 if when_thens.is_empty() {
267 return Err("CASE requires at least one WHEN clause".into());
268 }
269
270 Ok(SqlExpr::Case {
271 operand: None,
272 when_thens,
273 else_expr,
274 })
275}
276
277fn parse_arg_list(
278 tokens: &[Token],
279 pos: &mut usize,
280 depth: &mut usize,
281) -> Result<Vec<SqlExpr>, String> {
282 expect_token(tokens, pos, TokenKind::LParen, "(")?;
283 let mut args = Vec::new();
284 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen {
285 *pos += 1;
286 return Ok(args);
287 }
288 loop {
289 args.push(parse_expr(tokens, pos, depth)?);
290 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma {
291 *pos += 1;
292 } else {
293 break;
294 }
295 }
296 expect_token(tokens, pos, TokenKind::RParen, ")")?;
297 Ok(args)
298}
299
300fn peek_keyword(tokens: &[Token], pos: usize, keyword: &str) -> bool {
303 pos < tokens.len()
304 && tokens[pos].kind == TokenKind::Ident
305 && tokens[pos].text.eq_ignore_ascii_case(keyword)
306}
307
308fn expect_keyword(tokens: &[Token], pos: &mut usize, keyword: &str) -> Result<(), String> {
309 if peek_keyword(tokens, *pos, keyword) {
310 *pos += 1;
311 Ok(())
312 } else {
313 let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
314 Err(format!("expected '{keyword}', got '{got}'"))
315 }
316}
317
318fn expect_token(
319 tokens: &[Token],
320 pos: &mut usize,
321 kind: TokenKind,
322 expected: &str,
323) -> Result<(), String> {
324 if *pos < tokens.len() && tokens[*pos].kind == kind {
325 *pos += 1;
326 Ok(())
327 } else {
328 let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
329 Err(format!("expected '{expected}', got '{got}'"))
330 }
331}
332
333const NON_DETERMINISTIC: &[&str] = &[
336 "now",
337 "current_timestamp",
338 "random",
339 "nextval",
340 "uuid",
341 "uuid_v4",
342 "uuid_v7",
343 "gen_random_uuid",
344 "ulid",
345 "cuid2",
346 "nanoid",
347];
348
349fn validate_deterministic(expr: &SqlExpr) -> Result<(), String> {
350 match expr {
351 SqlExpr::Function { name, args } => {
352 if NON_DETERMINISTIC.contains(&name.as_str()) {
353 return Err(format!(
354 "non-deterministic function '{name}()' not allowed in GENERATED ALWAYS AS"
355 ));
356 }
357 for arg in args {
358 validate_deterministic(arg)?;
359 }
360 Ok(())
361 }
362 SqlExpr::BinaryOp { left, right, .. } => {
363 validate_deterministic(left)?;
364 validate_deterministic(right)
365 }
366 SqlExpr::Negate(inner) => validate_deterministic(inner),
367 SqlExpr::Coalesce(args) => {
368 for arg in args {
369 validate_deterministic(arg)?;
370 }
371 Ok(())
372 }
373 SqlExpr::Case {
374 operand,
375 when_thens,
376 else_expr,
377 } => {
378 if let Some(op) = operand {
379 validate_deterministic(op)?;
380 }
381 for (cond, then) in when_thens {
382 validate_deterministic(cond)?;
383 validate_deterministic(then)?;
384 }
385 if let Some(e) = else_expr {
386 validate_deterministic(e)?;
387 }
388 Ok(())
389 }
390 SqlExpr::Cast { expr, .. } => validate_deterministic(expr),
391 SqlExpr::NullIf(a, b) => {
392 validate_deterministic(a)?;
393 validate_deterministic(b)
394 }
395 SqlExpr::IsNull { expr, .. } => validate_deterministic(expr),
396 SqlExpr::Column(_) | SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => Ok(()),
397 }
398}
399
400fn collect_columns(expr: &SqlExpr, deps: &mut Vec<String>) {
401 match expr {
402 SqlExpr::Column(name) => deps.push(name.clone()),
403 SqlExpr::BinaryOp { left, right, .. } => {
404 collect_columns(left, deps);
405 collect_columns(right, deps);
406 }
407 SqlExpr::Negate(inner) => collect_columns(inner, deps),
408 SqlExpr::Function { args, .. } => {
409 for arg in args {
410 collect_columns(arg, deps);
411 }
412 }
413 SqlExpr::Coalesce(args) => {
414 for arg in args {
415 collect_columns(arg, deps);
416 }
417 }
418 SqlExpr::Case {
419 operand,
420 when_thens,
421 else_expr,
422 } => {
423 if let Some(op) = operand {
424 collect_columns(op, deps);
425 }
426 for (cond, then) in when_thens {
427 collect_columns(cond, deps);
428 collect_columns(then, deps);
429 }
430 if let Some(e) = else_expr {
431 collect_columns(e, deps);
432 }
433 }
434 SqlExpr::Cast { expr, .. } => collect_columns(expr, deps),
435 SqlExpr::NullIf(a, b) => {
436 collect_columns(a, deps);
437 collect_columns(b, deps);
438 }
439 SqlExpr::IsNull { expr, .. } => collect_columns(expr, deps),
440 SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => {}
441 }
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use nodedb_types::Value;
448
449 fn parse_ok(text: &str) -> (SqlExpr, Vec<String>) {
450 parse_generated_expr(text).unwrap()
451 }
452
453 #[test]
454 fn simple_arithmetic() {
455 let (expr, deps) = parse_ok("price * (1 + tax_rate)");
456 assert_eq!(deps, vec!["price", "tax_rate"]);
457 let doc = Value::from(serde_json::json!({"price": 100.0, "tax_rate": 0.08}));
458 let result = expr.eval(&doc);
459 assert_eq!(result.as_f64(), Some(108.0));
460 }
461
462 #[test]
463 fn round_function() {
464 let (expr, deps) = parse_ok("ROUND(price * (1 + tax_rate), 2)");
465 assert_eq!(deps, vec!["price", "tax_rate"]);
466 let doc = Value::from(serde_json::json!({"price": 99.99, "tax_rate": 0.08}));
467 let result = expr.eval(&doc);
468 assert_eq!(result, Value::Float(107.99));
469 }
470
471 #[test]
472 fn concat_function() {
473 let (expr, deps) = parse_ok("CONCAT(name, ' ', brand)");
474 assert_eq!(deps, vec!["brand", "name"]);
475 let doc = Value::from(serde_json::json!({"name": "Shoe", "brand": "Nike"}));
476 assert_eq!(expr.eval(&doc), Value::String("Shoe Nike".into()));
477 }
478
479 #[test]
480 fn coalesce() {
481 let (expr, _) = parse_ok("COALESCE(description, '')");
482 let doc = Value::from(serde_json::json!({"description": null}));
483 assert_eq!(expr.eval(&doc), Value::String("".into()));
484 }
485
486 #[test]
487 fn case_when() {
488 let (expr, deps) =
489 parse_ok("CASE WHEN discount > 0 THEN price * (1 - discount) ELSE price END");
490 assert!(deps.contains(&"discount".to_string()));
491 assert!(deps.contains(&"price".to_string()));
492
493 let doc = Value::from(serde_json::json!({"price": 100.0, "discount": 0.2}));
494 assert_eq!(expr.eval(&doc).as_f64(), Some(80.0));
495
496 let doc2 = Value::from(serde_json::json!({"price": 100.0, "discount": 0}));
497 assert_eq!(expr.eval(&doc2).as_f64(), Some(100.0));
498 }
499
500 #[test]
501 fn rejects_now() {
502 assert!(parse_generated_expr("NOW()").is_err());
503 }
504
505 #[test]
506 fn rejects_random() {
507 assert!(parse_generated_expr("RANDOM()").is_err());
508 }
509
510 #[test]
511 fn rejects_uuid() {
512 assert!(parse_generated_expr("UUID()").is_err());
513 }
514
515 #[test]
516 fn string_literal() {
517 let (expr, _) = parse_ok("CONCAT(name, ' - ', 'default')");
518 let doc = Value::from(serde_json::json!({"name": "Product"}));
519 assert_eq!(expr.eval(&doc), Value::String("Product - default".into()));
520 }
521
522 #[test]
523 fn null_literal() {
524 let (expr, _) = parse_ok("COALESCE(x, NULL, 0)");
525 let doc = Value::from(serde_json::json!({"x": null}));
526 assert_eq!(expr.eval(&doc), Value::Integer(0));
527 }
528
529 #[test]
530 fn nested_functions() {
531 let (expr, _) = parse_ok("ROUND(price * (1 - COALESCE(discount, 0)), 2)");
532 let doc = Value::from(serde_json::json!({"price": 49.99}));
533 assert_eq!(expr.eval(&doc), Value::Float(49.99));
534 }
535
536 #[test]
537 fn deeply_nested_parentheses_return_error_not_stack_overflow() {
538 let depth = 10_000;
539 let input = format!("{}x{}", "(".repeat(depth), ")".repeat(depth));
540 let result = parse_generated_expr(&input);
541 assert!(
542 result.is_err(),
543 "parse_generated_expr must return Err for {depth}-deep nesting"
544 );
545 }
546
547 #[test]
548 fn cjk_string_in_concat() {
549 let (expr, _) = parse_ok("CONCAT('你好', name)");
550 let doc = Value::from(serde_json::json!({"name": "world"}));
551 assert_eq!(expr.eval(&doc), Value::String("你好world".into()));
552 }
553
554 #[test]
555 fn comparison_with_utf8_literal() {
556 let (expr, deps) = parse_ok("name != '禁止'");
557 assert_eq!(deps, vec!["name"]);
558 let doc = Value::from(serde_json::json!({"name": "allowed"}));
559 assert_eq!(expr.eval(&doc), Value::Bool(true));
560 }
561}