1use super::expr::{BinaryOp, SqlExpr};
20use nodedb_types::Value;
21
22pub fn parse_generated_expr(text: &str) -> Result<(SqlExpr, Vec<String>), String> {
27 let tokens = tokenize(text)?;
28 let mut pos = 0;
29 let expr = parse_expr(&tokens, &mut pos)?;
30 if pos < tokens.len() {
31 return Err(format!(
32 "unexpected token after expression: '{}'",
33 tokens[pos].text
34 ));
35 }
36
37 validate_deterministic(&expr)?;
39
40 let mut deps = Vec::new();
42 collect_columns(&expr, &mut deps);
43 deps.sort();
44 deps.dedup();
45
46 Ok((expr, deps))
47}
48
49#[derive(Debug, Clone)]
52struct Token {
53 text: String,
54 kind: TokenKind,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq)]
58enum TokenKind {
59 Ident,
60 Number,
61 StringLit,
62 LParen,
63 RParen,
64 Comma,
65 Op,
66}
67
68fn tokenize(input: &str) -> Result<Vec<Token>, String> {
69 let bytes = input.as_bytes();
70 let mut tokens = Vec::new();
71 let mut i = 0;
72
73 while i < bytes.len() {
74 let b = bytes[i];
75
76 if b.is_ascii_whitespace() {
78 i += 1;
79 continue;
80 }
81
82 if b == b'(' {
84 tokens.push(Token {
85 text: "(".into(),
86 kind: TokenKind::LParen,
87 });
88 i += 1;
89 continue;
90 }
91 if b == b')' {
92 tokens.push(Token {
93 text: ")".into(),
94 kind: TokenKind::RParen,
95 });
96 i += 1;
97 continue;
98 }
99 if b == b',' {
100 tokens.push(Token {
101 text: ",".into(),
102 kind: TokenKind::Comma,
103 });
104 i += 1;
105 continue;
106 }
107
108 if i + 1 < bytes.len() {
110 let two = &input[i..i + 2];
111 if matches!(two, "<=" | ">=" | "!=" | "<>") {
112 tokens.push(Token {
113 text: two.into(),
114 kind: TokenKind::Op,
115 });
116 i += 2;
117 continue;
118 }
119 if two == "||" {
120 tokens.push(Token {
121 text: "||".into(),
122 kind: TokenKind::Op,
123 });
124 i += 2;
125 continue;
126 }
127 }
128
129 if matches!(b, b'+' | b'-' | b'*' | b'/' | b'%' | b'=' | b'<' | b'>') {
131 tokens.push(Token {
132 text: (b as char).to_string(),
133 kind: TokenKind::Op,
134 });
135 i += 1;
136 continue;
137 }
138
139 if b == b'\'' {
141 let mut s = String::new();
142 i += 1;
143 while i < bytes.len() {
144 if bytes[i] == b'\'' {
145 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
146 s.push('\'');
147 i += 2;
148 continue;
149 }
150 i += 1;
151 break;
152 }
153 s.push(bytes[i] as char);
154 i += 1;
155 }
156 tokens.push(Token {
157 text: s,
158 kind: TokenKind::StringLit,
159 });
160 continue;
161 }
162
163 if b.is_ascii_digit() || (b == b'.' && i + 1 < bytes.len() && bytes[i + 1].is_ascii_digit())
165 {
166 let start = i;
167 while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'.') {
168 i += 1;
169 }
170 tokens.push(Token {
171 text: input[start..i].to_string(),
172 kind: TokenKind::Number,
173 });
174 continue;
175 }
176
177 if b.is_ascii_alphabetic() || b == b'_' {
179 let start = i;
180 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
181 i += 1;
182 }
183 tokens.push(Token {
184 text: input[start..i].to_string(),
185 kind: TokenKind::Ident,
186 });
187 continue;
188 }
189
190 return Err(format!("unexpected character: '{}'", b as char));
191 }
192
193 Ok(tokens)
194}
195
196fn parse_expr(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
200 parse_or(tokens, pos)
201}
202
203fn parse_or(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
204 let mut left = parse_and(tokens, pos)?;
205 while peek_keyword(tokens, *pos, "OR") {
206 *pos += 1;
207 let right = parse_and(tokens, pos)?;
208 left = SqlExpr::BinaryOp {
209 left: Box::new(left),
210 op: BinaryOp::Or,
211 right: Box::new(right),
212 };
213 }
214 Ok(left)
215}
216
217fn parse_and(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
218 let mut left = parse_comparison(tokens, pos)?;
219 while peek_keyword(tokens, *pos, "AND") {
220 *pos += 1;
221 let right = parse_comparison(tokens, pos)?;
222 left = SqlExpr::BinaryOp {
223 left: Box::new(left),
224 op: BinaryOp::And,
225 right: Box::new(right),
226 };
227 }
228 Ok(left)
229}
230
231fn parse_comparison(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
232 let left = parse_additive(tokens, pos)?;
233 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
234 let op = match tokens[*pos].text.as_str() {
235 "=" => BinaryOp::Eq,
236 "!=" | "<>" => BinaryOp::NotEq,
237 "<" => BinaryOp::Lt,
238 "<=" => BinaryOp::LtEq,
239 ">" => BinaryOp::Gt,
240 ">=" => BinaryOp::GtEq,
241 _ => return Ok(left),
242 };
243 *pos += 1;
244 let right = parse_additive(tokens, pos)?;
245 return Ok(SqlExpr::BinaryOp {
246 left: Box::new(left),
247 op,
248 right: Box::new(right),
249 });
250 }
251 Ok(left)
252}
253
254fn parse_additive(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
255 let mut left = parse_multiplicative(tokens, pos)?;
256 while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
257 let op = match tokens[*pos].text.as_str() {
258 "+" => BinaryOp::Add,
259 "-" => BinaryOp::Sub,
260 "||" => BinaryOp::Concat,
261 _ => break,
262 };
263 *pos += 1;
264 let right = parse_multiplicative(tokens, pos)?;
265 left = SqlExpr::BinaryOp {
266 left: Box::new(left),
267 op,
268 right: Box::new(right),
269 };
270 }
271 Ok(left)
272}
273
274fn parse_multiplicative(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
275 let mut left = parse_unary(tokens, pos)?;
276 while *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op {
277 let op = match tokens[*pos].text.as_str() {
278 "*" => BinaryOp::Mul,
279 "/" => BinaryOp::Div,
280 "%" => BinaryOp::Mod,
281 _ => break,
282 };
283 *pos += 1;
284 let right = parse_unary(tokens, pos)?;
285 left = SqlExpr::BinaryOp {
286 left: Box::new(left),
287 op,
288 right: Box::new(right),
289 };
290 }
291 Ok(left)
292}
293
294fn parse_unary(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
295 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Op && tokens[*pos].text == "-" {
297 *pos += 1;
298 let expr = parse_primary(tokens, pos)?;
299 return Ok(SqlExpr::Negate(Box::new(expr)));
300 }
301 if peek_keyword(tokens, *pos, "NOT") {
303 *pos += 1;
304 let expr = parse_primary(tokens, pos)?;
305 return Ok(SqlExpr::Negate(Box::new(expr)));
306 }
307 parse_primary(tokens, pos)
308}
309
310fn parse_primary(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
311 if *pos >= tokens.len() {
312 return Err("unexpected end of expression".into());
313 }
314
315 let token = &tokens[*pos];
316
317 match token.kind {
318 TokenKind::LParen => {
320 *pos += 1;
321 let expr = parse_expr(tokens, pos)?;
322 expect_token(tokens, pos, TokenKind::RParen, ")")?;
323 Ok(expr)
324 }
325
326 TokenKind::Number => {
328 *pos += 1;
329 if let Ok(i) = token.text.parse::<i64>() {
330 Ok(SqlExpr::Literal(Value::Integer(i)))
331 } else if let Ok(f) = token.text.parse::<f64>() {
332 Ok(SqlExpr::Literal(Value::Float(f)))
333 } else {
334 Err(format!("invalid number: '{}'", token.text))
335 }
336 }
337
338 TokenKind::StringLit => {
340 *pos += 1;
341 Ok(SqlExpr::Literal(Value::String(token.text.clone())))
342 }
343
344 TokenKind::Ident => {
346 let name = token.text.clone();
347 let upper = name.to_uppercase();
348 *pos += 1;
349
350 match upper.as_str() {
351 "NULL" => Ok(SqlExpr::Literal(Value::Null)),
352 "TRUE" => Ok(SqlExpr::Literal(Value::Bool(true))),
353 "FALSE" => Ok(SqlExpr::Literal(Value::Bool(false))),
354 "CASE" => parse_case(tokens, pos),
355 "COALESCE" => {
356 let args = parse_arg_list(tokens, pos)?;
357 Ok(SqlExpr::Coalesce(args))
358 }
359 _ => {
360 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::LParen {
362 let args = parse_arg_list(tokens, pos)?;
363 Ok(SqlExpr::Function {
364 name: name.to_lowercase(),
365 args,
366 })
367 } else {
368 Ok(SqlExpr::Column(name.to_lowercase()))
370 }
371 }
372 }
373 }
374
375 _ => Err(format!("unexpected token: '{}'", token.text)),
376 }
377}
378
379fn parse_case(tokens: &[Token], pos: &mut usize) -> Result<SqlExpr, String> {
381 let mut when_thens = Vec::new();
382 let mut else_expr = None;
383
384 loop {
385 if peek_keyword(tokens, *pos, "WHEN") {
386 *pos += 1;
387 let cond = parse_expr(tokens, pos)?;
388 expect_keyword(tokens, pos, "THEN")?;
389 let then = parse_expr(tokens, pos)?;
390 when_thens.push((cond, then));
391 } else if peek_keyword(tokens, *pos, "ELSE") {
392 *pos += 1;
393 else_expr = Some(Box::new(parse_expr(tokens, pos)?));
394 } else if peek_keyword(tokens, *pos, "END") {
395 *pos += 1;
396 break;
397 } else {
398 return Err("expected WHEN, ELSE, or END in CASE expression".into());
399 }
400 }
401
402 if when_thens.is_empty() {
403 return Err("CASE requires at least one WHEN clause".into());
404 }
405
406 Ok(SqlExpr::Case {
407 operand: None,
408 when_thens,
409 else_expr,
410 })
411}
412
413fn parse_arg_list(tokens: &[Token], pos: &mut usize) -> Result<Vec<SqlExpr>, String> {
415 expect_token(tokens, pos, TokenKind::LParen, "(")?;
416 let mut args = Vec::new();
417 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::RParen {
418 *pos += 1;
419 return Ok(args);
420 }
421 loop {
422 args.push(parse_expr(tokens, pos)?);
423 if *pos < tokens.len() && tokens[*pos].kind == TokenKind::Comma {
424 *pos += 1;
425 } else {
426 break;
427 }
428 }
429 expect_token(tokens, pos, TokenKind::RParen, ")")?;
430 Ok(args)
431}
432
433fn peek_keyword(tokens: &[Token], pos: usize, keyword: &str) -> bool {
436 pos < tokens.len()
437 && tokens[pos].kind == TokenKind::Ident
438 && tokens[pos].text.eq_ignore_ascii_case(keyword)
439}
440
441fn expect_keyword(tokens: &[Token], pos: &mut usize, keyword: &str) -> Result<(), String> {
442 if peek_keyword(tokens, *pos, keyword) {
443 *pos += 1;
444 Ok(())
445 } else {
446 let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
447 Err(format!("expected '{keyword}', got '{got}'"))
448 }
449}
450
451fn expect_token(
452 tokens: &[Token],
453 pos: &mut usize,
454 kind: TokenKind,
455 expected: &str,
456) -> Result<(), String> {
457 if *pos < tokens.len() && tokens[*pos].kind == kind {
458 *pos += 1;
459 Ok(())
460 } else {
461 let got = tokens.get(*pos).map_or("EOF", |t| &t.text);
462 Err(format!("expected '{expected}', got '{got}'"))
463 }
464}
465
466const NON_DETERMINISTIC: &[&str] = &[
470 "now",
471 "current_timestamp",
472 "random",
473 "nextval",
474 "uuid",
475 "uuid_v4",
476 "uuid_v7",
477 "gen_random_uuid",
478 "ulid",
479 "cuid2",
480 "nanoid",
481];
482
483fn validate_deterministic(expr: &SqlExpr) -> Result<(), String> {
484 match expr {
485 SqlExpr::Function { name, args } => {
486 if NON_DETERMINISTIC.contains(&name.as_str()) {
487 return Err(format!(
488 "non-deterministic function '{name}()' not allowed in GENERATED ALWAYS AS"
489 ));
490 }
491 for arg in args {
492 validate_deterministic(arg)?;
493 }
494 Ok(())
495 }
496 SqlExpr::BinaryOp { left, right, .. } => {
497 validate_deterministic(left)?;
498 validate_deterministic(right)
499 }
500 SqlExpr::Negate(inner) => validate_deterministic(inner),
501 SqlExpr::Coalesce(args) => {
502 for arg in args {
503 validate_deterministic(arg)?;
504 }
505 Ok(())
506 }
507 SqlExpr::Case {
508 operand,
509 when_thens,
510 else_expr,
511 } => {
512 if let Some(op) = operand {
513 validate_deterministic(op)?;
514 }
515 for (cond, then) in when_thens {
516 validate_deterministic(cond)?;
517 validate_deterministic(then)?;
518 }
519 if let Some(e) = else_expr {
520 validate_deterministic(e)?;
521 }
522 Ok(())
523 }
524 SqlExpr::Cast { expr, .. } => validate_deterministic(expr),
525 SqlExpr::NullIf(a, b) => {
526 validate_deterministic(a)?;
527 validate_deterministic(b)
528 }
529 SqlExpr::IsNull { expr, .. } => validate_deterministic(expr),
530 SqlExpr::Column(_) | SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => Ok(()),
531 }
532}
533
534fn collect_columns(expr: &SqlExpr, deps: &mut Vec<String>) {
535 match expr {
536 SqlExpr::Column(name) => deps.push(name.clone()),
537 SqlExpr::BinaryOp { left, right, .. } => {
538 collect_columns(left, deps);
539 collect_columns(right, deps);
540 }
541 SqlExpr::Negate(inner) => collect_columns(inner, deps),
542 SqlExpr::Function { args, .. } => {
543 for arg in args {
544 collect_columns(arg, deps);
545 }
546 }
547 SqlExpr::Coalesce(args) => {
548 for arg in args {
549 collect_columns(arg, deps);
550 }
551 }
552 SqlExpr::Case {
553 operand,
554 when_thens,
555 else_expr,
556 } => {
557 if let Some(op) = operand {
558 collect_columns(op, deps);
559 }
560 for (cond, then) in when_thens {
561 collect_columns(cond, deps);
562 collect_columns(then, deps);
563 }
564 if let Some(e) = else_expr {
565 collect_columns(e, deps);
566 }
567 }
568 SqlExpr::Cast { expr, .. } => collect_columns(expr, deps),
569 SqlExpr::NullIf(a, b) => {
570 collect_columns(a, deps);
571 collect_columns(b, deps);
572 }
573 SqlExpr::IsNull { expr, .. } => collect_columns(expr, deps),
574 SqlExpr::Literal(_) | SqlExpr::OldColumn(_) => {}
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use nodedb_types::Value;
582
583 fn parse_ok(text: &str) -> (SqlExpr, Vec<String>) {
584 parse_generated_expr(text).unwrap()
585 }
586
587 #[test]
588 fn simple_arithmetic() {
589 let (expr, deps) = parse_ok("price * (1 + tax_rate)");
590 assert_eq!(deps, vec!["price", "tax_rate"]);
591 let doc = Value::from(serde_json::json!({"price": 100.0, "tax_rate": 0.08}));
592 let result = expr.eval(&doc);
593 assert_eq!(result.as_f64(), Some(108.0));
595 }
596
597 #[test]
598 fn round_function() {
599 let (expr, deps) = parse_ok("ROUND(price * (1 + tax_rate), 2)");
600 assert_eq!(deps, vec!["price", "tax_rate"]);
601 let doc = Value::from(serde_json::json!({"price": 99.99, "tax_rate": 0.08}));
602 let result = expr.eval(&doc);
603 assert_eq!(result, Value::Float(107.99));
604 }
605
606 #[test]
607 fn concat_function() {
608 let (expr, deps) = parse_ok("CONCAT(name, ' ', brand)");
609 assert_eq!(deps, vec!["brand", "name"]);
610 let doc = Value::from(serde_json::json!({"name": "Shoe", "brand": "Nike"}));
611 assert_eq!(expr.eval(&doc), Value::String("Shoe Nike".into()));
612 }
613
614 #[test]
615 fn coalesce() {
616 let (expr, _) = parse_ok("COALESCE(description, '')");
617 let doc = Value::from(serde_json::json!({"description": null}));
618 assert_eq!(expr.eval(&doc), Value::String("".into()));
619 }
620
621 #[test]
622 fn case_when() {
623 let (expr, deps) =
624 parse_ok("CASE WHEN discount > 0 THEN price * (1 - discount) ELSE price END");
625 assert!(deps.contains(&"discount".to_string()));
626 assert!(deps.contains(&"price".to_string()));
627
628 let doc = Value::from(serde_json::json!({"price": 100.0, "discount": 0.2}));
629 assert_eq!(expr.eval(&doc).as_f64(), Some(80.0));
630
631 let doc2 = Value::from(serde_json::json!({"price": 100.0, "discount": 0}));
632 assert_eq!(expr.eval(&doc2).as_f64(), Some(100.0));
633 }
634
635 #[test]
636 fn rejects_now() {
637 assert!(parse_generated_expr("NOW()").is_err());
638 }
639
640 #[test]
641 fn rejects_random() {
642 assert!(parse_generated_expr("RANDOM()").is_err());
643 }
644
645 #[test]
646 fn rejects_uuid() {
647 assert!(parse_generated_expr("UUID()").is_err());
648 }
649
650 #[test]
651 fn string_literal() {
652 let (expr, _) = parse_ok("CONCAT(name, ' - ', 'default')");
653 let doc = Value::from(serde_json::json!({"name": "Product"}));
654 assert_eq!(expr.eval(&doc), Value::String("Product - default".into()));
655 }
656
657 #[test]
658 fn null_literal() {
659 let (expr, _) = parse_ok("COALESCE(x, NULL, 0)");
660 let doc = Value::from(serde_json::json!({"x": null}));
661 assert_eq!(expr.eval(&doc), Value::Integer(0));
662 }
663
664 #[test]
665 fn nested_functions() {
666 let (expr, _) = parse_ok("ROUND(price * (1 - COALESCE(discount, 0)), 2)");
667 let doc = Value::from(serde_json::json!({"price": 49.99}));
668 assert_eq!(expr.eval(&doc), Value::Float(49.99));
669 }
670}