Skip to main content

nodedb_query/expr/
eval.rs

1//! Row-scope evaluator for [`SqlExpr`].
2//!
3//! `eval()` resolves column references against a single document. `eval_with_old()`
4//! resolves `Column(..)` against the post-update ("new") document and `OldColumn(..)`
5//! against the pre-update ("old") document — this is the path used by TRANSITION
6//! CHECK and similar old/new diff predicates.
7
8use nodedb_types::Value;
9
10use crate::value_ops::{coerced_eq, is_truthy, to_value_number, value_to_f64};
11
12use super::binary::eval_binary_op;
13use super::types::SqlExpr;
14
15/// Row scope for `SqlExpr::eval_scope`: how `Column(..)` and `OldColumn(..)`
16/// resolve to `Value`s. The shared evaluator walks the AST once and calls
17/// into this scope for every leaf column reference — both `eval()` and
18/// `eval_with_old()` delegate here instead of duplicating the walk.
19struct RowScope<'a> {
20    new_doc: &'a Value,
21    /// Pre-update row, if this is an old/new evaluation (TRANSITION CHECK).
22    /// `None` means `OldColumn(..)` resolves to `Null`, matching plain `eval`.
23    old_doc: Option<&'a Value>,
24    /// Incoming `EXCLUDED.*` row for
25    /// `INSERT ... ON CONFLICT DO UPDATE SET col = EXCLUDED.col`. `None`
26    /// means `ExcludedColumn(..)` resolves to `Null`, matching plain `eval`.
27    excluded_doc: Option<&'a Value>,
28}
29
30impl<'a> RowScope<'a> {
31    fn column(&self, name: &str) -> Value {
32        self.new_doc.get(name).cloned().unwrap_or(Value::Null)
33    }
34
35    fn old_column(&self, name: &str) -> Value {
36        match self.old_doc {
37            Some(old) => old.get(name).cloned().unwrap_or(Value::Null),
38            None => Value::Null,
39        }
40    }
41
42    fn excluded_column(&self, name: &str) -> Value {
43        match self.excluded_doc {
44            Some(excluded) => excluded.get(name).cloned().unwrap_or(Value::Null),
45            None => Value::Null,
46        }
47    }
48}
49
50impl SqlExpr {
51    /// Evaluate this expression against a document.
52    ///
53    /// Column references look up fields in the document. Missing fields
54    /// return `Null`. Arithmetic on non-numeric values returns `Null`.
55    /// `OldColumn(..)` resolves to `Null` (use `eval_with_old` for the
56    /// TRANSITION CHECK path).
57    pub fn eval(&self, doc: &Value) -> Value {
58        self.eval_scope(&RowScope {
59            new_doc: doc,
60            old_doc: None,
61            excluded_doc: None,
62        })
63    }
64
65    /// Evaluate with access to both NEW and OLD documents, used by
66    /// TRANSITION CHECK predicates. `Column(name)` resolves against
67    /// `new_doc`; `OldColumn(name)` resolves against `old_doc`.
68    pub fn eval_with_old(&self, new_doc: &Value, old_doc: &Value) -> Value {
69        self.eval_scope(&RowScope {
70            new_doc,
71            old_doc: Some(old_doc),
72            excluded_doc: None,
73        })
74    }
75
76    /// Evaluate with access to the incoming `EXCLUDED.*` row, used by
77    /// `INSERT ... ON CONFLICT DO UPDATE`. `Column(name)` resolves
78    /// against the existing (current) row `doc`; `ExcludedColumn(name)`
79    /// resolves against `excluded`.
80    pub fn eval_with_excluded(&self, doc: &Value, excluded: &Value) -> Value {
81        self.eval_scope(&RowScope {
82            new_doc: doc,
83            old_doc: None,
84            excluded_doc: Some(excluded),
85        })
86    }
87
88    /// Shared walker: one match, one recursion scheme, parameterised by the
89    /// row-scope so `eval` and `eval_with_old` can't drift out of sync.
90    fn eval_scope(&self, scope: &RowScope<'_>) -> Value {
91        match self {
92            SqlExpr::Column(name) => scope.column(name),
93            SqlExpr::OldColumn(name) => scope.old_column(name),
94            SqlExpr::ExcludedColumn(name) => scope.excluded_column(name),
95
96            SqlExpr::Literal(v) => v.clone(),
97
98            SqlExpr::BinaryOp { left, op, right } => {
99                let l = left.eval_scope(scope);
100                let r = right.eval_scope(scope);
101                eval_binary_op(&l, *op, &r)
102            }
103
104            SqlExpr::Negate(inner) => {
105                let v = inner.eval_scope(scope);
106                if let Some(b) = v.as_bool() {
107                    Value::Bool(!b)
108                } else {
109                    match value_to_f64(&v, false) {
110                        Some(n) => to_value_number(-n),
111                        None => Value::Null,
112                    }
113                }
114            }
115
116            SqlExpr::Function { name, args } => {
117                let evaluated: Vec<Value> = args.iter().map(|a| a.eval_scope(scope)).collect();
118                crate::functions::eval_function(name, &evaluated)
119            }
120
121            SqlExpr::Cast { expr, to_type } => {
122                let v = expr.eval_scope(scope);
123                crate::cast::eval_cast(&v, to_type)
124            }
125
126            SqlExpr::Case {
127                operand,
128                when_thens,
129                else_expr,
130            } => {
131                let op_val = operand.as_ref().map(|e| e.eval_scope(scope));
132                for (when_expr, then_expr) in when_thens {
133                    let when_val = when_expr.eval_scope(scope);
134                    let matches = match &op_val {
135                        Some(ov) => coerced_eq(ov, &when_val),
136                        None => is_truthy(&when_val),
137                    };
138                    if matches {
139                        return then_expr.eval_scope(scope);
140                    }
141                }
142                match else_expr {
143                    Some(e) => e.eval_scope(scope),
144                    None => Value::Null,
145                }
146            }
147
148            SqlExpr::Coalesce(exprs) => {
149                for expr in exprs {
150                    let v = expr.eval_scope(scope);
151                    if !v.is_null() {
152                        return v;
153                    }
154                }
155                Value::Null
156            }
157
158            SqlExpr::NullIf(a, b) => {
159                let va = a.eval_scope(scope);
160                let vb = b.eval_scope(scope);
161                if coerced_eq(&va, &vb) {
162                    Value::Null
163                } else {
164                    va
165                }
166            }
167
168            SqlExpr::IsNull { expr, negated } => {
169                let v = expr.eval_scope(scope);
170                let is_null = v.is_null();
171                Value::Bool(if *negated { !is_null } else { is_null })
172            }
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::super::types::BinaryOp;
180    use super::*;
181
182    fn doc() -> Value {
183        Value::Object(
184            [
185                ("name".to_string(), Value::String("Alice".into())),
186                ("age".to_string(), Value::Integer(30)),
187                ("price".to_string(), Value::Float(10.5)),
188                ("qty".to_string(), Value::Integer(4)),
189                ("active".to_string(), Value::Bool(true)),
190                ("email".to_string(), Value::Null),
191            ]
192            .into_iter()
193            .collect(),
194        )
195    }
196
197    #[test]
198    fn column_ref() {
199        let expr = SqlExpr::Column("name".into());
200        assert_eq!(expr.eval(&doc()), Value::String("Alice".into()));
201    }
202
203    #[test]
204    fn missing_column() {
205        let expr = SqlExpr::Column("missing".into());
206        assert_eq!(expr.eval(&doc()), Value::Null);
207    }
208
209    #[test]
210    fn literal() {
211        let expr = SqlExpr::Literal(Value::Integer(42));
212        assert_eq!(expr.eval(&doc()), Value::Integer(42));
213    }
214
215    #[test]
216    fn add() {
217        let expr = SqlExpr::BinaryOp {
218            left: Box::new(SqlExpr::Column("price".into())),
219            op: BinaryOp::Add,
220            right: Box::new(SqlExpr::Literal(Value::Float(1.5))),
221        };
222        assert_eq!(expr.eval(&doc()), Value::Integer(12));
223    }
224
225    #[test]
226    fn multiply() {
227        let expr = SqlExpr::BinaryOp {
228            left: Box::new(SqlExpr::Column("price".into())),
229            op: BinaryOp::Mul,
230            right: Box::new(SqlExpr::Column("qty".into())),
231        };
232        assert_eq!(expr.eval(&doc()), Value::Integer(42));
233    }
234
235    #[test]
236    fn case_when() {
237        let expr = SqlExpr::Case {
238            operand: None,
239            when_thens: vec![(
240                SqlExpr::BinaryOp {
241                    left: Box::new(SqlExpr::Column("age".into())),
242                    op: BinaryOp::GtEq,
243                    right: Box::new(SqlExpr::Literal(Value::Integer(18))),
244                },
245                SqlExpr::Literal(Value::String("adult".into())),
246            )],
247            else_expr: Some(Box::new(SqlExpr::Literal(Value::String("minor".into())))),
248        };
249        assert_eq!(expr.eval(&doc()), Value::String("adult".into()));
250    }
251
252    #[test]
253    fn coalesce() {
254        let expr = SqlExpr::Coalesce(vec![
255            SqlExpr::Column("email".into()),
256            SqlExpr::Literal(Value::String("default@example.com".into())),
257        ]);
258        assert_eq!(
259            expr.eval(&doc()),
260            Value::String("default@example.com".into())
261        );
262    }
263
264    #[test]
265    fn is_null() {
266        let expr = SqlExpr::IsNull {
267            expr: Box::new(SqlExpr::Column("email".into())),
268            negated: false,
269        };
270        assert_eq!(expr.eval(&doc()), Value::Bool(true));
271    }
272}