use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use spg_sql::ast::{BinOp, CastTarget, ColumnName, Expr, Literal, UnOp};
use spg_storage::{ColumnSchema, DataType, Row, Value};
#[derive(Debug, Clone)]
pub struct EvalContext<'a> {
pub columns: &'a [ColumnSchema],
pub table_alias: Option<&'a str>,
pub params: &'a [Value],
}
impl<'a> EvalContext<'a> {
pub const fn new(columns: &'a [ColumnSchema], table_alias: Option<&'a str>) -> Self {
Self {
columns,
table_alias,
params: &[],
}
}
#[must_use]
pub const fn with_params(mut self, params: &'a [Value]) -> Self {
self.params = params;
self
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum EvalError {
ColumnNotFound { name: String },
UnknownQualifier { qualifier: String },
DivisionByZero,
TypeMismatch { detail: String },
PlaceholderOutOfRange { n: u16, bound: u16 },
}
impl core::fmt::Display for EvalError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::ColumnNotFound { name } => write!(f, "column not found: {name}"),
Self::UnknownQualifier { qualifier } => {
write!(f, "unknown table qualifier: {qualifier}")
}
Self::DivisionByZero => f.write_str("division by zero"),
Self::TypeMismatch { detail } => write!(f, "type mismatch: {detail}"),
Self::PlaceholderOutOfRange { n, bound } => write!(
f,
"parameter ${n} referenced but only {bound} bound by client"
),
}
}
}
pub fn eval_expr(expr: &Expr, row: &Row, ctx: &EvalContext<'_>) -> Result<Value, EvalError> {
match expr {
Expr::Literal(l) => Ok(literal_to_value(l)),
Expr::Column(c) => resolve_column(c, row, ctx),
Expr::Placeholder(n) => {
let idx = usize::from(*n).saturating_sub(1);
ctx.params
.get(idx)
.cloned()
.ok_or_else(|| EvalError::PlaceholderOutOfRange {
n: *n,
bound: u16::try_from(ctx.params.len()).unwrap_or(u16::MAX),
})
}
Expr::Unary { op, expr } => {
let v = eval_expr(expr, row, ctx)?;
apply_unary(*op, v)
}
Expr::Binary { lhs, op, rhs } => {
let l = eval_expr(lhs, row, ctx)?;
let r = eval_expr(rhs, row, ctx)?;
apply_binary(*op, l, r)
}
Expr::Cast { expr, target } => {
let v = eval_expr(expr, row, ctx)?;
cast_value(v, *target)
}
Expr::IsNull { expr, negated } => {
let v = eval_expr(expr, row, ctx)?;
let is_null = matches!(v, Value::Null);
Ok(Value::Bool(if *negated { !is_null } else { is_null }))
}
Expr::FunctionCall { name, args } => {
let evaluated: Result<Vec<Value>, _> =
args.iter().map(|a| eval_expr(a, row, ctx)).collect();
apply_function(name, &evaluated?)
}
Expr::Like {
expr,
pattern,
negated,
} => {
let v = eval_expr(expr, row, ctx)?;
let p = eval_expr(pattern, row, ctx)?;
let (text, pat) = match (v, p) {
(Value::Null, _) | (_, Value::Null) => return Ok(Value::Null),
(Value::Text(a), Value::Text(b)) => (a, b),
(Value::Text(_), other) | (other, _) => {
return Err(EvalError::TypeMismatch {
detail: format!("LIKE requires text operands, got {:?}", other.data_type()),
});
}
};
let m = like_match(&text, &pat);
Ok(Value::Bool(if *negated { !m } else { m }))
}
Expr::Extract { field, source } => {
let v = eval_expr(source, row, ctx)?;
extract_field(*field, &v)
}
Expr::ScalarSubquery(_) | Expr::Exists { .. } | Expr::InSubquery { .. } => {
Err(EvalError::TypeMismatch {
detail: "subquery reached row eval — engine resolver bug".into(),
})
}
Expr::WindowFunction { .. } => Err(EvalError::TypeMismatch {
detail: "window function reached row eval — engine rewrite bug".into(),
}),
Expr::Array(items) => {
let mut out: Vec<Option<String>> = Vec::with_capacity(items.len());
for elem in items {
match eval_expr(elem, row, ctx)? {
Value::Null => out.push(None),
Value::Text(s) => out.push(Some(s)),
other => out.push(Some(value_to_text_for_array(&other))),
}
}
Ok(Value::TextArray(out))
}
Expr::ArraySubscript { target, index } => {
let target_v = eval_expr(target, row, ctx)?;
let idx_v = eval_expr(index, row, ctx)?;
if matches!(target_v, Value::Null) || matches!(idx_v, Value::Null) {
return Ok(Value::Null);
}
let Value::TextArray(items) = target_v else {
return Err(EvalError::TypeMismatch {
detail: format!(
"subscript target must be an array, got {:?}",
target_v.data_type()
),
});
};
let i: i64 = match idx_v {
Value::Int(n) => i64::from(n),
Value::BigInt(n) => n,
Value::SmallInt(n) => i64::from(n),
other => {
return Err(EvalError::TypeMismatch {
detail: format!("array subscript must be integer, got {:?}", other.data_type()),
});
}
};
if i < 1 {
return Ok(Value::Null);
}
let pos = (i - 1) as usize;
match items.get(pos) {
Some(Some(s)) => Ok(Value::Text(s.clone())),
Some(None) | None => Ok(Value::Null),
}
}
Expr::AnyAll {
expr,
op,
array,
is_any,
} => {
let lhs = eval_expr(expr, row, ctx)?;
let arr = eval_expr(array, row, ctx)?;
if matches!(arr, Value::Null) {
return Ok(Value::Null);
}
let Value::TextArray(items) = arr else {
return Err(EvalError::TypeMismatch {
detail: format!(
"ANY/ALL right-hand side must be an array, got {:?}",
arr.data_type()
),
});
};
let mut saw_null = matches!(lhs, Value::Null);
let mut saw_match = false;
let mut saw_mismatch = false;
for elem in items {
let elem_v = match elem {
Some(s) => Value::Text(s),
None => {
saw_null = true;
continue;
}
};
if matches!(lhs, Value::Null) {
saw_null = true;
continue;
}
match apply_binary(*op, lhs.clone(), elem_v) {
Ok(Value::Bool(true)) => saw_match = true,
Ok(Value::Bool(false)) => saw_mismatch = true,
Ok(Value::Null) => saw_null = true,
Ok(other) => {
return Err(EvalError::TypeMismatch {
detail: format!(
"ANY/ALL comparison didn't return Bool: {:?}",
other.data_type()
),
});
}
Err(e) => return Err(e),
}
}
let result = if *is_any {
if saw_match {
Value::Bool(true)
} else if saw_null {
Value::Null
} else {
Value::Bool(false)
}
} else if saw_mismatch {
Value::Bool(false)
} else if saw_null {
Value::Null
} else {
Value::Bool(true)
};
Ok(result)
}
}
}
fn value_to_text_for_array(v: &Value) -> String {
match v {
Value::Text(s) | Value::Json(s) => s.clone(),
Value::Int(n) => n.to_string(),
Value::BigInt(n) => n.to_string(),
Value::SmallInt(n) => n.to_string(),
Value::Bool(b) => if *b { "true".into() } else { "false".into() },
Value::Float(x) => format!("{x}"),
Value::Date(d) => format_date(*d),
Value::Timestamp(t) => format_timestamp(*t),
Value::Numeric { scaled, scale } => format_numeric(*scaled, *scale),
_ => format!("{v:?}"),
}
}
fn extract_field(field: spg_sql::ast::ExtractField, v: &Value) -> Result<Value, EvalError> {
use spg_sql::ast::ExtractField as F;
if matches!(v, Value::Null) {
return Ok(Value::Null);
}
if let Value::Interval { months, micros } = *v {
let years = months / 12;
let mons = months % 12;
let secs_total = micros / 1_000_000;
let frac = micros % 1_000_000;
let result = match field {
F::Year => i64::from(years),
F::Month => i64::from(mons),
F::Day => micros / 86_400_000_000,
F::Hour => (secs_total / 3600) % 24,
F::Minute => (secs_total / 60) % 60,
F::Second => secs_total % 60,
F::Microsecond => (secs_total % 60) * 1_000_000 + frac,
};
return Ok(Value::BigInt(result));
}
let (days, day_micros) = match *v {
Value::Date(d) => (d, 0_i64),
Value::Timestamp(t) => {
let days = t.div_euclid(86_400_000_000);
let day_micros = t.rem_euclid(86_400_000_000);
(i32::try_from(days).unwrap_or(i32::MAX), day_micros)
}
_ => {
return Err(EvalError::TypeMismatch {
detail: format!(
"EXTRACT requires DATE / TIMESTAMP / INTERVAL, got {:?}",
v.data_type()
),
});
}
};
let (y, m, d) = civil_components(days);
let secs = day_micros / 1_000_000;
let hh = secs / 3600;
let mm = (secs / 60) % 60;
let ss = secs % 60;
let frac = day_micros % 1_000_000;
let result = match field {
F::Year => i64::from(y),
F::Month => i64::from(m),
F::Day => i64::from(d),
F::Hour => hh,
F::Minute => mm,
F::Second => ss,
F::Microsecond => ss * 1_000_000 + frac,
};
Ok(Value::BigInt(result))
}
fn civil_components(days: i32) -> (i32, u32, u32) {
civil_from_days(days)
}
fn like_match(text: &str, pattern: &str) -> bool {
let text: Vec<char> = text.chars().collect();
let pat: Vec<char> = pattern.chars().collect();
like_match_inner(&text, 0, &pat, 0)
}
fn like_match_inner(text: &[char], mut ti: usize, pat: &[char], mut pi: usize) -> bool {
while pi < pat.len() {
match pat[pi] {
'%' => {
while pi < pat.len() && pat[pi] == '%' {
pi += 1;
}
if pi == pat.len() {
return true;
}
for k in ti..=text.len() {
if like_match_inner(text, k, pat, pi) {
return true;
}
}
return false;
}
'_' => {
if ti >= text.len() {
return false;
}
ti += 1;
pi += 1;
}
'\\' if pi + 1 < pat.len() => {
let want = pat[pi + 1];
if ti >= text.len() || text[ti] != want {
return false;
}
ti += 1;
pi += 2;
}
c => {
if ti >= text.len() || text[ti] != c {
return false;
}
ti += 1;
pi += 1;
}
}
}
ti == text.len()
}
fn apply_function(name: &str, args: &[Value]) -> Result<Value, EvalError> {
match name.to_ascii_lowercase().as_str() {
"length" => {
if args.len() != 1 {
return Err(EvalError::TypeMismatch {
detail: format!("length() takes 1 arg, got {}", args.len()),
});
}
match &args[0] {
Value::Null => Ok(Value::Null),
Value::Text(s) => {
let n = i32::try_from(s.chars().count()).unwrap_or(i32::MAX);
Ok(Value::Int(n))
}
Value::Bytes(b) => {
let n = i32::try_from(b.len()).unwrap_or(i32::MAX);
Ok(Value::Int(n))
}
other => Err(EvalError::TypeMismatch {
detail: format!("length() needs text or bytea, got {:?}", other.data_type()),
}),
}
}
"octet_length" => {
if args.len() != 1 {
return Err(EvalError::TypeMismatch {
detail: format!("octet_length() takes 1 arg, got {}", args.len()),
});
}
match &args[0] {
Value::Null => Ok(Value::Null),
Value::Text(s) => {
let n = i32::try_from(s.len()).unwrap_or(i32::MAX);
Ok(Value::Int(n))
}
Value::Bytes(b) => {
let n = i32::try_from(b.len()).unwrap_or(i32::MAX);
Ok(Value::Int(n))
}
other => Err(EvalError::TypeMismatch {
detail: format!(
"octet_length() needs text or bytea, got {:?}",
other.data_type()
),
}),
}
}
"array_length" => {
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("array_length() takes 2 args, got {}", args.len()),
});
}
if matches!(args[0], Value::Null) || matches!(args[1], Value::Null) {
return Ok(Value::Null);
}
let Value::TextArray(items) = &args[0] else {
return Err(EvalError::TypeMismatch {
detail: format!(
"array_length() first arg must be an array, got {:?}",
args[0].data_type()
),
});
};
let dim: i64 = match args[1] {
Value::Int(n) => i64::from(n),
Value::BigInt(n) => n,
Value::SmallInt(n) => i64::from(n),
_ => {
return Err(EvalError::TypeMismatch {
detail: format!(
"array_length() second arg must be integer, got {:?}",
args[1].data_type()
),
});
}
};
if dim != 1 {
return Ok(Value::Null);
}
let n = i32::try_from(items.len()).unwrap_or(i32::MAX);
Ok(Value::Int(n))
}
"array_position" => {
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("array_position() takes 2 args, got {}", args.len()),
});
}
if matches!(args[0], Value::Null) {
return Ok(Value::Null);
}
let Value::TextArray(items) = &args[0] else {
return Err(EvalError::TypeMismatch {
detail: format!(
"array_position() first arg must be an array, got {:?}",
args[0].data_type()
),
});
};
let needle = match &args[1] {
Value::Text(s) => s.clone(),
Value::Null => return Ok(Value::Null),
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"array_position() needle must be text, got {:?}",
other.data_type()
),
});
}
};
for (idx, item) in items.iter().enumerate() {
if let Some(s) = item
&& s == &needle
{
return Ok(Value::Int(i32::try_from(idx + 1).unwrap_or(i32::MAX)));
}
}
Ok(Value::Null)
}
"upper" => {
if args.len() != 1 {
return Err(EvalError::TypeMismatch {
detail: format!("upper() takes 1 arg, got {}", args.len()),
});
}
match &args[0] {
Value::Null => Ok(Value::Null),
Value::Text(s) => Ok(Value::Text(s.to_uppercase())),
other => Err(EvalError::TypeMismatch {
detail: format!("upper() needs text, got {:?}", other.data_type()),
}),
}
}
"lower" => {
if args.len() != 1 {
return Err(EvalError::TypeMismatch {
detail: format!("lower() takes 1 arg, got {}", args.len()),
});
}
match &args[0] {
Value::Null => Ok(Value::Null),
Value::Text(s) => Ok(Value::Text(s.to_lowercase())),
other => Err(EvalError::TypeMismatch {
detail: format!("lower() needs text, got {:?}", other.data_type()),
}),
}
}
"abs" => {
if args.len() != 1 {
return Err(EvalError::TypeMismatch {
detail: format!("abs() takes 1 arg, got {}", args.len()),
});
}
match &args[0] {
Value::Null => Ok(Value::Null),
Value::Int(n) => Ok(Value::Int(n.wrapping_abs())),
Value::BigInt(n) => Ok(Value::BigInt(n.wrapping_abs())),
Value::Float(x) => Ok(Value::Float(x.abs())),
other => Err(EvalError::TypeMismatch {
detail: format!("abs() needs numeric, got {:?}", other.data_type()),
}),
}
}
"coalesce" => {
for a in args {
if !matches!(a, Value::Null) {
return Ok(a.clone());
}
}
Ok(Value::Null)
}
"date_trunc" => date_trunc(args),
"date_part" => date_part(args),
"age" => age(args),
"to_char" => to_char(args),
"encode" => encode_text(args),
"decode" => decode_text(args),
"error_on_null" => error_on_null(args),
other => Err(EvalError::TypeMismatch {
detail: format!("unknown function `{other}`"),
}),
}
}
fn encode_text(args: &[Value]) -> Result<Value, EvalError> {
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("encode() takes 2 args, got {}", args.len()),
});
}
if matches!(args[0], Value::Null) || matches!(args[1], Value::Null) {
return Ok(Value::Null);
}
let bytes: &[u8] = match &args[0] {
Value::Text(s) => s.as_bytes(),
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"encode() expects text bytes, got {:?}",
other.data_type()
),
});
}
};
let fmt = match &args[1] {
Value::Text(s) => s.to_ascii_lowercase(),
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"encode() format must be text, got {:?}",
other.data_type()
),
});
}
};
let out = match fmt.as_str() {
"base64" => b64_encode(bytes, B64_STD),
"base64url" => b64_encode(bytes, B64_URL),
"base32hex" => b32hex_encode(bytes),
"hex" => hex_encode(bytes),
other => {
return Err(EvalError::TypeMismatch {
detail: format!("encode(): unknown format `{other}`"),
});
}
};
Ok(Value::Text(out))
}
fn decode_text(args: &[Value]) -> Result<Value, EvalError> {
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("decode() takes 2 args, got {}", args.len()),
});
}
if matches!(args[0], Value::Null) || matches!(args[1], Value::Null) {
return Ok(Value::Null);
}
let text = match &args[0] {
Value::Text(s) => s.as_str(),
other => {
return Err(EvalError::TypeMismatch {
detail: format!("decode() expects text, got {:?}", other.data_type()),
});
}
};
let fmt = match &args[1] {
Value::Text(s) => s.to_ascii_lowercase(),
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"decode() format must be text, got {:?}",
other.data_type()
),
});
}
};
let bytes = match fmt.as_str() {
"base64" => b64_decode(text, B64_STD)?,
"base64url" => b64_decode(text, B64_URL)?,
"base32hex" => b32hex_decode(text)?,
"hex" => hex_decode(text)?,
other => {
return Err(EvalError::TypeMismatch {
detail: format!("decode(): unknown format `{other}`"),
});
}
};
let s = String::from_utf8(bytes).map_err(|_| EvalError::TypeMismatch {
detail: "decode(): result bytes are not valid UTF-8 (SPG stores raw bytes as Text)".into(),
})?;
Ok(Value::Text(s))
}
fn error_on_null(args: &[Value]) -> Result<Value, EvalError> {
if args.len() != 1 {
return Err(EvalError::TypeMismatch {
detail: format!("error_on_null() takes 1 arg, got {}", args.len()),
});
}
if matches!(args[0], Value::Null) {
return Err(EvalError::TypeMismatch {
detail: "error_on_null(): argument is NULL".into(),
});
}
Ok(args[0].clone())
}
const B64_STD: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
const B64_URL: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
const B32HEX_ALPHABET: &[u8; 32] = b"0123456789ABCDEFGHIJKLMNOPQRSTUV";
fn b64_encode(bytes: &[u8], alpha: &[u8; 64]) -> String {
let mut out = String::with_capacity((bytes.len() + 2) / 3 * 4);
let mut i = 0;
while i + 3 <= bytes.len() {
let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8) | (bytes[i + 2] as u32);
out.push(alpha[((n >> 18) & 0x3f) as usize] as char);
out.push(alpha[((n >> 12) & 0x3f) as usize] as char);
out.push(alpha[((n >> 6) & 0x3f) as usize] as char);
out.push(alpha[(n & 0x3f) as usize] as char);
i += 3;
}
let rem = bytes.len() - i;
if rem == 1 {
let n = (bytes[i] as u32) << 16;
out.push(alpha[((n >> 18) & 0x3f) as usize] as char);
out.push(alpha[((n >> 12) & 0x3f) as usize] as char);
out.push('=');
out.push('=');
} else if rem == 2 {
let n = ((bytes[i] as u32) << 16) | ((bytes[i + 1] as u32) << 8);
out.push(alpha[((n >> 18) & 0x3f) as usize] as char);
out.push(alpha[((n >> 12) & 0x3f) as usize] as char);
out.push(alpha[((n >> 6) & 0x3f) as usize] as char);
out.push('=');
}
out
}
fn b64_decode(text: &str, alpha: &[u8; 64]) -> Result<Vec<u8>, EvalError> {
let mut lookup = [255u8; 256];
for (i, &c) in alpha.iter().enumerate() {
lookup[c as usize] = i as u8;
}
let mut out = Vec::with_capacity(text.len() * 3 / 4);
let mut buf: u32 = 0;
let mut bits: u32 = 0;
for c in text.bytes() {
if c == b'=' {
break;
}
if c == b'\n' || c == b'\r' || c == b' ' {
continue;
}
let v = lookup[c as usize];
if v == 255 {
return Err(EvalError::TypeMismatch {
detail: format!("decode(base64): invalid char {:?}", c as char),
});
}
buf = (buf << 6) | v as u32;
bits += 6;
if bits >= 8 {
bits -= 8;
out.push(((buf >> bits) & 0xff) as u8);
}
}
Ok(out)
}
fn b32hex_encode(bytes: &[u8]) -> String {
let mut out = String::with_capacity((bytes.len() * 8 + 4) / 5);
let mut buf: u64 = 0;
let mut bits: u32 = 0;
for &b in bytes {
buf = (buf << 8) | b as u64;
bits += 8;
while bits >= 5 {
bits -= 5;
out.push(B32HEX_ALPHABET[((buf >> bits) & 0x1f) as usize] as char);
}
}
if bits > 0 {
out.push(B32HEX_ALPHABET[((buf << (5 - bits)) & 0x1f) as usize] as char);
}
while out.len() % 8 != 0 {
out.push('=');
}
out
}
fn b32hex_decode(text: &str) -> Result<Vec<u8>, EvalError> {
let mut lookup = [255u8; 256];
for (i, &c) in B32HEX_ALPHABET.iter().enumerate() {
lookup[c as usize] = i as u8;
let lower = (c as char).to_ascii_lowercase() as u8;
lookup[lower as usize] = i as u8;
}
let mut out = Vec::with_capacity(text.len() * 5 / 8);
let mut buf: u64 = 0;
let mut bits: u32 = 0;
for c in text.bytes() {
if c == b'=' {
break;
}
if c == b'\n' || c == b'\r' || c == b' ' {
continue;
}
let v = lookup[c as usize];
if v == 255 {
return Err(EvalError::TypeMismatch {
detail: format!("decode(base32hex): invalid char {:?}", c as char),
});
}
buf = (buf << 5) | v as u64;
bits += 5;
if bits >= 8 {
bits -= 8;
out.push(((buf >> bits) & 0xff) as u8);
}
}
Ok(out)
}
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0xf) as usize] as char);
}
out
}
fn hex_decode(text: &str) -> Result<Vec<u8>, EvalError> {
let trimmed = text.trim();
if trimmed.len() % 2 != 0 {
return Err(EvalError::TypeMismatch {
detail: "decode(hex): input length must be even".into(),
});
}
let mut out = Vec::with_capacity(trimmed.len() / 2);
let mut hi: u8 = 0;
for (i, c) in trimmed.bytes().enumerate() {
let v = match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => {
return Err(EvalError::TypeMismatch {
detail: format!("decode(hex): invalid char {:?}", c as char),
});
}
};
if i % 2 == 0 {
hi = v;
} else {
out.push((hi << 4) | v);
}
}
Ok(out)
}
fn date_part(args: &[Value]) -> Result<Value, EvalError> {
use spg_sql::ast::ExtractField as F;
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("date_part() takes 2 args, got {}", args.len()),
});
}
if matches!(&args[0], Value::Null) || matches!(&args[1], Value::Null) {
return Ok(Value::Null);
}
let Value::Text(field_name) = &args[0] else {
return Err(EvalError::TypeMismatch {
detail: format!(
"date_part() needs a text field, got {:?}",
args[0].data_type()
),
});
};
let field = match field_name.to_ascii_lowercase().as_str() {
"year" => F::Year,
"month" => F::Month,
"day" => F::Day,
"hour" => F::Hour,
"minute" => F::Minute,
"second" => F::Second,
"microsecond" | "microseconds" => F::Microsecond,
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"unknown date_part field {other:?}; \
supported: year, month, day, hour, minute, second, microsecond"
),
});
}
};
extract_field(field, &args[1])
}
fn age(args: &[Value]) -> Result<Value, EvalError> {
if args.is_empty() || args.len() > 2 {
return Err(EvalError::TypeMismatch {
detail: format!("age() takes 1 or 2 args, got {}", args.len()),
});
}
if args.iter().any(|v| matches!(v, Value::Null)) {
return Ok(Value::Null);
}
let to_micros = |v: &Value| -> Result<i64, EvalError> {
match v {
Value::Timestamp(t) => Ok(*t),
Value::Date(d) => Ok(i64::from(*d) * 86_400_000_000),
other => Err(EvalError::TypeMismatch {
detail: format!("age() needs DATE or TIMESTAMP, got {:?}", other.data_type()),
}),
}
};
if args.len() == 1 {
return Err(EvalError::TypeMismatch {
detail: "single-arg age() is unsupported in v2.12 \
(use age(CURRENT_DATE, t) explicitly)"
.into(),
});
}
let a = to_micros(&args[0])?;
let b = to_micros(&args[1])?;
let delta = a.checked_sub(b).ok_or(EvalError::TypeMismatch {
detail: "age() subtraction overflows i64 microseconds".into(),
})?;
Ok(Value::Interval {
months: 0,
micros: delta,
})
}
fn to_char(args: &[Value]) -> Result<Value, EvalError> {
use core::fmt::Write as _;
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("to_char() takes 2 args, got {}", args.len()),
});
}
if matches!(&args[0], Value::Null) || matches!(&args[1], Value::Null) {
return Ok(Value::Null);
}
let Value::Text(fmt) = &args[1] else {
return Err(EvalError::TypeMismatch {
detail: format!(
"to_char() needs a text format, got {:?}",
args[1].data_type()
),
});
};
let (days, day_micros) = match &args[0] {
Value::Date(d) => (*d, 0_i64),
Value::Timestamp(t) => {
let days = t.div_euclid(86_400_000_000);
(
i32::try_from(days).unwrap_or(i32::MAX),
t.rem_euclid(86_400_000_000),
)
}
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"to_char() needs DATE or TIMESTAMP, got {:?}",
other.data_type()
),
});
}
};
let (y, mo, d) = civil_from_days(days);
let secs = day_micros / 1_000_000;
let frac = day_micros % 1_000_000;
let hh24 = u32::try_from(secs / 3600).unwrap_or(0);
let mi = u32::try_from((secs / 60) % 60).unwrap_or(0);
let ss = u32::try_from(secs % 60).unwrap_or(0);
let hh12 = match hh24 % 12 {
0 => 12,
x => x,
};
let ampm = if hh24 < 12 { "AM" } else { "PM" };
let ms = u32::try_from(frac / 1_000).unwrap_or(0); let us = u32::try_from(frac).unwrap_or(0);
let mut out = String::with_capacity(fmt.len() + 8);
let bytes = fmt.as_bytes();
let mut i = 0;
while i < bytes.len() {
let rest = &bytes[i..];
if rest.starts_with(b"YYYY") {
let _ = write!(out, "{y:04}");
i += 4;
} else if rest.starts_with(b"YY") {
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
let yy = (y.rem_euclid(100)) as u32;
let _ = write!(out, "{yy:02}");
i += 2;
} else if rest.starts_with(b"Month") {
out.push_str(MONTH_FULL[(mo - 1) as usize]);
i += 5;
} else if rest.starts_with(b"Mon") {
out.push_str(MONTH_ABBR[(mo - 1) as usize]);
i += 3;
} else if rest.starts_with(b"MM") {
let _ = write!(out, "{mo:02}");
i += 2;
} else if rest.starts_with(b"DD") {
let _ = write!(out, "{d:02}");
i += 2;
} else if rest.starts_with(b"HH24") {
let _ = write!(out, "{hh24:02}");
i += 4;
} else if rest.starts_with(b"HH12") {
let _ = write!(out, "{hh12:02}");
i += 4;
} else if rest.starts_with(b"MI") {
let _ = write!(out, "{mi:02}");
i += 2;
} else if rest.starts_with(b"SS") {
let _ = write!(out, "{ss:02}");
i += 2;
} else if rest.starts_with(b"MS") {
let _ = write!(out, "{ms:03}");
i += 2;
} else if rest.starts_with(b"US") {
let _ = write!(out, "{us:06}");
i += 2;
} else if rest.starts_with(b"AM") || rest.starts_with(b"PM") {
out.push_str(ampm);
i += 2;
} else {
out.push(bytes[i] as char);
i += 1;
}
}
Ok(Value::Text(out))
}
const MONTH_FULL: [&str; 12] = [
"January",
"February",
"March",
"April",
"May",
"June",
"July",
"August",
"September",
"October",
"November",
"December",
];
const MONTH_ABBR: [&str; 12] = [
"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
];
fn date_trunc(args: &[Value]) -> Result<Value, EvalError> {
if args.len() != 2 {
return Err(EvalError::TypeMismatch {
detail: format!("date_trunc() takes 2 args, got {}", args.len()),
});
}
if matches!(&args[0], Value::Null) || matches!(&args[1], Value::Null) {
return Ok(Value::Null);
}
let Value::Text(unit) = &args[0] else {
return Err(EvalError::TypeMismatch {
detail: format!(
"date_trunc() needs a text unit, got {:?}",
args[0].data_type()
),
});
};
let micros = match &args[1] {
Value::Timestamp(t) => *t,
Value::Date(d) => i64::from(*d) * 86_400_000_000,
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"date_trunc() needs DATE or TIMESTAMP, got {:?}",
other.data_type()
),
});
}
};
let unit_lc = unit.to_ascii_lowercase();
let days = micros.div_euclid(86_400_000_000);
let day_micros = micros.rem_euclid(86_400_000_000);
let day_i32 = i32::try_from(days).unwrap_or(i32::MAX);
let (y, m, _) = civil_from_days(day_i32);
let truncated = match unit_lc.as_str() {
"year" => i64::from(days_from_civil(y, 1, 1)) * 86_400_000_000,
"month" => i64::from(days_from_civil(y, m, 1)) * 86_400_000_000,
"day" => days * 86_400_000_000,
"hour" => days * 86_400_000_000 + (day_micros / 3_600_000_000) * 3_600_000_000,
"minute" => days * 86_400_000_000 + (day_micros / 60_000_000) * 60_000_000,
"second" => days * 86_400_000_000 + (day_micros / 1_000_000) * 1_000_000,
other => {
return Err(EvalError::TypeMismatch {
detail: format!(
"unknown date_trunc unit {other:?}; \
supported: year, month, day, hour, minute, second"
),
});
}
};
Ok(Value::Timestamp(truncated))
}
pub fn cast_value(v: Value, target: CastTarget) -> Result<Value, EvalError> {
if matches!(v, Value::Null) {
return Ok(Value::Null);
}
match target {
CastTarget::Vector => cast_to_vector(v),
CastTarget::Text => Ok(Value::Text(value_to_text(&v))),
CastTarget::Int => cast_numeric_to_int(v),
CastTarget::BigInt => cast_numeric_to_bigint(v),
CastTarget::Float => cast_numeric_to_float(v),
CastTarget::Bool => cast_to_bool(v),
CastTarget::Date => cast_to_date(v),
CastTarget::Timestamp | CastTarget::Timestamptz => cast_to_timestamp(v),
CastTarget::Interval => cast_to_interval(v),
CastTarget::Json | CastTarget::Jsonb => match v {
Value::Json(s) => Ok(Value::Json(s)),
Value::Text(s) => Ok(Value::Json(s)),
other => Err(EvalError::TypeMismatch {
detail: alloc::format!(
"::json / ::jsonb only accepts TEXT-shape inputs, got {:?}",
other.data_type()
),
}),
},
CastTarget::RegType | CastTarget::RegClass => Err(EvalError::TypeMismatch {
detail:
"::regtype / ::regclass not supported on SPG \
(no pg_catalog); use SHOW TABLES / spg_table_ddl instead"
.into(),
}),
CastTarget::TextArray => match v {
Value::TextArray(items) => Ok(Value::TextArray(items)),
Value::Text(s) => decode_text_array_external(&s).map(Value::TextArray),
other => Err(EvalError::TypeMismatch {
detail: alloc::format!(
"::TEXT[] only accepts TEXT / TEXT[] inputs, got {:?}",
other.data_type()
),
}),
},
}
}
fn decode_text_array_external(s: &str) -> Result<Vec<Option<String>>, EvalError> {
let trimmed = s.trim();
let inner = trimmed
.strip_prefix('{')
.and_then(|x| x.strip_suffix('}'))
.ok_or_else(|| EvalError::TypeMismatch {
detail: alloc::format!("TEXT[] literal {s:?} must be enclosed in '{{...}}'"),
})?;
let mut out: Vec<Option<String>> = Vec::new();
if inner.trim().is_empty() {
return Ok(out);
}
let bytes = inner.as_bytes();
let mut i = 0;
while i <= bytes.len() {
while i < bytes.len() && (bytes[i] == b' ' || bytes[i] == b'\t') {
i += 1;
}
if i < bytes.len() && bytes[i] == b'"' {
i += 1;
let mut buf = String::new();
while i < bytes.len() && bytes[i] != b'"' {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
buf.push(bytes[i + 1] as char);
i += 2;
} else {
buf.push(bytes[i] as char);
i += 1;
}
}
if i >= bytes.len() {
return Err(EvalError::TypeMismatch {
detail: "unterminated quoted element in TEXT[] literal".into(),
});
}
i += 1;
out.push(Some(buf));
} else {
let start = i;
while i < bytes.len() && bytes[i] != b',' {
i += 1;
}
let raw = inner[start..i].trim();
if raw.eq_ignore_ascii_case("NULL") {
out.push(None);
} else {
out.push(Some(raw.to_string()));
}
}
while i < bytes.len() && (bytes[i] == b' ' || bytes[i] == b'\t') {
i += 1;
}
if i >= bytes.len() {
break;
}
if bytes[i] != b',' {
return Err(EvalError::TypeMismatch {
detail: "expected ',' between TEXT[] elements".into(),
});
}
i += 1;
}
Ok(out)
}
fn cast_to_interval(v: Value) -> Result<Value, EvalError> {
match v {
Value::Interval { months, micros } => Ok(Value::Interval { months, micros }),
Value::Text(s) => {
let (months, micros) = spg_sql::parser::parse_interval_text(&s)
.ok_or_else(|| EvalError::TypeMismatch {
detail: alloc::format!("cannot parse {s:?} as INTERVAL"),
})?;
Ok(Value::Interval { months, micros })
}
other => Err(EvalError::TypeMismatch {
detail: alloc::format!(
"::INTERVAL only accepts TEXT-shape inputs, got {:?}",
other.data_type()
),
}),
}
}
fn cast_to_date(v: Value) -> Result<Value, EvalError> {
match v {
Value::Date(d) => Ok(Value::Date(d)),
Value::Int(n) => Ok(Value::Date(n)),
Value::BigInt(n) => {
i32::try_from(n)
.map(Value::Date)
.map_err(|_| EvalError::TypeMismatch {
detail: "bigint days-since-epoch out of DATE range".into(),
})
}
Value::Timestamp(t) => {
let days = t.div_euclid(86_400_000_000);
i32::try_from(days)
.map(Value::Date)
.map_err(|_| EvalError::TypeMismatch {
detail: "timestamp out of DATE range".into(),
})
}
Value::Text(s) => parse_date_literal(&s)
.map(Value::Date)
.ok_or(EvalError::TypeMismatch {
detail: format!("cannot parse {s:?} as DATE (expected YYYY-MM-DD)"),
}),
other => Err(EvalError::TypeMismatch {
detail: format!("cannot cast {:?} to DATE", other.data_type()),
}),
}
}
fn cast_to_timestamp(v: Value) -> Result<Value, EvalError> {
match v {
Value::Timestamp(t) => Ok(Value::Timestamp(t)),
Value::Int(n) => Ok(Value::Timestamp(i64::from(n))),
Value::BigInt(n) => Ok(Value::Timestamp(n)),
Value::Date(d) => Ok(Value::Timestamp(i64::from(d) * 86_400_000_000)),
Value::Text(s) => {
parse_timestamp_literal(&s)
.map(Value::Timestamp)
.ok_or(EvalError::TypeMismatch {
detail: format!(
"cannot parse {s:?} as TIMESTAMP \
(expected YYYY-MM-DD[ HH:MM:SS[.ffffff]])"
),
})
}
other => Err(EvalError::TypeMismatch {
detail: format!("cannot cast {:?} to TIMESTAMP", other.data_type()),
}),
}
}
fn value_to_text(v: &Value) -> String {
match v {
Value::SmallInt(n) => format!("{n}"),
Value::Int(n) => format!("{n}"),
Value::BigInt(n) => format!("{n}"),
Value::Float(x) => format!("{x}"),
Value::Text(s) | Value::Json(s) => s.clone(),
Value::Bool(b) => (if *b { "true" } else { "false" }).into(),
Value::Vector(v) => {
let cells: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
format!("[{}]", cells.join(", "))
}
Value::Sq8Vector(q) => {
let cells: Vec<String> = spg_storage::quantize::dequantize(q)
.iter()
.map(|x| format!("{x}"))
.collect();
format!("[{}]", cells.join(", "))
}
Value::HalfVector(h) => {
let cells: Vec<String> = h.to_f32_vec().iter().map(|x| format!("{x}")).collect();
format!("[{}]", cells.join(", "))
}
Value::Numeric { scaled, scale } => format_numeric(*scaled, *scale),
Value::Date(d) => format_date(*d),
Value::Timestamp(t) => format_timestamp(*t),
Value::Interval { months, micros } => format_interval(*months, *micros),
Value::Null => "NULL".into(),
_ => format!("{v:?}"),
}
}
pub fn format_date(days: i32) -> String {
let (y, m, d) = civil_from_days(days);
format!("{y:04}-{m:02}-{d:02}")
}
pub fn format_timestamp(micros: i64) -> String {
const MICROS_PER_DAY: i64 = 86_400_000_000;
let days = micros.div_euclid(MICROS_PER_DAY);
let day_micros = micros.rem_euclid(MICROS_PER_DAY);
let day_i32 = i32::try_from(days).unwrap_or(i32::MAX);
let (y, m, d) = civil_from_days(day_i32);
let secs = day_micros / 1_000_000;
let frac = day_micros % 1_000_000;
let hh = secs / 3600;
let mm = (secs / 60) % 60;
let ss = secs % 60;
if frac == 0 {
format!("{y:04}-{m:02}-{d:02} {hh:02}:{mm:02}:{ss:02}")
} else {
let raw = format!("{frac:06}");
let trimmed = raw.trim_end_matches('0');
format!("{y:04}-{m:02}-{d:02} {hh:02}:{mm:02}:{ss:02}.{trimmed}")
}
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn civil_from_days(days: i32) -> (i32, u32, u32) {
let z = i64::from(days) + 719_468;
let era = z.div_euclid(146_097);
let doe = (z - era * 146_097) as u32;
let yoe = (doe.saturating_sub(doe / 1460) + doe / 36524 - doe / 146_096) / 365;
let y_base = i64::from(yoe) + era * 400;
let doy = doe.saturating_sub(365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy.saturating_sub((153 * mp + 2) / 5) + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let y = if m <= 2 { y_base + 1 } else { y_base };
(y as i32, m, d)
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
pub fn days_from_civil(y: i32, m: u32, d: u32) -> i32 {
let y_adj = if m <= 2 {
i64::from(y) - 1
} else {
i64::from(y)
};
let era = y_adj.div_euclid(400);
let yoe = (y_adj - era * 400) as u32;
let doy = (153 * (if m > 2 { m - 3 } else { m + 9 }) + 2) / 5 + d.saturating_sub(1);
let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
let total = era * 146_097 + i64::from(doe) - 719_468;
i32::try_from(total).unwrap_or(i32::MAX)
}
pub fn parse_date_literal(s: &str) -> Option<i32> {
let bytes = s.as_bytes();
if bytes.len() != 10 || bytes[4] != b'-' || bytes[7] != b'-' {
return None;
}
let y: i32 = s[0..4].parse().ok()?;
let m: u32 = s[5..7].parse().ok()?;
let d: u32 = s[8..10].parse().ok()?;
if !(1..=12).contains(&m) || !(1..=31).contains(&d) {
return None;
}
Some(days_from_civil(y, m, d))
}
pub fn parse_timestamp_literal(s: &str) -> Option<i64> {
let trimmed = s.trim();
let (date_part, time_part) = match trimmed.find([' ', 'T']) {
Some(i) => (&trimmed[..i], Some(&trimmed[i + 1..])),
None => (trimmed, None),
};
let days = parse_date_literal(date_part)?;
let day_micros = match time_part {
None => 0,
Some(t) => parse_time_of_day_micros(t)?,
};
Some(i64::from(days) * 86_400_000_000 + day_micros)
}
fn parse_time_of_day_micros(t: &str) -> Option<i64> {
let (time, frac_str) = match t.split_once('.') {
Some((a, b)) => (a, Some(b)),
None => (t, None),
};
let bytes = time.as_bytes();
if bytes.len() != 8 || bytes[2] != b':' || bytes[5] != b':' {
return None;
}
let hh: i64 = time[0..2].parse().ok()?;
let mm: i64 = time[3..5].parse().ok()?;
let ss: i64 = time[6..8].parse().ok()?;
if !(0..24).contains(&hh) || !(0..60).contains(&mm) || !(0..60).contains(&ss) {
return None;
}
let frac_micros: i64 = match frac_str {
None => 0,
Some(f) => {
if f.is_empty() || f.len() > 9 {
return None;
}
let mut padded = String::with_capacity(6);
padded.push_str(&f[..f.len().min(6)]);
while padded.len() < 6 {
padded.push('0');
}
padded.parse().ok()?
}
};
Some(((hh * 3600 + mm * 60 + ss) * 1_000_000) + frac_micros)
}
pub fn format_interval(months: i32, micros: i64) -> String {
const MICROS_PER_DAY: i64 = 86_400_000_000;
let mut parts: Vec<String> = Vec::new();
let years = months / 12;
let mons = months % 12;
let unit = |n: i64, singular: &'static str, plural: &'static str| -> &'static str {
if n == 1 { singular } else { plural }
};
if years != 0 {
parts.push(format!(
"{years} {}",
unit(i64::from(years), "year", "years")
));
}
if mons != 0 {
parts.push(format!("{mons} {}", unit(i64::from(mons), "mon", "mons")));
}
let days = micros / MICROS_PER_DAY;
let mut rem = micros % MICROS_PER_DAY;
if days != 0 {
parts.push(format!("{days} {}", unit(days, "day", "days")));
}
if rem != 0 {
let neg = rem < 0;
if neg {
rem = -rem;
}
let secs = rem / 1_000_000;
let frac = rem % 1_000_000;
let hh = secs / 3600;
let mm = (secs / 60) % 60;
let ss = secs % 60;
let sign = if neg { "-" } else { "" };
if frac == 0 {
parts.push(format!("{sign}{hh:02}:{mm:02}:{ss:02}"));
} else {
let raw = format!("{frac:06}");
let trimmed = raw.trim_end_matches('0');
parts.push(format!("{sign}{hh:02}:{mm:02}:{ss:02}.{trimmed}"));
}
}
if parts.is_empty() {
"0".into()
} else {
parts.join(" ")
}
}
fn add_months_to_civil(y: i32, m: u32, d: u32, months: i32) -> (i32, u32, u32) {
let total_months = i64::from(y) * 12 + i64::from(m) - 1 + i64::from(months);
let new_year = i32::try_from(total_months.div_euclid(12)).unwrap_or(i32::MAX);
let new_month_zero = total_months.rem_euclid(12);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let new_month = (new_month_zero as u32) + 1;
let max_day = days_in_month(new_year, new_month);
(new_year, new_month, d.min(max_day))
}
const fn days_in_month(y: i32, m: u32) -> u32 {
match m {
1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
2 => {
if y.rem_euclid(4) == 0 && (y.rem_euclid(100) != 0 || y.rem_euclid(400) == 0) {
29
} else {
28
}
}
_ => 30,
}
}
pub fn format_text_array(items: &[Option<String>]) -> String {
let mut out = String::with_capacity(2 + items.len() * 8);
out.push('{');
for (i, item) in items.iter().enumerate() {
if i > 0 {
out.push(',');
}
match item {
None => out.push_str("NULL"),
Some(s) => {
let needs_quote = s.is_empty()
|| s.eq_ignore_ascii_case("NULL")
|| s.chars().any(|c| matches!(c, ',' | '{' | '}' | '"' | '\\' | ' ' | '\t'));
if needs_quote {
out.push('"');
for c in s.chars() {
if c == '"' || c == '\\' {
out.push('\\');
}
out.push(c);
}
out.push('"');
} else {
out.push_str(s);
}
}
}
}
out.push('}');
out
}
pub fn format_bytea_hex(b: &[u8]) -> String {
let mut out = String::with_capacity(2 + 2 * b.len());
out.push_str("\\x");
const HEX: &[u8; 16] = b"0123456789abcdef";
for byte in b {
out.push(HEX[(byte >> 4) as usize] as char);
out.push(HEX[(byte & 0x0F) as usize] as char);
}
out
}
pub fn format_numeric(scaled: i128, scale: u8) -> String {
if scale == 0 {
return format!("{scaled}");
}
let negative = scaled < 0;
let mag_str = scaled.unsigned_abs().to_string();
let mag_bytes = mag_str.as_bytes();
let scale_u = scale as usize;
let mut out = String::with_capacity(mag_str.len() + 3);
if negative {
out.push('-');
}
if mag_bytes.len() <= scale_u {
out.push('0');
out.push('.');
for _ in mag_bytes.len()..scale_u {
out.push('0');
}
out.push_str(&mag_str);
} else {
let split = mag_bytes.len() - scale_u;
out.push_str(&mag_str[..split]);
out.push('.');
out.push_str(&mag_str[split..]);
}
out
}
fn cast_numeric_to_int(v: Value) -> Result<Value, EvalError> {
match v {
Value::Int(n) => Ok(Value::Int(n)),
Value::BigInt(n) => i32::try_from(n)
.map(Value::Int)
.map_err(|_| EvalError::TypeMismatch {
detail: format!("bigint {n} does not fit in int"),
}),
#[allow(clippy::cast_possible_truncation)]
Value::Float(x) => Ok(Value::Int(x as i32)),
Value::Text(s) => {
s.trim()
.parse::<i32>()
.map(Value::Int)
.map_err(|_| EvalError::TypeMismatch {
detail: format!("cannot parse {s:?} as int"),
})
}
Value::Bool(b) => Ok(Value::Int(i32::from(b))),
other => Err(EvalError::TypeMismatch {
detail: format!("cannot cast {:?} to int", other.data_type()),
}),
}
}
fn cast_numeric_to_bigint(v: Value) -> Result<Value, EvalError> {
match v {
Value::Int(n) => Ok(Value::BigInt(i64::from(n))),
Value::BigInt(n) => Ok(Value::BigInt(n)),
#[allow(clippy::cast_possible_truncation)]
Value::Float(x) => Ok(Value::BigInt(x as i64)),
Value::Text(s) => {
s.trim()
.parse::<i64>()
.map(Value::BigInt)
.map_err(|_| EvalError::TypeMismatch {
detail: format!("cannot parse {s:?} as bigint"),
})
}
Value::Bool(b) => Ok(Value::BigInt(i64::from(b))),
other => Err(EvalError::TypeMismatch {
detail: format!("cannot cast {:?} to bigint", other.data_type()),
}),
}
}
fn cast_numeric_to_float(v: Value) -> Result<Value, EvalError> {
match v {
Value::Int(n) => Ok(Value::Float(f64::from(n))),
#[allow(clippy::cast_precision_loss)]
Value::BigInt(n) => Ok(Value::Float(n as f64)),
Value::Float(x) => Ok(Value::Float(x)),
Value::Text(s) => {
s.trim()
.parse::<f64>()
.map(Value::Float)
.map_err(|_| EvalError::TypeMismatch {
detail: format!("cannot parse {s:?} as float"),
})
}
other => Err(EvalError::TypeMismatch {
detail: format!("cannot cast {:?} to float", other.data_type()),
}),
}
}
fn cast_to_bool(v: Value) -> Result<Value, EvalError> {
match v {
Value::Bool(b) => Ok(Value::Bool(b)),
Value::Int(n) => Ok(Value::Bool(n != 0)),
Value::BigInt(n) => Ok(Value::Bool(n != 0)),
Value::Text(s) => {
let lo = s.trim().to_ascii_lowercase();
match lo.as_str() {
"true" | "t" | "yes" | "y" | "1" | "on" => Ok(Value::Bool(true)),
"false" | "f" | "no" | "n" | "0" | "off" => Ok(Value::Bool(false)),
_ => Err(EvalError::TypeMismatch {
detail: format!("cannot parse {s:?} as bool"),
}),
}
}
other => Err(EvalError::TypeMismatch {
detail: format!("cannot cast {:?} to bool", other.data_type()),
}),
}
}
pub fn cast_to_vector(v: Value) -> Result<Value, EvalError> {
match v {
Value::Null => Ok(Value::Null),
Value::Vector(v) => Ok(Value::Vector(v)),
Value::Text(s) => parse_vector_text(&s)
.map(Value::Vector)
.ok_or(EvalError::TypeMismatch {
detail: format!("cannot parse {s:?} as a vector literal"),
}),
other => Err(EvalError::TypeMismatch {
detail: format!("::vector requires text input, got {:?}", other.data_type()),
}),
}
}
fn parse_vector_text(s: &str) -> Option<Vec<f32>> {
let trimmed = s.trim();
let inner = trimmed.strip_prefix('[')?.strip_suffix(']')?;
let trimmed_inner = inner.trim();
if trimmed_inner.is_empty() {
return Some(Vec::new());
}
let mut out = Vec::new();
for part in trimmed_inner.split(',') {
let f: f32 = part.trim().parse().ok()?;
out.push(f);
}
Some(out)
}
fn literal_to_value(l: &Literal) -> Value {
match l {
Literal::Integer(n) => {
if let Ok(small) = i32::try_from(*n) {
Value::Int(small)
} else {
Value::BigInt(*n)
}
}
Literal::Float(x) => Value::Float(*x),
Literal::String(s) => Value::Text(s.clone()),
Literal::Vector(v) => Value::Vector(v.clone()),
Literal::Bool(b) => Value::Bool(*b),
Literal::Null => Value::Null,
Literal::Interval { months, micros, .. } => Value::Interval {
months: *months,
micros: *micros,
},
}
}
fn resolve_column(c: &ColumnName, row: &Row, ctx: &EvalContext<'_>) -> Result<Value, EvalError> {
if let Some(q) = &c.qualifier {
let composite = alloc::format!("{q}.{name}", name = c.name);
if let Some(pos) = ctx.columns.iter().position(|s| s.name == composite) {
return Ok(row.values[pos].clone());
}
let expected = ctx.table_alias.ok_or_else(|| EvalError::UnknownQualifier {
qualifier: q.clone(),
})?;
if q != expected {
return Err(EvalError::UnknownQualifier {
qualifier: q.clone(),
});
}
}
if let Some(pos) = ctx.columns.iter().position(|s| s.name == c.name) {
return Ok(row.values[pos].clone());
}
let suffix = alloc::format!(".{name}", name = c.name);
let mut matches = ctx
.columns
.iter()
.enumerate()
.filter(|(_, s)| s.name.ends_with(&suffix));
let first = matches.next();
let extra = matches.next();
match (first, extra) {
(Some((pos, _)), None) => Ok(row.values[pos].clone()),
(Some(_), Some(_)) => Err(EvalError::TypeMismatch {
detail: alloc::format!("ambiguous column reference: {}", c.name),
}),
_ => Err(EvalError::ColumnNotFound {
name: c.name.clone(),
}),
}
}
fn apply_unary(op: UnOp, v: Value) -> Result<Value, EvalError> {
match (op, v) {
(_, Value::Null) => Ok(Value::Null),
(UnOp::Neg, Value::Int(n)) => {
n.checked_neg()
.map(Value::Int)
.ok_or(EvalError::TypeMismatch {
detail: "integer overflow on unary -".into(),
})
}
(UnOp::Neg, Value::BigInt(n)) => {
n.checked_neg()
.map(Value::BigInt)
.ok_or(EvalError::TypeMismatch {
detail: "bigint overflow on unary -".into(),
})
}
(UnOp::Neg, Value::Float(x)) => Ok(Value::Float(-x)),
(UnOp::Neg, other) => Err(EvalError::TypeMismatch {
detail: format!("unary - applied to {:?}", other.data_type()),
}),
(UnOp::Not, Value::Bool(b)) => Ok(Value::Bool(!b)),
(UnOp::Not, other) => Err(EvalError::TypeMismatch {
detail: format!("NOT applied to {:?}", other.data_type()),
}),
}
}
fn values_not_distinct(l: &Value, r: &Value) -> bool {
match (l, r) {
(Value::Null, Value::Null) => true,
(Value::Null, _) | (_, Value::Null) => false,
_ => l == r,
}
}
fn apply_binary(op: BinOp, l: Value, r: Value) -> Result<Value, EvalError> {
if let BinOp::And = op {
return and_3vl(l, r);
}
if let BinOp::Or = op {
return or_3vl(l, r);
}
if let BinOp::IsNotDistinctFrom = op {
return Ok(Value::Bool(values_not_distinct(&l, &r)));
}
if let BinOp::IsDistinctFrom = op {
return Ok(Value::Bool(!values_not_distinct(&l, &r)));
}
if l.is_null() || r.is_null() {
return Ok(Value::Null);
}
if matches!(l, Value::Numeric { .. }) || matches!(r, Value::Numeric { .. }) {
return apply_binary_numeric(op, l, r);
}
if let Some(result) = apply_binary_calendar(op, &l, &r)? {
return Ok(result);
}
match op {
BinOp::Add => arith(l, r, i64::checked_add, |a, b| a + b, "+"),
BinOp::Sub => arith(l, r, i64::checked_sub, |a, b| a - b, "-"),
BinOp::Mul => arith(l, r, i64::checked_mul, |a, b| a * b, "*"),
BinOp::Div => div_op(l, r),
BinOp::L2Distance => l2_distance(l, r),
BinOp::InnerProduct => inner_product(l, r),
BinOp::CosineDistance => cosine_distance(l, r),
BinOp::Concat => Ok(text_concat(&l, &r)),
BinOp::JsonGet => crate::json::path_get(&l, &r, false),
BinOp::JsonGetText => crate::json::path_get(&l, &r, true),
BinOp::JsonGetPath => crate::json::path_walk(&l, &r, false),
BinOp::JsonGetPathText => crate::json::path_walk(&l, &r, true),
BinOp::JsonContains => crate::json::contains(&l, &r),
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
compare(op, &l, &r)
}
BinOp::And
| BinOp::Or
| BinOp::IsDistinctFrom
| BinOp::IsNotDistinctFrom => unreachable!("handled above"),
}
}
fn apply_binary_calendar(op: BinOp, l: &Value, r: &Value) -> Result<Option<Value>, EvalError> {
let int_value = |v: &Value| -> Option<i64> {
match v {
Value::SmallInt(n) => Some(i64::from(*n)),
Value::Int(n) => Some(i64::from(*n)),
Value::BigInt(n) => Some(*n),
_ => None,
}
};
match (l, r) {
(Value::Date(a), Value::Date(b)) if op == BinOp::Sub => {
return Ok(Some(Value::BigInt(i64::from(*a) - i64::from(*b))));
}
(Value::Timestamp(a), Value::Timestamp(b)) if op == BinOp::Sub => {
let delta = a.checked_sub(*b).ok_or(EvalError::TypeMismatch {
detail: "TIMESTAMP - TIMESTAMP overflows i64 microseconds".into(),
})?;
return Ok(Some(Value::BigInt(delta)));
}
_ => {}
}
if let Some(out) = apply_binary_interval(op, l, r)? {
return Ok(Some(out));
}
match (l, r) {
(Value::Date(d), other) if op == BinOp::Add => {
if let Some(n) = int_value(other) {
let days = i64::from(*d).saturating_add(n);
let days32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "DATE + integer overflows DATE range".into(),
})?;
return Ok(Some(Value::Date(days32)));
}
}
(other, Value::Date(d)) if op == BinOp::Add => {
if let Some(n) = int_value(other) {
let days = i64::from(*d).saturating_add(n);
let days32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "integer + DATE overflows DATE range".into(),
})?;
return Ok(Some(Value::Date(days32)));
}
}
(Value::Date(d), other) if op == BinOp::Sub => {
if let Some(n) = int_value(other) {
let days = i64::from(*d).saturating_sub(n);
let days32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "DATE - integer overflows DATE range".into(),
})?;
return Ok(Some(Value::Date(days32)));
}
}
_ => {}
}
Ok(None)
}
fn apply_binary_interval(op: BinOp, l: &Value, r: &Value) -> Result<Option<Value>, EvalError> {
let (lhs, rhs, sign): (&Value, &Value, i64) = match (l, r, op) {
(Value::Interval { .. }, _, BinOp::Add) => (r, l, 1),
(_, Value::Interval { .. }, BinOp::Add) => (l, r, 1),
(_, Value::Interval { .. }, BinOp::Sub) => (l, r, -1),
_ => return Ok(None),
};
let Value::Interval {
months: rhs_months,
micros: rhs_us,
} = rhs
else {
unreachable!("rhs guaranteed to be Interval by the match above");
};
let signed_months = i64::from(*rhs_months) * sign;
let signed_micros = rhs_us.checked_mul(sign).ok_or(EvalError::TypeMismatch {
detail: "INTERVAL micros overflows on negation".into(),
})?;
match lhs {
Value::Timestamp(t) => Ok(Some(Value::Timestamp(add_interval_to_micros(
*t,
signed_months,
signed_micros,
)?))),
Value::Date(d) => {
let day_aligned = signed_micros.rem_euclid(86_400_000_000) == 0;
if day_aligned {
let micros_per_day = 86_400_000_000_i64;
let days_delta = signed_micros / micros_per_day;
let shifted = shift_date_by_months(*d, signed_months)?;
let new_days =
i64::from(shifted)
.checked_add(days_delta)
.ok_or(EvalError::TypeMismatch {
detail: "DATE ± INTERVAL overflows DATE range".into(),
})?;
let days32 = i32::try_from(new_days).map_err(|_| EvalError::TypeMismatch {
detail: "DATE ± INTERVAL overflows DATE range".into(),
})?;
Ok(Some(Value::Date(days32)))
} else {
let base =
i64::from(*d)
.checked_mul(86_400_000_000)
.ok_or(EvalError::TypeMismatch {
detail: "DATE → TIMESTAMP lift overflows for INTERVAL math".into(),
})?;
Ok(Some(Value::Timestamp(add_interval_to_micros(
base,
signed_months,
signed_micros,
)?)))
}
}
Value::Interval {
months: lhs_months,
micros: lhs_us,
} => {
let new_months = i64::from(*lhs_months)
.checked_add(signed_months)
.and_then(|n| i32::try_from(n).ok())
.ok_or(EvalError::TypeMismatch {
detail: "INTERVAL ± INTERVAL months overflows i32".into(),
})?;
let new_micros = lhs_us
.checked_add(signed_micros)
.ok_or(EvalError::TypeMismatch {
detail: "INTERVAL ± INTERVAL micros overflows i64".into(),
})?;
Ok(Some(Value::Interval {
months: new_months,
micros: new_micros,
}))
}
_ => Err(EvalError::TypeMismatch {
detail: format!(
"operator {op:?} not defined for {:?} and INTERVAL",
lhs.data_type()
),
}),
}
}
fn shift_date_by_months(d: i32, months: i64) -> Result<i32, EvalError> {
let (y, m, day) = civil_from_days(d);
let months_i32 = i32::try_from(months).map_err(|_| EvalError::TypeMismatch {
detail: "INTERVAL months delta out of i32 range".into(),
})?;
let (ny, nm, nd) = add_months_to_civil(y, m, day, months_i32);
Ok(days_from_civil(ny, nm, nd))
}
fn add_interval_to_micros(t: i64, months: i64, micros: i64) -> Result<i64, EvalError> {
let mut out = t;
if months != 0 {
const MICROS_PER_DAY: i64 = 86_400_000_000;
let days = out.div_euclid(MICROS_PER_DAY);
let day_micros = out.rem_euclid(MICROS_PER_DAY);
let day_i32 = i32::try_from(days).map_err(|_| EvalError::TypeMismatch {
detail: "TIMESTAMP day component out of i32 range for INTERVAL months math".into(),
})?;
let shifted_days = shift_date_by_months(day_i32, months)?;
out = i64::from(shifted_days)
.checked_mul(MICROS_PER_DAY)
.and_then(|n| n.checked_add(day_micros))
.ok_or(EvalError::TypeMismatch {
detail: "TIMESTAMP ± INTERVAL months overflows i64 microseconds".into(),
})?;
}
out.checked_add(micros).ok_or(EvalError::TypeMismatch {
detail: "TIMESTAMP ± INTERVAL micros overflows i64".into(),
})
}
#[allow(clippy::needless_pass_by_value)] fn apply_binary_numeric(op: BinOp, l: Value, r: Value) -> Result<Value, EvalError> {
let float_path = matches!(l, Value::Float(_)) || matches!(r, Value::Float(_));
if float_path {
let af = as_f64(&l)?;
let bf = as_f64(&r)?;
return match op {
BinOp::Add => Ok(Value::Float(af + bf)),
BinOp::Sub => Ok(Value::Float(af - bf)),
BinOp::Mul => Ok(Value::Float(af * bf)),
BinOp::Div => {
if bf == 0.0 {
Err(EvalError::DivisionByZero)
} else {
Ok(Value::Float(af / bf))
}
}
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
let ord = af.partial_cmp(&bf).ok_or(EvalError::TypeMismatch {
detail: "NaN in NUMERIC/Float comparison".into(),
})?;
Ok(Value::Bool(cmp_to_bool(op, ord)))
}
BinOp::Concat => Ok(text_concat(&l, &r)),
other => Err(EvalError::TypeMismatch {
detail: format!("operator {other:?} not defined for NUMERIC and Float"),
}),
};
}
let (a, sa) = numeric_or_widen(&l).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("NUMERIC op against non-numeric {:?}", l.data_type()),
})?;
let (b, sb) = numeric_or_widen(&r).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("NUMERIC op against non-numeric {:?}", r.data_type()),
})?;
match op {
BinOp::Add | BinOp::Sub => {
let target_scale = sa.max(sb);
let lhs = rescale(a, sa, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
let rhs = rescale(b, sb, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
let r = match op {
BinOp::Add => lhs.checked_add(rhs),
BinOp::Sub => lhs.checked_sub(rhs),
_ => unreachable!(),
}
.ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on +/-".into(),
})?;
Ok(Value::Numeric {
scaled: r,
scale: target_scale,
})
}
BinOp::Mul => {
let scaled = a.checked_mul(b).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on *".into(),
})?;
Ok(Value::Numeric {
scaled,
scale: sa.saturating_add(sb),
})
}
BinOp::Div => {
if b == 0 {
return Err(EvalError::DivisionByZero);
}
let target_scale = sa.max(sb);
let bump = pow10_i128(target_scale.saturating_add(sb).saturating_sub(sa));
let num = a.checked_mul(bump).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on / scaling".into(),
})?;
let half = if b >= 0 { b / 2 } else { -(b / 2) };
let adj = if (num >= 0) == (b >= 0) {
num + half
} else {
num - half
};
Ok(Value::Numeric {
scaled: adj / b,
scale: target_scale,
})
}
BinOp::Eq | BinOp::NotEq | BinOp::Lt | BinOp::LtEq | BinOp::Gt | BinOp::GtEq => {
let target_scale = sa.max(sb);
let lhs = rescale(a, sa, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
let rhs = rescale(b, sb, target_scale).ok_or(EvalError::TypeMismatch {
detail: "NUMERIC overflow on rescale".into(),
})?;
Ok(Value::Bool(cmp_to_bool(op, lhs.cmp(&rhs))))
}
BinOp::Concat => Ok(text_concat(&l, &r)),
other => Err(EvalError::TypeMismatch {
detail: format!("operator {other:?} not defined for NUMERIC"),
}),
}
}
fn numeric_or_widen(v: &Value) -> Option<(i128, u8)> {
match v {
Value::Numeric { scaled, scale } => Some((*scaled, *scale)),
Value::Int(n) => Some((i128::from(*n), 0)),
Value::SmallInt(n) => Some((i128::from(*n), 0)),
Value::BigInt(n) => Some((i128::from(*n), 0)),
_ => None,
}
}
fn rescale(scaled: i128, src: u8, dst: u8) -> Option<i128> {
if src == dst {
return Some(scaled);
}
if dst > src {
scaled.checked_mul(pow10_i128(dst - src))
} else {
let drop = pow10_i128(src - dst);
let half = drop / 2;
let r = if scaled >= 0 {
scaled + half
} else {
scaled - half
};
Some(r / drop)
}
}
const fn pow10_i128(p: u8) -> i128 {
let mut acc: i128 = 1;
let mut i = 0;
while i < p {
acc *= 10;
i += 1;
}
acc
}
const fn cmp_to_bool(op: BinOp, ord: core::cmp::Ordering) -> bool {
use core::cmp::Ordering::{Equal, Greater, Less};
match op {
BinOp::Eq => matches!(ord, Equal),
BinOp::NotEq => !matches!(ord, Equal),
BinOp::Lt => matches!(ord, Less),
BinOp::LtEq => matches!(ord, Less | Equal),
BinOp::Gt => matches!(ord, Greater),
BinOp::GtEq => matches!(ord, Greater | Equal),
_ => false,
}
}
fn text_concat(l: &Value, r: &Value) -> Value {
match (l, r) {
(Value::Null, _) | (_, Value::Null) => {
if matches!(l, Value::TextArray(_)) || matches!(r, Value::TextArray(_)) {
return Value::Null;
}
}
(Value::TextArray(a), Value::TextArray(b)) => {
let mut out = a.clone();
out.extend(b.iter().cloned());
return Value::TextArray(out);
}
(Value::TextArray(a), Value::Text(s)) => {
let mut out = a.clone();
out.push(Some(s.clone()));
return Value::TextArray(out);
}
(Value::Text(s), Value::TextArray(b)) => {
let mut out: alloc::vec::Vec<Option<alloc::string::String>> =
alloc::vec::Vec::with_capacity(1 + b.len());
out.push(Some(s.clone()));
out.extend(b.iter().cloned());
return Value::TextArray(out);
}
_ => {}
}
let a = value_to_text(l);
let b = value_to_text(r);
Value::Text(a + &b)
}
fn inner_product(l: Value, r: Value) -> Result<Value, EvalError> {
let (a, b) = unwrap_vec_pair(l, r, "<#>")?;
let mut dot: f64 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
dot += f64::from(*x) * f64::from(*y);
}
Ok(Value::Float(-dot))
}
fn cosine_distance(l: Value, r: Value) -> Result<Value, EvalError> {
let (a, b) = unwrap_vec_pair(l, r, "<=>")?;
let mut dot: f64 = 0.0;
let mut na: f64 = 0.0;
let mut nb: f64 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let xf = f64::from(*x);
let yf = f64::from(*y);
dot += xf * yf;
na += xf * xf;
nb += yf * yf;
}
let denom = sqrt_newton(na) * sqrt_newton(nb);
if denom == 0.0 {
return Ok(Value::Float(f64::NAN));
}
Ok(Value::Float(1.0 - dot / denom))
}
fn unwrap_vec_pair(l: Value, r: Value, op: &str) -> Result<(Vec<f32>, Vec<f32>), EvalError> {
let to_f32 = |v: Value| -> Option<Vec<f32>> {
match v {
Value::Vector(a) => Some(a),
Value::Sq8Vector(q) => Some(spg_storage::quantize::dequantize(&q)),
Value::HalfVector(h) => Some(h.to_f32_vec()),
_ => None,
}
};
let l_ty = l.data_type();
let r_ty = r.data_type();
match (to_f32(l), to_f32(r)) {
(Some(a), Some(b)) => {
if a.len() != b.len() {
return Err(EvalError::TypeMismatch {
detail: format!("vector dim mismatch in {op}: {} vs {}", a.len(), b.len()),
});
}
Ok((a, b))
}
_ => Err(EvalError::TypeMismatch {
detail: format!("{op} requires two vectors, got {l_ty:?} and {r_ty:?}"),
}),
}
}
fn arith(
l: Value,
r: Value,
int_op: impl Fn(i64, i64) -> Option<i64>,
float_op: impl Fn(f64, f64) -> f64,
op_name: &str,
) -> Result<Value, EvalError> {
let widen = |v: Value| -> Value {
match v {
Value::SmallInt(n) => Value::Int(i32::from(n)),
other => other,
}
};
let l = widen(l);
let r = widen(r);
match (l, r) {
(Value::Int(a), Value::Int(b)) => {
let result = int_op(i64::from(a), i64::from(b)).ok_or(EvalError::TypeMismatch {
detail: format!("integer overflow on {op_name}"),
})?;
if let Ok(small) = i32::try_from(result) {
Ok(Value::Int(small))
} else {
Ok(Value::BigInt(result))
}
}
(Value::Int(a), Value::BigInt(b)) | (Value::BigInt(b), Value::Int(a)) => {
let result = int_op(i64::from(a), b).ok_or(EvalError::TypeMismatch {
detail: format!("bigint overflow on {op_name}"),
})?;
Ok(Value::BigInt(result))
}
(Value::BigInt(a), Value::BigInt(b)) => {
let result = int_op(a, b).ok_or(EvalError::TypeMismatch {
detail: format!("bigint overflow on {op_name}"),
})?;
Ok(Value::BigInt(result))
}
(a, b)
if a.data_type() == Some(DataType::Float) || b.data_type() == Some(DataType::Float) =>
{
let af = as_f64(&a)?;
let bf = as_f64(&b)?;
Ok(Value::Float(float_op(af, bf)))
}
(a, b) => Err(EvalError::TypeMismatch {
detail: format!(
"{op_name} applied to non-numeric: {:?} vs {:?}",
a.data_type(),
b.data_type()
),
}),
}
}
#[allow(clippy::many_single_char_names)] fn l2_distance(l: Value, r: Value) -> Result<Value, EvalError> {
let (a, b) = unwrap_vec_pair(l, r, "<->")?;
let mut sum: f64 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let d = f64::from(*x) - f64::from(*y);
sum += d * d;
}
Ok(Value::Float(sqrt_newton(sum)))
}
fn sqrt_newton(x: f64) -> f64 {
if x <= 0.0 {
return 0.0;
}
let mut g = x;
for _ in 0..10 {
g = 0.5 * (g + x / g);
}
g
}
fn div_op(l: Value, r: Value) -> Result<Value, EvalError> {
let any_float = matches!(l.data_type(), Some(DataType::Float))
|| matches!(r.data_type(), Some(DataType::Float));
if any_float {
let a = as_f64(&l)?;
let b = as_f64(&r)?;
if b == 0.0 {
return Err(EvalError::DivisionByZero);
}
return Ok(Value::Float(a / b));
}
arith(
l,
r,
|a, b| {
if b == 0 { None } else { Some(a / b) }
},
|a, b| a / b,
"/",
)
.map_err(|e| match e {
EvalError::TypeMismatch { detail } if detail.contains('/') => EvalError::DivisionByZero,
other => other,
})
}
fn as_f64(v: &Value) -> Result<f64, EvalError> {
match v {
Value::SmallInt(n) => Ok(f64::from(*n)),
Value::Int(n) => Ok(f64::from(*n)),
#[allow(clippy::cast_precision_loss)]
Value::BigInt(n) => Ok(*n as f64),
Value::Float(x) => Ok(*x),
#[allow(clippy::cast_precision_loss)]
Value::Numeric { scaled, scale } => {
let mut div = 1.0_f64;
for _ in 0..*scale {
div *= 10.0;
}
Ok((*scaled as f64) / div)
}
other => Err(EvalError::TypeMismatch {
detail: format!("cannot convert {:?} to FLOAT", other.data_type()),
}),
}
}
fn compare(op: BinOp, l: &Value, r: &Value) -> Result<Value, EvalError> {
let ord = match (l, r) {
(Value::Int(a), Value::Int(b)) => i64::from(*a).cmp(&i64::from(*b)),
(Value::Int(a), Value::BigInt(b)) => i64::from(*a).cmp(b),
(Value::BigInt(a), Value::Int(b)) => a.cmp(&i64::from(*b)),
(Value::BigInt(a), Value::BigInt(b)) => a.cmp(b),
(a, b)
if matches!(a.data_type(), Some(DataType::Float))
|| matches!(b.data_type(), Some(DataType::Float)) =>
{
let af = as_f64(a)?;
let bf = as_f64(b)?;
af.partial_cmp(&bf).ok_or(EvalError::TypeMismatch {
detail: "NaN in comparison".into(),
})?
}
(Value::Text(a), Value::Text(b)) => a.cmp(b),
(Value::Bool(a), Value::Bool(b)) => a.cmp(b),
(Value::Date(a), Value::Date(b)) => a.cmp(b),
(Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b),
(Value::Date(a), Value::Timestamp(b)) => (i64::from(*a) * 86_400_000_000).cmp(b),
(Value::Timestamp(a), Value::Date(b)) => a.cmp(&(i64::from(*b) * 86_400_000_000)),
(Value::Date(a), Value::Text(b)) => {
let bd = parse_date_literal(b).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {b:?} as DATE for comparison"),
})?;
a.cmp(&bd)
}
(Value::Text(a), Value::Date(b)) => {
let ad = parse_date_literal(a).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {a:?} as DATE for comparison"),
})?;
ad.cmp(b)
}
(Value::Timestamp(a), Value::Text(b)) => {
let bt = parse_timestamp_literal(b).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {b:?} as TIMESTAMP for comparison"),
})?;
a.cmp(&bt)
}
(Value::Text(a), Value::Timestamp(b)) => {
let at = parse_timestamp_literal(a).ok_or_else(|| EvalError::TypeMismatch {
detail: format!("cannot parse {a:?} as TIMESTAMP for comparison"),
})?;
at.cmp(b)
}
(a, b) => {
return Err(EvalError::TypeMismatch {
detail: format!(
"comparison between {:?} and {:?}",
a.data_type(),
b.data_type()
),
});
}
};
let result = match op {
BinOp::Eq => ord.is_eq(),
BinOp::NotEq => !ord.is_eq(),
BinOp::Lt => ord.is_lt(),
BinOp::LtEq => ord.is_le(),
BinOp::Gt => ord.is_gt(),
BinOp::GtEq => ord.is_ge(),
BinOp::And
| BinOp::Or
| BinOp::Add
| BinOp::Sub
| BinOp::Mul
| BinOp::Div
| BinOp::L2Distance
| BinOp::InnerProduct
| BinOp::CosineDistance
| BinOp::Concat
| BinOp::JsonGet
| BinOp::JsonGetText
| BinOp::JsonGetPath
| BinOp::JsonGetPathText
| BinOp::JsonContains
| BinOp::IsDistinctFrom
| BinOp::IsNotDistinctFrom => {
unreachable!("compare() only called with comparison ops")
}
};
Ok(Value::Bool(result))
}
fn and_3vl(l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::Bool(false), _) | (_, Value::Bool(false)) => Ok(Value::Bool(false)),
(Value::Bool(true), Value::Bool(true)) => Ok(Value::Bool(true)),
(Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
(a, b) => Err(EvalError::TypeMismatch {
detail: format!(
"AND on non-boolean: {:?} and {:?}",
a.data_type(),
b.data_type()
),
}),
}
}
fn or_3vl(l: Value, r: Value) -> Result<Value, EvalError> {
match (l, r) {
(Value::Bool(true), _) | (_, Value::Bool(true)) => Ok(Value::Bool(true)),
(Value::Bool(false), Value::Bool(false)) => Ok(Value::Bool(false)),
(Value::Null, _) | (_, Value::Null) => Ok(Value::Null),
(a, b) => Err(EvalError::TypeMismatch {
detail: format!(
"OR on non-boolean: {:?} and {:?}",
a.data_type(),
b.data_type()
),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use spg_storage::{ColumnSchema, Row};
fn col(name: &str, ty: DataType) -> ColumnSchema {
ColumnSchema::new(name, ty, true)
}
fn ctx<'a>(cols: &'a [ColumnSchema], alias: Option<&'a str>) -> EvalContext<'a> {
EvalContext::new(cols, alias)
}
fn lit(n: i64) -> Expr {
Expr::Literal(Literal::Integer(n))
}
fn null() -> Expr {
Expr::Literal(Literal::Null)
}
fn col_ref(name: &str) -> Expr {
Expr::Column(ColumnName {
qualifier: None,
name: name.into(),
})
}
#[test]
fn literal_evaluates_to_value() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
assert_eq!(eval_expr(&lit(42), &r, &c).unwrap(), Value::Int(42));
assert_eq!(
eval_expr(&Expr::Literal(Literal::Float(1.5)), &r, &c).unwrap(),
Value::Float(1.5)
);
assert_eq!(eval_expr(&null(), &r, &c).unwrap(), Value::Null);
}
#[test]
fn column_lookup_unqualified() {
let cs = vec![col("a", DataType::Int), col("b", DataType::Text)];
let r = Row::new(vec![Value::Int(7), Value::Text("hi".into())]);
let c = ctx(&cs, None);
assert_eq!(eval_expr(&col_ref("a"), &r, &c).unwrap(), Value::Int(7));
assert_eq!(
eval_expr(&col_ref("b"), &r, &c).unwrap(),
Value::Text("hi".into())
);
}
#[test]
fn column_not_found_errors() {
let cs = vec![col("a", DataType::Int)];
let r = Row::new(vec![Value::Int(0)]);
let c = ctx(&cs, None);
let err = eval_expr(&col_ref("ghost"), &r, &c).unwrap_err();
assert!(matches!(err, EvalError::ColumnNotFound { ref name } if name == "ghost"));
}
#[test]
fn qualified_column_matches_alias() {
let cs = vec![col("a", DataType::Int)];
let r = Row::new(vec![Value::Int(5)]);
let c = ctx(&cs, Some("u"));
let qualified = Expr::Column(ColumnName {
qualifier: Some("u".into()),
name: "a".into(),
});
assert_eq!(eval_expr(&qualified, &r, &c).unwrap(), Value::Int(5));
}
#[test]
fn qualified_column_unknown_alias_errors() {
let cs = vec![col("a", DataType::Int)];
let r = Row::new(vec![Value::Int(5)]);
let c = ctx(&cs, Some("u"));
let wrong = Expr::Column(ColumnName {
qualifier: Some("x".into()),
name: "a".into(),
});
assert!(matches!(
eval_expr(&wrong, &r, &c).unwrap_err(),
EvalError::UnknownQualifier { .. }
));
}
#[test]
fn arithmetic_with_widening() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let e = Expr::Binary {
lhs: alloc::boxed::Box::new(lit(2)),
op: BinOp::Add,
rhs: alloc::boxed::Box::new(Expr::Literal(Literal::Float(0.5))),
};
assert_eq!(eval_expr(&e, &r, &c).unwrap(), Value::Float(2.5));
}
#[test]
fn division_by_zero_errors() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let e = Expr::Binary {
lhs: alloc::boxed::Box::new(lit(1)),
op: BinOp::Div,
rhs: alloc::boxed::Box::new(lit(0)),
};
assert_eq!(
eval_expr(&e, &r, &c).unwrap_err(),
EvalError::DivisionByZero
);
}
#[test]
fn comparison_returns_bool() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let e = Expr::Binary {
lhs: alloc::boxed::Box::new(lit(1)),
op: BinOp::Lt,
rhs: alloc::boxed::Box::new(lit(2)),
};
assert_eq!(eval_expr(&e, &r, &c).unwrap(), Value::Bool(true));
}
#[test]
fn null_propagates_through_arithmetic() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let e = Expr::Binary {
lhs: alloc::boxed::Box::new(lit(1)),
op: BinOp::Add,
rhs: alloc::boxed::Box::new(null()),
};
assert_eq!(eval_expr(&e, &r, &c).unwrap(), Value::Null);
}
#[test]
fn and_three_valued_logic() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let tt = |a: bool, b_null: bool| Expr::Binary {
lhs: alloc::boxed::Box::new(Expr::Literal(Literal::Bool(a))),
op: BinOp::And,
rhs: alloc::boxed::Box::new(if b_null {
null()
} else {
Expr::Literal(Literal::Bool(true))
}),
};
assert_eq!(
eval_expr(&tt(false, true), &r, &c).unwrap(),
Value::Bool(false)
);
assert_eq!(eval_expr(&tt(true, true), &r, &c).unwrap(), Value::Null);
assert_eq!(
eval_expr(&tt(true, false), &r, &c).unwrap(),
Value::Bool(true)
);
}
#[test]
fn or_three_valued_logic() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let or_with_null = |a: bool| Expr::Binary {
lhs: alloc::boxed::Box::new(Expr::Literal(Literal::Bool(a))),
op: BinOp::Or,
rhs: alloc::boxed::Box::new(null()),
};
assert_eq!(
eval_expr(&or_with_null(true), &r, &c).unwrap(),
Value::Bool(true)
);
assert_eq!(
eval_expr(&or_with_null(false), &r, &c).unwrap(),
Value::Null
);
}
#[test]
fn not_on_null_is_null() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let e = Expr::Unary {
op: UnOp::Not,
expr: alloc::boxed::Box::new(null()),
};
assert_eq!(eval_expr(&e, &r, &c).unwrap(), Value::Null);
}
#[test]
fn text_comparison_lexicographic() {
let r = Row::new(vec![]);
let cs: [ColumnSchema; 0] = [];
let c = ctx(&cs, None);
let e = Expr::Binary {
lhs: alloc::boxed::Box::new(Expr::Literal(Literal::String("apple".into()))),
op: BinOp::Lt,
rhs: alloc::boxed::Box::new(Expr::Literal(Literal::String("banana".into()))),
};
assert_eq!(eval_expr(&e, &r, &c).unwrap(), Value::Bool(true));
}
#[test]
fn interval_format_basics() {
assert_eq!(format_interval(0, 0), "0");
assert_eq!(format_interval(0, 86_400_000_000), "1 day");
assert_eq!(format_interval(0, -86_400_000_000), "-1 days");
assert_eq!(format_interval(0, 3_600_000_000), "01:00:00");
assert_eq!(
format_interval(0, 86_400_000_000 + 9_000_000),
"1 day 00:00:09"
);
assert_eq!(format_interval(14, 0), "1 year 2 mons");
assert_eq!(format_interval(-1, 0), "-1 mons");
}
#[test]
fn interval_add_to_timestamp_micros_part() {
let ts = i64::from(days_from_civil(2024, 1, 1)) * 86_400_000_000;
let r = add_interval_to_micros(ts, 0, 3_600_000_000).unwrap();
let expected = ts + 3_600_000_000;
assert_eq!(r, expected);
}
#[test]
fn interval_clamp_month_end() {
let d = days_from_civil(2024, 1, 31);
let shifted = shift_date_by_months(d, 1).unwrap();
let (y, m, day) = civil_from_days(shifted);
assert_eq!((y, m, day), (2024, 2, 29));
let d = days_from_civil(2023, 1, 31);
let shifted = shift_date_by_months(d, 1).unwrap();
let (y, m, day) = civil_from_days(shifted);
assert_eq!((y, m, day), (2023, 2, 28));
let d = days_from_civil(2024, 3, 31);
let shifted = shift_date_by_months(d, -1).unwrap();
let (y, m, day) = civil_from_days(shifted);
assert_eq!((y, m, day), (2024, 2, 29));
}
#[test]
fn interval_date_plus_pure_days_stays_date() {
let d = days_from_civil(2024, 6, 1);
let lhs = Value::Date(d);
let rhs = Value::Interval {
months: 0,
micros: 7 * 86_400_000_000,
};
let v = apply_binary_interval(BinOp::Add, &lhs, &rhs)
.unwrap()
.unwrap();
let expected = days_from_civil(2024, 6, 8);
assert_eq!(v, Value::Date(expected));
}
#[test]
fn interval_date_plus_sub_day_lifts_to_timestamp() {
let d = days_from_civil(2024, 6, 1);
let lhs = Value::Date(d);
let rhs = Value::Interval {
months: 0,
micros: 3_600_000_000,
};
let v = apply_binary_interval(BinOp::Add, &lhs, &rhs)
.unwrap()
.unwrap();
let expected = i64::from(d) * 86_400_000_000 + 3_600_000_000;
assert_eq!(v, Value::Timestamp(expected));
}
}