use alloc::format;
use alloc::vec::Vec;
use spg_sql::ast::{BinOp, UnOp};
use spg_storage::{DataType, Value};
use super::{
EvalError, add_months_to_civil, civil_from_days, days_from_civil, inet_op_bool_result,
parse_date_literal, parse_timestamp_literal, ts_match, tsvector_concat, value_to_text,
};
pub(super) fn apply_unary(op: UnOp, v: Value) -> Result<Value, EvalError> {
match (op, v) {
(_, Value::Null) => Ok(Value::Null),
(UnOp::Neg, Value::Int(n)) => {
n.checked_neg()
.map(Value::Int)
.ok_or(EvalError::TypeMismatch {
detail: "integer overflow on unary -".into(),
})
}
(UnOp::Neg, Value::BigInt(n)) => {
n.checked_neg()
.map(Value::BigInt)
.ok_or(EvalError::TypeMismatch {
detail: "bigint overflow on unary -".into(),
})
}
(UnOp::Neg, Value::Float(x)) => Ok(Value::Float(-x)),
(UnOp::Neg, other) => Err(EvalError::TypeMismatch {
detail: format!("unary - applied to {:?}", other.data_type()),
}),
(UnOp::BitNot, Value::SmallInt(n)) => Ok(Value::Int(!i32::from(n))),
(UnOp::BitNot, Value::Int(n)) => Ok(Value::Int(!n)),
(UnOp::BitNot, Value::BigInt(n)) => Ok(Value::BigInt(!n)),
(UnOp::BitNot, other) => Err(EvalError::TypeMismatch {
detail: format!("cannot apply ~ to {other:?}"),
}),
(UnOp::Not, Value::Bool(b)) => Ok(Value::Bool(!b)),
(UnOp::Not, other) => Err(EvalError::TypeMismatch {
detail: format!("NOT applied to {:?}", other.data_type()),
}),
}
}
fn values_not_distinct(l: &Value, r: &Value) -> bool {
match (l, r) {
(Value::Null, Value::Null) => true,
(Value::Null, _) | (_, Value::Null) => false,
_ => l == r,
}
}
pub(super) fn apply_binary(op: BinOp, l: Value, r: Value) -> Result<Value, EvalError> {
if let BinOp::And = op {
return and_3vl(l, r);
}
if let BinOp::Or = op {
return or_3vl(l, r);
}
if let BinOp::IsNotDistinctFrom = op {
return Ok(Value::Bool(values_not_distinct(&l, &r)));
}
if let BinOp::IsDistinctFrom = op {
return Ok(Value::Bool(!values_not_distinct(&l, &r)));
}
if l.is_null() || r.is_null() {
return Ok(Value::Null);
}
if matches!(l, Value::Numeric { .. }) || matches!(r, Value::Numeric { .. }) {
return apply_binary_numeric(op, l, r);
}
if let Some(result) = apply_binary_calendar(op, &l, &r)? {
return Ok(result);
}
match op {
BinOp::Add => arith(l, r, i64::checked_add, |a, b| a + b, "+"),
BinOp::Sub => arith(l, r, i64::checked_sub, |a, b| a - b, "-"),
BinOp::Mul => arith(l, r, i64::checked_mul, |a, b| a * b, "*"),
BinOp::Div => div_op(l, r),
BinOp::L2Distance => l2_distance(l, r),
BinOp::InnerProduct => inner_product(l, r),
BinOp::CosineDistance => cosine_distance(l, r),
BinOp::Concat => Ok(text_concat(&l, &r)),
BinOp::BitOr => bitop(l, r, |a, b| a | b, "|"),
BinOp::BitAnd => bitop(l, r, |a, b| a & b, "&"),
BinOp::JsonGet => crate::json::path_get(&l, &r, false),
BinOp::JsonGetText => crate::json::path_get(&l, &r, true),
BinOp::JsonGetPath => crate::json::path_walk(&l, &r, false),
BinOp::JsonGetPathText => crate::json::path_walk(&l, &r, true),
BinOp::JsonContains => crate::json::contains(&l, &r),
BinOp::TsMatch => ts_match(l, r),
BinOp::InetContainedBy
| BinOp::InetContainedByEq
| BinOp::InetContains
| BinOp::InetContainsEq
| BinOp::InetOverlap => inet_op_bool_result(op, &l, &r),
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
compare(op, &l, &r)
}
BinOp::And | BinOp::Or | BinOp::IsDistinctFrom | BinOp::IsNotDistinctFrom => {
unreachable!("handled above")
}
}
}
fn apply_binary_calendar(op: BinOp, l: &Value, r: &Value) -> Result<Option<Value>, EvalError> {
let int_value = |v: &Value| -> Option<i64> {
match v {
Value::SmallInt(n) => Some(i64::from(*n)),
Value::Int(n) => Some(i64::from(*n)),
Value::BigInt(n) => Some(*n),
_ => None,
}
};
match (l, r) {
(Value::Date(a), Value::Date(b)) if op == BinOp::Sub => {
return Ok(Some(Value::BigInt(i64::from(*a) - i64::from(*b))));
}
(Value::Timestamp(a), Value::Timestamp(b)) if op == BinOp::Sub => {
let delta = a.checked_sub(*b).ok_or(EvalError::TypeMismatch {
detail: "TIMESTAMP - TIMESTAMP overflows i64 microseconds".into(),
})?;
return Ok(Some(Value::BigInt(delta)));
}
_ => {}
}
if let Some(out) = apply_binary_interval(op, l, r)? {
return Ok(Some(out));
}
match (l, r) {
(Value::Date(d), other) if op == BinOp::Add => {
if let Some(n) = int_value(other) {
let days = i64::from(*d).saturating_add(n);
let days32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "DATE + integer overflows DATE range".into(),
})?;
return Ok(Some(Value::Date(days32)));
}
}
(other, Value::Date(d)) if op == BinOp::Add => {
if let Some(n) = int_value(other) {
let days = i64::from(*d).saturating_add(n);
let days32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "integer + DATE overflows DATE range".into(),
})?;
return Ok(Some(Value::Date(days32)));
}
}
(Value::Date(d), other) if op == BinOp::Sub => {
if let Some(n) = int_value(other) {
let days = i64::from(*d).saturating_sub(n);
let days32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "DATE - integer overflows DATE range".into(),
})?;
return Ok(Some(Value::Date(days32)));
}
}
_ => {}
}
Ok(None)
}
pub(crate) fn apply_binary_interval(
op: BinOp,
l: &Value,
r: &Value,
) -> Result<Option<Value>, EvalError> {
let (lhs, rhs, sign): (&Value, &Value, i64) = match (l, r, op) {
(Value::Interval { .. }, _, BinOp::Add) => (r, l, 1),
(_, Value::Interval { .. }, BinOp::Add) => (l, r, 1),
(_, Value::Interval { .. }, BinOp::Sub) => (l, r, -1),
_ => return Ok(None),
};
let Value::Interval {
months: rhs_months,
micros: rhs_us,
} = rhs
else {
unreachable!("rhs guaranteed to be Interval by the match above");
};
let signed_months = i64::from(*rhs_months) * sign;
let signed_micros = rhs_us.checked_mul(sign).ok_or(EvalError::TypeMismatch {
detail: "INTERVAL micros overflows on negation".into(),
})?;
match lhs {
Value::Timestamp(t) => Ok(Some(Value::Timestamp(add_interval_to_micros(
*t,
signed_months,
signed_micros,
)?))),
Value::Date(d) => {
let day_aligned = signed_micros.rem_euclid(86_400_000_000) == 0;
if day_aligned {
let micros_per_day = 86_400_000_000_i64;
let days_delta = signed_micros / micros_per_day;
let shifted = shift_date_by_months(*d, signed_months)?;
let new_days =
i64::from(shifted)
.checked_add(days_delta)
.ok_or(EvalError::TypeMismatch {
detail: "DATE ± INTERVAL overflows DATE range".into(),
})?;
let days32 = i32::try_from(new_days).map_err(|_| EvalError::TypeMismatch {
detail: "DATE ± INTERVAL overflows DATE range".into(),
})?;
Ok(Some(Value::Date(days32)))
} else {
let base =
i64::from(*d)
.checked_mul(86_400_000_000)
.ok_or(EvalError::TypeMismatch {
detail: "DATE → TIMESTAMP lift overflows for INTERVAL math".into(),
})?;
Ok(Some(Value::Timestamp(add_interval_to_micros(
base,
signed_months,
signed_micros,
)?)))
}
}
Value::Interval {
months: lhs_months,
micros: lhs_us,
} => {
let new_months = i64::from(*lhs_months)
.checked_add(signed_months)
.and_then(|n| i32::try_from(n).ok())
.ok_or(EvalError::TypeMismatch {
detail: "INTERVAL ± INTERVAL months overflows i32".into(),
})?;
let new_micros = lhs_us
.checked_add(signed_micros)
.ok_or(EvalError::TypeMismatch {
detail: "INTERVAL ± INTERVAL micros overflows i64".into(),
})?;
Ok(Some(Value::Interval {
months: new_months,
micros: new_micros,
}))
}
_ => Err(EvalError::TypeMismatch {
detail: format!(
"operator {op:?} not defined for {:?} and INTERVAL",
lhs.data_type()
),
}),
}
}
fn shift_date_by_months(d: i32, months: i64) -> Result<i32, EvalError> {
let (y, m, day) = civil_from_days(d);
let months_i32 = i32::try_from(months).map_err(|_| EvalError::TypeMismatch {
detail: "INTERVAL months delta out of i32 range".into(),
})?;
let (ny, nm, nd) = add_months_to_civil(y, m, day, months_i32);
Ok(days_from_civil(ny, nm, nd))
}
fn add_interval_to_micros(t: i64, months: i64, micros: i64) -> Result<i64, EvalError> {
let mut out = t;
if months != 0 {
const MICROS_PER_DAY: i64 = 86_400_000_000;
let days = out.div_euclid(MICROS_PER_DAY);
let day_micros = out.rem_euclid(MICROS_PER_DAY);
let day_i32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "TIMESTAMP day component out of i32 range for INTERVAL months math".into(),
})?;
let shifted_days = shift_date_by_months(day_i32, months)?;
out = i64::from(shifted_days)
.checked_mul(MICROS_PER_DAY)
.and_then(|n| n.checked_add(day_micros))
.ok_or(EvalError::TypeMismatch {
detail: "TIMESTAMP ± INTERVAL months overflows i64 microseconds".into(),
})?;
}
out.checked_add(micros).ok_or(EvalError::TypeMismatch {
detail: "TIMESTAMP ± INTERVAL micros overflows i64".into(),
})
}
#[allow(clippy::needless_pass_by_value)] fn apply_binary_numeric(op: BinOp, l: Value, r: Value) -> Result<Value, EvalError> {
let float_path = matches!(l, Value::Float(_)) || matches!(r, Value::Float(_));
if float_path {
let af = as_f64(&l)?;
let bf = as_f64(&r)?;
return match op {
BinOp::Add => Ok(Value::Float(af + bf)),
BinOp::Sub => Ok(Value::Float(af - bf)),
BinOp::Mul => Ok(Value::Float(af * bf)),
BinOp::Div => {
if bf == 0.0 {
Err(EvalError::DivisionByZero)
} else {
Ok(Value::Float(af / bf))
}
}
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
let ord = af.partial_cmp(&bf).ok_or(EvalError::TypeMismatch {
detail: "NaN in NUMERIC/Float comparison".into(),
})?;
Ok(Value::Bool(cmp_to_bool(op, ord)))
}
BinOp::Concat => Ok(text_concat(&l, &r)),
other => Err(EvalError::TypeMismatch {
detail: format!("operator {other:?} not defined for NUMERIC and Float"),
}),
};
}
let (a, sa) = numeric_or_widen(&l).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("NUMERIC op against non-numeric {:?}", l.data_type()),
})?;
let (b, sb) = numeric_or_widen(&r).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("NUMERIC op against non-numeric {:?}", r.data_type()),
})?;
match op {
BinOp::Add | BinOp::Sub => {
let target_scale = sa.max(sb);
let lhs = rescale(a, sa, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
let rhs = rescale(b, sb, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
let r = match op {
BinOp::Add => lhs.checked_add(rhs),
BinOp::Sub => lhs.checked_sub(rhs),
_ => unreachable!(),
}
.ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on +/-".into(),
})?;
Ok(Value::Numeric {
scaled: r,
scale: target_scale,
})
}
BinOp::Mul => {
let scaled = a.checked_mul(b).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on *".into(),
})?;
Ok(Value::Numeric {
scaled,
scale: sa.saturating_add(sb),
})
}
BinOp::Div => {
if b == 0 {
return Err(EvalError::DivisionByZero);
}
let target_scale = sa.max(sb);
let bump = pow10_i128(target_scale.saturating_add(sb).saturating_sub(sa));
let num = a.checked_mul(bump).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on / scaling".into(),
})?;
let half = if b >= 0 { b / 2 } else { -(b / 2) };
let adj = if (num >= 0) == (b >= 0) {
num + half
} else {
num - half
};
Ok(Value::Numeric {
scaled: adj / b,
scale: target_scale,
})
}
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
let target_scale = sa.max(sb);
let lhs = rescale(a, sa, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
let rhs = rescale(b, sb, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
Ok(Value::Bool(cmp_to_bool(op, lhs.cmp(&rhs))))
}
BinOp::Concat => Ok(text_concat(&l, &r)),
other => Err(EvalError::TypeMismatch {
detail: format!("operator {other:?} not defined for NUMERIC"),
}),
}
}
fn numeric_or_widen(v: &Value) -> Option<(i128, u8)> {
match v {
Value::Numeric { scaled, scale } => Some((*scaled, *scale)),
Value::Int(n) => Some((i128::from(*n), 0)),
Value::SmallInt(n) => Some((i128::from(*n), 0)),
Value::BigInt(n) => Some((i128::from(*n), 0)),
_ => None,
}
}
fn rescale(scaled: i128, src: u8, dst: u8) -> Option<i128> {
if src == dst {
return Some(scaled);
}
if dst > src {
scaled.checked_mul(pow10_i128(dst - src))
} else {
let drop = pow10_i128(src - dst);
let half = drop / 2;
let r = if scaled >= 0 {
scaled + half
} else {
scaled - half
};
Some(r / drop)
}
}
pub(super) const fn pow10_i128(p: u8) -> i128 {
let mut acc: i128 = 1;
let mut i = 0;
while i < p {
acc *= 10;
i += 1;
}
acc
}
const fn cmp_to_bool(op: BinOp, ord: core::cmp::Ordering) -> bool {
use core::cmp::Ordering::{Equal, Greater, Less};
match op {
BinOp::Eq => matches!(ord, Equal),
BinOp::NotEq => !matches!(ord, Equal),
BinOp::Lt => matches!(ord, Less),
BinOp::LtEq => matches!(ord, Less | Equal),
BinOp::Gt => matches!(ord, Greater),
BinOp::GtEq => matches!(ord, Greater | Equal),
_ => false,
}
}
fn text_concat(l: &Value, r: &Value) -> Value {
if let (Value::TsVector(a), Value::TsVector(b)) = (l, r) {
return tsvector_concat(a, b);
}
match (l, r) {
(Value::Null, _) | (_, Value::Null) => {
if matches!(
l,
Value::TextArray(_) | Value::IntArray(_) | Value::BigIntArray(_) | Value::Bytes(_)
) || matches!(
r,
Value::TextArray(_) | Value::IntArray(_) | Value::BigIntArray(_) | Value::Bytes(_)
) {
return Value::Null;
}
}
(Value::TextArray(a), Value::TextArray(b)) => {
let mut out = a.clone();
out.extend(b.iter().cloned());
return Value::TextArray(out);
}
(Value::TextArray(a), Value::Text(s)) => {
let mut out = a.clone();
out.push(Some(s.clone()));
return Value::TextArray(out);
}
(Value::Text(s), Value::TextArray(b)) => {
let mut out: alloc::vec::Vec<Option<alloc::string::String>> =
alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(s.clone()));
out.extend(b.iter().cloned());
return Value::TextArray(out);
}
(Value::IntArray(a), Value::IntArray(b)) => {
let mut out = a.clone();
out.extend(b.iter().copied());
return Value::IntArray(out);
}
(Value::IntArray(a), Value::Int(n)) => {
let mut out = a.clone();
out.push(Some(*n));
return Value::IntArray(out);
}
(Value::IntArray(a), Value::SmallInt(n)) => {
let mut out = a.clone();
out.push(Some(i32::from(*n)));
return Value::IntArray(out);
}
(Value::Int(n), Value::IntArray(b)) => {
let mut out: alloc::vec::Vec<Option<i32>> = alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(*n));
out.extend(b.iter().copied());
return Value::IntArray(out);
}
(Value::SmallInt(n), Value::IntArray(b)) => {
let mut out: alloc::vec::Vec<Option<i32>> = alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(i32::from(*n)));
out.extend(b.iter().copied());
return Value::IntArray(out);
}
(Value::BigIntArray(a), Value::BigIntArray(b)) => {
let mut out = a.clone();
out.extend(b.iter().copied());
return Value::BigIntArray(out);
}
(Value::BigIntArray(a), Value::IntArray(b)) => {
let mut out = a.clone();
out.extend(b.iter().map(|o| o.map(i64::from)));
return Value::BigIntArray(out);
}
(Value::IntArray(a), Value::BigIntArray(b)) => {
let mut out: alloc::vec::Vec<Option<i64>> =
a.iter().map(|o| o.map(i64::from)).collect();
out.extend(b.iter().copied());
return Value::BigIntArray(out);
}
(Value::BigIntArray(a), Value::BigInt(n)) => {
let mut out = a.clone();
out.push(Some(*n));
return Value::BigIntArray(out);
}
(Value::BigIntArray(a), Value::Int(n)) => {
let mut out = a.clone();
out.push(Some(i64::from(*n)));
return Value::BigIntArray(out);
}
(Value::BigIntArray(a), Value::SmallInt(n)) => {
let mut out = a.clone();
out.push(Some(i64::from(*n)));
return Value::BigIntArray(out);
}
(Value::BigInt(n), Value::BigIntArray(b)) => {
let mut out: alloc::vec::Vec<Option<i64>> = alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(*n));
out.extend(b.iter().copied());
return Value::BigIntArray(out);
}
(Value::Int(n), Value::BigIntArray(b)) => {
let mut out: alloc::vec::Vec<Option<i64>> = alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(i64::from(*n)));
out.extend(b.iter().copied());
return Value::BigIntArray(out);
}
(Value::SmallInt(n), Value::BigIntArray(b)) => {
let mut out: alloc::vec::Vec<Option<i64>> = alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(i64::from(*n)));
out.extend(b.iter().copied());
return Value::BigIntArray(out);
}
(Value::Bytes(a), Value::Bytes(b)) => {
let mut out = a.clone();
out.extend_from_slice(b);
return Value::Bytes(out);
}
_ => {}
}
let a = value_to_text(l);
let b = value_to_text(r);
Value::Text(a + &b)
}
fn inner_product(l: Value, r: Value) -> Result<Value, EvalError> {
let (a, b) = unwrap_vec_pair(l, r, "<#>")?;
let mut dot: f64 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
dot += f64::from(*x) * f64::from(*y);
}
Ok(Value::Float(-dot))
}
fn cosine_distance(l: Value, r: Value) -> Result<Value, EvalError> {
let (a, b) = unwrap_vec_pair(l, r, "<=>")?;
let mut dot: f64 = 0.0;
let mut na: f64 = 0.0;
let mut nb: f64 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let xf = f64::from(*x);
let yf = f64::from(*y);
dot += xf * yf;
na += xf * xf;
nb += yf * yf;
}
let denom = sqrt_newton(na) * sqrt_newton(nb);
if denom == 0.0 {
return Ok(Value::Float(f64::NAN));
}
Ok(Value::Float(1.0 - dot / denom))
}
fn unwrap_vec_pair(l: Value, r: Value, op: &str) -> Result<(Vec<f32>, Vec<f32>), EvalError> {
let to_f32 = |v: Value| -> Option<Vec<f32>> {
match v {
Value::Vector(a) => Some(a),
Value::Sq8Vector(q) => Some(spg_storage::quantize::dequantize(&q)),
Value::HalfVector(h) => Some(h.to_f32_vec()),
_ => None,
}
};
let l_ty = l.data_type();
let r_ty = r.data_type();
match (to_f32(l), to_f32(r)) {
(Some(a), Some(b)) => {
if a.len() != b.len() {
return Err(EvalError::TypeMismatch {
detail: format!("vector dim mismatch in {op}: {} vs {}", a.len(), b.len()),
});
}
Ok((a, b))
}
_ => Err(EvalError::TypeMismatch {
detail: format!("{op} requires two vectors, got {l_ty:?} and {r_ty:?}"),
}),
}
}
fn bitop(
l: Value,
r: Value,
f: impl Fn(i64, i64) -> i64,
op_name: &str,
) -> Result<Value, EvalError> {
let widen = |v: Value| -> Value {
match v {
Value::SmallInt(n) => Value::Int(i32::from(n)),
other => other,
}
};
match (widen(l), widen(r)) {
(Value::Int(a), Value::Int(b)) => {
let result = f(i64::from(a), i64::from(b));
Ok(Value::Int(result as i32))
}
(Value::Int(a), Value::BigInt(b)) | (Value::BigInt(b), Value::Int(a)) => {
Ok(Value::BigInt(f(i64::from(a), b)))
}
(Value::BigInt(a), Value::BigInt(b)) => Ok(Value::BigInt(f(a, b))),
(a, b) => Err(EvalError::TypeMismatch {
detail: format!("cannot apply {op_name} to {a:?} and {b:?}"),
}),
}
}
fn arith(
l: Value,
r: Value,
int_op: impl Fn(i64, i64) -> Option<i64>,
float_op: impl Fn(f64, f64) -> f64,
op_name: &str,
) -> Result<Value, EvalError> {
let widen = |v: Value| -> Value {
match v {
Value::SmallInt(n) => Value::Int(i32::from(n)),
other => other,
}
};
let l = widen(l);
let r = widen(r);
match (l, r) {
(Value::Int(a), Value::Int(b)) => {
let result = int_op(i64::from(a), i64::from(b)).ok_or(EvalError::TypeMismatch {
detail: format!("integer overflow on {op_name}"),
})?;
if let Ok(small) = i32::try_from(result) {
Ok(Value::Int(small))
} else {
Ok(Value::BigInt(result))
}
}
(Value::Int(a), Value::BigInt(b)) | (Value::BigInt(b), Value::Int(a)) => {
let result = int_op(i64::from(a), b).ok_or(EvalError::TypeMismatch {
detail: format!("bigint overflow on {op_name}"),
})?;
Ok(Value::BigInt(result))
}
(Value::BigInt(a), Value::BigInt(b)) => {
let result = int_op(a, b).ok_or(EvalError::TypeMismatch {
detail: format!("bigint overflow on {op_name}"),
})?;
Ok(Value::BigInt(result))
}
(a, b)
if a.data_type() == Some(DataType::Float) || b.data_type() == Some(DataType::Float) =>
{
let af = as_f64(&a)?;
let bf = as_f64(&b)?;
Ok(Value::Float(float_op(af, bf)))
}
(a, b) => Err(EvalError::TypeMismatch {
detail: format!(
"{op_name} applied to non-numeric: {:?} vs {:?}",
a.data_type(),
b.data_type()
),
}),
}
}
#[allow(clippy::many_single_char_names)] fn l2_distance(l: Value, r: Value) -> Result<Value, EvalError> {
let (a, b) = unwrap_vec_pair(l, r, "<->")?;
let mut sum: f64 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let d = f64::from(*x) - f64::from(*y);
sum += d * d;
}
Ok(Value::Float(sqrt_newton(sum)))
}
fn sqrt_newton(x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let mut g = x;
for _ in 0..10 {
g = 0.5 * (g + x / g);
}
g
}
fn div_op(l: Value, r: Value) -> Result<Value, EvalError> {
let any_float = matches!(l.data_type(), Some(DataType::Float))
|| matches!(r.data_type(), Some(DataType::Float));
if any_float {
let a = as_f64(&l)?;
let b = as_f64(&r)?;
if b == 0.0 {
return Err(EvalError::DivisionByZero);
}
return Ok(Value::Float(a / b));
}
arith(
l,
r,
|a, b| {
if b == 0 { None } else { Some(a / b) }
},
|a, b| a / b,
"/",
)
.map_err(|e| match e {
EvalError::TypeMismatch { detail } if detail.contains('/') => EvalError::DivisionByZero,
other => other,
})
}
fn as_f64(v: &Value) -> Result<f64, EvalError> {
match v {
Value::SmallInt(n) => Ok(f64::from(*n)),
Value::Int(n) => Ok(f64::from(*n)),
#[allow(clippy::cast_precision_loss)]
Value::BigInt(n) => Ok(*n as f64),
Value::Float(x) => Ok(*x),
#[allow(clippy::cast_precision_loss)]
Value::Numeric { scaled, scale } => {
let mut div = 1.0_f64;
for _ in 0..*scale {
div *= 10.0;
}
Ok((*scaled as f64) / div)
}
other => Err(EvalError::TypeMismatch {
detail: format!("cannot convert {:?} to FLOAT", other.data_type()),
}),
}
}
pub(super) fn compare(op: BinOp, l: &Value, r: &Value) -> Result<Value, EvalError> {
let ord = match (l, r) {
(Value::Int(a), Value::Int(b)) => i64::from(*a).cmp(&i64::from(*b)),
(Value::Int(a), Value::BigInt(b)) => i64::from(*a).cmp(b),
(Value::BigInt(a), Value::Int(b)) => a.cmp(&i64::from(*b)),
(Value::BigInt(a), Value::BigInt(b)) => a.cmp(b),
(a, b)
if matches!(a.data_type(), Some(DataType::Float))
|| matches!(b.data_type(), Some(DataType::Float)) =>
{
let af = as_f64(a)?;
let bf = as_f64(b)?;
af.partial_cmp(&bf).ok_or(EvalError::TypeMismatch {
detail: "NaN in comparison".into(),
})?
}
(Value::Text(a), Value::Text(b)) => a.cmp(b),
(Value::Bool(a), Value::Bool(b)) => a.cmp(b),
(Value::Date(a), Value::Date(b)) => a.cmp(b),
(Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
(Value::Date(a), Value::Timestamp(b)) => (i64::from(*a) * 86_400_000_000).cmp(b),
(Value::Timestamp(a), Value::Date(b)) => a.cmp(&(i64::from(*b) * 86_400_000_000)),
(Value::Date(a), Value::Text(b)) => {
let bd = parse_date_literal(b).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {b:?} as DATE for comparison"),
})?;
a.cmp(&bd)
}
(Value::Text(a), Value::Date(b)) => {
let ad = parse_date_literal(a).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {a:?} as DATE for comparison"),
})?;
ad.cmp(b)
}
(Value::Timestamp(a), Value::Text(b)) => {
let bt = parse_timestamp_literal(b).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {b:?} as TIMESTAMP for comparison"),
})?;
a.cmp(&bt)
}
(Value::Text(a), Value::Timestamp(b)) => {
let at = parse_timestamp_literal(a).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {a:?} as TIMESTAMP for comparison"),
})?;
at.cmp(b)
}
(Value::Uuid(a), Value::Uuid(b)) => a.cmp(b),
(Value::Uuid(a), Value::Text(b)) => {
let bu = spg_storage::parse_uuid_str(b).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("invalid input syntax for type uuid: {b:?}"),
})?;
a.cmp(&bu)
}
(Value::Text(a), Value::Uuid(b)) => {
let au = spg_storage::parse_uuid_str(a).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("invalid input syntax for type uuid: {a:?}"),
})?;
au.cmp(b)
}
(a, b) => {
return Err(EvalError::TypeMismatch {
detail: format!(
"comparison between {:?} and {:?}",
a.data_type(),
b.data_type()
),
});
}
};
let result = match op {
BinOp::Eq => ord.is_eq(),
BinOp::NotEq => !ord.is_eq(),
BinOp::Lt => ord.is_lt(),
BinOp::LtEq => ord.is_le(),
BinOp::Gt => ord.is_gt(),
BinOp::GtEq => ord.is_ge(),
BinOp::And
| BinOp::Or
| BinOp::BitOr
| BinOp::BitAnd
| BinOp::Add
| BinOp::Sub
| BinOp::Mul
| BinOp::Div
| BinOp::L2Distance
| BinOp::InnerProduct
| BinOp::CosineDistance
| BinOp::Concat
| BinOp::JsonGet
| BinOp::JsonGetText
| BinOp::JsonGetPath
| BinOp::JsonGetPathText
| BinOp::JsonContains
| BinOp::TsMatch
| BinOp::IsDistinctFrom
| BinOp::IsNotDistinctFrom
| BinOp::InetContainedBy
| BinOp::InetContainedByEq
| BinOp::InetContains
| BinOp::InetContainsEq
| BinOp::InetOverlap => {
unreachable!("compare() only called with comparison ops")
}
};
Ok(Value::Bool(result))
}
pub(crate) fn and_3vl(l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::Bool(false), _) | (_, Value::Bool(false)) => Ok(Value::Bool(false)),
(Value::Bool(true), Value::Bool(true)) => Ok(Value::Bool(true)),
(Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
(a, b) => Err(EvalError::TypeMismatch {
detail: format!(
"AND on non-boolean: {:?} and {:?}",
a.data_type(),
b.data_type()
),
}),
}
}
fn or_3vl(l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::Bool(true), _) | (_, Value::Bool(true)) => Ok(Value::Bool(true)),
(Value::Bool(false), Value::Bool(false)) => Ok(Value::Bool(false)),
(Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
(a, b) => Err(EvalError::TypeMismatch {
detail: format!(
"OR on non-boolean: {:?} and {:?}",
a.data_type(),
b.data_type()
),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn interval_add_to_timestamp_micros_part() {
let ts = i64::from(days_from_civil(2024, 1, 1)) * 86_400_000_000;
let r = add_interval_to_micros(ts, 0, 3_600_000_000).unwrap();
let expected = ts + 3_600_000_000;
assert_eq!(r, expected);
}
#[test]
fn interval_clamp_month_end() {
let d = days_from_civil(2024, 1, 31);
let shifted = shift_date_by_months(d, 1).unwrap();
let (y, m, day) = civil_from_days(shifted);
assert_eq!((y, m, day), (2024, 2, 29));
let d = days_from_civil(2023, 1, 31);
let shifted = shift_date_by_months(d, 1).unwrap();
let (y, m, day) = civil_from_days(shifted);
assert_eq!((y, m, day), (2023, 2, 28));
let d = days_from_civil(2024, 3, 31);
let shifted = shift_date_by_months(d, -1).unwrap();
let (y, m, day) = civil_from_days(shifted);
assert_eq!((y, m, day), (2024, 2, 29));
}
#[test]
fn interval_date_plus_pure_days_stays_date() {
let d = days_from_civil(2024, 6, 1);
let lhs = Value::Date(d);
let rhs = Value::Interval {
months: 0,
micros: 7 * 86_400_000_000,
};
let v = apply_binary_interval(BinOp::Add, &lhs, &rhs)
.unwrap()
.unwrap();
let expected = days_from_civil(2024, 6, 8);
assert_eq!(v, Value::Date(expected));
}
#[test]
fn interval_date_plus_sub_day_lifts_to_timestamp() {
let d = days_from_civil(2024, 6, 1);
let lhs = Value::Date(d);
let rhs = Value::Interval {
months: 0,
micros: 3_600_000_000,
};
let v = apply_binary_interval(BinOp::Add, &lhs, &rhs)
.unwrap()
.unwrap();
let expected = i64::from(d) * 86_400_000_000 + 3_600_000_000;
assert_eq!(v, Value::Timestamp(expected));
}
}