Skip to main content

nodedb_query/expr/
eval.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Row-scope evaluator for [`SqlExpr`].
4//!
5//! `eval()` resolves column references against a single document. `eval_with_old()`
6//! resolves `Column(..)` against the post-update ("new") document and `OldColumn(..)`
7//! against the pre-update ("old") document — this is the path used by TRANSITION
8//! CHECK and similar old/new diff predicates.
9
10use nodedb_types::Value;
11
12use crate::value_ops::{coerced_eq, is_truthy, to_value_number, value_to_f64};
13
14use super::binary::eval_binary_op;
15use super::types::SqlExpr;
16
17/// Row scope for `SqlExpr::eval_scope`: how `Column(..)` and `OldColumn(..)`
18/// resolve to `Value`s. The shared evaluator walks the AST once and calls
19/// into this scope for every leaf column reference — both `eval()` and
20/// `eval_with_old()` delegate here instead of duplicating the walk.
21struct RowScope<'a> {
22    new_doc: &'a Value,
23    /// Pre-update row, if this is an old/new evaluation (TRANSITION CHECK).
24    /// `None` means `OldColumn(..)` resolves to `Null`, matching plain `eval`.
25    old_doc: Option<&'a Value>,
26    /// Incoming `EXCLUDED.*` row for
27    /// `INSERT ... ON CONFLICT DO UPDATE SET col = EXCLUDED.col`. `None`
28    /// means `ExcludedColumn(..)` resolves to `Null`, matching plain `eval`.
29    excluded_doc: Option<&'a Value>,
30}
31
32impl<'a> RowScope<'a> {
33    fn column(&self, name: &str) -> Value {
34        self.new_doc.get(name).cloned().unwrap_or(Value::Null)
35    }
36
37    fn old_column(&self, name: &str) -> Value {
38        match self.old_doc {
39            Some(old) => old.get(name).cloned().unwrap_or(Value::Null),
40            None => Value::Null,
41        }
42    }
43
44    fn excluded_column(&self, name: &str) -> Value {
45        match self.excluded_doc {
46            Some(excluded) => excluded.get(name).cloned().unwrap_or(Value::Null),
47            None => Value::Null,
48        }
49    }
50}
51
52impl SqlExpr {
53    /// Evaluate this expression against a document.
54    ///
55    /// Column references look up fields in the document. Missing fields
56    /// return `Null`. Arithmetic on non-numeric values returns `Null`.
57    /// `OldColumn(..)` resolves to `Null` (use `eval_with_old` for the
58    /// TRANSITION CHECK path).
59    pub fn eval(&self, doc: &Value) -> Value {
60        self.eval_scope(&RowScope {
61            new_doc: doc,
62            old_doc: None,
63            excluded_doc: None,
64        })
65    }
66
67    /// Evaluate with access to both NEW and OLD documents, used by
68    /// TRANSITION CHECK predicates. `Column(name)` resolves against
69    /// `new_doc`; `OldColumn(name)` resolves against `old_doc`.
70    pub fn eval_with_old(&self, new_doc: &Value, old_doc: &Value) -> Value {
71        self.eval_scope(&RowScope {
72            new_doc,
73            old_doc: Some(old_doc),
74            excluded_doc: None,
75        })
76    }
77
78    /// Evaluate with access to the incoming `EXCLUDED.*` row, used by
79    /// `INSERT ... ON CONFLICT DO UPDATE`. `Column(name)` resolves
80    /// against the existing (current) row `doc`; `ExcludedColumn(name)`
81    /// resolves against `excluded`.
82    pub fn eval_with_excluded(&self, doc: &Value, excluded: &Value) -> Value {
83        self.eval_scope(&RowScope {
84            new_doc: doc,
85            old_doc: None,
86            excluded_doc: Some(excluded),
87        })
88    }
89
90    /// Shared walker: one match, one recursion scheme, parameterised by the
91    /// row-scope so `eval` and `eval_with_old` can't drift out of sync.
92    fn eval_scope(&self, scope: &RowScope<'_>) -> Value {
93        match self {
94            SqlExpr::Column(name) => scope.column(name),
95            SqlExpr::OldColumn(name) => scope.old_column(name),
96            SqlExpr::ExcludedColumn(name) => scope.excluded_column(name),
97
98            SqlExpr::Literal(v) => v.clone(),
99
100            SqlExpr::BinaryOp { left, op, right } => {
101                let l = left.eval_scope(scope);
102                let r = right.eval_scope(scope);
103                eval_binary_op(&l, *op, &r)
104            }
105
106            SqlExpr::Negate(inner) => {
107                let v = inner.eval_scope(scope);
108                if let Some(b) = v.as_bool() {
109                    Value::Bool(!b)
110                } else {
111                    match value_to_f64(&v, false) {
112                        Some(n) => to_value_number(-n),
113                        None => Value::Null,
114                    }
115                }
116            }
117
118            SqlExpr::Function { name, args } => {
119                let evaluated: Vec<Value> = args.iter().map(|a| a.eval_scope(scope)).collect();
120                crate::functions::eval_function(name, &evaluated)
121            }
122
123            SqlExpr::Cast { expr, to_type } => {
124                let v = expr.eval_scope(scope);
125                crate::cast::eval_cast(&v, to_type)
126            }
127
128            SqlExpr::Case {
129                operand,
130                when_thens,
131                else_expr,
132            } => {
133                let op_val = operand.as_ref().map(|e| e.eval_scope(scope));
134                for (when_expr, then_expr) in when_thens {
135                    let when_val = when_expr.eval_scope(scope);
136                    let matches = match &op_val {
137                        Some(ov) => coerced_eq(ov, &when_val),
138                        None => is_truthy(&when_val),
139                    };
140                    if matches {
141                        return then_expr.eval_scope(scope);
142                    }
143                }
144                match else_expr {
145                    Some(e) => e.eval_scope(scope),
146                    None => Value::Null,
147                }
148            }
149
150            SqlExpr::Coalesce(exprs) => {
151                for expr in exprs {
152                    let v = expr.eval_scope(scope);
153                    if !v.is_null() {
154                        return v;
155                    }
156                }
157                Value::Null
158            }
159
160            SqlExpr::NullIf(a, b) => {
161                let va = a.eval_scope(scope);
162                let vb = b.eval_scope(scope);
163                if coerced_eq(&va, &vb) {
164                    Value::Null
165                } else {
166                    va
167                }
168            }
169
170            SqlExpr::IsNull { expr, negated } => {
171                let v = expr.eval_scope(scope);
172                let is_null = v.is_null();
173                Value::Bool(if *negated { !is_null } else { is_null })
174            }
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::super::types::BinaryOp;
182    use super::*;
183
184    fn doc() -> Value {
185        Value::Object(
186            [
187                ("name".to_string(), Value::String("Alice".into())),
188                ("age".to_string(), Value::Integer(30)),
189                ("price".to_string(), Value::Float(10.5)),
190                ("qty".to_string(), Value::Integer(4)),
191                ("active".to_string(), Value::Bool(true)),
192                ("email".to_string(), Value::Null),
193            ]
194            .into_iter()
195            .collect(),
196        )
197    }
198
199    #[test]
200    fn column_ref() {
201        let expr = SqlExpr::Column("name".into());
202        assert_eq!(expr.eval(&doc()), Value::String("Alice".into()));
203    }
204
205    #[test]
206    fn missing_column() {
207        let expr = SqlExpr::Column("missing".into());
208        assert_eq!(expr.eval(&doc()), Value::Null);
209    }
210
211    #[test]
212    fn literal() {
213        let expr = SqlExpr::Literal(Value::Integer(42));
214        assert_eq!(expr.eval(&doc()), Value::Integer(42));
215    }
216
217    #[test]
218    fn add() {
219        let expr = SqlExpr::BinaryOp {
220            left: Box::new(SqlExpr::Column("price".into())),
221            op: BinaryOp::Add,
222            right: Box::new(SqlExpr::Literal(Value::Float(1.5))),
223        };
224        assert_eq!(expr.eval(&doc()), Value::Integer(12));
225    }
226
227    #[test]
228    fn multiply() {
229        let expr = SqlExpr::BinaryOp {
230            left: Box::new(SqlExpr::Column("price".into())),
231            op: BinaryOp::Mul,
232            right: Box::new(SqlExpr::Column("qty".into())),
233        };
234        assert_eq!(expr.eval(&doc()), Value::Integer(42));
235    }
236
237    #[test]
238    fn case_when() {
239        let expr = SqlExpr::Case {
240            operand: None,
241            when_thens: vec![(
242                SqlExpr::BinaryOp {
243                    left: Box::new(SqlExpr::Column("age".into())),
244                    op: BinaryOp::GtEq,
245                    right: Box::new(SqlExpr::Literal(Value::Integer(18))),
246                },
247                SqlExpr::Literal(Value::String("adult".into())),
248            )],
249            else_expr: Some(Box::new(SqlExpr::Literal(Value::String("minor".into())))),
250        };
251        assert_eq!(expr.eval(&doc()), Value::String("adult".into()));
252    }
253
254    #[test]
255    fn coalesce() {
256        let expr = SqlExpr::Coalesce(vec![
257            SqlExpr::Column("email".into()),
258            SqlExpr::Literal(Value::String("default@example.com".into())),
259        ]);
260        assert_eq!(
261            expr.eval(&doc()),
262            Value::String("default@example.com".into())
263        );
264    }
265
266    #[test]
267    fn is_null() {
268        let expr = SqlExpr::IsNull {
269            expr: Box::new(SqlExpr::Column("email".into())),
270            negated: false,
271        };
272        assert_eq!(expr.eval(&doc()), Value::Bool(true));
273    }
274}