Skip to main content

nodedb_query/
expr.rs

1//! SqlExpr AST definition and core evaluation.
2
3use crate::value_ops::{
4    coerced_eq, compare_values, is_truthy, to_value_number, value_to_display_string, value_to_f64,
5};
6use nodedb_types::Value;
7
8/// A serializable SQL expression that can be evaluated against a document.
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub enum SqlExpr {
11    /// Column reference: extract field value from the document.
12    Column(String),
13    /// Literal value.
14    Literal(Value),
15    /// Binary operation: left op right.
16    BinaryOp {
17        left: Box<SqlExpr>,
18        op: BinaryOp,
19        right: Box<SqlExpr>,
20    },
21    /// Unary negation: -expr or NOT expr.
22    Negate(Box<SqlExpr>),
23    /// Scalar function call.
24    Function { name: String, args: Vec<SqlExpr> },
25    /// CAST(expr AS type).
26    Cast {
27        expr: Box<SqlExpr>,
28        to_type: CastType,
29    },
30    /// CASE WHEN cond1 THEN val1 ... ELSE default END.
31    Case {
32        operand: Option<Box<SqlExpr>>,
33        when_thens: Vec<(SqlExpr, SqlExpr)>,
34        else_expr: Option<Box<SqlExpr>>,
35    },
36    /// COALESCE(expr1, expr2, ...): first non-null value.
37    Coalesce(Vec<SqlExpr>),
38    /// NULLIF(expr1, expr2): returns NULL if expr1 = expr2, else expr1.
39    NullIf(Box<SqlExpr>, Box<SqlExpr>),
40    /// IS NULL / IS NOT NULL.
41    IsNull { expr: Box<SqlExpr>, negated: bool },
42    /// OLD column reference: extract field value from the pre-update document.
43    /// Used in TRANSITION CHECK predicates. Resolves against the OLD row
44    /// when evaluated via `eval_with_old()`. Returns NULL in normal `eval()`.
45    OldColumn(String),
46}
47
48/// Binary operators.
49#[derive(
50    Debug,
51    Clone,
52    Copy,
53    serde::Serialize,
54    serde::Deserialize,
55    zerompk::ToMessagePack,
56    zerompk::FromMessagePack,
57)]
58#[msgpack(c_enum)]
59pub enum BinaryOp {
60    Add,
61    Sub,
62    Mul,
63    Div,
64    Mod,
65    Eq,
66    NotEq,
67    Gt,
68    GtEq,
69    Lt,
70    LtEq,
71    And,
72    Or,
73    Concat,
74}
75
76/// Target types for CAST.
77#[derive(
78    Debug,
79    Clone,
80    serde::Serialize,
81    serde::Deserialize,
82    zerompk::ToMessagePack,
83    zerompk::FromMessagePack,
84)]
85#[msgpack(c_enum)]
86pub enum CastType {
87    Int,
88    Float,
89    String,
90    Bool,
91}
92
93/// A computed projection column: alias + expression.
94#[derive(
95    Debug,
96    Clone,
97    serde::Serialize,
98    serde::Deserialize,
99    zerompk::ToMessagePack,
100    zerompk::FromMessagePack,
101)]
102pub struct ComputedColumn {
103    pub alias: String,
104    pub expr: SqlExpr,
105}
106
107// ─── Manual zerompk impls for SqlExpr ────────────────────────────────────────
108//
109// SqlExpr contains `nodedb_types::Value` (in the Literal variant) which implements
110// `zerompk::ToMessagePack` and `zerompk::FromMessagePack` natively.
111//
112// Encoding format: each variant is an array `[tag_u8, field1, field2, ...]`.
113// Tags: Column=0, Literal=1, BinaryOp=2, Negate=3, Function=4, Cast=5,
114//       Case=6, Coalesce=7, NullIf=8, IsNull=9, OldColumn=10.
115
116impl zerompk::ToMessagePack for SqlExpr {
117    fn write<W: zerompk::Write>(&self, writer: &mut W) -> zerompk::Result<()> {
118        match self {
119            SqlExpr::Column(s) => {
120                writer.write_array_len(2)?;
121                writer.write_u8(0)?;
122                writer.write_string(s)
123            }
124            SqlExpr::Literal(v) => {
125                writer.write_array_len(2)?;
126                writer.write_u8(1)?;
127                v.write(writer)
128            }
129            SqlExpr::BinaryOp { left, op, right } => {
130                writer.write_array_len(4)?;
131                writer.write_u8(2)?;
132                left.write(writer)?;
133                op.write(writer)?;
134                right.write(writer)
135            }
136            SqlExpr::Negate(inner) => {
137                writer.write_array_len(2)?;
138                writer.write_u8(3)?;
139                inner.write(writer)
140            }
141            SqlExpr::Function { name, args } => {
142                writer.write_array_len(3)?;
143                writer.write_u8(4)?;
144                writer.write_string(name)?;
145                args.write(writer)
146            }
147            SqlExpr::Cast { expr, to_type } => {
148                writer.write_array_len(3)?;
149                writer.write_u8(5)?;
150                expr.write(writer)?;
151                to_type.write(writer)
152            }
153            SqlExpr::Case {
154                operand,
155                when_thens,
156                else_expr,
157            } => {
158                writer.write_array_len(4)?;
159                writer.write_u8(6)?;
160                operand.write(writer)?;
161                // Encode when_thens as array of 2-element arrays.
162                writer.write_array_len(when_thens.len())?;
163                for (cond, val) in when_thens {
164                    writer.write_array_len(2)?;
165                    cond.write(writer)?;
166                    val.write(writer)?;
167                }
168                else_expr.write(writer)
169            }
170            SqlExpr::Coalesce(exprs) => {
171                writer.write_array_len(2)?;
172                writer.write_u8(7)?;
173                exprs.write(writer)
174            }
175            SqlExpr::NullIf(e1, e2) => {
176                writer.write_array_len(3)?;
177                writer.write_u8(8)?;
178                e1.write(writer)?;
179                e2.write(writer)
180            }
181            SqlExpr::IsNull { expr, negated } => {
182                writer.write_array_len(3)?;
183                writer.write_u8(9)?;
184                expr.write(writer)?;
185                writer.write_boolean(*negated)
186            }
187            SqlExpr::OldColumn(s) => {
188                writer.write_array_len(2)?;
189                writer.write_u8(10)?;
190                writer.write_string(s)
191            }
192        }
193    }
194}
195
196impl<'a> zerompk::FromMessagePack<'a> for SqlExpr {
197    fn read<R: zerompk::Read<'a>>(reader: &mut R) -> zerompk::Result<Self> {
198        let len = reader.read_array_len()?;
199        if len == 0 {
200            return Err(zerompk::Error::ArrayLengthMismatch {
201                expected: 1,
202                actual: 0,
203            });
204        }
205        let tag = reader.read_u8()?;
206        match tag {
207            0 => {
208                // Column(String)
209                Ok(SqlExpr::Column(reader.read_string()?.into_owned()))
210            }
211            1 => {
212                // Literal(Value)
213                let v = Value::read(reader)?;
214                Ok(SqlExpr::Literal(v))
215            }
216            2 => {
217                // BinaryOp { left, op, right }
218                let left = SqlExpr::read(reader)?;
219                let op = BinaryOp::read(reader)?;
220                let right = SqlExpr::read(reader)?;
221                Ok(SqlExpr::BinaryOp {
222                    left: Box::new(left),
223                    op,
224                    right: Box::new(right),
225                })
226            }
227            3 => {
228                // Negate(Box<SqlExpr>)
229                let inner = SqlExpr::read(reader)?;
230                Ok(SqlExpr::Negate(Box::new(inner)))
231            }
232            4 => {
233                // Function { name, args }
234                let name = reader.read_string()?.into_owned();
235                let args = Vec::<SqlExpr>::read(reader)?;
236                Ok(SqlExpr::Function { name, args })
237            }
238            5 => {
239                // Cast { expr, to_type }
240                let expr = SqlExpr::read(reader)?;
241                let to_type = CastType::read(reader)?;
242                Ok(SqlExpr::Cast {
243                    expr: Box::new(expr),
244                    to_type,
245                })
246            }
247            6 => {
248                // Case { operand, when_thens, else_expr }
249                let operand = Option::<Box<SqlExpr>>::read(reader)?;
250                let wt_len = reader.read_array_len()?;
251                let mut when_thens = Vec::with_capacity(wt_len);
252                for _ in 0..wt_len {
253                    let pair_len = reader.read_array_len()?;
254                    if pair_len != 2 {
255                        return Err(zerompk::Error::ArrayLengthMismatch {
256                            expected: 2,
257                            actual: pair_len,
258                        });
259                    }
260                    let cond = SqlExpr::read(reader)?;
261                    let val = SqlExpr::read(reader)?;
262                    when_thens.push((cond, val));
263                }
264                let else_expr = Option::<Box<SqlExpr>>::read(reader)?;
265                Ok(SqlExpr::Case {
266                    operand,
267                    when_thens,
268                    else_expr,
269                })
270            }
271            7 => {
272                // Coalesce(Vec<SqlExpr>)
273                let exprs = Vec::<SqlExpr>::read(reader)?;
274                Ok(SqlExpr::Coalesce(exprs))
275            }
276            8 => {
277                // NullIf(Box<SqlExpr>, Box<SqlExpr>)
278                let e1 = SqlExpr::read(reader)?;
279                let e2 = SqlExpr::read(reader)?;
280                Ok(SqlExpr::NullIf(Box::new(e1), Box::new(e2)))
281            }
282            9 => {
283                // IsNull { expr, negated }
284                let expr = SqlExpr::read(reader)?;
285                let negated = reader.read_boolean()?;
286                Ok(SqlExpr::IsNull {
287                    expr: Box::new(expr),
288                    negated,
289                })
290            }
291            10 => {
292                // OldColumn(String)
293                Ok(SqlExpr::OldColumn(reader.read_string()?.into_owned()))
294            }
295            _ => Err(zerompk::Error::InvalidMarker(tag)),
296        }
297    }
298}
299
300impl SqlExpr {
301    /// Evaluate this expression against a document.
302    ///
303    /// Returns a `Value`. Column references look up fields in the document.
304    /// Missing fields return `Null`. Arithmetic on non-numeric values returns `Null`.
305    pub fn eval(&self, doc: &Value) -> Value {
306        match self {
307            SqlExpr::Column(name) => doc.get(name).cloned().unwrap_or(Value::Null),
308
309            SqlExpr::Literal(v) => v.clone(),
310
311            SqlExpr::BinaryOp { left, op, right } => {
312                let l = left.eval(doc);
313                let r = right.eval(doc);
314                eval_binary_op(&l, *op, &r)
315            }
316
317            SqlExpr::Negate(inner) => {
318                let v = inner.eval(doc);
319                if let Some(b) = v.as_bool() {
320                    Value::Bool(!b)
321                } else {
322                    match value_to_f64(&v, false) {
323                        Some(n) => to_value_number(-n),
324                        None => Value::Null,
325                    }
326                }
327            }
328
329            SqlExpr::Function { name, args } => {
330                let evaluated: Vec<Value> = args.iter().map(|a| a.eval(doc)).collect();
331                crate::functions::eval_function(name, &evaluated)
332            }
333
334            SqlExpr::Cast { expr, to_type } => {
335                let v = expr.eval(doc);
336                crate::cast::eval_cast(&v, to_type)
337            }
338
339            SqlExpr::Case {
340                operand,
341                when_thens,
342                else_expr,
343            } => {
344                let op_val = operand.as_ref().map(|e| e.eval(doc));
345                for (when_expr, then_expr) in when_thens {
346                    let when_val = when_expr.eval(doc);
347                    let matches = match &op_val {
348                        Some(ov) => coerced_eq(ov, &when_val),
349                        None => is_truthy(&when_val),
350                    };
351                    if matches {
352                        return then_expr.eval(doc);
353                    }
354                }
355                match else_expr {
356                    Some(e) => e.eval(doc),
357                    None => Value::Null,
358                }
359            }
360
361            SqlExpr::Coalesce(exprs) => {
362                for expr in exprs {
363                    let v = expr.eval(doc);
364                    if !v.is_null() {
365                        return v;
366                    }
367                }
368                Value::Null
369            }
370
371            SqlExpr::NullIf(a, b) => {
372                let va = a.eval(doc);
373                let vb = b.eval(doc);
374                if coerced_eq(&va, &vb) {
375                    Value::Null
376                } else {
377                    va
378                }
379            }
380
381            SqlExpr::IsNull { expr, negated } => {
382                let v = expr.eval(doc);
383                let is_null = v.is_null();
384                Value::Bool(if *negated { !is_null } else { is_null })
385            }
386
387            SqlExpr::OldColumn(_) => Value::Null,
388        }
389    }
390
391    /// Evaluate with access to both NEW and OLD documents (for TRANSITION CHECK).
392    ///
393    /// `Column(name)` resolves against `new_doc`.
394    /// `OldColumn(name)` resolves against `old_doc`.
395    pub fn eval_with_old(&self, new_doc: &Value, old_doc: &Value) -> Value {
396        match self {
397            SqlExpr::Column(name) => new_doc.get(name).cloned().unwrap_or(Value::Null),
398            SqlExpr::OldColumn(name) => old_doc.get(name).cloned().unwrap_or(Value::Null),
399            SqlExpr::Literal(v) => v.clone(),
400            SqlExpr::BinaryOp { left, op, right } => {
401                let l = left.eval_with_old(new_doc, old_doc);
402                let r = right.eval_with_old(new_doc, old_doc);
403                eval_binary_op(&l, *op, &r)
404            }
405            SqlExpr::Negate(inner) => {
406                let v = inner.eval_with_old(new_doc, old_doc);
407                if let Some(b) = v.as_bool() {
408                    Value::Bool(!b)
409                } else {
410                    match value_to_f64(&v, false) {
411                        Some(n) => to_value_number(-n),
412                        None => Value::Null,
413                    }
414                }
415            }
416            SqlExpr::Function { name, args } => {
417                let evaluated: Vec<Value> = args
418                    .iter()
419                    .map(|a| a.eval_with_old(new_doc, old_doc))
420                    .collect();
421                crate::functions::eval_function(name, &evaluated)
422            }
423            SqlExpr::Cast { expr, to_type } => {
424                let v = expr.eval_with_old(new_doc, old_doc);
425                crate::cast::eval_cast(&v, to_type)
426            }
427            SqlExpr::Case {
428                operand,
429                when_thens,
430                else_expr,
431            } => {
432                let op_val = operand.as_ref().map(|e| e.eval_with_old(new_doc, old_doc));
433                for (when_expr, then_expr) in when_thens {
434                    let when_val = when_expr.eval_with_old(new_doc, old_doc);
435                    let matches = match &op_val {
436                        Some(ov) => coerced_eq(ov, &when_val),
437                        None => is_truthy(&when_val),
438                    };
439                    if matches {
440                        return then_expr.eval_with_old(new_doc, old_doc);
441                    }
442                }
443                match else_expr {
444                    Some(e) => e.eval_with_old(new_doc, old_doc),
445                    None => Value::Null,
446                }
447            }
448            SqlExpr::Coalesce(exprs) => {
449                for expr in exprs {
450                    let v = expr.eval_with_old(new_doc, old_doc);
451                    if !v.is_null() {
452                        return v;
453                    }
454                }
455                Value::Null
456            }
457            SqlExpr::NullIf(a, b) => {
458                let va = a.eval_with_old(new_doc, old_doc);
459                let vb = b.eval_with_old(new_doc, old_doc);
460                if coerced_eq(&va, &vb) {
461                    Value::Null
462                } else {
463                    va
464                }
465            }
466            SqlExpr::IsNull { expr, negated } => {
467                let v = expr.eval_with_old(new_doc, old_doc);
468                let is_null = v.is_null();
469                Value::Bool(if *negated { !is_null } else { is_null })
470            }
471        }
472    }
473}
474
475fn eval_binary_op(left: &Value, op: BinaryOp, right: &Value) -> Value {
476    match op {
477        BinaryOp::Add => match (value_to_f64(left, true), value_to_f64(right, true)) {
478            (Some(a), Some(b)) => to_value_number(a + b),
479            _ => Value::Null,
480        },
481        BinaryOp::Sub => match (value_to_f64(left, true), value_to_f64(right, true)) {
482            (Some(a), Some(b)) => to_value_number(a - b),
483            _ => Value::Null,
484        },
485        BinaryOp::Mul => match (value_to_f64(left, true), value_to_f64(right, true)) {
486            (Some(a), Some(b)) => to_value_number(a * b),
487            _ => Value::Null,
488        },
489        BinaryOp::Div => match (value_to_f64(left, true), value_to_f64(right, true)) {
490            (Some(a), Some(b)) => {
491                if b == 0.0 {
492                    Value::Null
493                } else {
494                    to_value_number(a / b)
495                }
496            }
497            _ => Value::Null,
498        },
499        BinaryOp::Mod => match (value_to_f64(left, true), value_to_f64(right, true)) {
500            (Some(a), Some(b)) => {
501                if b == 0.0 {
502                    Value::Null
503                } else {
504                    to_value_number(a % b)
505                }
506            }
507            _ => Value::Null,
508        },
509        BinaryOp::Concat => {
510            let ls = value_to_display_string(left);
511            let rs = value_to_display_string(right);
512            Value::String(format!("{ls}{rs}"))
513        }
514        BinaryOp::Eq => Value::Bool(coerced_eq(left, right)),
515        BinaryOp::NotEq => Value::Bool(!coerced_eq(left, right)),
516        BinaryOp::Gt => Value::Bool(compare_values(left, right) == std::cmp::Ordering::Greater),
517        BinaryOp::GtEq => {
518            let c = compare_values(left, right);
519            Value::Bool(c == std::cmp::Ordering::Greater || c == std::cmp::Ordering::Equal)
520        }
521        BinaryOp::Lt => Value::Bool(compare_values(left, right) == std::cmp::Ordering::Less),
522        BinaryOp::LtEq => {
523            let c = compare_values(left, right);
524            Value::Bool(c == std::cmp::Ordering::Less || c == std::cmp::Ordering::Equal)
525        }
526        BinaryOp::And => Value::Bool(is_truthy(left) && is_truthy(right)),
527        BinaryOp::Or => Value::Bool(is_truthy(left) || is_truthy(right)),
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    fn doc() -> Value {
536        Value::Object(
537            [
538                ("name".to_string(), Value::String("Alice".into())),
539                ("age".to_string(), Value::Integer(30)),
540                ("price".to_string(), Value::Float(10.5)),
541                ("qty".to_string(), Value::Integer(4)),
542                ("active".to_string(), Value::Bool(true)),
543                ("email".to_string(), Value::Null),
544            ]
545            .into_iter()
546            .collect(),
547        )
548    }
549
550    #[test]
551    fn column_ref() {
552        let expr = SqlExpr::Column("name".into());
553        assert_eq!(expr.eval(&doc()), Value::String("Alice".into()));
554    }
555
556    #[test]
557    fn missing_column() {
558        let expr = SqlExpr::Column("missing".into());
559        assert_eq!(expr.eval(&doc()), Value::Null);
560    }
561
562    #[test]
563    fn literal() {
564        let expr = SqlExpr::Literal(Value::Integer(42));
565        assert_eq!(expr.eval(&doc()), Value::Integer(42));
566    }
567
568    #[test]
569    fn add() {
570        let expr = SqlExpr::BinaryOp {
571            left: Box::new(SqlExpr::Column("price".into())),
572            op: BinaryOp::Add,
573            right: Box::new(SqlExpr::Literal(Value::Float(1.5))),
574        };
575        assert_eq!(expr.eval(&doc()), Value::Integer(12));
576    }
577
578    #[test]
579    fn multiply() {
580        let expr = SqlExpr::BinaryOp {
581            left: Box::new(SqlExpr::Column("price".into())),
582            op: BinaryOp::Mul,
583            right: Box::new(SqlExpr::Column("qty".into())),
584        };
585        assert_eq!(expr.eval(&doc()), Value::Integer(42));
586    }
587
588    #[test]
589    fn case_when() {
590        let expr = SqlExpr::Case {
591            operand: None,
592            when_thens: vec![(
593                SqlExpr::BinaryOp {
594                    left: Box::new(SqlExpr::Column("age".into())),
595                    op: BinaryOp::GtEq,
596                    right: Box::new(SqlExpr::Literal(Value::Integer(18))),
597                },
598                SqlExpr::Literal(Value::String("adult".into())),
599            )],
600            else_expr: Some(Box::new(SqlExpr::Literal(Value::String("minor".into())))),
601        };
602        assert_eq!(expr.eval(&doc()), Value::String("adult".into()));
603    }
604
605    #[test]
606    fn coalesce() {
607        let expr = SqlExpr::Coalesce(vec![
608            SqlExpr::Column("email".into()),
609            SqlExpr::Literal(Value::String("default@example.com".into())),
610        ]);
611        assert_eq!(
612            expr.eval(&doc()),
613            Value::String("default@example.com".into())
614        );
615    }
616
617    #[test]
618    fn is_null() {
619        let expr = SqlExpr::IsNull {
620            expr: Box::new(SqlExpr::Column("email".into())),
621            negated: false,
622        };
623        assert_eq!(expr.eval(&doc()), Value::Bool(true));
624    }
625}