1use std::sync::LazyLock;
23
24use nodedb_types::Value;
25use sonic_rs;
26
27use crate::functions::registry::{FunctionCategory, FunctionRegistry};
28use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};
29
30static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);
33
34pub fn default_registry() -> &'static FunctionRegistry {
36 &DEFAULT_REGISTRY
37}
38
39pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
41 fold_constant(expr, default_registry())
42}
43
44pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
48 match expr {
49 SqlExpr::Literal(v) => Some(v.clone()),
50 SqlExpr::UnaryOp {
51 op: UnaryOp::Neg,
52 expr,
53 } => match fold_constant(expr, registry)? {
54 SqlValue::Int(i) => Some(SqlValue::Int(-i)),
55 SqlValue::Float(f) => Some(SqlValue::Float(-f)),
56 SqlValue::Decimal(d) => Some(SqlValue::Decimal(-d)),
57 _ => None,
58 },
59 SqlExpr::BinaryOp { left, op, right } => {
60 let l = fold_constant(left, registry)?;
61 let r = fold_constant(right, registry)?;
62 fold_binary(l, *op, r)
63 }
64 SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
65 SqlExpr::Cast { expr, to_type } => fold_cast(fold_constant(expr, registry)?, to_type),
66 _ => None,
67 }
68}
69
70fn fold_cast(inner: SqlValue, to_type: &str) -> Option<SqlValue> {
74 let upper = to_type.to_uppercase();
75 let base = upper
77 .split('(')
78 .next()
79 .map(str::trim)
80 .unwrap_or(upper.as_str());
81
82 match base {
83 "NUMERIC" | "DECIMAL" => match inner {
84 SqlValue::Decimal(d) => Some(SqlValue::Decimal(d)),
85 SqlValue::Int(i) => Some(SqlValue::Decimal(rust_decimal::Decimal::from(i))),
86 SqlValue::Float(f) => rust_decimal::Decimal::try_from(f)
87 .ok()
88 .map(SqlValue::Decimal),
89 SqlValue::String(s) => rust_decimal::Decimal::from_str_exact(&s)
90 .ok()
91 .map(SqlValue::Decimal),
92 _ => None,
93 },
94 "INTEGER" | "INT" | "BIGINT" | "SMALLINT" | "INT2" | "INT4" | "INT8" => match inner {
95 SqlValue::Int(i) => Some(SqlValue::Int(i)),
96 SqlValue::Decimal(d) => {
97 rust_decimal::prelude::ToPrimitive::to_i64(&d).map(SqlValue::Int)
98 }
99 SqlValue::Float(f) => {
100 if f.is_finite() {
101 Some(SqlValue::Int(f as i64))
102 } else {
103 None
104 }
105 }
106 SqlValue::String(s) => s.parse::<i64>().ok().map(SqlValue::Int),
107 _ => None,
108 },
109 "FLOAT" | "DOUBLE" | "REAL" | "FLOAT4" | "FLOAT8" | "DOUBLE PRECISION" => match inner {
110 SqlValue::Float(f) => Some(SqlValue::Float(f)),
111 SqlValue::Int(i) => Some(SqlValue::Float(i as f64)),
112 SqlValue::Decimal(d) => {
113 rust_decimal::prelude::ToPrimitive::to_f64(&d).map(SqlValue::Float)
114 }
115 SqlValue::String(s) => s.parse::<f64>().ok().map(SqlValue::Float),
116 _ => None,
117 },
118 "TEXT" | "VARCHAR" | "CHAR" | "CHARACTER VARYING" | "CHARACTER" | "BPCHAR" => match inner {
119 SqlValue::String(s) => Some(SqlValue::String(s)),
120 SqlValue::Int(i) => Some(SqlValue::String(i.to_string())),
121 SqlValue::Float(f) => Some(SqlValue::String(f.to_string())),
122 SqlValue::Decimal(d) => Some(SqlValue::String(d.to_string())),
123 SqlValue::Bool(b) => Some(SqlValue::String(b.to_string())),
124 _ => None,
125 },
126 "BOOL" | "BOOLEAN" => match inner {
127 SqlValue::Bool(b) => Some(SqlValue::Bool(b)),
128 SqlValue::Int(i) => Some(SqlValue::Bool(i != 0)),
129 SqlValue::String(s) => match s.to_lowercase().as_str() {
130 "true" | "t" | "yes" | "1" | "on" => Some(SqlValue::Bool(true)),
131 "false" | "f" | "no" | "0" | "off" => Some(SqlValue::Bool(false)),
132 _ => None,
133 },
134 _ => None,
135 },
136 _ => None,
137 }
138}
139
140fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
141 Some(match (l, op, r) {
142 (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?),
144 (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?),
145 (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?),
146 (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
148 (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
149 (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
150 (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
152 SqlValue::Decimal(a.checked_add(b)?)
153 }
154 (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
155 SqlValue::Decimal(a.checked_sub(b)?)
156 }
157 (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
158 SqlValue::Decimal(a.checked_mul(b)?)
159 }
160 (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
161 SqlValue::Decimal(a.checked_div(b)?)
162 }
163 (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Int(b)) => {
165 SqlValue::Decimal(a.checked_add(rust_decimal::Decimal::from(b))?)
166 }
167 (SqlValue::Int(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
168 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_add(b)?)
169 }
170 (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Int(b)) => {
171 SqlValue::Decimal(a.checked_sub(rust_decimal::Decimal::from(b))?)
172 }
173 (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
174 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_sub(b)?)
175 }
176 (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Int(b)) => {
177 SqlValue::Decimal(a.checked_mul(rust_decimal::Decimal::from(b))?)
178 }
179 (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
180 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_mul(b)?)
181 }
182 (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Int(b)) => {
183 SqlValue::Decimal(a.checked_div(rust_decimal::Decimal::from(b))?)
184 }
185 (SqlValue::Int(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
186 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_div(b)?)
187 }
188 (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
190 SqlValue::String(format!("{a}{b}"))
191 }
192 _ => return None,
193 })
194}
195
196pub fn fold_function_call(
202 name: &str,
203 args: &[SqlExpr],
204 registry: &FunctionRegistry,
205) -> Option<SqlValue> {
206 let meta = registry.lookup(name)?;
210 if matches!(
211 meta.category,
212 FunctionCategory::Aggregate | FunctionCategory::Window
213 ) {
214 return None;
215 }
216
217 let folded_args: Vec<Value> = args
218 .iter()
219 .map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
220 .collect::<Option<_>>()?;
221
222 let result = nodedb_query::functions::eval_function(&name.to_lowercase(), &folded_args);
223 Some(ndb_to_sql_value(result))
224}
225
226fn sql_to_ndb_value(v: SqlValue) -> Value {
227 match v {
228 SqlValue::Null => Value::Null,
229 SqlValue::Bool(b) => Value::Bool(b),
230 SqlValue::Int(i) => Value::Integer(i),
231 SqlValue::Float(f) => Value::Float(f),
232 SqlValue::Decimal(d) => Value::Decimal(d),
233 SqlValue::String(s) => Value::String(s),
234 SqlValue::Bytes(b) => Value::Bytes(b),
235 SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
236 SqlValue::Timestamp(dt) => Value::NaiveDateTime(dt),
237 SqlValue::Timestamptz(dt) => Value::DateTime(dt),
238 }
239}
240
241fn ndb_to_sql_value(v: Value) -> SqlValue {
242 match v {
243 Value::Null => SqlValue::Null,
244 Value::Bool(b) => SqlValue::Bool(b),
245 Value::Integer(i) => SqlValue::Int(i),
246 Value::Float(f) => SqlValue::Float(f),
247 Value::String(s) => SqlValue::String(s),
248 Value::Bytes(b) => SqlValue::Bytes(b),
249 Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
250 Value::DateTime(dt) => SqlValue::Timestamptz(dt),
252 Value::NaiveDateTime(dt) => SqlValue::Timestamp(dt),
253 Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
254 Value::Duration(d) => SqlValue::String(d.to_human()),
255 Value::Decimal(d) => SqlValue::Decimal(d),
256 Value::Geometry(g) => sonic_rs::to_string(&g)
262 .map(SqlValue::String)
263 .unwrap_or(SqlValue::Null),
264 Value::Object(map) => sonic_rs::to_string(&map)
265 .map(SqlValue::String)
266 .unwrap_or(SqlValue::Null),
267 Value::Set(_) | Value::Range { .. } | Value::Record { .. } | Value::ArrayCell(_) => {
270 SqlValue::Null
271 }
272 _ => SqlValue::Null,
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281
282 #[test]
283 fn fold_now_produces_timestamptz() {
284 let registry = FunctionRegistry::new();
285 let expr = SqlExpr::Function {
286 name: "now".into(),
287 args: vec![],
288 distinct: false,
289 };
290 let val = fold_constant(&expr, ®istry).expect("now() should fold");
291 match val {
292 SqlValue::Timestamptz(dt) => {
293 assert!(dt.micros > 0, "expected post-epoch timestamp, got micros=0");
295 }
296 other => panic!("expected SqlValue::Timestamptz, got {other:?}"),
297 }
298 }
299
300 #[test]
301 fn fold_current_timestamp_produces_timestamptz() {
302 let registry = FunctionRegistry::new();
303 let expr = SqlExpr::Function {
304 name: "current_timestamp".into(),
305 args: vec![],
306 distinct: false,
307 };
308 assert!(matches!(
309 fold_constant(&expr, ®istry),
310 Some(SqlValue::Timestamptz(_))
311 ));
312 }
313
314 #[test]
315 fn fold_unknown_function_returns_none() {
316 let registry = FunctionRegistry::new();
317 let expr = SqlExpr::Function {
318 name: "definitely_not_a_real_function".into(),
319 args: vec![],
320 distinct: false,
321 };
322 assert!(fold_constant(&expr, ®istry).is_none());
323 }
324
325 #[test]
326 fn fold_literal_arithmetic_still_works() {
327 let registry = FunctionRegistry::new();
328 let expr = SqlExpr::BinaryOp {
329 left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
330 op: BinaryOp::Add,
331 right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
332 };
333 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Int(5)));
334 }
335
336 #[test]
337 fn fold_column_ref_returns_none() {
338 let registry = FunctionRegistry::new();
339 let expr = SqlExpr::Column {
340 table: None,
341 name: "name".into(),
342 };
343 assert!(fold_constant(&expr, ®istry).is_none());
344 }
345
346 #[test]
347 fn fold_decimal_literal() {
348 let registry = FunctionRegistry::new();
349 let d = rust_decimal::Decimal::new(12345, 2); let expr = SqlExpr::Literal(SqlValue::Decimal(d));
351 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Decimal(d)));
352 }
353
354 #[test]
355 fn fold_decimal_addition() {
356 use rust_decimal::Decimal;
357 let registry = FunctionRegistry::new();
358 let a = Decimal::new(12345, 2); let b = Decimal::new(45678, 2); let expr = SqlExpr::BinaryOp {
361 left: Box::new(SqlExpr::Literal(SqlValue::Decimal(a))),
362 op: BinaryOp::Add,
363 right: Box::new(SqlExpr::Literal(SqlValue::Decimal(b))),
364 };
365 let expected = Decimal::new(58023, 2); assert_eq!(
367 fold_constant(&expr, ®istry),
368 Some(SqlValue::Decimal(expected))
369 );
370 }
371
372 #[test]
373 fn fold_decimal_negation() {
374 use rust_decimal::Decimal;
375 let registry = FunctionRegistry::new();
376 let d = Decimal::new(100, 0);
377 let expr = SqlExpr::UnaryOp {
378 op: crate::types::UnaryOp::Neg,
379 expr: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
380 };
381 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Decimal(-d)));
382 }
383
384 #[test]
385 fn fold_st_geohash() {
386 let registry = FunctionRegistry::new();
387 let expr = SqlExpr::Function {
388 name: "st_geohash".into(),
389 args: vec![
390 SqlExpr::UnaryOp {
391 op: UnaryOp::Neg,
392 expr: Box::new(SqlExpr::Literal(SqlValue::Float(122.4))),
393 },
394 SqlExpr::Literal(SqlValue::Float(37.8)),
395 SqlExpr::Literal(SqlValue::Int(6)),
396 ],
397 distinct: false,
398 };
399 let v = fold_constant(&expr, ®istry);
400 match v {
401 Some(SqlValue::String(ref s)) if !s.is_empty() => {}
402 other => panic!("expected non-empty SqlValue::String, got {other:?}"),
403 }
404 }
405
406 #[test]
407 fn fold_st_distance_nested_st_point() {
408 let registry = FunctionRegistry::new();
409 let make_point = |lng: f64, lat: f64| SqlExpr::Function {
410 name: "st_point".into(),
411 args: vec![
412 SqlExpr::Literal(SqlValue::Float(lng)),
413 SqlExpr::Literal(SqlValue::Float(lat)),
414 ],
415 distinct: false,
416 };
417 let expr = SqlExpr::Function {
418 name: "st_distance".into(),
419 args: vec![make_point(-122.4, 37.8), make_point(-87.6, 41.8)],
420 distinct: false,
421 };
422 let v = fold_constant(&expr, ®istry);
423 match v {
424 Some(SqlValue::Float(d)) => {
425 assert!(d > 0.0, "distance should be positive, got {d}");
426 }
427 other => panic!("expected SqlValue::Float, got {other:?}"),
428 }
429 }
430
431 #[test]
432 fn fold_decimal_int_widening() {
433 use rust_decimal::Decimal;
434 let registry = FunctionRegistry::new();
435 let d = Decimal::new(100, 0); let expr = SqlExpr::BinaryOp {
437 left: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
438 op: BinaryOp::Add,
439 right: Box::new(SqlExpr::Literal(SqlValue::Int(50))),
440 };
441 let expected = Decimal::new(150, 0);
442 assert_eq!(
443 fold_constant(&expr, ®istry),
444 Some(SqlValue::Decimal(expected))
445 );
446 }
447}