1use crate::error::{QueryError, Result};
40use crate::schema::{ColumnName, DataType};
41use crate::value::Value;
42use kimberlite_types::{DateField, SubstringRange};
43
44#[derive(Debug, Clone)]
50pub enum ScalarExpr {
51 Literal(Value),
53 Column(ColumnName),
55
56 Upper(Box<ScalarExpr>),
59 Lower(Box<ScalarExpr>),
61 Length(Box<ScalarExpr>),
63 Trim(Box<ScalarExpr>),
65 Concat(Vec<ScalarExpr>),
68
69 Abs(Box<ScalarExpr>),
74 Round(Box<ScalarExpr>),
76 RoundScale(Box<ScalarExpr>, i32),
80 Ceil(Box<ScalarExpr>),
82 Floor(Box<ScalarExpr>),
84
85 Coalesce(Vec<ScalarExpr>),
88 Nullif(Box<ScalarExpr>, Box<ScalarExpr>),
90
91 Cast(Box<ScalarExpr>, DataType),
98
99 Mod(Box<ScalarExpr>, Box<ScalarExpr>),
103 Power(Box<ScalarExpr>, Box<ScalarExpr>),
107 Sqrt(Box<ScalarExpr>),
110 Substring(Box<ScalarExpr>, SubstringRange),
114 Extract(DateField, Box<ScalarExpr>),
118 DateTrunc(DateField, Box<ScalarExpr>),
123 Now,
129 CurrentTimestamp,
133 CurrentDate,
136}
137
138pub struct EvalContext<'a> {
141 pub columns: &'a [ColumnName],
142 pub row: &'a [Value],
143}
144
145impl<'a> EvalContext<'a> {
146 pub fn new(columns: &'a [ColumnName], row: &'a [Value]) -> Self {
147 assert!(
148 columns.len() == row.len(),
149 "EvalContext precondition: columns and row must have equal length",
150 );
151 Self { columns, row }
152 }
153
154 fn lookup(&self, name: &ColumnName) -> Result<&Value> {
155 self.columns
156 .iter()
157 .position(|c| c == name)
158 .and_then(|idx| self.row.get(idx))
159 .ok_or_else(|| QueryError::ColumnNotFound {
160 table: String::new(),
161 column: name.to_string(),
162 })
163 }
164}
165
166pub fn evaluate(expr: &ScalarExpr, ctx: &EvalContext<'_>) -> Result<Value> {
173 match expr {
174 ScalarExpr::Literal(v) => Ok(v.clone()),
175 ScalarExpr::Column(name) => Ok(ctx.lookup(name)?.clone()),
176
177 ScalarExpr::Upper(inner) => match evaluate(inner, ctx)? {
179 Value::Null => Ok(Value::Null),
180 Value::Text(s) => Ok(Value::Text(s.to_uppercase())),
181 other => Err(type_error("UPPER", "Text", &other)),
182 },
183 ScalarExpr::Lower(inner) => match evaluate(inner, ctx)? {
184 Value::Null => Ok(Value::Null),
185 Value::Text(s) => Ok(Value::Text(s.to_lowercase())),
186 other => Err(type_error("LOWER", "Text", &other)),
187 },
188 ScalarExpr::Length(inner) => match evaluate(inner, ctx)? {
189 Value::Null => Ok(Value::Null),
190 Value::Text(s) => {
191 let chars = s.chars().count();
193 debug_assert_eq!(chars, s.chars().count());
196 Ok(Value::BigInt(chars as i64))
197 }
198 other => Err(type_error("LENGTH", "Text", &other)),
199 },
200 ScalarExpr::Trim(inner) => match evaluate(inner, ctx)? {
201 Value::Null => Ok(Value::Null),
202 Value::Text(s) => Ok(Value::Text(s.trim().to_string())),
203 other => Err(type_error("TRIM", "Text", &other)),
204 },
205 ScalarExpr::Concat(parts) => {
206 assert!(
207 !parts.is_empty(),
208 "CONCAT precondition: at least one argument"
209 );
210 let mut out = String::new();
211 for p in parts {
212 match evaluate(p, ctx)? {
213 Value::Null => return Ok(Value::Null),
214 Value::Text(s) => out.push_str(&s),
215 other => return Err(type_error("CONCAT", "Text", &other)),
216 }
217 }
218 Ok(Value::Text(out))
219 }
220
221 ScalarExpr::Abs(inner) => match evaluate(inner, ctx)? {
223 Value::Null => Ok(Value::Null),
224 Value::TinyInt(n) => Ok(Value::TinyInt(n.saturating_abs())),
225 Value::SmallInt(n) => Ok(Value::SmallInt(n.saturating_abs())),
226 Value::Integer(n) => Ok(Value::Integer(n.saturating_abs())),
227 Value::BigInt(n) => Ok(Value::BigInt(n.saturating_abs())),
228 Value::Real(n) => Ok(Value::Real(n.abs())),
229 Value::Decimal(val, scale) => Ok(Value::Decimal(val.saturating_abs(), scale)),
230 other => Err(type_error("ABS", "Numeric", &other)),
231 },
232 ScalarExpr::Round(inner) => match evaluate(inner, ctx)? {
233 Value::Null => Ok(Value::Null),
234 v @ (Value::TinyInt(_) | Value::SmallInt(_) | Value::Integer(_) | Value::BigInt(_)) => {
236 Ok(v)
237 }
238 Value::Real(x) => Ok(Value::Real(x.round())),
239 Value::Decimal(val, scale) => Ok(decimal_round_to_scale(val, scale, 0)),
240 other => Err(type_error("ROUND", "Numeric", &other)),
241 },
242 ScalarExpr::RoundScale(inner, target_scale) => {
243 assert!(
244 *target_scale >= 0 && *target_scale < i32::from(u8::MAX),
245 "ROUND scale must fit in a non-negative u8",
246 );
247 let target = u8::try_from(*target_scale).unwrap_or(0);
248 match evaluate(inner, ctx)? {
249 Value::Null => Ok(Value::Null),
250 v @ (Value::TinyInt(_)
251 | Value::SmallInt(_)
252 | Value::Integer(_)
253 | Value::BigInt(_)) => Ok(v),
254 Value::Real(x) => {
255 let factor = 10f64.powi(i32::from(target));
258 Ok(Value::Real((x * factor).round() / factor))
259 }
260 Value::Decimal(val, scale) => Ok(decimal_round_to_scale(val, scale, target)),
261 other => Err(type_error("ROUND", "Numeric", &other)),
262 }
263 }
264 ScalarExpr::Ceil(inner) => match evaluate(inner, ctx)? {
265 Value::Null => Ok(Value::Null),
266 v @ (Value::TinyInt(_) | Value::SmallInt(_) | Value::Integer(_) | Value::BigInt(_)) => {
267 Ok(v)
268 }
269 Value::Real(x) => Ok(Value::Real(x.ceil())),
270 Value::Decimal(val, scale) => {
271 if scale == 0 {
272 Ok(Value::Decimal(val, 0))
273 } else {
274 Ok(decimal_ceil(val, scale))
275 }
276 }
277 other => Err(type_error("CEIL", "Numeric", &other)),
278 },
279 ScalarExpr::Floor(inner) => match evaluate(inner, ctx)? {
280 Value::Null => Ok(Value::Null),
281 v @ (Value::TinyInt(_) | Value::SmallInt(_) | Value::Integer(_) | Value::BigInt(_)) => {
282 Ok(v)
283 }
284 Value::Real(x) => Ok(Value::Real(x.floor())),
285 Value::Decimal(val, scale) => {
286 if scale == 0 {
287 Ok(Value::Decimal(val, 0))
288 } else {
289 Ok(decimal_floor(val, scale))
290 }
291 }
292 other => Err(type_error("FLOOR", "Numeric", &other)),
293 },
294
295 ScalarExpr::Coalesce(exprs) => {
297 assert!(
298 !exprs.is_empty(),
299 "COALESCE precondition: at least one argument"
300 );
301 for e in exprs {
302 let v = evaluate(e, ctx)?;
303 if !matches!(v, Value::Null) {
304 return Ok(v);
305 }
306 }
307 Ok(Value::Null)
308 }
309 ScalarExpr::Nullif(a, b) => {
310 let av = evaluate(a, ctx)?;
311 let bv = evaluate(b, ctx)?;
312 if av == bv { Ok(Value::Null) } else { Ok(av) }
313 }
314
315 ScalarExpr::Cast(inner, target) => cast_value(evaluate(inner, ctx)?, *target),
317
318 ScalarExpr::Mod(a, b) => eval_mod(evaluate(a, ctx)?, evaluate(b, ctx)?),
320 ScalarExpr::Power(base, exp) => eval_power(evaluate(base, ctx)?, evaluate(exp, ctx)?),
321 ScalarExpr::Sqrt(inner) => eval_sqrt(evaluate(inner, ctx)?),
322
323 ScalarExpr::Substring(inner, range) => eval_substring(evaluate(inner, ctx)?, *range),
325
326 ScalarExpr::Extract(field, inner) => eval_extract(*field, evaluate(inner, ctx)?),
328 ScalarExpr::DateTrunc(field, inner) => eval_date_trunc(*field, evaluate(inner, ctx)?),
329
330 ScalarExpr::Now | ScalarExpr::CurrentTimestamp | ScalarExpr::CurrentDate => {
338 panic!(
339 "scalar evaluator received raw NOW/CURRENT_TIMESTAMP/CURRENT_DATE \
340 — fold_time_constants planner pass must run first \
341 (AUDIT-2026-05 S3.7)"
342 )
343 }
344 }
345}
346
347fn eval_mod(a: Value, b: Value) -> Result<Value> {
352 if matches!(a, Value::Null) || matches!(b, Value::Null) {
354 return Ok(Value::Null);
355 }
356 let a64 = numeric_as_i64(&a, "MOD")?;
360 let b64 = numeric_as_i64(&b, "MOD")?;
361 if b64 == 0 {
363 return Ok(Value::Null);
364 }
365 let result = a64.wrapping_rem(b64);
368 debug_assert!(
369 result.wrapping_abs() < b64.wrapping_abs() || b64 == i64::MIN,
370 "MOD postcondition violated: |{result}| >= |{b64}|"
371 );
372 Ok(match a {
375 Value::TinyInt(_) => i8::try_from(result)
376 .map(Value::TinyInt)
377 .unwrap_or(Value::BigInt(result)),
378 Value::SmallInt(_) => i16::try_from(result)
379 .map(Value::SmallInt)
380 .unwrap_or(Value::BigInt(result)),
381 Value::Integer(_) => i32::try_from(result)
382 .map(Value::Integer)
383 .unwrap_or(Value::BigInt(result)),
384 _ => Value::BigInt(result),
385 })
386}
387
388fn eval_power(base: Value, exp: Value) -> Result<Value> {
389 if matches!(base, Value::Null) || matches!(exp, Value::Null) {
390 return Ok(Value::Null);
391 }
392 let base_f = numeric_as_f64(&base, "POWER")?;
397 let exp_f = numeric_as_f64(&exp, "POWER")?;
398 let result = base_f.powf(exp_f);
399 if result.is_nan() {
402 return Err(domain_error(
403 "POWER",
404 &format!("POWER({base_f}, {exp_f}) is NaN"),
405 ));
406 }
407 Ok(Value::Real(result))
408}
409
410fn eval_sqrt(value: Value) -> Result<Value> {
411 if matches!(value, Value::Null) {
412 return Ok(Value::Null);
413 }
414 let x = numeric_as_f64(&value, "SQRT")?;
415 if x < 0.0 {
416 return Err(domain_error(
417 "SQRT",
418 &format!("SQRT of negative input ({x})"),
419 ));
420 }
421 Ok(Value::Real(x.sqrt()))
422}
423
424fn eval_substring(value: Value, range: SubstringRange) -> Result<Value> {
425 if matches!(value, Value::Null) {
426 return Ok(Value::Null);
427 }
428 let Value::Text(s) = value else {
429 return Err(type_error("SUBSTRING", "Text", &value));
430 };
431 let chars: Vec<char> = s.chars().collect();
436 let total = chars.len() as i64;
437
438 let begin_inclusive_1based = range.start;
441 let begin0 = if begin_inclusive_1based < 1 {
442 0_i64
443 } else {
444 begin_inclusive_1based - 1
445 };
446
447 let end0 = match range.length {
449 Some(len) => {
450 let raw_end = begin_inclusive_1based.saturating_sub(1).saturating_add(len);
453 raw_end.min(total).max(0)
454 }
455 None => total,
456 };
457
458 let begin_clamped = begin0.max(0).min(total) as usize;
459 let end_clamped = end0.max(0).min(total) as usize;
460
461 if begin_clamped >= end_clamped {
462 return Ok(Value::Text(String::new()));
463 }
464 let out: String = chars[begin_clamped..end_clamped].iter().collect();
465 Ok(Value::Text(out))
466}
467
468fn eval_extract(field: DateField, value: Value) -> Result<Value> {
469 use chrono::{Datelike, Timelike};
470 if matches!(value, Value::Null) {
471 return Ok(Value::Null);
472 }
473 let timestamp_ns = match &value {
474 Value::Date(days) => i64::from(*days) * 86_400_000_000_000,
475 Value::Timestamp(ts) => ts.as_nanos() as i64,
476 other => return Err(type_error("EXTRACT", "Date or Timestamp", other)),
477 };
478
479 let secs = timestamp_ns.div_euclid(1_000_000_000);
483 let nsec_part = timestamp_ns.rem_euclid(1_000_000_000) as u32;
484 let dt = chrono::DateTime::<chrono::Utc>::from_timestamp(secs, nsec_part).ok_or_else(|| {
485 domain_error(
486 "EXTRACT",
487 &format!("timestamp {timestamp_ns} ns out of chrono range"),
488 )
489 })?;
490
491 let result = match field {
492 DateField::Year => Value::Integer(dt.year()),
493 DateField::Month => Value::Integer(dt.month() as i32),
494 DateField::Day => Value::Integer(dt.day() as i32),
495 DateField::Hour => Value::Integer(dt.hour() as i32),
496 DateField::Minute => Value::Integer(dt.minute() as i32),
497 DateField::Second => Value::Integer(dt.second() as i32),
498 DateField::Millisecond => Value::Integer((dt.timestamp_subsec_millis()) as i32),
499 DateField::Microsecond => Value::Integer((dt.timestamp_subsec_micros()) as i32),
500 DateField::DayOfWeek => {
501 let nfu = dt.weekday().num_days_from_sunday() as i32;
503 Value::Integer(nfu)
504 }
505 DateField::DayOfYear => Value::Integer(dt.ordinal() as i32),
506 DateField::Quarter => Value::Integer(((dt.month() - 1) / 3 + 1) as i32),
507 DateField::Week => Value::Integer(dt.iso_week().week() as i32),
508 DateField::Epoch => Value::BigInt(secs),
509 };
510 Ok(result)
511}
512
513fn eval_date_trunc(field: DateField, value: Value) -> Result<Value> {
514 use chrono::{Datelike, NaiveDate, NaiveDateTime, Timelike};
515 if matches!(value, Value::Null) {
516 return Ok(Value::Null);
517 }
518 if !field.is_truncatable() {
519 return Err(QueryError::ParseError(format!(
520 "DATE_TRUNC field {field:?} is not truncatable (use one of YEAR, MONTH, DAY, HOUR, MINUTE, SECOND)"
521 )));
522 }
523 let timestamp_ns = match &value {
524 Value::Date(days) => i64::from(*days) * 86_400_000_000_000,
525 Value::Timestamp(ts) => ts.as_nanos() as i64,
526 other => return Err(type_error("DATE_TRUNC", "Date or Timestamp", other)),
527 };
528
529 let secs = timestamp_ns.div_euclid(1_000_000_000);
530 let nsec_part = timestamp_ns.rem_euclid(1_000_000_000) as u32;
531 let dt = chrono::DateTime::<chrono::Utc>::from_timestamp(secs, nsec_part)
532 .ok_or_else(|| domain_error("DATE_TRUNC", "timestamp out of range"))?;
533 let nv = dt.naive_utc();
534
535 let truncated: NaiveDateTime = match field {
536 DateField::Year => NaiveDate::from_ymd_opt(nv.year(), 1, 1)
537 .and_then(|d| d.and_hms_opt(0, 0, 0))
538 .ok_or_else(|| domain_error("DATE_TRUNC", "year truncation"))?,
539 DateField::Month => NaiveDate::from_ymd_opt(nv.year(), nv.month(), 1)
540 .and_then(|d| d.and_hms_opt(0, 0, 0))
541 .ok_or_else(|| domain_error("DATE_TRUNC", "month truncation"))?,
542 DateField::Day => NaiveDate::from_ymd_opt(nv.year(), nv.month(), nv.day())
543 .and_then(|d| d.and_hms_opt(0, 0, 0))
544 .ok_or_else(|| domain_error("DATE_TRUNC", "day truncation"))?,
545 DateField::Hour => nv
546 .date()
547 .and_hms_opt(nv.hour(), 0, 0)
548 .ok_or_else(|| domain_error("DATE_TRUNC", "hour truncation"))?,
549 DateField::Minute => nv
550 .date()
551 .and_hms_opt(nv.hour(), nv.minute(), 0)
552 .ok_or_else(|| domain_error("DATE_TRUNC", "minute truncation"))?,
553 DateField::Second => nv
554 .date()
555 .and_hms_opt(nv.hour(), nv.minute(), nv.second())
556 .ok_or_else(|| domain_error("DATE_TRUNC", "second truncation"))?,
557 _ => unreachable!("non-truncatable field passed `is_truncatable` check"),
560 };
561
562 let truncated_ns = truncated
563 .and_utc()
564 .timestamp_nanos_opt()
565 .ok_or_else(|| domain_error("DATE_TRUNC", "truncated timestamp out of nanos range"))?;
566
567 match (&value, field) {
570 (Value::Date(_), DateField::Year | DateField::Month | DateField::Day) => Ok(Value::Date(
571 i32::try_from(truncated_ns / 86_400_000_000_000).unwrap_or(0),
572 )),
573 _ => Ok(Value::Timestamp(kimberlite_types::Timestamp::from_nanos(
574 truncated_ns.max(0) as u64,
575 ))),
576 }
577}
578
579fn numeric_as_i64(v: &Value, fn_name: &str) -> Result<i64> {
581 match v {
582 Value::TinyInt(n) => Ok(i64::from(*n)),
583 Value::SmallInt(n) => Ok(i64::from(*n)),
584 Value::Integer(n) => Ok(i64::from(*n)),
585 Value::BigInt(n) => Ok(*n),
586 other => Err(type_error(fn_name, "Integer", other)),
587 }
588}
589
590fn numeric_as_f64(v: &Value, fn_name: &str) -> Result<f64> {
592 match v {
593 Value::TinyInt(n) => Ok(f64::from(*n)),
594 Value::SmallInt(n) => Ok(f64::from(*n)),
595 Value::Integer(n) => Ok(f64::from(*n)),
596 #[allow(clippy::cast_precision_loss)]
597 Value::BigInt(n) => Ok(*n as f64),
598 Value::Real(n) => Ok(*n),
599 Value::Decimal(val, scale) => {
600 #[allow(clippy::cast_precision_loss)]
601 let f = (*val as f64) / 10f64.powi(i32::from(*scale));
602 Ok(f)
603 }
604 other => Err(type_error(fn_name, "Numeric", other)),
605 }
606}
607
608fn domain_error(fn_name: &str, detail: &str) -> QueryError {
609 QueryError::TypeMismatch {
610 expected: format!("{fn_name} domain"),
611 actual: detail.to_string(),
612 }
613}
614
615fn cast_value(value: Value, target: DataType) -> Result<Value> {
622 if matches!(value, Value::Null) {
623 return Ok(Value::Null);
624 }
625 match (value, target) {
626 (v @ Value::TinyInt(_), DataType::TinyInt)
629 | (v @ Value::SmallInt(_), DataType::SmallInt)
630 | (v @ Value::Integer(_), DataType::Integer)
631 | (v @ Value::BigInt(_), DataType::BigInt)
632 | (v @ Value::Real(_), DataType::Real)
633 | (v @ Value::Text(_), DataType::Text)
634 | (v @ Value::Bytes(_), DataType::Bytes)
635 | (v @ Value::Boolean(_), DataType::Boolean)
636 | (v @ Value::Date(_), DataType::Date)
637 | (v @ Value::Time(_), DataType::Time)
638 | (v @ Value::Timestamp(_), DataType::Timestamp)
639 | (v @ Value::Uuid(_), DataType::Uuid)
640 | (v @ Value::Json(_), DataType::Json) => Ok(v),
641
642 (Value::TinyInt(n), DataType::SmallInt) => Ok(Value::SmallInt(i16::from(n))),
644 (Value::TinyInt(n), DataType::Integer) => Ok(Value::Integer(i32::from(n))),
645 (Value::TinyInt(n), DataType::BigInt) => Ok(Value::BigInt(i64::from(n))),
646 (Value::SmallInt(n), DataType::Integer) => Ok(Value::Integer(i32::from(n))),
647 (Value::SmallInt(n), DataType::BigInt) => Ok(Value::BigInt(i64::from(n))),
648 (Value::Integer(n), DataType::BigInt) => Ok(Value::BigInt(i64::from(n))),
649
650 (Value::SmallInt(n), DataType::TinyInt) => i8::try_from(n)
652 .map(Value::TinyInt)
653 .map_err(|_| cast_error("SmallInt", "TinyInt", "overflow")),
654 (Value::Integer(n), DataType::TinyInt) => i8::try_from(n)
655 .map(Value::TinyInt)
656 .map_err(|_| cast_error("Integer", "TinyInt", "overflow")),
657 (Value::Integer(n), DataType::SmallInt) => i16::try_from(n)
658 .map(Value::SmallInt)
659 .map_err(|_| cast_error("Integer", "SmallInt", "overflow")),
660 (Value::BigInt(n), DataType::TinyInt) => i8::try_from(n)
661 .map(Value::TinyInt)
662 .map_err(|_| cast_error("BigInt", "TinyInt", "overflow")),
663 (Value::BigInt(n), DataType::SmallInt) => i16::try_from(n)
664 .map(Value::SmallInt)
665 .map_err(|_| cast_error("BigInt", "SmallInt", "overflow")),
666 (Value::BigInt(n), DataType::Integer) => i32::try_from(n)
667 .map(Value::Integer)
668 .map_err(|_| cast_error("BigInt", "Integer", "overflow")),
669
670 (Value::TinyInt(n), DataType::Real) => Ok(Value::Real(f64::from(n))),
673 (Value::SmallInt(n), DataType::Real) => Ok(Value::Real(f64::from(n))),
674 (Value::Integer(n), DataType::Real) => Ok(Value::Real(f64::from(n))),
675 #[allow(clippy::cast_precision_loss)]
676 (Value::BigInt(n), DataType::Real) => Ok(Value::Real(n as f64)),
677
678 (Value::Real(x), DataType::TinyInt) => f64_to_int::<i8>(x, "TinyInt").map(Value::TinyInt),
680 (Value::Real(x), DataType::SmallInt) => {
681 f64_to_int::<i16>(x, "SmallInt").map(Value::SmallInt)
682 }
683 (Value::Real(x), DataType::Integer) => f64_to_int::<i32>(x, "Integer").map(Value::Integer),
684 (Value::Real(x), DataType::BigInt) => f64_to_int::<i64>(x, "BigInt").map(Value::BigInt),
685
686 (Value::Text(s), DataType::TinyInt) => s
688 .trim()
689 .parse::<i8>()
690 .map(Value::TinyInt)
691 .map_err(|_| cast_error("Text", "TinyInt", &s)),
692 (Value::Text(s), DataType::SmallInt) => s
693 .trim()
694 .parse::<i16>()
695 .map(Value::SmallInt)
696 .map_err(|_| cast_error("Text", "SmallInt", &s)),
697 (Value::Text(s), DataType::Integer) => s
698 .trim()
699 .parse::<i32>()
700 .map(Value::Integer)
701 .map_err(|_| cast_error("Text", "Integer", &s)),
702 (Value::Text(s), DataType::BigInt) => s
703 .trim()
704 .parse::<i64>()
705 .map(Value::BigInt)
706 .map_err(|_| cast_error("Text", "BigInt", &s)),
707 (Value::Text(s), DataType::Real) => s
708 .trim()
709 .parse::<f64>()
710 .map(Value::Real)
711 .map_err(|_| cast_error("Text", "Real", &s)),
712 (Value::Text(s), DataType::Boolean) => match s.trim().to_ascii_lowercase().as_str() {
713 "true" | "t" | "1" => Ok(Value::Boolean(true)),
714 "false" | "f" | "0" => Ok(Value::Boolean(false)),
715 _ => Err(cast_error("Text", "Boolean", &s)),
716 },
717
718 (Value::TinyInt(n), DataType::Text) => Ok(Value::Text(n.to_string())),
720 (Value::SmallInt(n), DataType::Text) => Ok(Value::Text(n.to_string())),
721 (Value::Integer(n), DataType::Text) => Ok(Value::Text(n.to_string())),
722 (Value::BigInt(n), DataType::Text) => Ok(Value::Text(n.to_string())),
723 (Value::Real(n), DataType::Text) => Ok(Value::Text(n.to_string())),
724 (Value::Boolean(b), DataType::Text) => {
725 Ok(Value::Text(if b { "true" } else { "false" }.to_string()))
726 }
727
728 (v, t) => Err(QueryError::TypeMismatch {
730 expected: format!("CAST to {t:?}"),
731 actual: format!("{v:?}"),
732 }),
733 }
734}
735
736fn cast_error(from: &str, to: &str, detail: &str) -> QueryError {
737 QueryError::TypeMismatch {
738 expected: format!("CAST from {from} to {to}"),
739 actual: detail.to_string(),
740 }
741}
742
743fn f64_to_int<T>(x: f64, target: &str) -> Result<T>
746where
747 T: TryFrom<i64>,
748{
749 if !x.is_finite() {
750 return Err(cast_error("Real", target, &format!("{x}")));
751 }
752 let truncated = x.trunc();
754 #[allow(clippy::cast_possible_truncation)]
756 let as_i64 = if (i64::MIN as f64) <= truncated && truncated <= (i64::MAX as f64) {
757 truncated as i64
758 } else {
759 return Err(cast_error("Real", target, &format!("{x}")));
760 };
761 T::try_from(as_i64).map_err(|_| cast_error("Real", target, &format!("{x}")))
762}
763
764fn type_error(func: &str, expected: &str, got: &Value) -> QueryError {
765 QueryError::TypeMismatch {
766 expected: format!("{func} argument of type {expected}"),
767 actual: format!("{got:?}"),
768 }
769}
770
771fn decimal_round_to_scale(val: i128, from_scale: u8, to_scale: u8) -> Value {
776 if from_scale == to_scale {
777 return Value::Decimal(val, to_scale);
778 }
779 if to_scale > from_scale {
780 let diff = u32::from(to_scale - from_scale);
781 let factor = 10i128.pow(diff);
782 return Value::Decimal(val.saturating_mul(factor), to_scale);
783 }
784 let diff = u32::from(from_scale - to_scale);
786 let divisor = 10i128.pow(diff);
787 let half = divisor / 2;
788 let rounded = if val >= 0 {
789 (val + half) / divisor
790 } else {
791 (val - half) / divisor
792 };
793 Value::Decimal(rounded, to_scale)
794}
795
796fn decimal_ceil(val: i128, scale: u8) -> Value {
797 let divisor = 10i128.pow(u32::from(scale));
798 let floor_val = val / divisor;
799 let remainder = val % divisor;
800 let ceil = if remainder > 0 {
801 floor_val + 1
802 } else {
803 floor_val
804 };
805 Value::Decimal(ceil, 0)
806}
807
808fn decimal_floor(val: i128, scale: u8) -> Value {
809 let divisor = 10i128.pow(u32::from(scale));
810 let floor_val = val / divisor;
811 let remainder = val % divisor;
812 let floor = if remainder < 0 {
814 floor_val - 1
815 } else {
816 floor_val
817 };
818 Value::Decimal(floor, 0)
819}
820
821#[cfg(test)]
822mod tests {
823 use super::*;
824 use kimberlite_types::{DateField, SubstringRange};
825
826 fn ctx_empty() -> (Vec<ColumnName>, Vec<Value>) {
827 (Vec::new(), Vec::new())
828 }
829
830 fn lit(v: Value) -> ScalarExpr {
831 ScalarExpr::Literal(v)
832 }
833
834 fn eval_standalone(expr: &ScalarExpr) -> Result<Value> {
835 let (cols, row) = ctx_empty();
836 evaluate(expr, &EvalContext::new(&cols, &row))
837 }
838
839 #[test]
840 fn upper_lower_length_trim() {
841 assert_eq!(
842 eval_standalone(&ScalarExpr::Upper(Box::new(lit(Value::Text(
843 "hello".into()
844 )))))
845 .unwrap(),
846 Value::Text("HELLO".into()),
847 );
848 assert_eq!(
849 eval_standalone(&ScalarExpr::Lower(Box::new(lit(Value::Text(
850 "WORLD".into()
851 )))))
852 .unwrap(),
853 Value::Text("world".into()),
854 );
855 assert_eq!(
856 eval_standalone(&ScalarExpr::Length(Box::new(lit(Value::Text(
857 "café".into()
858 )))))
859 .unwrap(),
860 Value::BigInt(4),
861 "LENGTH is char count, not byte count",
862 );
863 assert_eq!(
864 eval_standalone(&ScalarExpr::Trim(Box::new(lit(Value::Text(
865 " hi ".into(),
866 )))))
867 .unwrap(),
868 Value::Text("hi".into()),
869 );
870 }
871
872 #[test]
873 fn concat_propagates_null_like_postgres() {
874 let ex = ScalarExpr::Concat(vec![
875 lit(Value::Text("a".into())),
876 lit(Value::Null),
877 lit(Value::Text("b".into())),
878 ]);
879 assert_eq!(eval_standalone(&ex).unwrap(), Value::Null);
880 }
881
882 #[test]
883 fn abs_preserves_subtype() {
884 assert_eq!(
885 eval_standalone(&ScalarExpr::Abs(Box::new(lit(Value::Integer(-5))))).unwrap(),
886 Value::Integer(5),
887 );
888 assert_eq!(
889 eval_standalone(&ScalarExpr::Abs(Box::new(lit(Value::Real(-1.5))))).unwrap(),
890 Value::Real(1.5),
891 );
892 }
893
894 #[test]
895 fn round_with_scale_rounds_decimal() {
896 let rounded = eval_standalone(&ScalarExpr::RoundScale(
898 Box::new(lit(Value::Decimal(12345, 2))),
899 1,
900 ))
901 .unwrap();
902 assert_eq!(rounded, Value::Decimal(1235, 1));
903
904 let rounded = eval_standalone(&ScalarExpr::RoundScale(
906 Box::new(lit(Value::Decimal(12344, 2))),
907 1,
908 ))
909 .unwrap();
910 assert_eq!(rounded, Value::Decimal(1234, 1));
911
912 let rounded = eval_standalone(&ScalarExpr::RoundScale(
914 Box::new(lit(Value::Decimal(-12345, 2))),
915 1,
916 ))
917 .unwrap();
918 assert_eq!(rounded, Value::Decimal(-1235, 1));
919 }
920
921 #[test]
922 fn ceil_and_floor_decimal() {
923 let c =
924 eval_standalone(&ScalarExpr::Ceil(Box::new(lit(Value::Decimal(12345, 2))))).unwrap();
925 assert_eq!(c, Value::Decimal(124, 0));
926 let f =
927 eval_standalone(&ScalarExpr::Floor(Box::new(lit(Value::Decimal(12345, 2))))).unwrap();
928 assert_eq!(f, Value::Decimal(123, 0));
929 }
930
931 #[test]
932 fn coalesce_returns_first_non_null() {
933 let ex = ScalarExpr::Coalesce(vec![
934 lit(Value::Null),
935 lit(Value::Null),
936 lit(Value::BigInt(42)),
937 lit(Value::BigInt(99)),
938 ]);
939 assert_eq!(eval_standalone(&ex).unwrap(), Value::BigInt(42));
940 }
941
942 #[test]
943 fn nullif_returns_null_when_equal() {
944 let eq = ScalarExpr::Nullif(
945 Box::new(lit(Value::Text("x".into()))),
946 Box::new(lit(Value::Text("x".into()))),
947 );
948 assert_eq!(eval_standalone(&eq).unwrap(), Value::Null);
949 let ne = ScalarExpr::Nullif(
950 Box::new(lit(Value::Text("x".into()))),
951 Box::new(lit(Value::Text("y".into()))),
952 );
953 assert_eq!(eval_standalone(&ne).unwrap(), Value::Text("x".into()));
954 }
955
956 #[test]
957 fn column_reference_resolves() {
958 let cols = vec![ColumnName::new(String::from("name"))];
959 let row = vec![Value::Text("Ada".into())];
960 let ctx = EvalContext::new(&cols, &row);
961 let ex = ScalarExpr::Upper(Box::new(ScalarExpr::Column(ColumnName::new(String::from(
962 "name",
963 )))));
964 assert_eq!(evaluate(&ex, &ctx).unwrap(), Value::Text("ADA".into()));
965 }
966
967 #[test]
968 fn null_input_propagates_through_scalar_fns() {
969 for expr in [
970 ScalarExpr::Upper(Box::new(lit(Value::Null))),
971 ScalarExpr::Lower(Box::new(lit(Value::Null))),
972 ScalarExpr::Length(Box::new(lit(Value::Null))),
973 ScalarExpr::Trim(Box::new(lit(Value::Null))),
974 ScalarExpr::Abs(Box::new(lit(Value::Null))),
975 ScalarExpr::Round(Box::new(lit(Value::Null))),
976 ScalarExpr::Ceil(Box::new(lit(Value::Null))),
977 ScalarExpr::Floor(Box::new(lit(Value::Null))),
978 ScalarExpr::Cast(Box::new(lit(Value::Null)), DataType::Integer),
979 ] {
980 assert_eq!(eval_standalone(&expr).unwrap(), Value::Null);
981 }
982 }
983
984 #[test]
985 fn cast_integer_widening_and_narrowing() {
986 let w = eval_standalone(&ScalarExpr::Cast(
988 Box::new(lit(Value::TinyInt(42))),
989 DataType::BigInt,
990 ))
991 .unwrap();
992 assert_eq!(w, Value::BigInt(42));
993
994 let ok = eval_standalone(&ScalarExpr::Cast(
996 Box::new(lit(Value::BigInt(127))),
997 DataType::TinyInt,
998 ))
999 .unwrap();
1000 assert_eq!(ok, Value::TinyInt(127));
1001
1002 let err = eval_standalone(&ScalarExpr::Cast(
1004 Box::new(lit(Value::BigInt(i64::from(i16::MAX) + 1))),
1005 DataType::SmallInt,
1006 ));
1007 assert!(err.is_err(), "narrowing overflow must be an error");
1008 }
1009
1010 #[test]
1011 fn cast_text_to_numeric_parses() {
1012 assert_eq!(
1013 eval_standalone(&ScalarExpr::Cast(
1014 Box::new(lit(Value::Text("42".into()))),
1015 DataType::Integer,
1016 ))
1017 .unwrap(),
1018 Value::Integer(42),
1019 );
1020 assert_eq!(
1021 eval_standalone(&ScalarExpr::Cast(
1022 Box::new(lit(Value::Text("1.5".into()))),
1023 DataType::Real,
1024 ))
1025 .unwrap(),
1026 Value::Real(1.5),
1027 );
1028 assert!(
1029 eval_standalone(&ScalarExpr::Cast(
1030 Box::new(lit(Value::Text("nope".into()))),
1031 DataType::Integer,
1032 ))
1033 .is_err(),
1034 "unparseable text must error rather than coerce to 0"
1035 );
1036 }
1037
1038 #[test]
1039 fn cast_numeric_to_text_formats_canonically() {
1040 assert_eq!(
1041 eval_standalone(&ScalarExpr::Cast(
1042 Box::new(lit(Value::BigInt(99))),
1043 DataType::Text,
1044 ))
1045 .unwrap(),
1046 Value::Text("99".into()),
1047 );
1048 assert_eq!(
1049 eval_standalone(&ScalarExpr::Cast(
1050 Box::new(lit(Value::Boolean(true))),
1051 DataType::Text,
1052 ))
1053 .unwrap(),
1054 Value::Text("true".into()),
1055 );
1056 }
1057
1058 #[test]
1059 fn cast_real_to_int_truncates_toward_zero() {
1060 assert_eq!(
1061 eval_standalone(&ScalarExpr::Cast(
1062 Box::new(lit(Value::Real(1.9))),
1063 DataType::Integer,
1064 ))
1065 .unwrap(),
1066 Value::Integer(1),
1067 );
1068 assert_eq!(
1069 eval_standalone(&ScalarExpr::Cast(
1070 Box::new(lit(Value::Real(-1.9))),
1071 DataType::Integer,
1072 ))
1073 .unwrap(),
1074 Value::Integer(-1),
1075 );
1076 assert!(
1077 eval_standalone(&ScalarExpr::Cast(
1078 Box::new(lit(Value::Real(f64::NAN))),
1079 DataType::Integer,
1080 ))
1081 .is_err(),
1082 "NaN cast must error"
1083 );
1084 }
1085
1086 #[test]
1087 fn cast_text_to_boolean_accepts_common_literals() {
1088 for (s, want) in [
1089 ("true", true),
1090 ("TRUE", true),
1091 ("t", true),
1092 ("1", true),
1093 ("false", false),
1094 ("F", false),
1095 ("0", false),
1096 ] {
1097 assert_eq!(
1098 eval_standalone(&ScalarExpr::Cast(
1099 Box::new(lit(Value::Text(s.into()))),
1100 DataType::Boolean,
1101 ))
1102 .unwrap(),
1103 Value::Boolean(want),
1104 "cast('{s}' as boolean)",
1105 );
1106 }
1107 }
1108
1109 #[test]
1114 fn mod_basic() {
1115 assert_eq!(
1116 eval_standalone(&ScalarExpr::Mod(
1117 Box::new(lit(Value::BigInt(10))),
1118 Box::new(lit(Value::BigInt(3)))
1119 ))
1120 .unwrap(),
1121 Value::BigInt(1),
1122 );
1123 }
1124
1125 #[test]
1126 fn mod_by_zero_returns_null_not_panic() {
1127 assert_eq!(
1129 eval_standalone(&ScalarExpr::Mod(
1130 Box::new(lit(Value::BigInt(7))),
1131 Box::new(lit(Value::BigInt(0)))
1132 ))
1133 .unwrap(),
1134 Value::Null,
1135 );
1136 }
1137
1138 #[test]
1139 fn mod_propagates_null() {
1140 assert_eq!(
1141 eval_standalone(&ScalarExpr::Mod(
1142 Box::new(lit(Value::Null)),
1143 Box::new(lit(Value::BigInt(3)))
1144 ))
1145 .unwrap(),
1146 Value::Null,
1147 );
1148 }
1149
1150 #[test]
1151 fn power_returns_real() {
1152 let r = eval_standalone(&ScalarExpr::Power(
1153 Box::new(lit(Value::BigInt(2))),
1154 Box::new(lit(Value::BigInt(10))),
1155 ))
1156 .unwrap();
1157 match r {
1158 Value::Real(x) => assert!((x - 1024.0).abs() < 1e-9),
1159 other => panic!("expected Real, got {other:?}"),
1160 }
1161 }
1162
1163 #[test]
1164 fn sqrt_basic() {
1165 let r = eval_standalone(&ScalarExpr::Sqrt(Box::new(lit(Value::BigInt(16))))).unwrap();
1166 match r {
1167 Value::Real(x) => assert!((x - 4.0).abs() < 1e-9),
1168 other => panic!("expected Real, got {other:?}"),
1169 }
1170 }
1171
1172 #[test]
1173 fn sqrt_negative_is_domain_error() {
1174 let err = eval_standalone(&ScalarExpr::Sqrt(Box::new(lit(Value::BigInt(-1)))))
1175 .expect_err("sqrt(-1) is a domain error");
1176 let msg = format!("{err:?}");
1177 assert!(msg.contains("SQRT") || msg.to_lowercase().contains("domain"));
1178 }
1179
1180 #[test]
1181 fn substring_basic() {
1182 let r = eval_standalone(&ScalarExpr::Substring(
1183 Box::new(lit(Value::Text("kimberlite".into()))),
1184 SubstringRange::try_new(1, 5).unwrap(),
1185 ))
1186 .unwrap();
1187 assert_eq!(r, Value::Text("kimbe".into()));
1188 }
1189
1190 #[test]
1191 fn substring_two_arg_form() {
1192 let r = eval_standalone(&ScalarExpr::Substring(
1194 Box::new(lit(Value::Text("kimberlite".into()))),
1195 SubstringRange::from_start(5),
1196 ))
1197 .unwrap();
1198 assert_eq!(r, Value::Text("erlite".into()));
1199 }
1200
1201 #[test]
1202 fn substring_unicode_char_correct() {
1203 let r = eval_standalone(&ScalarExpr::Substring(
1205 Box::new(lit(Value::Text("café".into()))),
1206 SubstringRange::try_new(1, 3).unwrap(),
1207 ))
1208 .unwrap();
1209 assert_eq!(r, Value::Text("caf".into()));
1210 }
1211
1212 #[test]
1213 fn substring_negative_start_clips_left() {
1214 let r = eval_standalone(&ScalarExpr::Substring(
1216 Box::new(lit(Value::Text("hello".into()))),
1217 SubstringRange::try_new(-1, 5).unwrap(),
1218 ))
1219 .unwrap();
1220 assert_eq!(r, Value::Text("hel".into()));
1221 }
1222
1223 #[test]
1224 fn substring_propagates_null() {
1225 let r = eval_standalone(&ScalarExpr::Substring(
1226 Box::new(lit(Value::Null)),
1227 SubstringRange::from_start(1),
1228 ))
1229 .unwrap();
1230 assert_eq!(r, Value::Null);
1231 }
1232
1233 #[test]
1234 fn extract_year_from_timestamp() {
1235 let ts = kimberlite_types::Timestamp::from_nanos(1_746_316_800 * 1_000_000_000);
1237 let r = eval_standalone(&ScalarExpr::Extract(
1238 DateField::Year,
1239 Box::new(lit(Value::Timestamp(ts))),
1240 ))
1241 .unwrap();
1242 assert_eq!(r, Value::Integer(2025));
1243 }
1244
1245 #[test]
1246 fn extract_month_day_from_date() {
1247 let days_since_epoch = 20_212_i32;
1250 let r_month = eval_standalone(&ScalarExpr::Extract(
1251 DateField::Month,
1252 Box::new(lit(Value::Date(days_since_epoch))),
1253 ))
1254 .unwrap();
1255 let r_day = eval_standalone(&ScalarExpr::Extract(
1256 DateField::Day,
1257 Box::new(lit(Value::Date(days_since_epoch))),
1258 ))
1259 .unwrap();
1260 assert_eq!(r_month, Value::Integer(5));
1261 assert_eq!(r_day, Value::Integer(4));
1262 }
1263
1264 #[test]
1265 fn extract_epoch_from_timestamp() {
1266 let ts = kimberlite_types::Timestamp::from_nanos(1_746_316_800 * 1_000_000_000);
1268 let r = eval_standalone(&ScalarExpr::Extract(
1269 DateField::Epoch,
1270 Box::new(lit(Value::Timestamp(ts))),
1271 ))
1272 .unwrap();
1273 assert_eq!(r, Value::BigInt(1_746_316_800));
1274 }
1275
1276 #[test]
1277 fn extract_propagates_null() {
1278 let r = eval_standalone(&ScalarExpr::Extract(
1279 DateField::Year,
1280 Box::new(lit(Value::Null)),
1281 ))
1282 .unwrap();
1283 assert_eq!(r, Value::Null);
1284 }
1285
1286 #[test]
1287 fn extract_rejects_non_temporal_input() {
1288 let err = eval_standalone(&ScalarExpr::Extract(
1289 DateField::Year,
1290 Box::new(lit(Value::Text("not a date".into()))),
1291 ))
1292 .expect_err("EXTRACT requires Date or Timestamp");
1293 assert!(format!("{err:?}").contains("EXTRACT"));
1294 }
1295
1296 #[test]
1297 fn date_trunc_year_collapses_to_january_first() {
1298 let ts = kimberlite_types::Timestamp::from_nanos(1_746_362_096 * 1_000_000_000);
1301 let r = eval_standalone(&ScalarExpr::DateTrunc(
1302 DateField::Year,
1303 Box::new(lit(Value::Timestamp(ts))),
1304 ))
1305 .unwrap();
1306 match r {
1307 Value::Timestamp(out) => {
1308 assert_eq!(out.as_nanos() as i64, 1_735_689_600 * 1_000_000_000_i64);
1309 }
1310 other => panic!("expected Timestamp, got {other:?}"),
1311 }
1312 }
1313
1314 #[test]
1315 fn date_trunc_rejects_non_truncatable_field() {
1316 let ts = kimberlite_types::Timestamp::from_nanos(1_746_316_800 * 1_000_000_000);
1317 let err = eval_standalone(&ScalarExpr::DateTrunc(
1318 DateField::Quarter,
1319 Box::new(lit(Value::Timestamp(ts))),
1320 ))
1321 .expect_err("DATE_TRUNC rejects non-truncatable field");
1322 assert!(format!("{err:?}").to_lowercase().contains("trunc"));
1323 }
1324
1325 #[test]
1326 fn date_trunc_propagates_null() {
1327 let r = eval_standalone(&ScalarExpr::DateTrunc(
1328 DateField::Year,
1329 Box::new(lit(Value::Null)),
1330 ))
1331 .unwrap();
1332 assert_eq!(r, Value::Null);
1333 }
1334
1335 #[test]
1336 #[should_panic(expected = "fold_time_constants")]
1337 fn now_panics_at_evaluator_when_unfolded() {
1338 let _ = eval_standalone(&ScalarExpr::Now);
1343 }
1344
1345 #[test]
1346 #[should_panic(expected = "fold_time_constants")]
1347 fn current_timestamp_panics_at_evaluator_when_unfolded() {
1348 let _ = eval_standalone(&ScalarExpr::CurrentTimestamp);
1349 }
1350
1351 #[test]
1352 #[should_panic(expected = "fold_time_constants")]
1353 fn current_date_panics_at_evaluator_when_unfolded() {
1354 let _ = eval_standalone(&ScalarExpr::CurrentDate);
1355 }
1356}