1use super::error::ExprParseError;
22use super::tokenizer::{Token, TokenKind, tokenize};
23use crate::expr::{BinaryOp, SqlExpr};
24use nodedb_types::Value;
25
26pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec<String>), ExprParseError> {
31 let tokens = tokenize(text)?;
32 let mut pos = 0;
33 let expr = parse_expr(&tokens, &mut pos, &mut 0)?;
34 if pos < tokens.len() {
35 return Err(ExprParseError::UnexpectedToken {
36 found: tokens[pos].text.clone(),
37 pos,
38 });
39 }
40
41 validate_deterministic(&expr)?;
43
44 let mut deps = Vec::new();
46 collect_columns(&expr, &mut deps);
47 deps.sort();
48 deps.dedup();
49
50 Ok((expr, deps))
51}
52
53const MAX_EXPR_DEPTH: usize = 128;
57
58fn parse_expr(
59 tokens: &[Token],
60 pos: &mut usize,
61 depth: &mut usize,
62) -> Result<SqlExpr, ExprParseError> {
63 parse_or(tokens, pos, depth)
64}
65
66fn parse_or(
67 tokens: &[Token],
68 pos: &mut usize,
69 depth: &mut usize,
70) -> Result<SqlExpr, ExprParseError> {
71 let mut left = parse_and(tokens, pos, depth)?;
72 while peek_keyword(tokens, *pos, "OR") {
73 *pos += 1;
74 let right = parse_and(tokens, pos, depth)?;
75 left = SqlExpr::BinaryOp {
76 left: Box::new(left),
77 op: BinaryOp::Or,
78 right: Box::new(right),
79 };
80 }
81 Ok(left)
82}
83
84fn parse_and(
85 tokens: &[Token],
86 pos: &mut usize,
87 depth: &mut usize,
88) -> Result<SqlExpr, ExprParseError> {
89 let mut left = parse_comparison(tokens, pos, depth)?;
90 while peek_keyword(tokens, *pos, "AND") {
91 *pos += 1;
92 let right = parse_comparison(tokens, pos, depth)?;
93 left = SqlExpr::BinaryOp {
94 left: Box::new(left),
95 op: BinaryOp::And,
96 right: Box::new(right),
97 };
98 }
99 Ok(left)
100}
101
102fn parse_comparison(
103 tokens: &[Token],
104 pos: &mut usize,
105 depth: &mut usize,
106) -> Result<SqlExpr, ExprParseError> {
107 let left = parse_additive(tokens, pos, depth)?;
108 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
109 let op = match tokens[*pos].text.as_str() {
110 "=" => BinaryOp::Eq,
111 "!=" | "<>" => BinaryOp::NotEq,
112 "<" => BinaryOp::Lt,
113 "<=" => BinaryOp::LtEq,
114 ">" => BinaryOp::Gt,
115 ">=" => BinaryOp::GtEq,
116 _ => return Ok(left),
117 };
118 *pos += 1;
119 let right = parse_additive(tokens, pos, depth)?;
120 return Ok(SqlExpr::BinaryOp {
121 left: Box::new(left),
122 op,
123 right: Box::new(right),
124 });
125 }
126 Ok(left)
127}
128
129fn parse_additive(
130 tokens: &[Token],
131 pos: &mut usize,
132 depth: &mut usize,
133) -> Result<SqlExpr, ExprParseError> {
134 let mut left = parse_multiplicative(tokens, pos, depth)?;
135 while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
136 let op = match tokens[*pos].text.as_str() {
137 "+" => BinaryOp::Add,
138 "-" => BinaryOp::Sub,
139 "||" => BinaryOp::Concat,
140 _ => break,
141 };
142 *pos += 1;
143 let right = parse_multiplicative(tokens, pos, depth)?;
144 left = SqlExpr::BinaryOp {
145 left: Box::new(left),
146 op,
147 right: Box::new(right),
148 };
149 }
150 Ok(left)
151}
152
153fn parse_multiplicative(
154 tokens: &[Token],
155 pos: &mut usize,
156 depth: &mut usize,
157) -> Result<SqlExpr, ExprParseError> {
158 let mut left = parse_unary(tokens, pos, depth)?;
159 while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
160 let op = match tokens[*pos].text.as_str() {
161 "*" => BinaryOp::Mul,
162 "/" => BinaryOp::Div,
163 "%" => BinaryOp::Mod,
164 _ => break,
165 };
166 *pos += 1;
167 let right = parse_unary(tokens, pos, depth)?;
168 left = SqlExpr::BinaryOp {
169 left: Box::new(left),
170 op,
171 right: Box::new(right),
172 };
173 }
174 Ok(left)
175}
176
177fn parse_unary(
178 tokens: &[Token],
179 pos: &mut usize,
180 depth: &mut usize,
181) -> Result<SqlExpr, ExprParseError> {
182 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" {
183 *pos += 1;
184 let expr = parse_primary(tokens, pos, depth)?;
185 return Ok(SqlExpr::Negate(Box::new(expr)));
186 }
187 if peek_keyword(tokens, *pos, "NOT") {
188 *pos += 1;
189 let expr = parse_primary(tokens, pos, depth)?;
190 return Ok(SqlExpr::Negate(Box::new(expr)));
191 }
192 parse_primary(tokens, pos, depth)
193}
194
195fn parse_primary(
196 tokens: &[Token],
197 pos: &mut usize,
198 depth: &mut usize,
199) -> Result<SqlExpr, ExprParseError> {
200 if *pos >= tokens.len() {
201 return Err(ExprParseError::UnexpectedEof);
202 }
203
204 let token = &tokens[*pos];
205
206 match token.kind {
207 TokenKind::LParen => {
208 *depth += 1;
209 if *depth > MAX_EXPR_DEPTH {
210 return Err(ExprParseError::DepthLimitExceeded {
211 max: MAX_EXPR_DEPTH,
212 });
213 }
214 *pos += 1;
215 let expr = parse_expr(tokens, pos, depth)?;
216 *depth -= 1;
217 expect_token(tokens, pos, TokenKind::RParen, ")")?;
218 Ok(expr)
219 }
220
221 TokenKind::Number => {
222 *pos += 1;
223 if let Ok(i) = token.text.parse::<i64>() {
224 Ok(SqlExpr::Literal(Value::Integer(i)))
225 } else if let Ok(f) = token.text.parse::<f64>() {
226 Ok(SqlExpr::Literal(Value::Float(f)))
227 } else {
228 Err(ExprParseError::InvalidLiteral {
229 detail: format!("invalid number: '{}'", token.text),
230 })
231 }
232 }
233
234 TokenKind::StringLit => {
235 *pos += 1;
236 Ok(SqlExpr::Literal(Value::String(token.text.clone())))
237 }
238
239 TokenKind::Ident => {
240 let name = token.text.clone();
241 let upper = name.to_uppercase();
242 *pos += 1;
243
244 match upper.as_str() {
245 "NULL" => Ok(SqlExpr::Literal(Value::Null)),
246 "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))),
247 "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))),
248 "CASE" => parse_case(tokens, pos, depth),
249 "COALESCE" => {
250 let args = parse_arg_list(tokens, pos, depth)?;
251 Ok(SqlExpr::Coalesce(args))
252 }
253 _ => {
254 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen {
255 let args = parse_arg_list(tokens, pos, depth)?;
256 Ok(SqlExpr::Function {
257 name: name.to_lowercase(),
258 args,
259 })
260 } else {
261 Ok(SqlExpr::Column(name.to_lowercase()))
262 }
263 }
264 }
265 }
266
267 _ => Err(ExprParseError::UnexpectedToken {
268 found: token.text.clone(),
269 pos: *pos,
270 }),
271 }
272}
273
274fn parse_case(
275 tokens: &[Token],
276 pos: &mut usize,
277 depth: &mut usize,
278) -> Result<SqlExpr, ExprParseError> {
279 let mut when_thens = Vec::new();
280 let mut else_expr = None;
281
282 loop {
283 if peek_keyword(tokens, *pos, "WHEN") {
284 *pos += 1;
285 let cond = parse_expr(tokens, pos, depth)?;
286 expect_keyword(tokens, pos, "THEN")?;
287 let then = parse_expr(tokens, pos, depth)?;
288 when_thens.push((cond, then));
289 } else if peek_keyword(tokens, *pos, "ELSE") {
290 *pos += 1;
291 else_expr = Some(Box::new(parse_expr(tokens, pos, depth)?));
292 } else if peek_keyword(tokens, *pos, "END") {
293 *pos += 1;
294 break;
295 } else {
296 return Err(ExprParseError::UnexpectedToken {
297 found: tokens
298 .get(*pos)
299 .map_or("EOF".to_string(), |t| t.text.clone()),
300 pos: *pos,
301 });
302 }
303 }
304
305 if when_thens.is_empty() {
306 return Err(ExprParseError::InvalidLiteral {
307 detail: "CASE requires at least one WHEN clause".to_string(),
308 });
309 }
310
311 Ok(SqlExpr::Case {
312 operand: None,
313 when_thens,
314 else_expr,
315 })
316}
317
318fn parse_arg_list(
319 tokens: &[Token],
320 pos: &mut usize,
321 depth: &mut usize,
322) -> Result<Vec<SqlExpr>, ExprParseError> {
323 expect_token(tokens, pos, TokenKind::LParen, "(")?;
324 let mut args = Vec::new();
325 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen {
326 *pos += 1;
327 return Ok(args);
328 }
329 loop {
330 args.push(parse_expr(tokens, pos, depth)?);
331 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma {
332 *pos += 1;
333 } else {
334 break;
335 }
336 }
337 expect_token(tokens, pos, TokenKind::RParen, ")")?;
338 Ok(args)
339}
340
341fn peek_keyword(tokens: &[Token], pos: usize, keyword: &str) -> bool {
344 pos < tokens.len()
345 && tokens[pos].kind == TokenKind::Ident
346 && tokens[pos].text.eq_ignore_ascii_case(keyword)
347}
348
349fn expect_keyword(tokens: &[Token], pos: &mut usize, keyword: &str) -> Result<(), ExprParseError> {
350 if peek_keyword(tokens, *pos, keyword) {
351 *pos += 1;
352 Ok(())
353 } else {
354 let got = tokens
355 .get(*pos)
356 .map_or_else(|| "EOF".to_string(), |t| t.text.clone());
357 Err(ExprParseError::UnexpectedToken {
358 found: got,
359 pos: *pos,
360 })
361 }
362}
363
364fn expect_token(
365 tokens: &[Token],
366 pos: &mut usize,
367 kind: TokenKind,
368 _expected: &str,
369) -> Result<(), ExprParseError> {
370 if *pos < tokens.len() && tokens[*pos].kind == kind {
371 *pos += 1;
372 Ok(())
373 } else {
374 let got = tokens
375 .get(*pos)
376 .map_or_else(|| "EOF".to_string(), |t| t.text.clone());
377 Err(ExprParseError::UnexpectedToken {
378 found: got,
379 pos: *pos,
380 })
381 }
382}
383
384const NON_DETERMINISTIC: &[&str] = &[
387 "now",
388 "current_timestamp",
389 "random",
390 "nextval",
391 "uuid",
392 "uuid_v4",
393 "uuid_v7",
394 "gen_random_uuid",
395 "ulid",
396 "cuid2",
397 "nanoid",
398];
399
400fn validate_deterministic(expr: &SqlExpr) -> Result<(), ExprParseError> {
401 match expr {
402 SqlExpr::Function { name, args } => {
403 if NON_DETERMINISTIC.contains(&name.as_str()) {
404 return Err(ExprParseError::InvalidLiteral {
405 detail: format!(
406 "non-deterministic function '{name}()' not allowed in GENERATED ALWAYS AS"
407 ),
408 });
409 }
410 for arg in args {
411 validate_deterministic(arg)?;
412 }
413 Ok(())
414 }
415 SqlExpr::BinaryOp { left, right, .. } => {
416 validate_deterministic(left)?;
417 validate_deterministic(right)
418 }
419 SqlExpr::Negate(inner) => validate_deterministic(inner),
420 SqlExpr::Coalesce(args) => {
421 for arg in args {
422 validate_deterministic(arg)?;
423 }
424 Ok(())
425 }
426 SqlExpr::Case {
427 operand,
428 when_thens,
429 else_expr,
430 } => {
431 if let Some(op) = operand {
432 validate_deterministic(op)?;
433 }
434 for (cond, then) in when_thens {
435 validate_deterministic(cond)?;
436 validate_deterministic(then)?;
437 }
438 if let Some(e) = else_expr {
439 validate_deterministic(e)?;
440 }
441 Ok(())
442 }
443 SqlExpr::Cast { expr, .. } => validate_deterministic(expr),
444 SqlExpr::NullIf(a, b) => {
445 validate_deterministic(a)?;
446 validate_deterministic(b)
447 }
448 SqlExpr::IsNull { expr, .. } => validate_deterministic(expr),
449 SqlExpr::Column(_)
450 | SqlExpr::Literal(_)
451 | SqlExpr::OldColumn(_)
452 | SqlExpr::ExcludedColumn(_) => Ok(()),
453 }
454}
455
456pub(crate) fn collect_columns(expr: &SqlExpr, deps: &mut Vec<String>) {
457 match expr {
458 SqlExpr::Column(name) => deps.push(name.clone()),
459 SqlExpr::BinaryOp { left, right, .. } => {
460 collect_columns(left, deps);
461 collect_columns(right, deps);
462 }
463 SqlExpr::Negate(inner) => collect_columns(inner, deps),
464 SqlExpr::Function { args, .. } => {
465 for arg in args {
466 collect_columns(arg, deps);
467 }
468 }
469 SqlExpr::Coalesce(args) => {
470 for arg in args {
471 collect_columns(arg, deps);
472 }
473 }
474 SqlExpr::Case {
475 operand,
476 when_thens,
477 else_expr,
478 } => {
479 if let Some(op) = operand {
480 collect_columns(op, deps);
481 }
482 for (cond, then) in when_thens {
483 collect_columns(cond, deps);
484 collect_columns(then, deps);
485 }
486 if let Some(e) = else_expr {
487 collect_columns(e, deps);
488 }
489 }
490 SqlExpr::Cast { expr, .. } => collect_columns(expr, deps),
491 SqlExpr::NullIf(a, b) => {
492 collect_columns(a, deps);
493 collect_columns(b, deps);
494 }
495 SqlExpr::IsNull { expr, .. } => collect_columns(expr, deps),
496 SqlExpr::Literal(_) | SqlExpr::OldColumn(_) | SqlExpr::ExcludedColumn(_) => {}
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use nodedb_types::Value;
504
505 fn parse_ok(text: &str) -> (SqlExpr, Vec<String>) {
506 parse_generated_expr(text).unwrap()
507 }
508
509 #[test]
510 fn simple_arithmetic() {
511 let (expr, deps) = parse_ok("price * (1 + tax_rate)");
512 assert_eq!(deps, vec!["price", "tax_rate"]);
513 let doc = Value::from(serde_json::json!({"price": 100.0, "tax_rate": 0.08}));
514 let result = expr.eval(&doc);
515 assert_eq!(result.as_f64(), Some(108.0));
516 }
517
518 #[test]
519 fn round_function() {
520 let (expr, deps) = parse_ok("ROUND(price * (1 + tax_rate), 2)");
521 assert_eq!(deps, vec!["price", "tax_rate"]);
522 let doc = Value::from(serde_json::json!({"price": 99.99, "tax_rate": 0.08}));
523 let result = expr.eval(&doc);
524 assert_eq!(result, Value::Float(107.99));
525 }
526
527 #[test]
528 fn concat_function() {
529 let (expr, deps) = parse_ok("CONCAT(name, ' ', brand)");
530 assert_eq!(deps, vec!["brand", "name"]);
531 let doc = Value::from(serde_json::json!({"name": "Shoe", "brand": "Nike"}));
532 assert_eq!(expr.eval(&doc), Value::String("Shoe Nike".into()));
533 }
534
535 #[test]
536 fn coalesce() {
537 let (expr, _) = parse_ok("COALESCE(description, '')");
538 let doc = Value::from(serde_json::json!({"description": null}));
539 assert_eq!(expr.eval(&doc), Value::String("".into()));
540 }
541
542 #[test]
543 fn case_when() {
544 let (expr, deps) =
545 parse_ok("CASE WHEN discount > 0 THEN price * (1 - discount) ELSE price END");
546 assert!(deps.contains(&"discount".to_string()));
547 assert!(deps.contains(&"price".to_string()));
548
549 let doc = Value::from(serde_json::json!({"price": 100.0, "discount": 0.2}));
550 assert_eq!(expr.eval(&doc).as_f64(), Some(80.0));
551
552 let doc2 = Value::from(serde_json::json!({"price": 100.0, "discount": 0}));
553 assert_eq!(expr.eval(&doc2).as_f64(), Some(100.0));
554 }
555
556 #[test]
557 fn rejects_now() {
558 assert!(parse_generated_expr("NOW()").is_err());
559 }
560
561 #[test]
562 fn rejects_random() {
563 assert!(parse_generated_expr("RANDOM()").is_err());
564 }
565
566 #[test]
567 fn rejects_uuid() {
568 assert!(parse_generated_expr("UUID()").is_err());
569 }
570
571 #[test]
572 fn string_literal() {
573 let (expr, _) = parse_ok("CONCAT(name, ' - ', 'default')");
574 let doc = Value::from(serde_json::json!({"name": "Product"}));
575 assert_eq!(expr.eval(&doc), Value::String("Product - default".into()));
576 }
577
578 #[test]
579 fn null_literal() {
580 let (expr, _) = parse_ok("COALESCE(x, NULL, 0)");
581 let doc = Value::from(serde_json::json!({"x": null}));
582 assert_eq!(expr.eval(&doc), Value::Integer(0));
583 }
584
585 #[test]
586 fn nested_functions() {
587 let (expr, _) = parse_ok("ROUND(price * (1 - COALESCE(discount, 0)), 2)");
588 let doc = Value::from(serde_json::json!({"price": 49.99}));
589 assert_eq!(expr.eval(&doc), Value::Float(49.99));
590 }
591
592 #[test]
593 fn deeply_nested_parentheses_return_error_not_stack_overflow() {
594 let depth = 10_000;
595 let input = format!("{}x{}", "(".repeat(depth), ")".repeat(depth));
596 let result = parse_generated_expr(&input);
597 assert!(
598 result.is_err(),
599 "parse_generated_expr must return Err for {depth}-deep nesting"
600 );
601 }
602
603 #[test]
604 fn cjk_string_in_concat() {
605 let (expr, _) = parse_ok("CONCAT('你好', name)");
606 let doc = Value::from(serde_json::json!({"name": "world"}));
607 assert_eq!(expr.eval(&doc), Value::String("你好world".into()));
608 }
609
610 #[test]
611 fn comparison_with_utf8_literal() {
612 let (expr, deps) = parse_ok("name != '禁止'");
613 assert_eq!(deps, vec!["name"]);
614 let doc = Value::from(serde_json::json!({"name": "allowed"}));
615 assert_eq!(expr.eval(&doc), Value::Bool(true));
616 }
617}