Skip to main content

nodedb_sql/planner/
const_fold.rs

1//! Plan-time constant folding for `SqlExpr`.
2//!
3//! Evaluates literal expressions and registered zero-or-few-arg scalar
4//! functions (e.g. `now()`, `current_timestamp`, `date_add(now(), '1h')`)
5//! at plan time via the shared `nodedb_query::functions::eval_function`
6//! evaluator.
7//!
8//! This keeps the bare-`SELECT` projection path, the `INSERT`/`UPSERT`
9//! `VALUES` path, and any future default-expression paths from drifting
10//! apart — they all reach the same evaluator that the Data Plane uses
11//! for column-reference evaluation.
12//!
13//! Semantics: Postgres / SQL-standard compatible. `now()` and
14//! `current_timestamp` snapshot once per statement — `CURRENT_TIMESTAMP`
15//! is defined to return the same value for every row of a single
16//! statement, and Postgres goes further (same value for the whole
17//! transaction). Folding at plan time satisfies both contracts and is
18//! cheaper than per-row runtime dispatch.
19
20use std::sync::LazyLock;
21
22use nodedb_types::Value;
23
24use crate::functions::registry::{FunctionCategory, FunctionRegistry};
25use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};
26
27/// Process-wide default registry. Used by call sites that don't already
28/// thread a `FunctionRegistry` through (e.g. the DML `VALUES` path).
29static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);
30
31/// Access the shared default registry.
32pub fn default_registry() -> &'static FunctionRegistry {
33    &DEFAULT_REGISTRY
34}
35
36/// Convenience wrapper around [`fold_constant`] using the default registry.
37pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
38    fold_constant(expr, default_registry())
39}
40
41/// Fold a `SqlExpr` to a literal `SqlValue` at plan time, or return
42/// `None` if the expression depends on row/runtime state (column refs,
43/// subqueries, unknown functions, etc.).
44pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
45    match expr {
46        SqlExpr::Literal(v) => Some(v.clone()),
47        SqlExpr::UnaryOp {
48            op: UnaryOp::Neg,
49            expr,
50        } => match fold_constant(expr, registry)? {
51            SqlValue::Int(i) => Some(SqlValue::Int(-i)),
52            SqlValue::Float(f) => Some(SqlValue::Float(-f)),
53            _ => None,
54        },
55        SqlExpr::BinaryOp { left, op, right } => {
56            let l = fold_constant(left, registry)?;
57            let r = fold_constant(right, registry)?;
58            fold_binary(l, *op, r)
59        }
60        SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
61        _ => None,
62    }
63}
64
65fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
66    Some(match (l, op, r) {
67        (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?),
68        (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?),
69        (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?),
70        (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
71        (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
72        (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
73        (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
74            SqlValue::String(format!("{a}{b}"))
75        }
76        _ => return None,
77    })
78}
79
80/// Fold a function call by recursively folding its arguments, dispatching
81/// through the shared scalar evaluator, and converting the result back to
82/// `SqlValue`. Only folds functions that are present in `registry`, so
83/// callers can distinguish "unknown function" from "known function, all
84/// args folded".
85pub fn fold_function_call(
86    name: &str,
87    args: &[SqlExpr],
88    registry: &FunctionRegistry,
89) -> Option<SqlValue> {
90    // Gate on registry so unknown-function paths keep their existing
91    // fallbacks instead of collapsing to SqlValue::Null. Aggregates and
92    // window functions aren't foldable — they need a row stream.
93    let meta = registry.lookup(name)?;
94    if matches!(
95        meta.category,
96        FunctionCategory::Aggregate | FunctionCategory::Window
97    ) {
98        return None;
99    }
100
101    let folded_args: Vec<Value> = args
102        .iter()
103        .map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
104        .collect::<Option<_>>()?;
105
106    let result = nodedb_query::functions::eval_function(name, &folded_args);
107    Some(ndb_to_sql_value(result))
108}
109
110fn sql_to_ndb_value(v: SqlValue) -> Value {
111    match v {
112        SqlValue::Null => Value::Null,
113        SqlValue::Bool(b) => Value::Bool(b),
114        SqlValue::Int(i) => Value::Integer(i),
115        SqlValue::Float(f) => Value::Float(f),
116        SqlValue::String(s) => Value::String(s),
117        SqlValue::Bytes(b) => Value::Bytes(b),
118        SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
119    }
120}
121
122fn ndb_to_sql_value(v: Value) -> SqlValue {
123    match v {
124        Value::Null => SqlValue::Null,
125        Value::Bool(b) => SqlValue::Bool(b),
126        Value::Integer(i) => SqlValue::Int(i),
127        Value::Float(f) => SqlValue::Float(f),
128        Value::String(s) => SqlValue::String(s),
129        Value::Bytes(b) => SqlValue::Bytes(b),
130        Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
131        Value::DateTime(dt) => SqlValue::String(dt.to_iso8601()),
132        Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
133        Value::Duration(d) => SqlValue::String(d.to_human()),
134        Value::Decimal(d) => SqlValue::String(d.to_string()),
135        // Structured and opaque types collapse to Null — callers that
136        // need these go through the runtime expression path, not folding.
137        Value::Object(_)
138        | Value::Geometry(_)
139        | Value::Set(_)
140        | Value::Range { .. }
141        | Value::Record { .. } => SqlValue::Null,
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn fold_now_produces_non_epoch_string() {
151        let registry = FunctionRegistry::new();
152        let expr = SqlExpr::Function {
153            name: "now".into(),
154            args: vec![],
155            distinct: false,
156        };
157        let val = fold_constant(&expr, &registry).expect("now() should fold");
158        match val {
159            SqlValue::String(s) => {
160                assert!(!s.starts_with("1970"), "got {s}");
161                assert!(s.contains('T'), "not ISO-8601: {s}");
162            }
163            other => panic!("expected string, got {other:?}"),
164        }
165    }
166
167    #[test]
168    fn fold_current_timestamp() {
169        let registry = FunctionRegistry::new();
170        let expr = SqlExpr::Function {
171            name: "current_timestamp".into(),
172            args: vec![],
173            distinct: false,
174        };
175        assert!(matches!(
176            fold_constant(&expr, &registry),
177            Some(SqlValue::String(_))
178        ));
179    }
180
181    #[test]
182    fn fold_unknown_function_returns_none() {
183        let registry = FunctionRegistry::new();
184        let expr = SqlExpr::Function {
185            name: "definitely_not_a_real_function".into(),
186            args: vec![],
187            distinct: false,
188        };
189        assert!(fold_constant(&expr, &registry).is_none());
190    }
191
192    #[test]
193    fn fold_literal_arithmetic_still_works() {
194        let registry = FunctionRegistry::new();
195        let expr = SqlExpr::BinaryOp {
196            left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
197            op: BinaryOp::Add,
198            right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
199        };
200        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Int(5)));
201    }
202
203    #[test]
204    fn fold_column_ref_returns_none() {
205        let registry = FunctionRegistry::new();
206        let expr = SqlExpr::Column {
207            table: None,
208            name: "name".into(),
209        };
210        assert!(fold_constant(&expr, &registry).is_none());
211    }
212}