Skip to main content

nodedb_sql/planner/
const_fold.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Plan-time constant folding for `SqlExpr`.
4//!
5//! Evaluates literal expressions and registered zero-or-few-arg scalar
6//! functions (e.g. `now()`, `current_timestamp`, `date_add(now(), '1h')`)
7//! at plan time via the shared `nodedb_query::functions::eval_function`
8//! evaluator.
9//!
10//! This keeps the bare-`SELECT` projection path, the `INSERT`/`UPSERT`
11//! `VALUES` path, and any future default-expression paths from drifting
12//! apart — they all reach the same evaluator that the Data Plane uses
13//! for column-reference evaluation.
14//!
15//! Semantics: Postgres / SQL-standard compatible. `now()` and
16//! `current_timestamp` snapshot once per statement — `CURRENT_TIMESTAMP`
17//! is defined to return the same value for every row of a single
18//! statement, and Postgres goes further (same value for the whole
19//! transaction). Folding at plan time satisfies both contracts and is
20//! cheaper than per-row runtime dispatch.
21
22use std::sync::LazyLock;
23
24use nodedb_types::Value;
25use sonic_rs;
26
27use crate::functions::registry::{FunctionCategory, FunctionRegistry};
28use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};
29
30/// Process-wide default registry. Used by call sites that don't already
31/// thread a `FunctionRegistry` through (e.g. the DML `VALUES` path).
32static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);
33
34/// Access the shared default registry.
35pub fn default_registry() -> &'static FunctionRegistry {
36    &DEFAULT_REGISTRY
37}
38
39/// Convenience wrapper around [`fold_constant`] using the default registry.
40pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
41    fold_constant(expr, default_registry())
42}
43
44/// Fold a `SqlExpr` to a literal `SqlValue` at plan time, or return
45/// `None` if the expression depends on row/runtime state (column refs,
46/// subqueries, unknown functions, etc.).
47pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
48    match expr {
49        SqlExpr::Literal(v) => Some(v.clone()),
50        SqlExpr::UnaryOp {
51            op: UnaryOp::Neg,
52            expr,
53        } => match fold_constant(expr, registry)? {
54            SqlValue::Int(i) => Some(SqlValue::Int(-i)),
55            SqlValue::Float(f) => Some(SqlValue::Float(-f)),
56            SqlValue::Decimal(d) => Some(SqlValue::Decimal(-d)),
57            _ => None,
58        },
59        SqlExpr::BinaryOp { left, op, right } => {
60            let l = fold_constant(left, registry)?;
61            let r = fold_constant(right, registry)?;
62            fold_binary(l, *op, r)
63        }
64        SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
65        SqlExpr::Cast { expr, to_type } => fold_cast(fold_constant(expr, registry)?, to_type),
66        _ => None,
67    }
68}
69
70/// Fold a CAST at plan time. Only applies when the inner expression is already
71/// a constant. The `to_type` string comes from sqlparser's `format!("{data_type}")`
72/// output, so parameterised types like `NUMERIC(5,1)` must be matched by prefix.
73fn fold_cast(inner: SqlValue, to_type: &str) -> Option<SqlValue> {
74    let upper = to_type.to_uppercase();
75    // Strip any precision/scale suffix: "NUMERIC(5,1)" → "NUMERIC".
76    let base = upper
77        .split('(')
78        .next()
79        .map(str::trim)
80        .unwrap_or(upper.as_str());
81
82    match base {
83        "NUMERIC" | "DECIMAL" => match inner {
84            SqlValue::Decimal(d) => Some(SqlValue::Decimal(d)),
85            SqlValue::Int(i) => Some(SqlValue::Decimal(rust_decimal::Decimal::from(i))),
86            SqlValue::Float(f) => rust_decimal::Decimal::try_from(f)
87                .ok()
88                .map(SqlValue::Decimal),
89            SqlValue::String(s) => rust_decimal::Decimal::from_str_exact(&s)
90                .ok()
91                .map(SqlValue::Decimal),
92            _ => None,
93        },
94        "INTEGER" | "INT" | "BIGINT" | "SMALLINT" | "INT2" | "INT4" | "INT8" => match inner {
95            SqlValue::Int(i) => Some(SqlValue::Int(i)),
96            SqlValue::Decimal(d) => {
97                rust_decimal::prelude::ToPrimitive::to_i64(&d).map(SqlValue::Int)
98            }
99            SqlValue::Float(f) => {
100                if f.is_finite() {
101                    Some(SqlValue::Int(f as i64))
102                } else {
103                    None
104                }
105            }
106            SqlValue::String(s) => s.parse::<i64>().ok().map(SqlValue::Int),
107            _ => None,
108        },
109        "FLOAT" | "DOUBLE" | "REAL" | "FLOAT4" | "FLOAT8" | "DOUBLE PRECISION" => match inner {
110            SqlValue::Float(f) => Some(SqlValue::Float(f)),
111            SqlValue::Int(i) => Some(SqlValue::Float(i as f64)),
112            SqlValue::Decimal(d) => {
113                rust_decimal::prelude::ToPrimitive::to_f64(&d).map(SqlValue::Float)
114            }
115            SqlValue::String(s) => s.parse::<f64>().ok().map(SqlValue::Float),
116            _ => None,
117        },
118        "TEXT" | "VARCHAR" | "CHAR" | "CHARACTER VARYING" | "CHARACTER" | "BPCHAR" => match inner {
119            SqlValue::String(s) => Some(SqlValue::String(s)),
120            SqlValue::Int(i) => Some(SqlValue::String(i.to_string())),
121            SqlValue::Float(f) => Some(SqlValue::String(f.to_string())),
122            SqlValue::Decimal(d) => Some(SqlValue::String(d.to_string())),
123            SqlValue::Bool(b) => Some(SqlValue::String(b.to_string())),
124            _ => None,
125        },
126        "BOOL" | "BOOLEAN" => match inner {
127            SqlValue::Bool(b) => Some(SqlValue::Bool(b)),
128            SqlValue::Int(i) => Some(SqlValue::Bool(i != 0)),
129            SqlValue::String(s) => match s.to_lowercase().as_str() {
130                "true" | "t" | "yes" | "1" | "on" => Some(SqlValue::Bool(true)),
131                "false" | "f" | "no" | "0" | "off" => Some(SqlValue::Bool(false)),
132                _ => None,
133            },
134            _ => None,
135        },
136        _ => None,
137    }
138}
139
140fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
141    Some(match (l, op, r) {
142        // Int × Int arithmetic.
143        (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?),
144        (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?),
145        (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?),
146        // Float × Float arithmetic.
147        (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
148        (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
149        (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
150        // Decimal × Decimal arithmetic.
151        (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
152            SqlValue::Decimal(a.checked_add(b)?)
153        }
154        (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
155            SqlValue::Decimal(a.checked_sub(b)?)
156        }
157        (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
158            SqlValue::Decimal(a.checked_mul(b)?)
159        }
160        (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
161            SqlValue::Decimal(a.checked_div(b)?)
162        }
163        // Decimal × Int widening (Int promotes to Decimal).
164        (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Int(b)) => {
165            SqlValue::Decimal(a.checked_add(rust_decimal::Decimal::from(b))?)
166        }
167        (SqlValue::Int(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
168            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_add(b)?)
169        }
170        (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Int(b)) => {
171            SqlValue::Decimal(a.checked_sub(rust_decimal::Decimal::from(b))?)
172        }
173        (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
174            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_sub(b)?)
175        }
176        (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Int(b)) => {
177            SqlValue::Decimal(a.checked_mul(rust_decimal::Decimal::from(b))?)
178        }
179        (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
180            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_mul(b)?)
181        }
182        (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Int(b)) => {
183            SqlValue::Decimal(a.checked_div(rust_decimal::Decimal::from(b))?)
184        }
185        (SqlValue::Int(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
186            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_div(b)?)
187        }
188        // String concat.
189        (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
190            SqlValue::String(format!("{a}{b}"))
191        }
192        _ => return None,
193    })
194}
195
196/// Fold a function call by recursively folding its arguments, dispatching
197/// through the shared scalar evaluator, and converting the result back to
198/// `SqlValue`. Only folds functions that are present in `registry`, so
199/// callers can distinguish "unknown function" from "known function, all
200/// args folded".
201pub fn fold_function_call(
202    name: &str,
203    args: &[SqlExpr],
204    registry: &FunctionRegistry,
205) -> Option<SqlValue> {
206    // Gate on registry so unknown-function paths keep their existing
207    // fallbacks instead of collapsing to SqlValue::Null. Aggregates and
208    // window functions aren't foldable — they need a row stream.
209    let meta = registry.lookup(name)?;
210    if matches!(
211        meta.category,
212        FunctionCategory::Aggregate | FunctionCategory::Window
213    ) {
214        return None;
215    }
216
217    let folded_args: Vec<Value> = args
218        .iter()
219        .map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
220        .collect::<Option<_>>()?;
221
222    let result = nodedb_query::functions::eval_function(&name.to_lowercase(), &folded_args);
223    Some(ndb_to_sql_value(result))
224}
225
226fn sql_to_ndb_value(v: SqlValue) -> Value {
227    match v {
228        SqlValue::Null => Value::Null,
229        SqlValue::Bool(b) => Value::Bool(b),
230        SqlValue::Int(i) => Value::Integer(i),
231        SqlValue::Float(f) => Value::Float(f),
232        SqlValue::Decimal(d) => Value::Decimal(d),
233        SqlValue::String(s) => Value::String(s),
234        SqlValue::Bytes(b) => Value::Bytes(b),
235        SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
236        SqlValue::Timestamp(dt) => Value::NaiveDateTime(dt),
237        SqlValue::Timestamptz(dt) => Value::DateTime(dt),
238    }
239}
240
241fn ndb_to_sql_value(v: Value) -> SqlValue {
242    match v {
243        Value::Null => SqlValue::Null,
244        Value::Bool(b) => SqlValue::Bool(b),
245        Value::Integer(i) => SqlValue::Int(i),
246        Value::Float(f) => SqlValue::Float(f),
247        Value::String(s) => SqlValue::String(s),
248        Value::Bytes(b) => SqlValue::Bytes(b),
249        Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
250        // TZ-aware DateTime → Timestamptz; naive → Timestamp.
251        Value::DateTime(dt) => SqlValue::Timestamptz(dt),
252        Value::NaiveDateTime(dt) => SqlValue::Timestamp(dt),
253        Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
254        Value::Duration(d) => SqlValue::String(d.to_human()),
255        Value::Decimal(d) => SqlValue::Decimal(d),
256        // Geometry and Object values are serialized to JSON strings so that
257        // nested function calls like ST_Distance(ST_Point(...), ST_Point(...))
258        // survive the SqlValue round-trip. The geo evaluator's geom_arg helper
259        // recovers Geometry from a GeoJSON string; Object results (e.g. from
260        // ST_GeoHashDecode) reach the client as a JSON-encoded string column.
261        Value::Geometry(g) => sonic_rs::to_string(&g)
262            .map(SqlValue::String)
263            .unwrap_or(SqlValue::Null),
264        Value::Object(map) => sonic_rs::to_string(&map)
265            .map(SqlValue::String)
266            .unwrap_or(SqlValue::Null),
267        // Structured and opaque types collapse to Null — callers that
268        // need these go through the runtime expression path, not folding.
269        Value::Set(_) | Value::Range { .. } | Value::Record { .. } | Value::ArrayCell(_) => {
270            SqlValue::Null
271        }
272        // Value is #[non_exhaustive]; future variants collapse to Null in the
273        // constant-folding path — runtime expression evaluation handles them.
274        _ => SqlValue::Null,
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281
282    #[test]
283    fn fold_now_produces_timestamptz() {
284        let registry = FunctionRegistry::new();
285        let expr = SqlExpr::Function {
286            name: "now".into(),
287            args: vec![],
288            distinct: false,
289        };
290        let val = fold_constant(&expr, &registry).expect("now() should fold");
291        match val {
292            SqlValue::Timestamptz(dt) => {
293                // Sanity: must not be epoch (year 1970).
294                assert!(dt.micros > 0, "expected post-epoch timestamp, got micros=0");
295            }
296            other => panic!("expected SqlValue::Timestamptz, got {other:?}"),
297        }
298    }
299
300    #[test]
301    fn fold_current_timestamp_produces_timestamptz() {
302        let registry = FunctionRegistry::new();
303        let expr = SqlExpr::Function {
304            name: "current_timestamp".into(),
305            args: vec![],
306            distinct: false,
307        };
308        assert!(matches!(
309            fold_constant(&expr, &registry),
310            Some(SqlValue::Timestamptz(_))
311        ));
312    }
313
314    #[test]
315    fn fold_unknown_function_returns_none() {
316        let registry = FunctionRegistry::new();
317        let expr = SqlExpr::Function {
318            name: "definitely_not_a_real_function".into(),
319            args: vec![],
320            distinct: false,
321        };
322        assert!(fold_constant(&expr, &registry).is_none());
323    }
324
325    #[test]
326    fn fold_literal_arithmetic_still_works() {
327        let registry = FunctionRegistry::new();
328        let expr = SqlExpr::BinaryOp {
329            left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
330            op: BinaryOp::Add,
331            right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
332        };
333        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Int(5)));
334    }
335
336    #[test]
337    fn fold_column_ref_returns_none() {
338        let registry = FunctionRegistry::new();
339        let expr = SqlExpr::Column {
340            table: None,
341            name: "name".into(),
342        };
343        assert!(fold_constant(&expr, &registry).is_none());
344    }
345
346    #[test]
347    fn fold_decimal_literal() {
348        let registry = FunctionRegistry::new();
349        let d = rust_decimal::Decimal::new(12345, 2); // 123.45
350        let expr = SqlExpr::Literal(SqlValue::Decimal(d));
351        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Decimal(d)));
352    }
353
354    #[test]
355    fn fold_decimal_addition() {
356        use rust_decimal::Decimal;
357        let registry = FunctionRegistry::new();
358        let a = Decimal::new(12345, 2); // 123.45
359        let b = Decimal::new(45678, 2); // 456.78
360        let expr = SqlExpr::BinaryOp {
361            left: Box::new(SqlExpr::Literal(SqlValue::Decimal(a))),
362            op: BinaryOp::Add,
363            right: Box::new(SqlExpr::Literal(SqlValue::Decimal(b))),
364        };
365        let expected = Decimal::new(58023, 2); // 580.23
366        assert_eq!(
367            fold_constant(&expr, &registry),
368            Some(SqlValue::Decimal(expected))
369        );
370    }
371
372    #[test]
373    fn fold_decimal_negation() {
374        use rust_decimal::Decimal;
375        let registry = FunctionRegistry::new();
376        let d = Decimal::new(100, 0);
377        let expr = SqlExpr::UnaryOp {
378            op: crate::types::UnaryOp::Neg,
379            expr: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
380        };
381        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Decimal(-d)));
382    }
383
384    #[test]
385    fn fold_st_geohash() {
386        let registry = FunctionRegistry::new();
387        let expr = SqlExpr::Function {
388            name: "st_geohash".into(),
389            args: vec![
390                SqlExpr::UnaryOp {
391                    op: UnaryOp::Neg,
392                    expr: Box::new(SqlExpr::Literal(SqlValue::Float(122.4))),
393                },
394                SqlExpr::Literal(SqlValue::Float(37.8)),
395                SqlExpr::Literal(SqlValue::Int(6)),
396            ],
397            distinct: false,
398        };
399        let v = fold_constant(&expr, &registry);
400        match v {
401            Some(SqlValue::String(ref s)) if !s.is_empty() => {}
402            other => panic!("expected non-empty SqlValue::String, got {other:?}"),
403        }
404    }
405
406    #[test]
407    fn fold_st_distance_nested_st_point() {
408        let registry = FunctionRegistry::new();
409        let make_point = |lng: f64, lat: f64| SqlExpr::Function {
410            name: "st_point".into(),
411            args: vec![
412                SqlExpr::Literal(SqlValue::Float(lng)),
413                SqlExpr::Literal(SqlValue::Float(lat)),
414            ],
415            distinct: false,
416        };
417        let expr = SqlExpr::Function {
418            name: "st_distance".into(),
419            args: vec![make_point(-122.4, 37.8), make_point(-87.6, 41.8)],
420            distinct: false,
421        };
422        let v = fold_constant(&expr, &registry);
423        match v {
424            Some(SqlValue::Float(d)) => {
425                assert!(d > 0.0, "distance should be positive, got {d}");
426            }
427            other => panic!("expected SqlValue::Float, got {other:?}"),
428        }
429    }
430
431    #[test]
432    fn fold_decimal_int_widening() {
433        use rust_decimal::Decimal;
434        let registry = FunctionRegistry::new();
435        let d = Decimal::new(100, 0); // 100
436        let expr = SqlExpr::BinaryOp {
437            left: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
438            op: BinaryOp::Add,
439            right: Box::new(SqlExpr::Literal(SqlValue::Int(50))),
440        };
441        let expected = Decimal::new(150, 0);
442        assert_eq!(
443            fold_constant(&expr, &registry),
444            Some(SqlValue::Decimal(expected))
445        );
446    }
447}