1use std::sync::LazyLock;
23
24use nodedb_types::Value;
25use sonic_rs;
26
27use crate::functions::registry::{FunctionCategory, FunctionRegistry};
28use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};
29
30static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);
33
34pub fn default_registry() -> &'static FunctionRegistry {
36 &DEFAULT_REGISTRY
37}
38
39pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
41 fold_constant(expr, default_registry())
42}
43
44pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
48 match expr {
49 SqlExpr::Literal(v) => Some(v.clone()),
50 SqlExpr::UnaryOp {
51 op: UnaryOp::Neg,
52 expr,
53 } => match fold_constant(expr, registry)? {
54 SqlValue::Int(i) => Some(SqlValue::Int(-i)),
55 SqlValue::Float(f) => Some(SqlValue::Float(-f)),
56 SqlValue::Decimal(d) => Some(SqlValue::Decimal(-d)),
57 _ => None,
58 },
59 SqlExpr::BinaryOp { left, op, right } => {
60 let l = fold_constant(left, registry)?;
61 let r = fold_constant(right, registry)?;
62 fold_binary(l, *op, r)
63 }
64 SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
65 SqlExpr::Cast { expr, to_type } => fold_cast(fold_constant(expr, registry)?, to_type),
66 _ => None,
67 }
68}
69
70fn fold_cast(inner: SqlValue, to_type: &str) -> Option<SqlValue> {
74 let upper = to_type.to_uppercase();
75 let base = upper
77 .split('(')
78 .next()
79 .map(str::trim)
80 .unwrap_or(upper.as_str());
81
82 match base {
83 "NUMERIC" | "DECIMAL" => match inner {
84 SqlValue::Decimal(d) => Some(SqlValue::Decimal(d)),
85 SqlValue::Int(i) => Some(SqlValue::Decimal(rust_decimal::Decimal::from(i))),
86 SqlValue::Float(f) => rust_decimal::Decimal::try_from(f)
87 .ok()
88 .map(SqlValue::Decimal),
89 SqlValue::String(s) => rust_decimal::Decimal::from_str_exact(&s)
90 .ok()
91 .map(SqlValue::Decimal),
92 _ => None,
93 },
94 "INTEGER" | "INT" | "BIGINT" | "SMALLINT" | "INT2" | "INT4" | "INT8" => match inner {
95 SqlValue::Int(i) => Some(SqlValue::Int(i)),
96 SqlValue::Decimal(d) => {
97 rust_decimal::prelude::ToPrimitive::to_i64(&d).map(SqlValue::Int)
98 }
99 SqlValue::Float(f) => {
100 if f.is_finite() {
101 Some(SqlValue::Int(f as i64))
102 } else {
103 None
104 }
105 }
106 SqlValue::String(s) => s.parse::<i64>().ok().map(SqlValue::Int),
107 _ => None,
108 },
109 "FLOAT" | "DOUBLE" | "REAL" | "FLOAT4" | "FLOAT8" | "DOUBLE PRECISION" => match inner {
110 SqlValue::Float(f) => Some(SqlValue::Float(f)),
111 SqlValue::Int(i) => Some(SqlValue::Float(i as f64)),
112 SqlValue::Decimal(d) => {
113 rust_decimal::prelude::ToPrimitive::to_f64(&d).map(SqlValue::Float)
114 }
115 SqlValue::String(s) => s.parse::<f64>().ok().map(SqlValue::Float),
116 _ => None,
117 },
118 "TEXT" | "VARCHAR" | "CHAR" | "CHARACTER VARYING" | "CHARACTER" | "BPCHAR" => match inner {
119 SqlValue::String(s) => Some(SqlValue::String(s)),
120 SqlValue::Int(i) => Some(SqlValue::String(i.to_string())),
121 SqlValue::Float(f) => Some(SqlValue::String(f.to_string())),
122 SqlValue::Decimal(d) => Some(SqlValue::String(d.to_string())),
123 SqlValue::Bool(b) => Some(SqlValue::String(b.to_string())),
124 _ => None,
125 },
126 "BOOL" | "BOOLEAN" => match inner {
127 SqlValue::Bool(b) => Some(SqlValue::Bool(b)),
128 SqlValue::Int(i) => Some(SqlValue::Bool(i != 0)),
129 SqlValue::String(s) => match s.to_lowercase().as_str() {
130 "true" | "t" | "yes" | "1" | "on" => Some(SqlValue::Bool(true)),
131 "false" | "f" | "no" | "0" | "off" => Some(SqlValue::Bool(false)),
132 _ => None,
133 },
134 _ => None,
135 },
136 "JSON" | "JSONB" => match inner {
141 SqlValue::String(s) => Some(SqlValue::String(s)),
142 SqlValue::Int(i) => Some(SqlValue::String(i.to_string())),
143 SqlValue::Float(f) => Some(SqlValue::String(f.to_string())),
144 SqlValue::Decimal(d) => Some(SqlValue::String(d.to_string())),
145 SqlValue::Bool(b) => Some(SqlValue::String(b.to_string())),
146 SqlValue::Null => Some(SqlValue::Null),
147 _ => None,
148 },
149 _ => None,
150 }
151}
152
153fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
154 Some(match (l, op, r) {
155 (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?),
157 (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?),
158 (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?),
159 (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
161 (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
162 (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
163 (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
165 SqlValue::Decimal(a.checked_add(b)?)
166 }
167 (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
168 SqlValue::Decimal(a.checked_sub(b)?)
169 }
170 (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
171 SqlValue::Decimal(a.checked_mul(b)?)
172 }
173 (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
174 SqlValue::Decimal(a.checked_div(b)?)
175 }
176 (SqlValue::Decimal(a), BinaryOp::Add, SqlValue::Int(b)) => {
178 SqlValue::Decimal(a.checked_add(rust_decimal::Decimal::from(b))?)
179 }
180 (SqlValue::Int(a), BinaryOp::Add, SqlValue::Decimal(b)) => {
181 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_add(b)?)
182 }
183 (SqlValue::Decimal(a), BinaryOp::Sub, SqlValue::Int(b)) => {
184 SqlValue::Decimal(a.checked_sub(rust_decimal::Decimal::from(b))?)
185 }
186 (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Decimal(b)) => {
187 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_sub(b)?)
188 }
189 (SqlValue::Decimal(a), BinaryOp::Mul, SqlValue::Int(b)) => {
190 SqlValue::Decimal(a.checked_mul(rust_decimal::Decimal::from(b))?)
191 }
192 (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Decimal(b)) => {
193 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_mul(b)?)
194 }
195 (SqlValue::Decimal(a), BinaryOp::Div, SqlValue::Int(b)) => {
196 SqlValue::Decimal(a.checked_div(rust_decimal::Decimal::from(b))?)
197 }
198 (SqlValue::Int(a), BinaryOp::Div, SqlValue::Decimal(b)) => {
199 SqlValue::Decimal(rust_decimal::Decimal::from(a).checked_div(b)?)
200 }
201 (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
203 SqlValue::String(format!("{a}{b}"))
204 }
205 _ => return None,
206 })
207}
208
209pub fn fold_function_call(
215 name: &str,
216 args: &[SqlExpr],
217 registry: &FunctionRegistry,
218) -> Option<SqlValue> {
219 let meta = registry.lookup(name)?;
223 if matches!(
224 meta.category,
225 FunctionCategory::Aggregate | FunctionCategory::Window
226 ) {
227 return None;
228 }
229
230 let folded_args: Vec<Value> = args
231 .iter()
232 .map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
233 .collect::<Option<_>>()?;
234
235 let result = nodedb_query::functions::eval_function(&name.to_lowercase(), &folded_args);
236 Some(ndb_to_sql_value(result))
237}
238
239fn sql_to_ndb_value(v: SqlValue) -> Value {
240 match v {
241 SqlValue::Null => Value::Null,
242 SqlValue::Bool(b) => Value::Bool(b),
243 SqlValue::Int(i) => Value::Integer(i),
244 SqlValue::Float(f) => Value::Float(f),
245 SqlValue::Decimal(d) => Value::Decimal(d),
246 SqlValue::String(s) => Value::String(s),
247 SqlValue::Bytes(b) => Value::Bytes(b),
248 SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
249 SqlValue::Timestamp(dt) => Value::NaiveDateTime(dt),
250 SqlValue::Timestamptz(dt) => Value::DateTime(dt),
251 }
252}
253
254fn ndb_to_sql_value(v: Value) -> SqlValue {
255 match v {
256 Value::Null => SqlValue::Null,
257 Value::Bool(b) => SqlValue::Bool(b),
258 Value::Integer(i) => SqlValue::Int(i),
259 Value::Float(f) => SqlValue::Float(f),
260 Value::String(s) => SqlValue::String(s),
261 Value::Bytes(b) => SqlValue::Bytes(b),
262 Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
263 Value::DateTime(dt) => SqlValue::Timestamptz(dt),
265 Value::NaiveDateTime(dt) => SqlValue::Timestamp(dt),
266 Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
267 Value::Duration(d) => SqlValue::String(d.to_human()),
268 Value::Decimal(d) => SqlValue::Decimal(d),
269 Value::Geometry(g) => sonic_rs::to_string(&g)
275 .map(SqlValue::String)
276 .unwrap_or(SqlValue::Null),
277 Value::Object(map) => sonic_rs::to_string(&map)
278 .map(SqlValue::String)
279 .unwrap_or(SqlValue::Null),
280 Value::Set(_) | Value::Range { .. } | Value::Record { .. } | Value::ArrayCell(_) => {
283 SqlValue::Null
284 }
285 _ => SqlValue::Null,
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn fold_now_produces_timestamptz() {
297 let registry = FunctionRegistry::new();
298 let expr = SqlExpr::Function {
299 name: "now".into(),
300 args: vec![],
301 distinct: false,
302 };
303 let val = fold_constant(&expr, ®istry).expect("now() should fold");
304 match val {
305 SqlValue::Timestamptz(dt) => {
306 assert!(dt.micros > 0, "expected post-epoch timestamp, got micros=0");
308 }
309 other => panic!("expected SqlValue::Timestamptz, got {other:?}"),
310 }
311 }
312
313 #[test]
314 fn fold_current_timestamp_produces_timestamptz() {
315 let registry = FunctionRegistry::new();
316 let expr = SqlExpr::Function {
317 name: "current_timestamp".into(),
318 args: vec![],
319 distinct: false,
320 };
321 assert!(matches!(
322 fold_constant(&expr, ®istry),
323 Some(SqlValue::Timestamptz(_))
324 ));
325 }
326
327 #[test]
328 fn fold_unknown_function_returns_none() {
329 let registry = FunctionRegistry::new();
330 let expr = SqlExpr::Function {
331 name: "definitely_not_a_real_function".into(),
332 args: vec![],
333 distinct: false,
334 };
335 assert!(fold_constant(&expr, ®istry).is_none());
336 }
337
338 #[test]
339 fn fold_literal_arithmetic_still_works() {
340 let registry = FunctionRegistry::new();
341 let expr = SqlExpr::BinaryOp {
342 left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
343 op: BinaryOp::Add,
344 right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
345 };
346 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Int(5)));
347 }
348
349 #[test]
350 fn fold_column_ref_returns_none() {
351 let registry = FunctionRegistry::new();
352 let expr = SqlExpr::Column {
353 table: None,
354 name: "name".into(),
355 };
356 assert!(fold_constant(&expr, ®istry).is_none());
357 }
358
359 #[test]
360 fn fold_decimal_literal() {
361 let registry = FunctionRegistry::new();
362 let d = rust_decimal::Decimal::new(12345, 2); let expr = SqlExpr::Literal(SqlValue::Decimal(d));
364 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Decimal(d)));
365 }
366
367 #[test]
368 fn fold_decimal_addition() {
369 use rust_decimal::Decimal;
370 let registry = FunctionRegistry::new();
371 let a = Decimal::new(12345, 2); let b = Decimal::new(45678, 2); let expr = SqlExpr::BinaryOp {
374 left: Box::new(SqlExpr::Literal(SqlValue::Decimal(a))),
375 op: BinaryOp::Add,
376 right: Box::new(SqlExpr::Literal(SqlValue::Decimal(b))),
377 };
378 let expected = Decimal::new(58023, 2); assert_eq!(
380 fold_constant(&expr, ®istry),
381 Some(SqlValue::Decimal(expected))
382 );
383 }
384
385 #[test]
386 fn fold_decimal_negation() {
387 use rust_decimal::Decimal;
388 let registry = FunctionRegistry::new();
389 let d = Decimal::new(100, 0);
390 let expr = SqlExpr::UnaryOp {
391 op: crate::types::UnaryOp::Neg,
392 expr: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
393 };
394 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Decimal(-d)));
395 }
396
397 #[test]
398 fn fold_st_geohash() {
399 let registry = FunctionRegistry::new();
400 let expr = SqlExpr::Function {
401 name: "st_geohash".into(),
402 args: vec![
403 SqlExpr::UnaryOp {
404 op: UnaryOp::Neg,
405 expr: Box::new(SqlExpr::Literal(SqlValue::Float(122.4))),
406 },
407 SqlExpr::Literal(SqlValue::Float(37.8)),
408 SqlExpr::Literal(SqlValue::Int(6)),
409 ],
410 distinct: false,
411 };
412 let v = fold_constant(&expr, ®istry);
413 match v {
414 Some(SqlValue::String(ref s)) if !s.is_empty() => {}
415 other => panic!("expected non-empty SqlValue::String, got {other:?}"),
416 }
417 }
418
419 #[test]
420 fn fold_st_distance_nested_st_point() {
421 let registry = FunctionRegistry::new();
422 let make_point = |lng: f64, lat: f64| SqlExpr::Function {
423 name: "st_point".into(),
424 args: vec![
425 SqlExpr::Literal(SqlValue::Float(lng)),
426 SqlExpr::Literal(SqlValue::Float(lat)),
427 ],
428 distinct: false,
429 };
430 let expr = SqlExpr::Function {
431 name: "st_distance".into(),
432 args: vec![make_point(-122.4, 37.8), make_point(-87.6, 41.8)],
433 distinct: false,
434 };
435 let v = fold_constant(&expr, ®istry);
436 match v {
437 Some(SqlValue::Float(d)) => {
438 assert!(d > 0.0, "distance should be positive, got {d}");
439 }
440 other => panic!("expected SqlValue::Float, got {other:?}"),
441 }
442 }
443
444 #[test]
445 fn fold_decimal_int_widening() {
446 use rust_decimal::Decimal;
447 let registry = FunctionRegistry::new();
448 let d = Decimal::new(100, 0); let expr = SqlExpr::BinaryOp {
450 left: Box::new(SqlExpr::Literal(SqlValue::Decimal(d))),
451 op: BinaryOp::Add,
452 right: Box::new(SqlExpr::Literal(SqlValue::Int(50))),
453 };
454 let expected = Decimal::new(150, 0);
455 assert_eq!(
456 fold_constant(&expr, ®istry),
457 Some(SqlValue::Decimal(expected))
458 );
459 }
460}