1use kyu_common::{KyuError, KyuResult};
7use kyu_parser::ast::{BinaryOp, ComparisonOp, StringOp, UnaryOp};
8use kyu_types::{LogicalType, TypedValue};
9
10use crate::bound_expr::BoundExpression;
11
12pub trait Tuple {
17 fn value_at(&self, idx: usize) -> Option<TypedValue>;
18}
19
20impl Tuple for [TypedValue] {
21 #[inline]
22 fn value_at(&self, idx: usize) -> Option<TypedValue> {
23 self.get(idx).cloned()
24 }
25}
26
27impl Tuple for Vec<TypedValue> {
28 #[inline]
29 fn value_at(&self, idx: usize) -> Option<TypedValue> {
30 self.get(idx).cloned()
31 }
32}
33
34#[inline]
40pub fn evaluate<T: Tuple + ?Sized>(expr: &BoundExpression, tuple: &T) -> KyuResult<TypedValue> {
41 match expr {
42 BoundExpression::Literal { value, .. } => Ok(value.clone()),
43
44 BoundExpression::Variable { index, .. } => tuple
45 .value_at(*index as usize)
46 .ok_or_else(|| KyuError::Runtime(format!("variable index {index} out of range"))),
47
48 BoundExpression::Property { object, .. } => {
49 evaluate(object, tuple)
52 }
53
54 BoundExpression::Parameter { index, .. } => tuple
55 .value_at(*index as usize)
56 .ok_or_else(|| KyuError::Runtime(format!("parameter index {index} out of range"))),
57
58 BoundExpression::UnaryOp { op, operand, .. } => {
59 let val = evaluate(operand, tuple)?;
60 eval_unary(*op, &val)
61 }
62
63 BoundExpression::BinaryOp {
64 op, left, right, ..
65 } => {
66 let lv = evaluate(left, tuple)?;
67 let rv = evaluate(right, tuple)?;
68 eval_binary(*op, &lv, &rv)
69 }
70
71 BoundExpression::Comparison { op, left, right } => {
72 let lv = evaluate(left, tuple)?;
73 let rv = evaluate(right, tuple)?;
74 eval_comparison(*op, &lv, &rv)
75 }
76
77 BoundExpression::IsNull { expr, negated } => {
78 let val = evaluate(expr, tuple)?;
79 let is_null = val.is_null();
80 Ok(TypedValue::Bool(if *negated { !is_null } else { is_null }))
81 }
82
83 BoundExpression::InList {
84 expr,
85 list,
86 negated,
87 } => {
88 let val = evaluate(expr, tuple)?;
89 if val.is_null() {
90 return Ok(TypedValue::Null);
91 }
92 let mut found = false;
93 let mut has_null = false;
94 for item in list {
95 let item_val = evaluate(item, tuple)?;
96 if item_val.is_null() {
97 has_null = true;
98 continue;
99 }
100 if val == item_val {
101 found = true;
102 break;
103 }
104 }
105 if found {
106 Ok(TypedValue::Bool(!*negated))
107 } else if has_null {
108 Ok(TypedValue::Null)
109 } else {
110 Ok(TypedValue::Bool(*negated))
111 }
112 }
113
114 BoundExpression::FunctionCall {
115 function_name,
116 args,
117 ..
118 } => {
119 let evaluated: Vec<TypedValue> = args
120 .iter()
121 .map(|a| evaluate(a, tuple))
122 .collect::<KyuResult<_>>()?;
123 eval_scalar_function(function_name, &evaluated)
124 }
125
126 BoundExpression::CountStar => Err(KyuError::NotImplemented(
127 "COUNT(*) in scalar evaluator".into(),
128 )),
129
130 BoundExpression::Case {
131 operand,
132 whens,
133 else_expr,
134 ..
135 } => {
136 if let Some(op) = operand {
137 let op_val = evaluate(op, tuple)?;
139 for (when_expr, then_expr) in whens {
140 let when_val = evaluate(when_expr, tuple)?;
141 if op_val == when_val {
142 return evaluate(then_expr, tuple);
143 }
144 }
145 } else {
146 for (when_expr, then_expr) in whens {
148 let cond = evaluate(when_expr, tuple)?;
149 if cond == TypedValue::Bool(true) {
150 return evaluate(then_expr, tuple);
151 }
152 }
153 }
154 if let Some(else_e) = else_expr {
155 evaluate(else_e, tuple)
156 } else {
157 Ok(TypedValue::Null)
158 }
159 }
160
161 BoundExpression::ListLiteral { elements, .. } => {
162 let values: Vec<TypedValue> = elements
163 .iter()
164 .map(|e| evaluate(e, tuple))
165 .collect::<KyuResult<_>>()?;
166 Ok(TypedValue::List(values))
167 }
168
169 BoundExpression::MapLiteral { entries, .. } => {
170 let values: Vec<(TypedValue, TypedValue)> = entries
171 .iter()
172 .map(|(k, v)| Ok((evaluate(k, tuple)?, evaluate(v, tuple)?)))
173 .collect::<KyuResult<_>>()?;
174 Ok(TypedValue::Map(values))
175 }
176
177 BoundExpression::Subscript { expr, index, .. } => {
178 let val = evaluate(expr, tuple)?;
179 let idx = evaluate(index, tuple)?;
180 match (&val, &idx) {
181 (TypedValue::Null, _) | (_, TypedValue::Null) => Ok(TypedValue::Null),
182 (TypedValue::List(list), TypedValue::Int64(i)) => {
183 let i = *i as usize;
185 if i == 0 || i > list.len() {
186 Ok(TypedValue::Null)
187 } else {
188 Ok(list[i - 1].clone())
189 }
190 }
191 _ => Err(KyuError::Runtime(
192 "subscript requires list and integer".into(),
193 )),
194 }
195 }
196
197 BoundExpression::Slice { expr, from, to, .. } => {
198 let val = evaluate(expr, tuple)?;
199 match val {
200 TypedValue::Null => Ok(TypedValue::Null),
201 TypedValue::List(list) => {
202 let len = list.len() as i64;
203 let start = match from {
204 Some(e) => match evaluate(e, tuple)? {
205 TypedValue::Int64(v) => (v.max(1) - 1) as usize,
206 TypedValue::Null => return Ok(TypedValue::Null),
207 _ => {
208 return Err(KyuError::Runtime(
209 "slice index must be integer".into(),
210 ));
211 }
212 },
213 None => 0,
214 };
215 let end = match to {
216 Some(e) => match evaluate(e, tuple)? {
217 TypedValue::Int64(v) => v.min(len) as usize,
218 TypedValue::Null => return Ok(TypedValue::Null),
219 _ => {
220 return Err(KyuError::Runtime(
221 "slice index must be integer".into(),
222 ));
223 }
224 },
225 None => list.len(),
226 };
227 if start >= end || start >= list.len() {
228 Ok(TypedValue::List(Vec::new()))
229 } else {
230 Ok(TypedValue::List(list[start..end].to_vec()))
231 }
232 }
233 _ => Err(KyuError::Runtime("slice requires list".into())),
234 }
235 }
236
237 BoundExpression::StringOp { op, left, right } => {
238 let lv = evaluate(left, tuple)?;
239 let rv = evaluate(right, tuple)?;
240 eval_string_op(*op, &lv, &rv)
241 }
242
243 BoundExpression::Cast { expr, target_type } => {
244 let val = evaluate(expr, tuple)?;
245 eval_cast(&val, target_type)
246 }
247
248 BoundExpression::HasLabel { .. } => {
249 Err(KyuError::NotImplemented(
251 "HasLabel in scalar evaluator".into(),
252 ))
253 }
254 }
255}
256
257pub fn evaluate_constant(expr: &BoundExpression) -> KyuResult<TypedValue> {
259 let empty: &[TypedValue] = &[];
260 evaluate(expr, empty)
261}
262
263fn eval_unary(op: UnaryOp, val: &TypedValue) -> KyuResult<TypedValue> {
264 if val.is_null() {
265 return Ok(TypedValue::Null);
266 }
267 match op {
268 UnaryOp::Not => match val {
269 TypedValue::Bool(b) => Ok(TypedValue::Bool(!b)),
270 _ => Err(KyuError::Runtime("NOT requires boolean".into())),
271 },
272 UnaryOp::Minus => match val {
273 TypedValue::Int8(v) => Ok(TypedValue::Int8(-v)),
274 TypedValue::Int16(v) => Ok(TypedValue::Int16(-v)),
275 TypedValue::Int32(v) => Ok(TypedValue::Int32(-v)),
276 TypedValue::Int64(v) => Ok(TypedValue::Int64(-v)),
277 TypedValue::Float(v) => Ok(TypedValue::Float(-v)),
278 TypedValue::Double(v) => Ok(TypedValue::Double(-v)),
279 _ => Err(KyuError::Runtime("unary minus requires numeric".into())),
280 },
281 UnaryOp::BitwiseNot => match val {
282 TypedValue::Int64(v) => Ok(TypedValue::Int64(!v)),
283 _ => Err(KyuError::Runtime("bitwise NOT requires integer".into())),
284 },
285 }
286}
287
288fn eval_binary(op: BinaryOp, left: &TypedValue, right: &TypedValue) -> KyuResult<TypedValue> {
289 if left.is_null() || right.is_null() {
291 return match op {
293 BinaryOp::And => eval_and(left, right),
294 BinaryOp::Or => eval_or(left, right),
295 _ => Ok(TypedValue::Null),
296 };
297 }
298
299 match op {
300 BinaryOp::Add => numeric_binop(left, right, |a, b| a + b, |a, b| a + b),
301 BinaryOp::Sub => numeric_binop(left, right, |a, b| a - b, |a, b| a - b),
302 BinaryOp::Mul => numeric_binop(left, right, |a, b| a * b, |a, b| a * b),
303 BinaryOp::Div => {
304 match right {
306 TypedValue::Int64(0) | TypedValue::Int32(0) => {
307 return Err(KyuError::Runtime("division by zero".into()));
308 }
309 TypedValue::Double(v) if *v == 0.0 => {
310 return Err(KyuError::Runtime("division by zero".into()));
311 }
312 _ => {}
313 }
314 numeric_binop(left, right, |a, b| a / b, |a, b| a / b)
315 }
316 BinaryOp::Mod => {
317 match right {
318 TypedValue::Int64(0) | TypedValue::Int32(0) => {
319 return Err(KyuError::Runtime("modulo by zero".into()));
320 }
321 _ => {}
322 }
323 numeric_binop(left, right, |a, b| a % b, |a, b| a % b)
324 }
325 BinaryOp::Pow => match (left, right) {
326 (TypedValue::Int64(a), TypedValue::Int64(b)) => {
327 Ok(TypedValue::Double((*a as f64).powf(*b as f64)))
328 }
329 (TypedValue::Double(a), TypedValue::Double(b)) => Ok(TypedValue::Double(a.powf(*b))),
330 _ => Err(KyuError::Runtime("pow requires numeric".into())),
331 },
332 BinaryOp::And => eval_and(left, right),
333 BinaryOp::Or => eval_or(left, right),
334 BinaryOp::Xor => match (left, right) {
335 (TypedValue::Bool(a), TypedValue::Bool(b)) => Ok(TypedValue::Bool(a ^ b)),
336 _ => Err(KyuError::Runtime("XOR requires boolean".into())),
337 },
338 BinaryOp::Concat => match (left, right) {
339 (TypedValue::String(a), TypedValue::String(b)) => {
340 Ok(TypedValue::String(SmolStr::new(format!("{a}{b}"))))
341 }
342 _ => Err(KyuError::Runtime("concat requires strings".into())),
343 },
344 BinaryOp::BitwiseAnd => match (left, right) {
345 (TypedValue::Int64(a), TypedValue::Int64(b)) => Ok(TypedValue::Int64(a & b)),
346 _ => Err(KyuError::Runtime("bitwise AND requires integers".into())),
347 },
348 BinaryOp::BitwiseOr => match (left, right) {
349 (TypedValue::Int64(a), TypedValue::Int64(b)) => Ok(TypedValue::Int64(a | b)),
350 _ => Err(KyuError::Runtime("bitwise OR requires integers".into())),
351 },
352 BinaryOp::ShiftLeft => match (left, right) {
353 (TypedValue::Int64(a), TypedValue::Int64(b)) => Ok(TypedValue::Int64(a << b)),
354 _ => Err(KyuError::Runtime("shift requires integers".into())),
355 },
356 BinaryOp::ShiftRight => match (left, right) {
357 (TypedValue::Int64(a), TypedValue::Int64(b)) => Ok(TypedValue::Int64(a >> b)),
358 _ => Err(KyuError::Runtime("shift requires integers".into())),
359 },
360 }
361}
362
363use smol_str::SmolStr;
364
365fn eval_and(left: &TypedValue, right: &TypedValue) -> KyuResult<TypedValue> {
366 match (left, right) {
368 (TypedValue::Bool(false), _) | (_, TypedValue::Bool(false)) => Ok(TypedValue::Bool(false)),
369 (TypedValue::Bool(true), TypedValue::Bool(true)) => Ok(TypedValue::Bool(true)),
370 _ => Ok(TypedValue::Null), }
372}
373
374fn eval_or(left: &TypedValue, right: &TypedValue) -> KyuResult<TypedValue> {
375 match (left, right) {
377 (TypedValue::Bool(true), _) | (_, TypedValue::Bool(true)) => Ok(TypedValue::Bool(true)),
378 (TypedValue::Bool(false), TypedValue::Bool(false)) => Ok(TypedValue::Bool(false)),
379 _ => Ok(TypedValue::Null), }
381}
382
383fn numeric_binop(
384 left: &TypedValue,
385 right: &TypedValue,
386 int_op: impl Fn(i64, i64) -> i64,
387 float_op: impl Fn(f64, f64) -> f64,
388) -> KyuResult<TypedValue> {
389 match (left, right) {
390 (TypedValue::Int64(a), TypedValue::Int64(b)) => Ok(TypedValue::Int64(int_op(*a, *b))),
391 (TypedValue::Int32(a), TypedValue::Int32(b)) => {
392 Ok(TypedValue::Int32(int_op(*a as i64, *b as i64) as i32))
393 }
394 (TypedValue::Double(a), TypedValue::Double(b)) => Ok(TypedValue::Double(float_op(*a, *b))),
395 (TypedValue::Float(a), TypedValue::Float(b)) => {
396 Ok(TypedValue::Float(float_op(*a as f64, *b as f64) as f32))
397 }
398 _ => Err(KyuError::Runtime(format!(
399 "arithmetic not defined for {:?} and {:?}",
400 left, right
401 ))),
402 }
403}
404
405fn eval_comparison(
406 op: ComparisonOp,
407 left: &TypedValue,
408 right: &TypedValue,
409) -> KyuResult<TypedValue> {
410 if left.is_null() || right.is_null() {
411 return Ok(TypedValue::Null);
412 }
413
414 let ord = compare_values(left, right)?;
415
416 let result = match op {
417 ComparisonOp::Eq => ord == std::cmp::Ordering::Equal,
418 ComparisonOp::Neq => ord != std::cmp::Ordering::Equal,
419 ComparisonOp::Lt => ord == std::cmp::Ordering::Less,
420 ComparisonOp::Le => ord != std::cmp::Ordering::Greater,
421 ComparisonOp::Gt => ord == std::cmp::Ordering::Greater,
422 ComparisonOp::Ge => ord != std::cmp::Ordering::Less,
423 ComparisonOp::RegexMatch => {
424 ord == std::cmp::Ordering::Equal
427 }
428 };
429
430 Ok(TypedValue::Bool(result))
431}
432
433fn compare_values(left: &TypedValue, right: &TypedValue) -> KyuResult<std::cmp::Ordering> {
434 match (left, right) {
435 (TypedValue::Int64(a), TypedValue::Int64(b)) => Ok(a.cmp(b)),
436 (TypedValue::Int32(a), TypedValue::Int32(b)) => Ok(a.cmp(b)),
437 (TypedValue::Double(a), TypedValue::Double(b)) => {
438 Ok(a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
439 }
440 (TypedValue::Float(a), TypedValue::Float(b)) => {
441 Ok(a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
442 }
443 (TypedValue::String(a), TypedValue::String(b)) => Ok(a.cmp(b)),
444 (TypedValue::Bool(a), TypedValue::Bool(b)) => Ok(a.cmp(b)),
445 _ => Err(KyuError::Runtime(format!(
446 "cannot compare {:?} and {:?}",
447 left, right
448 ))),
449 }
450}
451
452fn eval_string_op(op: StringOp, left: &TypedValue, right: &TypedValue) -> KyuResult<TypedValue> {
453 if left.is_null() || right.is_null() {
454 return Ok(TypedValue::Null);
455 }
456
457 match (left, right) {
458 (TypedValue::String(a), TypedValue::String(b)) => {
459 let result = match op {
460 StringOp::StartsWith => a.starts_with(b.as_str()),
461 StringOp::EndsWith => a.ends_with(b.as_str()),
462 StringOp::Contains => a.contains(b.as_str()),
463 };
464 Ok(TypedValue::Bool(result))
465 }
466 _ => Err(KyuError::Runtime(
467 "string operations require strings".into(),
468 )),
469 }
470}
471
472fn eval_cast(val: &TypedValue, target: &LogicalType) -> KyuResult<TypedValue> {
473 if val.is_null() {
474 return Ok(TypedValue::Null);
475 }
476
477 match target {
478 LogicalType::Int64 => match val {
479 TypedValue::Int8(v) => Ok(TypedValue::Int64(*v as i64)),
480 TypedValue::Int16(v) => Ok(TypedValue::Int64(*v as i64)),
481 TypedValue::Int32(v) => Ok(TypedValue::Int64(*v as i64)),
482 TypedValue::Int64(_) => Ok(val.clone()),
483 TypedValue::UInt8(v) => Ok(TypedValue::Int64(*v as i64)),
484 TypedValue::UInt16(v) => Ok(TypedValue::Int64(*v as i64)),
485 TypedValue::UInt32(v) => Ok(TypedValue::Int64(*v as i64)),
486 TypedValue::Float(v) => Ok(TypedValue::Int64(*v as i64)),
487 TypedValue::Double(v) => Ok(TypedValue::Int64(*v as i64)),
488 TypedValue::String(s) => s
489 .parse::<i64>()
490 .map(TypedValue::Int64)
491 .map_err(|_| KyuError::Runtime(format!("cannot cast '{s}' to INT64"))),
492 TypedValue::Bool(b) => Ok(TypedValue::Int64(if *b { 1 } else { 0 })),
493 _ => Err(KyuError::Runtime(format!("cannot cast {val:?} to INT64"))),
494 },
495 LogicalType::Double => match val {
496 TypedValue::Int8(v) => Ok(TypedValue::Double(*v as f64)),
497 TypedValue::Int16(v) => Ok(TypedValue::Double(*v as f64)),
498 TypedValue::Int32(v) => Ok(TypedValue::Double(*v as f64)),
499 TypedValue::Int64(v) => Ok(TypedValue::Double(*v as f64)),
500 TypedValue::UInt8(v) => Ok(TypedValue::Double(*v as f64)),
501 TypedValue::UInt16(v) => Ok(TypedValue::Double(*v as f64)),
502 TypedValue::UInt32(v) => Ok(TypedValue::Double(*v as f64)),
503 TypedValue::UInt64(v) => Ok(TypedValue::Double(*v as f64)),
504 TypedValue::Float(v) => Ok(TypedValue::Double(*v as f64)),
505 TypedValue::Double(_) => Ok(val.clone()),
506 TypedValue::String(s) => s
507 .parse::<f64>()
508 .map(TypedValue::Double)
509 .map_err(|_| KyuError::Runtime(format!("cannot cast '{s}' to DOUBLE"))),
510 _ => Err(KyuError::Runtime(format!("cannot cast {val:?} to DOUBLE"))),
511 },
512 LogicalType::Float => match val {
513 TypedValue::Int8(v) => Ok(TypedValue::Float(*v as f32)),
514 TypedValue::Int16(v) => Ok(TypedValue::Float(*v as f32)),
515 TypedValue::Int32(v) => Ok(TypedValue::Float(*v as f32)),
516 TypedValue::Int64(v) => Ok(TypedValue::Float(*v as f32)),
517 TypedValue::Float(_) => Ok(val.clone()),
518 TypedValue::Double(v) => Ok(TypedValue::Float(*v as f32)),
519 _ => Err(KyuError::Runtime(format!("cannot cast {val:?} to FLOAT"))),
520 },
521 LogicalType::String => Ok(TypedValue::String(SmolStr::new(format!("{val}")))),
522 LogicalType::Bool => match val {
523 TypedValue::Bool(_) => Ok(val.clone()),
524 TypedValue::Int64(v) => Ok(TypedValue::Bool(*v != 0)),
525 TypedValue::String(s) => match s.to_lowercase().as_str() {
526 "true" => Ok(TypedValue::Bool(true)),
527 "false" => Ok(TypedValue::Bool(false)),
528 _ => Err(KyuError::Runtime(format!("cannot cast '{s}' to BOOL"))),
529 },
530 _ => Err(KyuError::Runtime(format!("cannot cast {val:?} to BOOL"))),
531 },
532 LogicalType::Int32 => match val {
533 TypedValue::Int8(v) => Ok(TypedValue::Int32(*v as i32)),
534 TypedValue::Int16(v) => Ok(TypedValue::Int32(*v as i32)),
535 TypedValue::Int32(_) => Ok(val.clone()),
536 _ => Err(KyuError::Runtime(format!("cannot cast {val:?} to INT32"))),
537 },
538 LogicalType::Int16 => match val {
539 TypedValue::Int8(v) => Ok(TypedValue::Int16(*v as i16)),
540 TypedValue::Int16(_) => Ok(val.clone()),
541 _ => Err(KyuError::Runtime(format!("cannot cast {val:?} to INT16"))),
542 },
543 _ => Err(KyuError::Runtime(format!(
544 "cast to {} not supported in scalar evaluator",
545 target.type_name()
546 ))),
547 }
548}
549
550fn eval_scalar_function(name: &str, args: &[TypedValue]) -> KyuResult<TypedValue> {
551 use std::hash::{DefaultHasher, Hash, Hasher};
552
553 let null_transparent = !matches!(
556 name.to_lowercase().as_str(),
557 "coalesce" | "greatest" | "least" | "typeof"
558 );
559 if null_transparent && args.iter().any(|a| a.is_null()) {
560 return Ok(TypedValue::Null);
561 }
562
563 match name.to_lowercase().as_str() {
564 "abs" => match &args[0] {
566 TypedValue::Int64(v) => Ok(TypedValue::Int64(v.abs())),
567 TypedValue::Double(v) => Ok(TypedValue::Double(v.abs())),
568 TypedValue::Int32(v) => Ok(TypedValue::Int32(v.abs())),
569 _ => Err(KyuError::Runtime("abs requires numeric".into())),
570 },
571 "floor" => as_f64(&args[0]).map(|v| TypedValue::Double(v.floor())),
572 "ceil" => as_f64(&args[0]).map(|v| TypedValue::Double(v.ceil())),
573 "round" => as_f64(&args[0]).map(|v| TypedValue::Double(v.round())),
574 "sqrt" => as_f64(&args[0]).map(|v| TypedValue::Double(v.sqrt())),
575 "log" => as_f64(&args[0]).map(|v| TypedValue::Double(v.ln())),
576 "log2" => as_f64(&args[0]).map(|v| TypedValue::Double(v.log2())),
577 "log10" => as_f64(&args[0]).map(|v| TypedValue::Double(v.log10())),
578 "sin" => as_f64(&args[0]).map(|v| TypedValue::Double(v.sin())),
579 "cos" => as_f64(&args[0]).map(|v| TypedValue::Double(v.cos())),
580 "tan" => as_f64(&args[0]).map(|v| TypedValue::Double(v.tan())),
581 "sign" => match &args[0] {
582 TypedValue::Int64(v) => Ok(TypedValue::Int64(v.signum())),
583 TypedValue::Double(v) => Ok(TypedValue::Int64(if *v > 0.0 {
584 1
585 } else if *v < 0.0 {
586 -1
587 } else {
588 0
589 })),
590 _ => Err(KyuError::Runtime("sign requires numeric".into())),
591 },
592
593 "lower" => as_str(&args[0]).map(|s| TypedValue::String(SmolStr::new(s.to_lowercase()))),
595 "upper" => as_str(&args[0]).map(|s| TypedValue::String(SmolStr::new(s.to_uppercase()))),
596 "length" | "size" => match &args[0] {
597 TypedValue::String(s) => Ok(TypedValue::Int64(s.len() as i64)),
598 TypedValue::List(l) => Ok(TypedValue::Int64(l.len() as i64)),
599 _ => Err(KyuError::Runtime("length requires string or list".into())),
600 },
601 "trim" => as_str(&args[0]).map(|s| TypedValue::String(SmolStr::new(s.trim()))),
602 "ltrim" => as_str(&args[0]).map(|s| TypedValue::String(SmolStr::new(s.trim_start()))),
603 "rtrim" => as_str(&args[0]).map(|s| TypedValue::String(SmolStr::new(s.trim_end()))),
604 "reverse" => as_str(&args[0]).map(|s| {
605 TypedValue::String(SmolStr::new(
606 s.chars().rev().collect::<std::string::String>(),
607 ))
608 }),
609 "substring" => {
610 let s = as_str(&args[0])?;
611 let start = as_i64(&args[1])? as usize;
612 let len = as_i64(&args[2])? as usize;
613 let start = if start > 0 { start - 1 } else { 0 }; let result: std::string::String = s.chars().skip(start).take(len).collect();
615 Ok(TypedValue::String(SmolStr::new(result)))
616 }
617 "left" => {
618 let s = as_str(&args[0])?;
619 let n = as_i64(&args[1])? as usize;
620 let result: std::string::String = s.chars().take(n).collect();
621 Ok(TypedValue::String(SmolStr::new(result)))
622 }
623 "right" => {
624 let s = as_str(&args[0])?;
625 let n = as_i64(&args[1])? as usize;
626 let chars: Vec<char> = s.chars().collect();
627 let start = chars.len().saturating_sub(n);
628 let result: std::string::String = chars[start..].iter().collect();
629 Ok(TypedValue::String(SmolStr::new(result)))
630 }
631 "replace" => {
632 let s = as_str(&args[0])?;
633 let from = as_str(&args[1])?;
634 let to = as_str(&args[2])?;
635 Ok(TypedValue::String(SmolStr::new(s.replace(from, to))))
636 }
637 "concat" => {
638 let mut result = std::string::String::new();
639 for arg in args {
640 match arg {
641 TypedValue::String(s) => result.push_str(s),
642 TypedValue::Null => {} other => result.push_str(&format!("{other}")),
644 }
645 }
646 Ok(TypedValue::String(SmolStr::new(result)))
647 }
648 "lpad" => {
649 let s = as_str(&args[0])?;
650 let target_len = as_i64(&args[1])? as usize;
651 let pad = as_str(&args[2])?;
652 let mut result = std::string::String::from(s);
653 while result.len() < target_len {
654 let remaining = target_len - result.len();
655 let take: std::string::String = pad.chars().take(remaining).collect();
656 result = format!("{take}{result}");
657 }
658 Ok(TypedValue::String(SmolStr::new(result)))
659 }
660 "rpad" => {
661 let s = as_str(&args[0])?;
662 let target_len = as_i64(&args[1])? as usize;
663 let pad = as_str(&args[2])?;
664 let mut result = std::string::String::from(s);
665 while result.len() < target_len {
666 let remaining = target_len - result.len();
667 let take: std::string::String = pad.chars().take(remaining).collect();
668 result.push_str(&take);
669 }
670 Ok(TypedValue::String(SmolStr::new(result)))
671 }
672
673 "tostring" => Ok(TypedValue::String(SmolStr::new(format!("{}", args[0])))),
675 "tointeger" => match &args[0] {
676 TypedValue::Int64(_) => Ok(args[0].clone()),
677 TypedValue::Int32(v) => Ok(TypedValue::Int64(*v as i64)),
678 TypedValue::Double(v) => Ok(TypedValue::Int64(*v as i64)),
679 TypedValue::Float(v) => Ok(TypedValue::Int64(*v as i64)),
680 TypedValue::String(s) => s
681 .parse::<i64>()
682 .map(TypedValue::Int64)
683 .map_err(|_| KyuError::Runtime(format!("cannot convert '{s}' to integer"))),
684 TypedValue::Bool(b) => Ok(TypedValue::Int64(if *b { 1 } else { 0 })),
685 _ => Err(KyuError::Runtime("tointeger: unsupported type".into())),
686 },
687 "tofloat" => match &args[0] {
688 TypedValue::Double(_) => Ok(args[0].clone()),
689 TypedValue::Int64(v) => Ok(TypedValue::Double(*v as f64)),
690 TypedValue::Int32(v) => Ok(TypedValue::Double(*v as f64)),
691 TypedValue::Float(v) => Ok(TypedValue::Double(*v as f64)),
692 TypedValue::String(s) => s
693 .parse::<f64>()
694 .map(TypedValue::Double)
695 .map_err(|_| KyuError::Runtime(format!("cannot convert '{s}' to float"))),
696 _ => Err(KyuError::Runtime("tofloat: unsupported type".into())),
697 },
698 "toboolean" => match &args[0] {
699 TypedValue::Bool(_) => Ok(args[0].clone()),
700 TypedValue::Int64(v) => Ok(TypedValue::Bool(*v != 0)),
701 TypedValue::String(s) => match s.to_lowercase().as_str() {
702 "true" => Ok(TypedValue::Bool(true)),
703 "false" => Ok(TypedValue::Bool(false)),
704 _ => Err(KyuError::Runtime(format!(
705 "cannot convert '{s}' to boolean"
706 ))),
707 },
708 _ => Err(KyuError::Runtime("toboolean: unsupported type".into())),
709 },
710
711 "coalesce" => {
713 for arg in args {
714 if !arg.is_null() {
715 return Ok(arg.clone());
716 }
717 }
718 Ok(TypedValue::Null)
719 }
720 "typeof" => {
721 let type_name = match &args[0] {
722 TypedValue::Null => "NULL",
723 TypedValue::Bool(_) => "BOOL",
724 TypedValue::Int8(_) => "INT8",
725 TypedValue::Int16(_) => "INT16",
726 TypedValue::Int32(_) => "INT32",
727 TypedValue::Int64(_) => "INT64",
728 TypedValue::UInt8(_) => "UINT8",
729 TypedValue::UInt16(_) => "UINT16",
730 TypedValue::UInt32(_) => "UINT32",
731 TypedValue::UInt64(_) => "UINT64",
732 TypedValue::Float(_) => "FLOAT",
733 TypedValue::Double(_) => "DOUBLE",
734 TypedValue::String(_) => "STRING",
735 TypedValue::List(_) => "LIST",
736 TypedValue::Map(_) => "MAP",
737 TypedValue::Interval(_) => "INTERVAL",
738 TypedValue::InternalId(_) => "INTERNAL_ID",
739 _ => "UNKNOWN",
740 };
741 Ok(TypedValue::String(SmolStr::new(type_name)))
742 }
743 "hash" => {
744 let mut hasher = DefaultHasher::new();
745 format!("{:?}", args[0]).hash(&mut hasher);
748 Ok(TypedValue::Int64(hasher.finish() as i64))
749 }
750 "greatest" => {
751 let mut best: Option<&TypedValue> = None;
752 for arg in args {
753 if arg.is_null() {
754 continue;
755 }
756 match best {
757 None => best = Some(arg),
758 Some(b) => {
759 if let Ok(std::cmp::Ordering::Greater) = compare_values(arg, b) {
760 best = Some(arg);
761 }
762 }
763 }
764 }
765 Ok(best.cloned().unwrap_or(TypedValue::Null))
766 }
767 "least" => {
768 let mut best: Option<&TypedValue> = None;
769 for arg in args {
770 if arg.is_null() {
771 continue;
772 }
773 match best {
774 None => best = Some(arg),
775 Some(b) => {
776 if let Ok(std::cmp::Ordering::Less) = compare_values(arg, b) {
777 best = Some(arg);
778 }
779 }
780 }
781 }
782 Ok(best.cloned().unwrap_or(TypedValue::Null))
783 }
784
785 "json_extract" => {
787 let s = as_str(&args[0])?;
788 let path = as_str(&args[1])?;
789 json_extract(s, path)
790 }
791 "json_valid" => {
792 let s = as_str(&args[0])?;
793 let mut bytes = s.as_bytes().to_vec();
794 Ok(TypedValue::Bool(
795 simd_json::to_borrowed_value(&mut bytes).is_ok(),
796 ))
797 }
798 "json_type" => {
799 let s = as_str(&args[0])?;
800 json_type_fn(s)
801 }
802 "json_keys" => {
803 let s = as_str(&args[0])?;
804 json_keys(s)
805 }
806 "json_array_length" => {
807 let s = as_str(&args[0])?;
808 json_array_length(s)
809 }
810 "json_contains" => {
811 let s = as_str(&args[0])?;
812 let path = as_str(&args[1])?;
813 json_contains(s, path)
814 }
815 "json_set" => {
816 let s = as_str(&args[0])?;
817 let path = as_str(&args[1])?;
818 let value_str = as_str(&args[2])?;
819 json_set(s, path, value_str)
820 }
821
822 "range" => {
824 let start = as_i64(&args[0])?;
825 let end = as_i64(&args[1])?;
826 let list: Vec<TypedValue> = (start..=end).map(TypedValue::Int64).collect();
827 Ok(TypedValue::List(list))
828 }
829
830 _ => Err(KyuError::NotImplemented(format!(
831 "function '{name}' not implemented"
832 ))),
833 }
834}
835
836fn as_f64(val: &TypedValue) -> KyuResult<f64> {
838 match val {
839 TypedValue::Int64(v) => Ok(*v as f64),
840 TypedValue::Int32(v) => Ok(*v as f64),
841 TypedValue::Double(v) => Ok(*v),
842 TypedValue::Float(v) => Ok(*v as f64),
843 _ => Err(KyuError::Runtime("expected numeric value".into())),
844 }
845}
846
847fn as_i64(val: &TypedValue) -> KyuResult<i64> {
849 match val {
850 TypedValue::Int64(v) => Ok(*v),
851 TypedValue::Int32(v) => Ok(*v as i64),
852 _ => Err(KyuError::Runtime("expected integer value".into())),
853 }
854}
855
856fn as_str(val: &TypedValue) -> KyuResult<&str> {
858 match val {
859 TypedValue::String(s) => Ok(s.as_str()),
860 _ => Err(KyuError::Runtime("expected string value".into())),
861 }
862}
863
864use simd_json::prelude::{
868 TypedScalarValue, ValueAsArray, ValueAsMutArray, ValueAsMutObject, ValueAsObject,
869 ValueAsScalar, ValueObjectAccess, Writable,
870};
871
872fn json_navigate<'a>(
875 val: &'a simd_json::BorrowedValue<'a>,
876 path: &str,
877) -> Option<&'a simd_json::BorrowedValue<'a>> {
878 let path = path
879 .strip_prefix("$.")
880 .unwrap_or(path.strip_prefix('$').unwrap_or(path));
881 if path.is_empty() {
882 return Some(val);
883 }
884
885 let mut current = val;
886 for key in path.split('.') {
887 let key = key.trim_matches(|c| c == '[' || c == ']');
889 if let Ok(idx) = key.parse::<usize>() {
890 current = current.as_array()?.get(idx)?;
891 } else {
892 current = ValueObjectAccess::get(current, key)?;
893 }
894 }
895 Some(current)
896}
897
898fn json_value_to_typed(val: &simd_json::BorrowedValue<'_>) -> TypedValue {
900 if val.is_null() {
901 TypedValue::Null
902 } else if let Some(b) = val.as_bool() {
903 TypedValue::Bool(b)
904 } else if let Some(n) = val.as_i64() {
905 TypedValue::Int64(n)
906 } else if let Some(n) = val.as_f64() {
907 TypedValue::Double(n)
908 } else if let Some(s) = val.as_str() {
909 TypedValue::String(SmolStr::new(s))
910 } else {
911 TypedValue::String(SmolStr::new(json_serialize(val)))
913 }
914}
915
916fn json_serialize(val: &impl Writable) -> std::string::String {
918 let mut buf = Vec::new();
919 let _ = val.write(&mut buf);
920 unsafe { std::string::String::from_utf8_unchecked(buf) }
922}
923
924fn json_extract(json_str: &str, path: &str) -> KyuResult<TypedValue> {
925 let mut bytes = json_str.as_bytes().to_vec();
926 let parsed = simd_json::to_borrowed_value(&mut bytes)
927 .map_err(|e| KyuError::Runtime(format!("invalid JSON: {e}")))?;
928
929 match json_navigate(&parsed, path) {
930 Some(val) => Ok(json_value_to_typed(val)),
931 None => Ok(TypedValue::Null),
932 }
933}
934
935fn json_type_fn(json_str: &str) -> KyuResult<TypedValue> {
936 let mut bytes = json_str.as_bytes().to_vec();
937 let parsed = simd_json::to_borrowed_value(&mut bytes)
938 .map_err(|e| KyuError::Runtime(format!("invalid JSON: {e}")))?;
939
940 let type_name = if parsed.is_null() {
941 "null"
942 } else if parsed.as_bool().is_some() {
943 "boolean"
944 } else if parsed.as_i64().is_some() || parsed.as_f64().is_some() {
945 "number"
946 } else if parsed.as_str().is_some() {
947 "string"
948 } else if parsed.as_array().is_some() {
949 "array"
950 } else if parsed.as_object().is_some() {
951 "object"
952 } else {
953 "unknown"
954 };
955 Ok(TypedValue::String(SmolStr::new(type_name)))
956}
957
958fn json_keys(json_str: &str) -> KyuResult<TypedValue> {
959 let mut bytes = json_str.as_bytes().to_vec();
960 let parsed = simd_json::to_borrowed_value(&mut bytes)
961 .map_err(|e| KyuError::Runtime(format!("invalid JSON: {e}")))?;
962
963 match parsed.as_object() {
964 Some(obj) => {
965 let keys: Vec<TypedValue> = obj
966 .keys()
967 .map(|k| TypedValue::String(SmolStr::new(k.as_ref())))
968 .collect();
969 Ok(TypedValue::List(keys))
970 }
971 None => Ok(TypedValue::Null),
972 }
973}
974
975fn json_array_length(json_str: &str) -> KyuResult<TypedValue> {
976 let mut bytes = json_str.as_bytes().to_vec();
977 let parsed = simd_json::to_borrowed_value(&mut bytes)
978 .map_err(|e| KyuError::Runtime(format!("invalid JSON: {e}")))?;
979
980 match parsed.as_array() {
981 Some(arr) => Ok(TypedValue::Int64(arr.len() as i64)),
982 None => Ok(TypedValue::Null),
983 }
984}
985
986fn json_contains(json_str: &str, path: &str) -> KyuResult<TypedValue> {
987 let mut bytes = json_str.as_bytes().to_vec();
988 let parsed = simd_json::to_borrowed_value(&mut bytes)
989 .map_err(|e| KyuError::Runtime(format!("invalid JSON: {e}")))?;
990
991 Ok(TypedValue::Bool(json_navigate(&parsed, path).is_some()))
992}
993
994fn json_set(json_str: &str, path: &str, value_str: &str) -> KyuResult<TypedValue> {
995 let mut bytes = json_str.as_bytes().to_vec();
996 let mut doc = simd_json::to_owned_value(&mut bytes)
997 .map_err(|e| KyuError::Runtime(format!("invalid JSON: {e}")))?;
998
999 let mut val_bytes = value_str.as_bytes().to_vec();
1001 let new_val = simd_json::to_owned_value(&mut val_bytes).unwrap_or_else(|_| {
1002 simd_json::OwnedValue::from(value_str.to_string())
1004 });
1005
1006 let path = path
1007 .strip_prefix("$.")
1008 .unwrap_or(path.strip_prefix('$').unwrap_or(path));
1009
1010 if path.is_empty() {
1011 return Ok(TypedValue::String(SmolStr::new(json_serialize(&new_val))));
1012 }
1013
1014 let keys: Vec<&str> = path.split('.').collect();
1015
1016 let mut current = &mut doc;
1018 for &key in &keys[..keys.len() - 1] {
1019 let key = key.trim_matches(|c| c == '[' || c == ']');
1020 if let Ok(idx) = key.parse::<usize>() {
1021 current = current
1022 .as_array_mut()
1023 .and_then(|a| a.get_mut(idx))
1024 .ok_or_else(|| KyuError::Runtime(format!("path not found: {path}")))?;
1025 } else {
1026 current = current
1027 .as_object_mut()
1028 .and_then(|o| o.get_mut(key))
1029 .ok_or_else(|| KyuError::Runtime(format!("path not found: {path}")))?;
1030 }
1031 }
1032
1033 let last_key = keys[keys.len() - 1].trim_matches(|c| c == '[' || c == ']');
1034 if let Some(obj) = current.as_object_mut() {
1035 obj.insert(last_key.to_string(), new_val);
1036 } else if let Ok(idx) = last_key.parse::<usize>() {
1037 if let Some(arr) = current.as_array_mut() {
1038 if idx < arr.len() {
1039 arr[idx] = new_val;
1040 } else {
1041 return Err(KyuError::Runtime(format!(
1042 "array index {idx} out of bounds"
1043 )));
1044 }
1045 } else {
1046 return Err(KyuError::Runtime(
1047 "cannot set on non-object/non-array".into(),
1048 ));
1049 }
1050 } else {
1051 return Err(KyuError::Runtime("cannot set on non-object".into()));
1052 }
1053
1054 Ok(TypedValue::String(SmolStr::new(json_serialize(&doc))))
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use kyu_parser::ast::ComparisonOp;
1061
1062 fn lit(value: TypedValue, result_type: LogicalType) -> BoundExpression {
1063 BoundExpression::Literal { value, result_type }
1064 }
1065
1066 fn lit_int(v: i64) -> BoundExpression {
1067 lit(TypedValue::Int64(v), LogicalType::Int64)
1068 }
1069
1070 fn lit_str(s: &str) -> BoundExpression {
1071 lit(TypedValue::String(SmolStr::new(s)), LogicalType::String)
1072 }
1073
1074 fn lit_bool(v: bool) -> BoundExpression {
1075 lit(TypedValue::Bool(v), LogicalType::Bool)
1076 }
1077
1078 fn lit_null() -> BoundExpression {
1079 lit(TypedValue::Null, LogicalType::Any)
1080 }
1081
1082 #[test]
1083 fn evaluate_literal() {
1084 assert_eq!(
1085 evaluate_constant(&lit_int(42)).unwrap(),
1086 TypedValue::Int64(42)
1087 );
1088 }
1089
1090 #[test]
1091 fn evaluate_variable() {
1092 let expr = BoundExpression::Variable {
1093 index: 1,
1094 result_type: LogicalType::String,
1095 };
1096 let tuple = vec![
1097 TypedValue::Int64(1),
1098 TypedValue::String(SmolStr::new("hello")),
1099 ];
1100 assert_eq!(
1101 evaluate(&expr, &tuple).unwrap(),
1102 TypedValue::String(SmolStr::new("hello"))
1103 );
1104 }
1105
1106 #[test]
1107 fn evaluate_add() {
1108 let expr = BoundExpression::BinaryOp {
1109 op: BinaryOp::Add,
1110 left: Box::new(lit_int(10)),
1111 right: Box::new(lit_int(32)),
1112 result_type: LogicalType::Int64,
1113 };
1114 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(42));
1115 }
1116
1117 #[test]
1118 fn evaluate_sub_mul() {
1119 let sub = BoundExpression::BinaryOp {
1120 op: BinaryOp::Sub,
1121 left: Box::new(lit_int(50)),
1122 right: Box::new(lit_int(8)),
1123 result_type: LogicalType::Int64,
1124 };
1125 assert_eq!(evaluate_constant(&sub).unwrap(), TypedValue::Int64(42));
1126
1127 let mul = BoundExpression::BinaryOp {
1128 op: BinaryOp::Mul,
1129 left: Box::new(lit_int(6)),
1130 right: Box::new(lit_int(7)),
1131 result_type: LogicalType::Int64,
1132 };
1133 assert_eq!(evaluate_constant(&mul).unwrap(), TypedValue::Int64(42));
1134 }
1135
1136 #[test]
1137 fn evaluate_division_by_zero() {
1138 let expr = BoundExpression::BinaryOp {
1139 op: BinaryOp::Div,
1140 left: Box::new(lit_int(10)),
1141 right: Box::new(lit_int(0)),
1142 result_type: LogicalType::Int64,
1143 };
1144 assert!(evaluate_constant(&expr).is_err());
1145 }
1146
1147 #[test]
1148 fn evaluate_null_propagation() {
1149 let expr = BoundExpression::BinaryOp {
1150 op: BinaryOp::Add,
1151 left: Box::new(lit_int(10)),
1152 right: Box::new(lit_null()),
1153 result_type: LogicalType::Int64,
1154 };
1155 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Null);
1156 }
1157
1158 #[test]
1159 fn evaluate_three_valued_and() {
1160 let expr = BoundExpression::BinaryOp {
1162 op: BinaryOp::And,
1163 left: Box::new(lit_bool(false)),
1164 right: Box::new(lit_null()),
1165 result_type: LogicalType::Bool,
1166 };
1167 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(false));
1168
1169 let expr2 = BoundExpression::BinaryOp {
1171 op: BinaryOp::And,
1172 left: Box::new(lit_bool(true)),
1173 right: Box::new(lit_null()),
1174 result_type: LogicalType::Bool,
1175 };
1176 assert_eq!(evaluate_constant(&expr2).unwrap(), TypedValue::Null);
1177 }
1178
1179 #[test]
1180 fn evaluate_three_valued_or() {
1181 let expr = BoundExpression::BinaryOp {
1183 op: BinaryOp::Or,
1184 left: Box::new(lit_bool(true)),
1185 right: Box::new(lit_null()),
1186 result_type: LogicalType::Bool,
1187 };
1188 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(true));
1189
1190 let expr2 = BoundExpression::BinaryOp {
1192 op: BinaryOp::Or,
1193 left: Box::new(lit_bool(false)),
1194 right: Box::new(lit_null()),
1195 result_type: LogicalType::Bool,
1196 };
1197 assert_eq!(evaluate_constant(&expr2).unwrap(), TypedValue::Null);
1198 }
1199
1200 #[test]
1201 fn evaluate_comparison() {
1202 let expr = BoundExpression::Comparison {
1203 op: ComparisonOp::Gt,
1204 left: Box::new(lit_int(10)),
1205 right: Box::new(lit_int(5)),
1206 };
1207 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(true));
1208
1209 let expr2 = BoundExpression::Comparison {
1210 op: ComparisonOp::Eq,
1211 left: Box::new(lit_str("a")),
1212 right: Box::new(lit_str("a")),
1213 };
1214 assert_eq!(evaluate_constant(&expr2).unwrap(), TypedValue::Bool(true));
1215 }
1216
1217 #[test]
1218 fn evaluate_comparison_null() {
1219 let expr = BoundExpression::Comparison {
1220 op: ComparisonOp::Eq,
1221 left: Box::new(lit_int(1)),
1222 right: Box::new(lit_null()),
1223 };
1224 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Null);
1225 }
1226
1227 #[test]
1228 fn evaluate_is_null() {
1229 let expr = BoundExpression::IsNull {
1230 expr: Box::new(lit_null()),
1231 negated: false,
1232 };
1233 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(true));
1234
1235 let expr2 = BoundExpression::IsNull {
1236 expr: Box::new(lit_int(1)),
1237 negated: true,
1238 };
1239 assert_eq!(evaluate_constant(&expr2).unwrap(), TypedValue::Bool(true));
1240 }
1241
1242 #[test]
1243 fn evaluate_string_ops() {
1244 let sw = BoundExpression::StringOp {
1245 op: StringOp::StartsWith,
1246 left: Box::new(lit_str("hello")),
1247 right: Box::new(lit_str("hel")),
1248 };
1249 assert_eq!(evaluate_constant(&sw).unwrap(), TypedValue::Bool(true));
1250
1251 let ew = BoundExpression::StringOp {
1252 op: StringOp::EndsWith,
1253 left: Box::new(lit_str("hello")),
1254 right: Box::new(lit_str("lo")),
1255 };
1256 assert_eq!(evaluate_constant(&ew).unwrap(), TypedValue::Bool(true));
1257
1258 let ct = BoundExpression::StringOp {
1259 op: StringOp::Contains,
1260 left: Box::new(lit_str("hello")),
1261 right: Box::new(lit_str("ell")),
1262 };
1263 assert_eq!(evaluate_constant(&ct).unwrap(), TypedValue::Bool(true));
1264 }
1265
1266 #[test]
1267 fn evaluate_concat() {
1268 let expr = BoundExpression::BinaryOp {
1269 op: BinaryOp::Concat,
1270 left: Box::new(lit_str("hello")),
1271 right: Box::new(lit_str(" world")),
1272 result_type: LogicalType::String,
1273 };
1274 assert_eq!(
1275 evaluate_constant(&expr).unwrap(),
1276 TypedValue::String(SmolStr::new("hello world"))
1277 );
1278 }
1279
1280 #[test]
1281 fn evaluate_not() {
1282 let expr = BoundExpression::UnaryOp {
1283 op: UnaryOp::Not,
1284 operand: Box::new(lit_bool(true)),
1285 result_type: LogicalType::Bool,
1286 };
1287 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(false));
1288 }
1289
1290 #[test]
1291 fn evaluate_unary_minus() {
1292 let expr = BoundExpression::UnaryOp {
1293 op: UnaryOp::Minus,
1294 operand: Box::new(lit_int(42)),
1295 result_type: LogicalType::Int64,
1296 };
1297 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(-42));
1298 }
1299
1300 #[test]
1301 fn evaluate_case_searched() {
1302 let expr = BoundExpression::Case {
1303 operand: None,
1304 whens: vec![(lit_bool(false), lit_int(1)), (lit_bool(true), lit_int(2))],
1305 else_expr: Some(Box::new(lit_int(3))),
1306 result_type: LogicalType::Int64,
1307 };
1308 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(2));
1309 }
1310
1311 #[test]
1312 fn evaluate_case_else() {
1313 let expr = BoundExpression::Case {
1314 operand: None,
1315 whens: vec![(lit_bool(false), lit_int(1))],
1316 else_expr: Some(Box::new(lit_int(99))),
1317 result_type: LogicalType::Int64,
1318 };
1319 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(99));
1320 }
1321
1322 #[test]
1323 fn evaluate_list_literal() {
1324 let expr = BoundExpression::ListLiteral {
1325 elements: vec![lit_int(1), lit_int(2), lit_int(3)],
1326 result_type: LogicalType::List(Box::new(LogicalType::Int64)),
1327 };
1328 assert_eq!(
1329 evaluate_constant(&expr).unwrap(),
1330 TypedValue::List(vec![
1331 TypedValue::Int64(1),
1332 TypedValue::Int64(2),
1333 TypedValue::Int64(3)
1334 ])
1335 );
1336 }
1337
1338 #[test]
1339 fn evaluate_subscript() {
1340 let expr = BoundExpression::Subscript {
1341 expr: Box::new(BoundExpression::ListLiteral {
1342 elements: vec![lit_int(10), lit_int(20), lit_int(30)],
1343 result_type: LogicalType::List(Box::new(LogicalType::Int64)),
1344 }),
1345 index: Box::new(lit_int(2)), result_type: LogicalType::Int64,
1347 };
1348 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(20));
1349 }
1350
1351 #[test]
1352 fn evaluate_cast_int_to_double() {
1353 let expr = BoundExpression::Cast {
1354 expr: Box::new(lit_int(42)),
1355 target_type: LogicalType::Double,
1356 };
1357 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Double(42.0));
1358 }
1359
1360 #[test]
1361 fn evaluate_cast_to_string() {
1362 let expr = BoundExpression::Cast {
1363 expr: Box::new(lit_int(42)),
1364 target_type: LogicalType::String,
1365 };
1366 assert_eq!(
1367 evaluate_constant(&expr).unwrap(),
1368 TypedValue::String(SmolStr::new("42"))
1369 );
1370 }
1371
1372 #[test]
1373 fn evaluate_in_list() {
1374 let expr = BoundExpression::InList {
1375 expr: Box::new(lit_int(2)),
1376 list: vec![lit_int(1), lit_int(2), lit_int(3)],
1377 negated: false,
1378 };
1379 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(true));
1380
1381 let expr_not = BoundExpression::InList {
1382 expr: Box::new(lit_int(5)),
1383 list: vec![lit_int(1), lit_int(2), lit_int(3)],
1384 negated: false,
1385 };
1386 assert_eq!(
1387 evaluate_constant(&expr_not).unwrap(),
1388 TypedValue::Bool(false)
1389 );
1390 }
1391
1392 #[test]
1393 fn evaluate_nested_expression() {
1394 let add = BoundExpression::BinaryOp {
1396 op: BinaryOp::Add,
1397 left: Box::new(lit_int(1)),
1398 right: Box::new(lit_int(2)),
1399 result_type: LogicalType::Int64,
1400 };
1401 let mul = BoundExpression::BinaryOp {
1402 op: BinaryOp::Mul,
1403 left: Box::new(add),
1404 right: Box::new(lit_int(3)),
1405 result_type: LogicalType::Int64,
1406 };
1407 assert_eq!(evaluate_constant(&mul).unwrap(), TypedValue::Int64(9));
1408 }
1409
1410 use crate::bound_expr::FunctionId;
1413
1414 fn func(name: &str, args: Vec<BoundExpression>, ret: LogicalType) -> BoundExpression {
1415 BoundExpression::FunctionCall {
1416 function_id: FunctionId(0),
1417 function_name: SmolStr::new(name),
1418 args,
1419 distinct: false,
1420 result_type: ret,
1421 }
1422 }
1423
1424 fn lit_f64(v: f64) -> BoundExpression {
1425 lit(TypedValue::Double(v), LogicalType::Double)
1426 }
1427
1428 #[test]
1429 fn func_abs_int() {
1430 let expr = func("abs", vec![lit_int(-42)], LogicalType::Int64);
1431 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(42));
1432 }
1433
1434 #[test]
1435 fn func_abs_double() {
1436 let expr = func("abs", vec![lit_f64(-3.14)], LogicalType::Double);
1437 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Double(3.14));
1438 }
1439
1440 #[test]
1441 fn func_floor_ceil_round() {
1442 let floor = func("floor", vec![lit_f64(3.7)], LogicalType::Double);
1443 assert_eq!(evaluate_constant(&floor).unwrap(), TypedValue::Double(3.0));
1444
1445 let ceil = func("ceil", vec![lit_f64(3.2)], LogicalType::Double);
1446 assert_eq!(evaluate_constant(&ceil).unwrap(), TypedValue::Double(4.0));
1447
1448 let round = func("round", vec![lit_f64(3.5)], LogicalType::Double);
1449 assert_eq!(evaluate_constant(&round).unwrap(), TypedValue::Double(4.0));
1450 }
1451
1452 #[test]
1453 fn func_sqrt() {
1454 let expr = func("sqrt", vec![lit_f64(9.0)], LogicalType::Double);
1455 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Double(3.0));
1456 }
1457
1458 #[test]
1459 fn func_log_family() {
1460 let ln = func(
1461 "log",
1462 vec![lit_f64(std::f64::consts::E)],
1463 LogicalType::Double,
1464 );
1465 let result = evaluate_constant(&ln).unwrap();
1466 if let TypedValue::Double(v) = result {
1467 assert!((v - 1.0).abs() < 1e-10);
1468 }
1469
1470 let log2 = func("log2", vec![lit_f64(8.0)], LogicalType::Double);
1471 assert_eq!(evaluate_constant(&log2).unwrap(), TypedValue::Double(3.0));
1472
1473 let log10 = func("log10", vec![lit_f64(100.0)], LogicalType::Double);
1474 assert_eq!(evaluate_constant(&log10).unwrap(), TypedValue::Double(2.0));
1475 }
1476
1477 #[test]
1478 fn func_trig() {
1479 let sin = func("sin", vec![lit_f64(0.0)], LogicalType::Double);
1480 assert_eq!(evaluate_constant(&sin).unwrap(), TypedValue::Double(0.0));
1481
1482 let cos = func("cos", vec![lit_f64(0.0)], LogicalType::Double);
1483 assert_eq!(evaluate_constant(&cos).unwrap(), TypedValue::Double(1.0));
1484
1485 let tan = func("tan", vec![lit_f64(0.0)], LogicalType::Double);
1486 assert_eq!(evaluate_constant(&tan).unwrap(), TypedValue::Double(0.0));
1487 }
1488
1489 #[test]
1490 fn func_sign() {
1491 let pos = func("sign", vec![lit_int(42)], LogicalType::Int64);
1492 assert_eq!(evaluate_constant(&pos).unwrap(), TypedValue::Int64(1));
1493
1494 let neg = func("sign", vec![lit_int(-5)], LogicalType::Int64);
1495 assert_eq!(evaluate_constant(&neg).unwrap(), TypedValue::Int64(-1));
1496
1497 let zero = func("sign", vec![lit_int(0)], LogicalType::Int64);
1498 assert_eq!(evaluate_constant(&zero).unwrap(), TypedValue::Int64(0));
1499
1500 let neg_dbl = func("sign", vec![lit_f64(-3.14)], LogicalType::Int64);
1501 assert_eq!(evaluate_constant(&neg_dbl).unwrap(), TypedValue::Int64(-1));
1502 }
1503
1504 #[test]
1505 fn func_lower_upper() {
1506 let lower = func("lower", vec![lit_str("Hello")], LogicalType::String);
1507 assert_eq!(
1508 evaluate_constant(&lower).unwrap(),
1509 TypedValue::String(SmolStr::new("hello"))
1510 );
1511
1512 let upper = func("upper", vec![lit_str("Hello")], LogicalType::String);
1513 assert_eq!(
1514 evaluate_constant(&upper).unwrap(),
1515 TypedValue::String(SmolStr::new("HELLO"))
1516 );
1517 }
1518
1519 #[test]
1520 fn func_length_size() {
1521 let len_str = func("length", vec![lit_str("hello")], LogicalType::Int64);
1522 assert_eq!(evaluate_constant(&len_str).unwrap(), TypedValue::Int64(5));
1523
1524 let size_list = func(
1525 "size",
1526 vec![BoundExpression::ListLiteral {
1527 elements: vec![lit_int(1), lit_int(2), lit_int(3)],
1528 result_type: LogicalType::List(Box::new(LogicalType::Int64)),
1529 }],
1530 LogicalType::Int64,
1531 );
1532 assert_eq!(evaluate_constant(&size_list).unwrap(), TypedValue::Int64(3));
1533 }
1534
1535 #[test]
1536 fn func_trim() {
1537 let trim = func("trim", vec![lit_str(" hello ")], LogicalType::String);
1538 assert_eq!(
1539 evaluate_constant(&trim).unwrap(),
1540 TypedValue::String(SmolStr::new("hello"))
1541 );
1542
1543 let ltrim = func("ltrim", vec![lit_str(" hello")], LogicalType::String);
1544 assert_eq!(
1545 evaluate_constant(<rim).unwrap(),
1546 TypedValue::String(SmolStr::new("hello"))
1547 );
1548
1549 let rtrim = func("rtrim", vec![lit_str("hello ")], LogicalType::String);
1550 assert_eq!(
1551 evaluate_constant(&rtrim).unwrap(),
1552 TypedValue::String(SmolStr::new("hello"))
1553 );
1554 }
1555
1556 #[test]
1557 fn func_reverse() {
1558 let rev = func("reverse", vec![lit_str("hello")], LogicalType::String);
1559 assert_eq!(
1560 evaluate_constant(&rev).unwrap(),
1561 TypedValue::String(SmolStr::new("olleh"))
1562 );
1563 }
1564
1565 #[test]
1566 fn func_substring() {
1567 let sub = func(
1568 "substring",
1569 vec![lit_str("hello world"), lit_int(1), lit_int(5)],
1570 LogicalType::String,
1571 );
1572 assert_eq!(
1573 evaluate_constant(&sub).unwrap(),
1574 TypedValue::String(SmolStr::new("hello"))
1575 );
1576
1577 let sub2 = func(
1578 "substring",
1579 vec![lit_str("hello"), lit_int(3), lit_int(2)],
1580 LogicalType::String,
1581 );
1582 assert_eq!(
1583 evaluate_constant(&sub2).unwrap(),
1584 TypedValue::String(SmolStr::new("ll"))
1585 );
1586 }
1587
1588 #[test]
1589 fn func_left_right() {
1590 let left = func(
1591 "left",
1592 vec![lit_str("hello"), lit_int(3)],
1593 LogicalType::String,
1594 );
1595 assert_eq!(
1596 evaluate_constant(&left).unwrap(),
1597 TypedValue::String(SmolStr::new("hel"))
1598 );
1599
1600 let right = func(
1601 "right",
1602 vec![lit_str("hello"), lit_int(3)],
1603 LogicalType::String,
1604 );
1605 assert_eq!(
1606 evaluate_constant(&right).unwrap(),
1607 TypedValue::String(SmolStr::new("llo"))
1608 );
1609 }
1610
1611 #[test]
1612 fn func_replace() {
1613 let rep = func(
1614 "replace",
1615 vec![lit_str("hello world"), lit_str("world"), lit_str("rust")],
1616 LogicalType::String,
1617 );
1618 assert_eq!(
1619 evaluate_constant(&rep).unwrap(),
1620 TypedValue::String(SmolStr::new("hello rust"))
1621 );
1622 }
1623
1624 #[test]
1625 fn func_concat() {
1626 let cat = func(
1627 "concat",
1628 vec![lit_str("hello"), lit_str(" "), lit_str("world")],
1629 LogicalType::String,
1630 );
1631 assert_eq!(
1632 evaluate_constant(&cat).unwrap(),
1633 TypedValue::String(SmolStr::new("hello world"))
1634 );
1635 }
1636
1637 #[test]
1638 fn func_lpad_rpad() {
1639 let lpad = func(
1640 "lpad",
1641 vec![lit_str("hi"), lit_int(5), lit_str("x")],
1642 LogicalType::String,
1643 );
1644 assert_eq!(
1645 evaluate_constant(&lpad).unwrap(),
1646 TypedValue::String(SmolStr::new("xxxhi"))
1647 );
1648
1649 let rpad = func(
1650 "rpad",
1651 vec![lit_str("hi"), lit_int(5), lit_str("x")],
1652 LogicalType::String,
1653 );
1654 assert_eq!(
1655 evaluate_constant(&rpad).unwrap(),
1656 TypedValue::String(SmolStr::new("hixxx"))
1657 );
1658 }
1659
1660 #[test]
1661 fn func_tostring() {
1662 let ts = func("tostring", vec![lit_int(42)], LogicalType::String);
1663 assert_eq!(
1664 evaluate_constant(&ts).unwrap(),
1665 TypedValue::String(SmolStr::new("42"))
1666 );
1667 }
1668
1669 #[test]
1670 fn func_tointeger() {
1671 let ti = func("tointeger", vec![lit_f64(3.7)], LogicalType::Int64);
1672 assert_eq!(evaluate_constant(&ti).unwrap(), TypedValue::Int64(3));
1673
1674 let ti2 = func("tointeger", vec![lit_str("42")], LogicalType::Int64);
1675 assert_eq!(evaluate_constant(&ti2).unwrap(), TypedValue::Int64(42));
1676 }
1677
1678 #[test]
1679 fn func_tofloat() {
1680 let tf = func("tofloat", vec![lit_int(42)], LogicalType::Double);
1681 assert_eq!(evaluate_constant(&tf).unwrap(), TypedValue::Double(42.0));
1682
1683 let tf2 = func("tofloat", vec![lit_str("3.14")], LogicalType::Double);
1684 assert_eq!(evaluate_constant(&tf2).unwrap(), TypedValue::Double(3.14));
1685 }
1686
1687 #[test]
1688 fn func_toboolean() {
1689 let tb = func("toboolean", vec![lit_str("true")], LogicalType::Bool);
1690 assert_eq!(evaluate_constant(&tb).unwrap(), TypedValue::Bool(true));
1691
1692 let tb2 = func("toboolean", vec![lit_int(0)], LogicalType::Bool);
1693 assert_eq!(evaluate_constant(&tb2).unwrap(), TypedValue::Bool(false));
1694 }
1695
1696 #[test]
1697 fn func_coalesce() {
1698 let co = func(
1699 "coalesce",
1700 vec![lit_null(), lit_null(), lit_int(42)],
1701 LogicalType::Int64,
1702 );
1703 assert_eq!(evaluate_constant(&co).unwrap(), TypedValue::Int64(42));
1704
1705 let co2 = func("coalesce", vec![lit_null(), lit_null()], LogicalType::Any);
1706 assert_eq!(evaluate_constant(&co2).unwrap(), TypedValue::Null);
1707 }
1708
1709 #[test]
1710 fn func_typeof() {
1711 let ty = func("typeof", vec![lit_int(42)], LogicalType::String);
1712 assert_eq!(
1713 evaluate_constant(&ty).unwrap(),
1714 TypedValue::String(SmolStr::new("INT64"))
1715 );
1716
1717 let ty2 = func("typeof", vec![lit_str("hello")], LogicalType::String);
1718 assert_eq!(
1719 evaluate_constant(&ty2).unwrap(),
1720 TypedValue::String(SmolStr::new("STRING"))
1721 );
1722
1723 let ty3 = func("typeof", vec![lit_null()], LogicalType::String);
1724 assert_eq!(
1725 evaluate_constant(&ty3).unwrap(),
1726 TypedValue::String(SmolStr::new("NULL"))
1727 );
1728 }
1729
1730 #[test]
1731 fn func_hash() {
1732 let h1 = func("hash", vec![lit_int(42)], LogicalType::Int64);
1733 let result = evaluate_constant(&h1).unwrap();
1734 assert!(matches!(result, TypedValue::Int64(_)));
1735
1736 let h2 = func("hash", vec![lit_int(42)], LogicalType::Int64);
1738 assert_eq!(
1739 evaluate_constant(&h1).unwrap(),
1740 evaluate_constant(&h2).unwrap()
1741 );
1742 }
1743
1744 #[test]
1745 fn func_greatest_least() {
1746 let g = func(
1747 "greatest",
1748 vec![lit_int(1), lit_int(5), lit_int(3)],
1749 LogicalType::Int64,
1750 );
1751 assert_eq!(evaluate_constant(&g).unwrap(), TypedValue::Int64(5));
1752
1753 let l = func(
1754 "least",
1755 vec![lit_int(1), lit_int(5), lit_int(3)],
1756 LogicalType::Int64,
1757 );
1758 assert_eq!(evaluate_constant(&l).unwrap(), TypedValue::Int64(1));
1759
1760 let gn = func(
1762 "greatest",
1763 vec![lit_null(), lit_int(3), lit_null()],
1764 LogicalType::Int64,
1765 );
1766 assert_eq!(evaluate_constant(&gn).unwrap(), TypedValue::Int64(3));
1767 }
1768
1769 #[test]
1770 fn func_range() {
1771 let r = func(
1772 "range",
1773 vec![lit_int(1), lit_int(4)],
1774 LogicalType::List(Box::new(LogicalType::Int64)),
1775 );
1776 assert_eq!(
1777 evaluate_constant(&r).unwrap(),
1778 TypedValue::List(vec![
1779 TypedValue::Int64(1),
1780 TypedValue::Int64(2),
1781 TypedValue::Int64(3),
1782 TypedValue::Int64(4)
1783 ])
1784 );
1785 }
1786
1787 #[test]
1788 fn func_null_propagation() {
1789 let abs_null = func("abs", vec![lit_null()], LogicalType::Int64);
1791 assert_eq!(evaluate_constant(&abs_null).unwrap(), TypedValue::Null);
1792
1793 let lower_null = func("lower", vec![lit_null()], LogicalType::String);
1794 assert_eq!(evaluate_constant(&lower_null).unwrap(), TypedValue::Null);
1795 }
1796
1797 #[test]
1800 fn json_extract_nested_path() {
1801 let expr = func(
1802 "json_extract",
1803 vec![
1804 lit_str(r#"{"address":{"city":"Tokyo","zip":"100"}}"#),
1805 lit_str("$.address.city"),
1806 ],
1807 LogicalType::String,
1808 );
1809 assert_eq!(
1810 evaluate_constant(&expr).unwrap(),
1811 TypedValue::String(SmolStr::new("Tokyo"))
1812 );
1813 }
1814
1815 #[test]
1816 fn json_extract_missing_path() {
1817 let expr = func(
1818 "json_extract",
1819 vec![lit_str(r#"{"a":1}"#), lit_str("$.b")],
1820 LogicalType::String,
1821 );
1822 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Null);
1823 }
1824
1825 #[test]
1826 fn json_extract_array_index() {
1827 let expr = func(
1828 "json_extract",
1829 vec![lit_str(r#"{"items":[10,20,30]}"#), lit_str("$.items.1")],
1830 LogicalType::String,
1831 );
1832 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(20));
1833 }
1834
1835 #[test]
1836 fn json_extract_scalar_values() {
1837 let b = func(
1839 "json_extract",
1840 vec![lit_str(r#"{"ok":true}"#), lit_str("ok")],
1841 LogicalType::String,
1842 );
1843 assert_eq!(evaluate_constant(&b).unwrap(), TypedValue::Bool(true));
1844
1845 let i = func(
1847 "json_extract",
1848 vec![lit_str(r#"{"n":42}"#), lit_str("n")],
1849 LogicalType::String,
1850 );
1851 assert_eq!(evaluate_constant(&i).unwrap(), TypedValue::Int64(42));
1852
1853 let d = func(
1855 "json_extract",
1856 vec![lit_str(r#"{"pi":3.14}"#), lit_str("pi")],
1857 LogicalType::String,
1858 );
1859 assert_eq!(evaluate_constant(&d).unwrap(), TypedValue::Double(3.14));
1860
1861 let n = func(
1863 "json_extract",
1864 vec![lit_str(r#"{"x":null}"#), lit_str("x")],
1865 LogicalType::String,
1866 );
1867 assert_eq!(evaluate_constant(&n).unwrap(), TypedValue::Null);
1868 }
1869
1870 #[test]
1871 fn json_extract_root() {
1872 let expr = func(
1874 "json_extract",
1875 vec![lit_str(r#"{"a":1}"#), lit_str("$")],
1876 LogicalType::String,
1877 );
1878 let result = evaluate_constant(&expr).unwrap();
1879 if let TypedValue::String(s) = &result {
1881 assert!(s.contains("\"a\""));
1882 } else {
1883 panic!("expected String, got {result:?}");
1884 }
1885 }
1886
1887 #[test]
1888 fn json_valid_ok() {
1889 let valid_obj = func("json_valid", vec![lit_str(r#"{"a":1}"#)], LogicalType::Bool);
1890 assert_eq!(
1891 evaluate_constant(&valid_obj).unwrap(),
1892 TypedValue::Bool(true)
1893 );
1894
1895 let valid_arr = func("json_valid", vec![lit_str("[1,2,3]")], LogicalType::Bool);
1896 assert_eq!(
1897 evaluate_constant(&valid_arr).unwrap(),
1898 TypedValue::Bool(true)
1899 );
1900 }
1901
1902 #[test]
1903 fn json_valid_invalid() {
1904 let invalid = func("json_valid", vec![lit_str("{not json}")], LogicalType::Bool);
1905 assert_eq!(
1906 evaluate_constant(&invalid).unwrap(),
1907 TypedValue::Bool(false)
1908 );
1909
1910 let empty = func("json_valid", vec![lit_str("")], LogicalType::Bool);
1911 assert_eq!(evaluate_constant(&empty).unwrap(), TypedValue::Bool(false));
1912 }
1913
1914 #[test]
1915 fn json_type_variants() {
1916 let cases = vec![
1917 (r#"{"a":1}"#, "object"),
1918 ("[1,2]", "array"),
1919 (r#""hello""#, "string"),
1920 ("42", "number"),
1921 ("3.14", "number"),
1922 ("true", "boolean"),
1923 ("null", "null"),
1924 ];
1925 for (input, expected) in cases {
1926 let expr = func("json_type", vec![lit_str(input)], LogicalType::String);
1927 assert_eq!(
1928 evaluate_constant(&expr).unwrap(),
1929 TypedValue::String(SmolStr::new(expected)),
1930 "json_type({input}) should be {expected}"
1931 );
1932 }
1933 }
1934
1935 #[test]
1936 fn json_keys_object() {
1937 let expr = func(
1938 "json_keys",
1939 vec![lit_str(r#"{"b":2,"a":1}"#)],
1940 LogicalType::List(Box::new(LogicalType::String)),
1941 );
1942 let result = evaluate_constant(&expr).unwrap();
1943 if let TypedValue::List(keys) = result {
1944 let key_strs: Vec<&str> = keys
1945 .iter()
1946 .map(|k| match k {
1947 TypedValue::String(s) => s.as_str(),
1948 _ => panic!("expected string key"),
1949 })
1950 .collect();
1951 assert!(key_strs.contains(&"a"));
1952 assert!(key_strs.contains(&"b"));
1953 assert_eq!(key_strs.len(), 2);
1954 } else {
1955 panic!("expected List, got {result:?}");
1956 }
1957 }
1958
1959 #[test]
1960 fn json_keys_non_object() {
1961 let expr = func(
1962 "json_keys",
1963 vec![lit_str("[1,2,3]")],
1964 LogicalType::List(Box::new(LogicalType::String)),
1965 );
1966 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Null);
1967 }
1968
1969 #[test]
1970 fn json_array_length_ok() {
1971 let expr = func(
1972 "json_array_length",
1973 vec![lit_str("[1,2,3,4,5]")],
1974 LogicalType::Int64,
1975 );
1976 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Int64(5));
1977 }
1978
1979 #[test]
1980 fn json_array_length_non_array() {
1981 let expr = func(
1982 "json_array_length",
1983 vec![lit_str(r#"{"a":1}"#)],
1984 LogicalType::Int64,
1985 );
1986 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Null);
1987 }
1988
1989 #[test]
1990 fn json_contains_existing() {
1991 let expr = func(
1992 "json_contains",
1993 vec![lit_str(r#"{"a":{"b":1}}"#), lit_str("$.a.b")],
1994 LogicalType::Bool,
1995 );
1996 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(true));
1997 }
1998
1999 #[test]
2000 fn json_contains_missing() {
2001 let expr = func(
2002 "json_contains",
2003 vec![lit_str(r#"{"a":1}"#), lit_str("$.x.y")],
2004 LogicalType::Bool,
2005 );
2006 assert_eq!(evaluate_constant(&expr).unwrap(), TypedValue::Bool(false));
2007 }
2008
2009 #[test]
2010 fn json_set_existing_key() {
2011 let expr = func(
2012 "json_set",
2013 vec![lit_str(r#"{"a":1,"b":2}"#), lit_str("$.a"), lit_str("99")],
2014 LogicalType::String,
2015 );
2016 let result = evaluate_constant(&expr).unwrap();
2017 if let TypedValue::String(s) = &result {
2019 let check = func(
2020 "json_extract",
2021 vec![lit_str(s.as_str()), lit_str("a")],
2022 LogicalType::String,
2023 );
2024 assert_eq!(evaluate_constant(&check).unwrap(), TypedValue::Int64(99));
2025 } else {
2026 panic!("expected String, got {result:?}");
2027 }
2028 }
2029
2030 #[test]
2031 fn json_set_nested() {
2032 let expr = func(
2033 "json_set",
2034 vec![lit_str(r#"{"a":{"x":1}}"#), lit_str("$.a.x"), lit_str("42")],
2035 LogicalType::String,
2036 );
2037 let result = evaluate_constant(&expr).unwrap();
2038 if let TypedValue::String(s) = &result {
2039 let check = func(
2040 "json_extract",
2041 vec![lit_str(s.as_str()), lit_str("a.x")],
2042 LogicalType::String,
2043 );
2044 assert_eq!(evaluate_constant(&check).unwrap(), TypedValue::Int64(42));
2045 } else {
2046 panic!("expected String, got {result:?}");
2047 }
2048 }
2049
2050 #[test]
2051 fn json_null_propagation() {
2052 let extract_null = func(
2054 "json_extract",
2055 vec![lit_null(), lit_str("$.a")],
2056 LogicalType::String,
2057 );
2058 assert_eq!(evaluate_constant(&extract_null).unwrap(), TypedValue::Null);
2059
2060 let valid_null = func("json_valid", vec![lit_null()], LogicalType::Bool);
2061 assert_eq!(evaluate_constant(&valid_null).unwrap(), TypedValue::Null);
2062 }
2063}