Skip to main content

nodedb_sql/planner/
const_fold.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Plan-time constant folding for `SqlExpr`.
4//!
5//! Evaluates literal expressions and registered zero-or-few-arg scalar
6//! functions (e.g. `now()`, `current_timestamp`, `date_add(now(), '1h')`)
7//! at plan time via the shared `nodedb_query::functions::eval_function`
8//! evaluator.
9//!
10//! This keeps the bare-`SELECT` projection path, the `INSERT`/`UPSERT`
11//! `VALUES` path, and any future default-expression paths from drifting
12//! apart — they all reach the same evaluator that the Data Plane uses
13//! for column-reference evaluation.
14//!
15//! Semantics: Postgres / SQL-standard compatible. `now()` and
16//! `current_timestamp` snapshot once per statement — `CURRENT_TIMESTAMP`
17//! is defined to return the same value for every row of a single
18//! statement, and Postgres goes further (same value for the whole
19//! transaction). Folding at plan time satisfies both contracts and is
20//! cheaper than per-row runtime dispatch.
21
22use std::sync::LazyLock;
23
24use nodedb_types::Value;
25use sonic_rs;
26
27use crate::functions::registry::{FunctionCategory, FunctionRegistry};
28use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};
29
30/// Process-wide default registry. Used by call sites that don't already
31/// thread a `FunctionRegistry` through (e.g. the DML `VALUES` path).
32static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);
33
34/// Access the shared default registry.
35pub fn default_registry() -> &'static FunctionRegistry {
36    &DEFAULT_REGISTRY
37}
38
39/// Convenience wrapper around [`fold_constant`] using the default registry.
40pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
41    fold_constant(expr, default_registry())
42}
43
44/// Fold a `SqlExpr` to a literal `SqlValue` at plan time, or return
45/// `None` if the expression depends on row/runtime state (column refs,
46/// subqueries, unknown functions, etc.).
47pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
48    match expr {
49        SqlExpr::Literal(v) => Some(v.clone()),
50        SqlExpr::UnaryOp {
51            op: UnaryOp::Neg,
52            expr,
53        } => match fold_constant(expr, registry)? {
54            SqlValue::Int(i) => Some(SqlValue::Int(-i)),
55            SqlValue::Float(f) => Some(SqlValue::Float(-f)),
56            SqlValue::Decimal(d) => Some(SqlValue::Decimal(-d)),
57            _ => None,
58        },
59        SqlExpr::BinaryOp { left, op, right } => {
60            let l = fold_constant(left, registry)?;
61            let r = fold_constant(right, registry)?;
62            fold_binary(l, *op, r)
63        }
64        SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
65        SqlExpr::Cast { expr, to_type } => fold_cast(fold_constant(expr, registry)?, to_type),
66        _ => None,
67    }
68}
69
70/// Fold a CAST at plan time. Only applies when the inner expression is already
71/// a constant. The `to_type` string comes from sqlparser's `format!("{data_type}")`
72/// output, so parameterised types like `NUMERIC(5,1)` must be matched by prefix.
73fn fold_cast(inner: SqlValue, to_type: &str) -> Option<SqlValue> {
74    let upper = to_type.to_uppercase();
75    // Strip any precision/scale suffix: "NUMERIC(5,1)" → "NUMERIC".
76    let base = upper
77        .split('(')
78        .next()
79        .map(str::trim)
80        .unwrap_or(upper.as_str());
81
82    match base {
83        "NUMERIC" | "DECIMAL" => match inner {
84            SqlValue::Decimal(d) => Some(SqlValue::Decimal(d)),
85            SqlValue::Int(i) => Some(SqlValue::Decimal(rust_decimal::Decimal::from(i))),
86            SqlValue::Float(f) => rust_decimal::Decimal::try_from(f)
87                .ok()
88                .map(SqlValue::Decimal),
89            SqlValue::String(s) => rust_decimal::Decimal::from_str_exact(&s)
90                .ok()
91                .map(SqlValue::Decimal),
92            _ => None,
93        },
94        "INTEGER" | "INT" | "BIGINT" | "SMALLINT" | "INT2" | "INT4" | "INT8" => match inner {
95            SqlValue::Int(i) => Some(SqlValue::Int(i)),
96            SqlValue::Decimal(d) => {
97                rust_decimal::prelude::ToPrimitive::to_i64(&d).map(SqlValue::Int)
98            }
99            SqlValue::Float(f) => {
100                if f.is_finite() {
101                    Some(SqlValue::Int(f as i64))
102                } else {
103                    None
104                }
105            }
106            SqlValue::String(s) => s.parse::<i64>().ok().map(SqlValue::Int),
107            _ => None,
108        },
109        "FLOAT" | "DOUBLE" | "REAL" | "FLOAT4" | "FLOAT8" | "DOUBLE PRECISION" => match inner {
110            SqlValue::Float(f) => Some(SqlValue::Float(f)),
111            SqlValue::Int(i) => Some(SqlValue::Float(i as f64)),
112            SqlValue::Decimal(d) => {
113                rust_decimal::prelude::ToPrimitive::to_f64(&d).map(SqlValue::Float)
114            }
115            SqlValue::String(s) => s.parse::<f64>().ok().map(SqlValue::Float),
116            _ => None,
117        },
118        "TEXT" | "VARCHAR" | "CHAR" | "CHARACTER VARYING" | "CHARACTER" | "BPCHAR" => match inner {
119            SqlValue::String(s) => Some(SqlValue::String(s)),
120            SqlValue::Int(i) => Some(SqlValue::String(i.to_string())),
121            SqlValue::Float(f) => Some(SqlValue::String(f.to_string())),
122            SqlValue::Decimal(d) => Some(SqlValue::String(d.to_string())),
123            SqlValue::Bool(b) => Some(SqlValue::String(b.to_string())),
124            _ => None,
125        },
126        "BOOL" | "BOOLEAN" => match inner {
127            SqlValue::Bool(b) => Some(SqlValue::Bool(b)),
128            SqlValue::Int(i) => Some(SqlValue::Bool(i != 0)),
129            SqlValue::String(s) => match s.to_lowercase().as_str() {
130                "true" | "t" | "yes" | "1" | "on" => Some(SqlValue::Bool(true)),
131                "false" | "f" | "no" | "0" | "off" => Some(SqlValue::Bool(false)),
132                _ => None,
133            },
134            _ => None,
135        },
136        // `JSON` / `JSONB` — JSON values live internally as their text form
137        // in `SqlValue::String`; the write path parses JSON-looking strings
138        // into document structure. The cast elides to the inner value's JSON
139        // text (mirrors the `::tsvector` / `::tsquery` elision in the resolver).
140        "JSON" | "JSONB" => match inner {
141            SqlValue::String(s) => Some(SqlValue::String(s)),
142            SqlValue::Int(i) => Some(SqlValue::String(i.to_string())),
143            SqlValue::Float(f) => Some(SqlValue::String(f.to_string())),
144            SqlValue::Decimal(d) => Some(SqlValue::String(d.to_string())),
145            SqlValue::Bool(b) => Some(SqlValue::String(b.to_string())),
146            SqlValue::Null => Some(SqlValue::Null),
147            _ => None,
148        },
149        _ => None,
150    }
151}
152
153fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
154    Some(match (l, op, r) {
155        // Int × Int arithmetic.
156        (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?),
157        (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?),
158        (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?),
159        // Float × Float arithmetic.
160        (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
161        (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
162        (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
163        // Decimal × Decimal arithmetic.
164        (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
165            SqlValue::Decimal(a.checked_add(b)?)
166        }
167        (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
168            SqlValue::Decimal(a.checked_sub(b)?)
169        }
170        (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
171            SqlValue::Decimal(a.checked_mul(b)?)
172        }
173        (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
174            SqlValue::Decimal(a.checked_div(b)?)
175        }
176        // Decimal × Int widening (Int promotes to Decimal).
177        (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Int(b)) => {
178            SqlValue::Decimal(a.checked_add(rust_decimal::Decimal::from(b))?)
179        }
180        (SqlValue::Int(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
181            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_add(b)?)
182        }
183        (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Int(b)) => {
184            SqlValue::Decimal(a.checked_sub(rust_decimal::Decimal::from(b))?)
185        }
186        (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
187            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_sub(b)?)
188        }
189        (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Int(b)) => {
190            SqlValue::Decimal(a.checked_mul(rust_decimal::Decimal::from(b))?)
191        }
192        (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
193            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_mul(b)?)
194        }
195        (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Int(b)) => {
196            SqlValue::Decimal(a.checked_div(rust_decimal::Decimal::from(b))?)
197        }
198        (SqlValue::Int(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
199            SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_div(b)?)
200        }
201        // String concat.
202        (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
203            SqlValue::String(format!("{a}{b}"))
204        }
205        _ => return None,
206    })
207}
208
209/// Fold a function call by recursively folding its arguments, dispatching
210/// through the shared scalar evaluator, and converting the result back to
211/// `SqlValue`. Only folds functions that are present in `registry`, so
212/// callers can distinguish "unknown function" from "known function, all
213/// args folded".
214pub fn fold_function_call(
215    name: &str,
216    args: &[SqlExpr],
217    registry: &FunctionRegistry,
218) -> Option<SqlValue> {
219    // Gate on registry so unknown-function paths keep their existing
220    // fallbacks instead of collapsing to SqlValue::Null. Aggregates and
221    // window functions aren't foldable — they need a row stream.
222    let meta = registry.lookup(name)?;
223    if matches!(
224        meta.category,
225        FunctionCategory::Aggregate | FunctionCategory::Window
226    ) {
227        return None;
228    }
229
230    let folded_args: Vec<Value> = args
231        .iter()
232        .map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
233        .collect::<Option<_>>()?;
234
235    let result = nodedb_query::functions::eval_function(&name.to_lowercase(), &folded_args);
236    Some(ndb_to_sql_value(result))
237}
238
239fn sql_to_ndb_value(v: SqlValue) -> Value {
240    match v {
241        SqlValue::Null => Value::Null,
242        SqlValue::Bool(b) => Value::Bool(b),
243        SqlValue::Int(i) => Value::Integer(i),
244        SqlValue::Float(f) => Value::Float(f),
245        SqlValue::Decimal(d) => Value::Decimal(d),
246        SqlValue::String(s) => Value::String(s),
247        SqlValue::Bytes(b) => Value::Bytes(b),
248        SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
249        SqlValue::Timestamp(dt) => Value::NaiveDateTime(dt),
250        SqlValue::Timestamptz(dt) => Value::DateTime(dt),
251    }
252}
253
254fn ndb_to_sql_value(v: Value) -> SqlValue {
255    match v {
256        Value::Null => SqlValue::Null,
257        Value::Bool(b) => SqlValue::Bool(b),
258        Value::Integer(i) => SqlValue::Int(i),
259        Value::Float(f) => SqlValue::Float(f),
260        Value::String(s) => SqlValue::String(s),
261        Value::Bytes(b) => SqlValue::Bytes(b),
262        Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
263        // TZ-aware DateTime → Timestamptz; naive → Timestamp.
264        Value::DateTime(dt) => SqlValue::Timestamptz(dt),
265        Value::NaiveDateTime(dt) => SqlValue::Timestamp(dt),
266        Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
267        Value::Duration(d) => SqlValue::String(d.to_human()),
268        Value::Decimal(d) => SqlValue::Decimal(d),
269        // Geometry and Object values are serialized to JSON strings so that
270        // nested function calls like ST_Distance(ST_Point(...), ST_Point(...))
271        // survive the SqlValue round-trip. The geo evaluator's geom_arg helper
272        // recovers Geometry from a GeoJSON string; Object results (e.g. from
273        // ST_GeoHashDecode) reach the client as a JSON-encoded string column.
274        Value::Geometry(g) => sonic_rs::to_string(&g)
275            .map(SqlValue::String)
276            .unwrap_or(SqlValue::Null),
277        Value::Object(map) => sonic_rs::to_string(&map)
278            .map(SqlValue::String)
279            .unwrap_or(SqlValue::Null),
280        // Structured and opaque types collapse to Null — callers that
281        // need these go through the runtime expression path, not folding.
282        Value::Set(_) | Value::Range { .. } | Value::Record { .. } | Value::ArrayCell(_) => {
283            SqlValue::Null
284        }
285        // Value is #[non_exhaustive]; future variants collapse to Null in the
286        // constant-folding path — runtime expression evaluation handles them.
287        _ => SqlValue::Null,
288    }
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    #[test]
296    fn fold_now_produces_timestamptz() {
297        let registry = FunctionRegistry::new();
298        let expr = SqlExpr::Function {
299            name: "now".into(),
300            args: vec![],
301            distinct: false,
302        };
303        let val = fold_constant(&expr, &registry).expect("now() should fold");
304        match val {
305            SqlValue::Timestamptz(dt) => {
306                // Sanity: must not be epoch (year 1970).
307                assert!(dt.micros > 0, "expected post-epoch timestamp, got micros=0");
308            }
309            other => panic!("expected SqlValue::Timestamptz, got {other:?}"),
310        }
311    }
312
313    #[test]
314    fn fold_current_timestamp_produces_timestamptz() {
315        let registry = FunctionRegistry::new();
316        let expr = SqlExpr::Function {
317            name: "current_timestamp".into(),
318            args: vec![],
319            distinct: false,
320        };
321        assert!(matches!(
322            fold_constant(&expr, &registry),
323            Some(SqlValue::Timestamptz(_))
324        ));
325    }
326
327    #[test]
328    fn fold_unknown_function_returns_none() {
329        let registry = FunctionRegistry::new();
330        let expr = SqlExpr::Function {
331            name: "definitely_not_a_real_function".into(),
332            args: vec![],
333            distinct: false,
334        };
335        assert!(fold_constant(&expr, &registry).is_none());
336    }
337
338    #[test]
339    fn fold_literal_arithmetic_still_works() {
340        let registry = FunctionRegistry::new();
341        let expr = SqlExpr::BinaryOp {
342            left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
343            op: BinaryOp::Add,
344            right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
345        };
346        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Int(5)));
347    }
348
349    #[test]
350    fn fold_column_ref_returns_none() {
351        let registry = FunctionRegistry::new();
352        let expr = SqlExpr::Column {
353            table: None,
354            name: "name".into(),
355        };
356        assert!(fold_constant(&expr, &registry).is_none());
357    }
358
359    #[test]
360    fn fold_decimal_literal() {
361        let registry = FunctionRegistry::new();
362        let d = rust_decimal::Decimal::new(12345, 2); // 123.45
363        let expr = SqlExpr::Literal(SqlValue::Decimal(d));
364        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Decimal(d)));
365    }
366
367    #[test]
368    fn fold_decimal_addition() {
369        use rust_decimal::Decimal;
370        let registry = FunctionRegistry::new();
371        let a = Decimal::new(12345, 2); // 123.45
372        let b = Decimal::new(45678, 2); // 456.78
373        let expr = SqlExpr::BinaryOp {
374            left: Box::new(SqlExpr::Literal(SqlValue::Decimal(a))),
375            op: BinaryOp::Add,
376            right: Box::new(SqlExpr::Literal(SqlValue::Decimal(b))),
377        };
378        let expected = Decimal::new(58023, 2); // 580.23
379        assert_eq!(
380            fold_constant(&expr, &registry),
381            Some(SqlValue::Decimal(expected))
382        );
383    }
384
385    #[test]
386    fn fold_decimal_negation() {
387        use rust_decimal::Decimal;
388        let registry = FunctionRegistry::new();
389        let d = Decimal::new(100, 0);
390        let expr = SqlExpr::UnaryOp {
391            op: crate::types::UnaryOp::Neg,
392            expr: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
393        };
394        assert_eq!(fold_constant(&expr, &registry), Some(SqlValue::Decimal(-d)));
395    }
396
397    #[test]
398    fn fold_st_geohash() {
399        let registry = FunctionRegistry::new();
400        let expr = SqlExpr::Function {
401            name: "st_geohash".into(),
402            args: vec![
403                SqlExpr::UnaryOp {
404                    op: UnaryOp::Neg,
405                    expr: Box::new(SqlExpr::Literal(SqlValue::Float(122.4))),
406                },
407                SqlExpr::Literal(SqlValue::Float(37.8)),
408                SqlExpr::Literal(SqlValue::Int(6)),
409            ],
410            distinct: false,
411        };
412        let v = fold_constant(&expr, &registry);
413        match v {
414            Some(SqlValue::String(ref s)) if !s.is_empty() => {}
415            other => panic!("expected non-empty SqlValue::String, got {other:?}"),
416        }
417    }
418
419    #[test]
420    fn fold_st_distance_nested_st_point() {
421        let registry = FunctionRegistry::new();
422        let make_point = |lng: f64, lat: f64| SqlExpr::Function {
423            name: "st_point".into(),
424            args: vec![
425                SqlExpr::Literal(SqlValue::Float(lng)),
426                SqlExpr::Literal(SqlValue::Float(lat)),
427            ],
428            distinct: false,
429        };
430        let expr = SqlExpr::Function {
431            name: "st_distance".into(),
432            args: vec![make_point(-122.4, 37.8), make_point(-87.6, 41.8)],
433            distinct: false,
434        };
435        let v = fold_constant(&expr, &registry);
436        match v {
437            Some(SqlValue::Float(d)) => {
438                assert!(d > 0.0, "distance should be positive, got {d}");
439            }
440            other => panic!("expected SqlValue::Float, got {other:?}"),
441        }
442    }
443
444    #[test]
445    fn fold_decimal_int_widening() {
446        use rust_decimal::Decimal;
447        let registry = FunctionRegistry::new();
448        let d = Decimal::new(100, 0); // 100
449        let expr = SqlExpr::BinaryOp {
450            left: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
451            op: BinaryOp::Add,
452            right: Box::new(SqlExpr::Literal(SqlValue::Int(50))),
453        };
454        let expected = Decimal::new(150, 0);
455        assert_eq!(
456            fold_constant(&expr, &registry),
457            Some(SqlValue::Decimal(expected))
458        );
459    }
460}