1use std::sync::LazyLock;
21
22use nodedb_types::Value;
23
24use crate::functions::registry::{FunctionCategory, FunctionRegistry};
25use crate::types::{BinaryOp, SqlExpr, SqlValue, UnaryOp};
26
27static DEFAULT_REGISTRY: LazyLock<FunctionRegistry> = LazyLock::new(FunctionRegistry::new);
30
31pub fn default_registry() -> &'static FunctionRegistry {
33 &DEFAULT_REGISTRY
34}
35
36pub fn fold_constant_default(expr: &SqlExpr) -> Option<SqlValue> {
38 fold_constant(expr, default_registry())
39}
40
41pub fn fold_constant(expr: &SqlExpr, registry: &FunctionRegistry) -> Option<SqlValue> {
45 match expr {
46 SqlExpr::Literal(v) => Some(v.clone()),
47 SqlExpr::UnaryOp {
48 op: UnaryOp::Neg,
49 expr,
50 } => match fold_constant(expr, registry)? {
51 SqlValue::Int(i) => Some(SqlValue::Int(-i)),
52 SqlValue::Float(f) => Some(SqlValue::Float(-f)),
53 _ => None,
54 },
55 SqlExpr::BinaryOp { left, op, right } => {
56 let l = fold_constant(left, registry)?;
57 let r = fold_constant(right, registry)?;
58 fold_binary(l, *op, r)
59 }
60 SqlExpr::Function { name, args, .. } => fold_function_call(name, args, registry),
61 _ => None,
62 }
63}
64
65fn fold_binary(l: SqlValue, op: BinaryOp, r: SqlValue) -> Option<SqlValue> {
66 Some(match (l, op, r) {
67 (SqlValue::Int(a), BinaryOp::Add, SqlValue::Int(b)) => SqlValue::Int(a.checked_add(b)?),
68 (SqlValue::Int(a), BinaryOp::Sub, SqlValue::Int(b)) => SqlValue::Int(a.checked_sub(b)?),
69 (SqlValue::Int(a), BinaryOp::Mul, SqlValue::Int(b)) => SqlValue::Int(a.checked_mul(b)?),
70 (SqlValue::Float(a), BinaryOp::Add, SqlValue::Float(b)) => SqlValue::Float(a + b),
71 (SqlValue::Float(a), BinaryOp::Sub, SqlValue::Float(b)) => SqlValue::Float(a - b),
72 (SqlValue::Float(a), BinaryOp::Mul, SqlValue::Float(b)) => SqlValue::Float(a * b),
73 (SqlValue::String(a), BinaryOp::Concat, SqlValue::String(b)) => {
74 SqlValue::String(format!("{a}{b}"))
75 }
76 _ => return None,
77 })
78}
79
80pub fn fold_function_call(
86 name: &str,
87 args: &[SqlExpr],
88 registry: &FunctionRegistry,
89) -> Option<SqlValue> {
90 let meta = registry.lookup(name)?;
94 if matches!(
95 meta.category,
96 FunctionCategory::Aggregate | FunctionCategory::Window
97 ) {
98 return None;
99 }
100
101 let folded_args: Vec<Value> = args
102 .iter()
103 .map(|a| fold_constant(a, registry).map(sql_to_ndb_value))
104 .collect::<Option<_>>()?;
105
106 let result = nodedb_query::functions::eval_function(name, &folded_args);
107 Some(ndb_to_sql_value(result))
108}
109
110fn sql_to_ndb_value(v: SqlValue) -> Value {
111 match v {
112 SqlValue::Null => Value::Null,
113 SqlValue::Bool(b) => Value::Bool(b),
114 SqlValue::Int(i) => Value::Integer(i),
115 SqlValue::Float(f) => Value::Float(f),
116 SqlValue::String(s) => Value::String(s),
117 SqlValue::Bytes(b) => Value::Bytes(b),
118 SqlValue::Array(a) => Value::Array(a.into_iter().map(sql_to_ndb_value).collect()),
119 }
120}
121
122fn ndb_to_sql_value(v: Value) -> SqlValue {
123 match v {
124 Value::Null => SqlValue::Null,
125 Value::Bool(b) => SqlValue::Bool(b),
126 Value::Integer(i) => SqlValue::Int(i),
127 Value::Float(f) => SqlValue::Float(f),
128 Value::String(s) => SqlValue::String(s),
129 Value::Bytes(b) => SqlValue::Bytes(b),
130 Value::Array(a) => SqlValue::Array(a.into_iter().map(ndb_to_sql_value).collect()),
131 Value::DateTime(dt) => SqlValue::String(dt.to_iso8601()),
132 Value::Uuid(s) | Value::Ulid(s) | Value::Regex(s) => SqlValue::String(s),
133 Value::Duration(d) => SqlValue::String(d.to_human()),
134 Value::Decimal(d) => SqlValue::String(d.to_string()),
135 Value::Object(_)
138 | Value::Geometry(_)
139 | Value::Set(_)
140 | Value::Range { .. }
141 | Value::Record { .. } => SqlValue::Null,
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn fold_now_produces_non_epoch_string() {
151 let registry = FunctionRegistry::new();
152 let expr = SqlExpr::Function {
153 name: "now".into(),
154 args: vec![],
155 distinct: false,
156 };
157 let val = fold_constant(&expr, ®istry).expect("now() should fold");
158 match val {
159 SqlValue::String(s) => {
160 assert!(!s.starts_with("1970"), "got {s}");
161 assert!(s.contains('T'), "not ISO-8601: {s}");
162 }
163 other => panic!("expected string, got {other:?}"),
164 }
165 }
166
167 #[test]
168 fn fold_current_timestamp() {
169 let registry = FunctionRegistry::new();
170 let expr = SqlExpr::Function {
171 name: "current_timestamp".into(),
172 args: vec![],
173 distinct: false,
174 };
175 assert!(matches!(
176 fold_constant(&expr, ®istry),
177 Some(SqlValue::String(_))
178 ));
179 }
180
181 #[test]
182 fn fold_unknown_function_returns_none() {
183 let registry = FunctionRegistry::new();
184 let expr = SqlExpr::Function {
185 name: "definitely_not_a_real_function".into(),
186 args: vec![],
187 distinct: false,
188 };
189 assert!(fold_constant(&expr, ®istry).is_none());
190 }
191
192 #[test]
193 fn fold_literal_arithmetic_still_works() {
194 let registry = FunctionRegistry::new();
195 let expr = SqlExpr::BinaryOp {
196 left: Box::new(SqlExpr::Literal(SqlValue::Int(2))),
197 op: BinaryOp::Add,
198 right: Box::new(SqlExpr::Literal(SqlValue::Int(3))),
199 };
200 assert_eq!(fold_constant(&expr, ®istry), Some(SqlValue::Int(5)));
201 }
202
203 #[test]
204 fn fold_column_ref_returns_none() {
205 let registry = FunctionRegistry::new();
206 let expr = SqlExpr::Column {
207 table: None,
208 name: "name".into(),
209 };
210 assert!(fold_constant(&expr, ®istry).is_none());
211 }
212}