use super::cast_catalog::{find_cast, CastContext, CastEntry};
use super::function_catalog::{FunctionEntry, FUNCTION_CATALOG};
use super::operator_catalog::{OperatorEntry, OperatorKind, OPERATOR_CATALOG};
use super::types::DataType;
use crate::storage::query::ast::BinOp;
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct OperandCoercions {
pub casts: Vec<Option<DataType>>,
}
impl OperandCoercions {
pub fn identity(arity: usize) -> Self {
Self {
casts: vec![None; arity],
}
}
pub fn is_identity(&self) -> bool {
self.casts.iter().all(Option::is_none)
}
pub fn at(&self, idx: usize) -> Option<DataType> {
self.casts.get(idx).copied().flatten()
}
}
pub trait CoercionSpine {
fn resolve_cast(&self, from: DataType, to: DataType) -> Option<&'static CastEntry>;
fn resolve_binop(
&self,
op: BinOp,
lhs: DataType,
rhs: DataType,
) -> Option<(&'static OperatorEntry, OperandCoercions)>;
fn resolve_function(
&self,
name: &str,
args: &[DataType],
) -> Option<(&'static FunctionEntry, OperandCoercions)>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct BuiltinSpine;
impl CoercionSpine for BuiltinSpine {
fn resolve_cast(&self, from: DataType, to: DataType) -> Option<&'static CastEntry> {
if from == to {
return None;
}
super::cast_catalog::CAST_CATALOG
.iter()
.find(|e| e.src == from && e.target == to && e.context.allows(CastContext::Implicit))
}
fn resolve_binop(
&self,
op: BinOp,
lhs: DataType,
rhs: DataType,
) -> Option<(&'static OperatorEntry, OperandCoercions)> {
let symbol = binop_symbol(op);
let kind = OperatorKind::Infix;
let exact = OPERATOR_CATALOG
.iter()
.filter(|e| e.name == symbol && e.kind == kind)
.find(|e| e.lhs_type == lhs && e.rhs_type == rhs);
if let Some(entry) = exact {
return Some((entry, OperandCoercions::identity(2)));
}
let mut best: Option<(usize, &'static OperatorEntry, OperandCoercions)> = None;
for entry in OPERATOR_CATALOG
.iter()
.filter(|e| e.name == symbol && e.kind == kind)
{
let lhs_ok = entry.lhs_type == lhs
|| find_cast(lhs, entry.lhs_type, CastContext::Implicit).is_some();
let rhs_ok = entry.rhs_type == rhs
|| find_cast(rhs, entry.rhs_type, CastContext::Implicit).is_some();
if !lhs_ok || !rhs_ok {
continue;
}
let lhs_exact = (entry.lhs_type == lhs) as usize;
let rhs_exact = (entry.rhs_type == rhs) as usize;
let score = lhs_exact + rhs_exact;
let coercions = OperandCoercions {
casts: vec![
if lhs_exact == 1 {
None
} else {
Some(entry.lhs_type)
},
if rhs_exact == 1 {
None
} else {
Some(entry.rhs_type)
},
],
};
match best {
None => best = Some((score, entry, coercions)),
Some((prev_score, prev_entry, _)) => {
if score > prev_score
|| (score == prev_score
&& entry.return_type.is_preferred()
&& !prev_entry.return_type.is_preferred())
{
best = Some((score, entry, coercions));
}
}
}
}
best.map(|(_, e, c)| (e, c))
}
fn resolve_function(
&self,
name: &str,
args: &[DataType],
) -> Option<(&'static FunctionEntry, OperandCoercions)> {
let mut best: Option<(usize, &'static FunctionEntry, OperandCoercions)> = None;
for entry in FUNCTION_CATALOG
.iter()
.filter(|e| e.name.eq_ignore_ascii_case(name))
{
if !entry.variadic && entry.arg_types.len() != args.len() {
continue;
}
if entry.variadic && args.is_empty() {
continue;
}
let (compatible, coercions, score) = if entry.variadic {
if entry.name.eq_ignore_ascii_case("CONCAT")
|| entry.name.eq_ignore_ascii_case("CONCAT_WS")
{
(true, OperandCoercions::identity(args.len()), args.len())
} else {
let target = entry.arg_types[0];
let mut casts = Vec::with_capacity(args.len());
let mut ok = true;
let mut exact = 0usize;
for arg in args {
if *arg == target {
casts.push(None);
exact += 1;
} else if find_cast(*arg, target, CastContext::Implicit).is_some() {
casts.push(Some(target));
} else {
ok = false;
break;
}
}
(ok, OperandCoercions { casts }, exact)
}
} else {
let mut casts = Vec::with_capacity(args.len());
let mut ok = true;
let mut exact = 0usize;
for (target, arg) in entry.arg_types.iter().zip(args.iter()) {
if *target == *arg {
casts.push(None);
exact += 1;
} else if find_cast(*arg, *target, CastContext::Implicit).is_some() {
casts.push(Some(*target));
} else {
ok = false;
break;
}
}
(ok, OperandCoercions { casts }, exact)
};
if !compatible {
continue;
}
match best {
None => best = Some((score, entry, coercions)),
Some((prev_score, prev_entry, _)) => {
if score > prev_score
|| (score == prev_score
&& entry.return_type.is_preferred()
&& !prev_entry.return_type.is_preferred())
{
best = Some((score, entry, coercions));
}
}
}
}
best.map(|(_, e, c)| (e, c))
}
}
fn binop_symbol(op: BinOp) -> &'static str {
match op {
BinOp::Add => "+",
BinOp::Sub => "-",
BinOp::Mul => "*",
BinOp::Div => "/",
BinOp::Mod => "%",
BinOp::Concat => "||",
BinOp::Eq => "=",
BinOp::Ne => "<>",
BinOp::Lt => "<",
BinOp::Le => "<=",
BinOp::Gt => ">",
BinOp::Ge => ">=",
BinOp::And => "AND",
BinOp::Or => "OR",
}
}
pub fn resolve_cast(from: DataType, to: DataType) -> Option<&'static CastEntry> {
BuiltinSpine.resolve_cast(from, to)
}
pub fn resolve_binop(
op: BinOp,
lhs: DataType,
rhs: DataType,
) -> Option<(&'static OperatorEntry, OperandCoercions)> {
BuiltinSpine.resolve_binop(op, lhs, rhs)
}
pub fn resolve_function(
name: &str,
args: &[DataType],
) -> Option<(&'static FunctionEntry, OperandCoercions)> {
BuiltinSpine.resolve_function(name, args)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn binop_exact_match_emits_identity_coercions() {
let (entry, coercions) = resolve_binop(BinOp::Add, DataType::Integer, DataType::Integer)
.expect("int + int must resolve");
assert_eq!(entry.name, "+");
assert_eq!(entry.lhs_type, DataType::Integer);
assert_eq!(entry.rhs_type, DataType::Integer);
assert_eq!(entry.return_type, DataType::Integer);
assert!(coercions.is_identity());
}
#[test]
fn binop_int_plus_float_resolves_exact() {
let (entry, coercions) = resolve_binop(BinOp::Add, DataType::Integer, DataType::Float)
.expect("int + float must resolve");
assert_eq!(entry.lhs_type, DataType::Integer);
assert_eq!(entry.rhs_type, DataType::Float);
assert_eq!(entry.return_type, DataType::Float);
assert!(coercions.is_identity());
}
#[test]
fn binop_int_plus_bigint_widens_to_preferred_float() {
let (entry, coercions) = resolve_binop(BinOp::Add, DataType::Integer, DataType::BigInt)
.expect("int + bigint must resolve via widening");
assert_eq!(entry.return_type, DataType::Float);
assert_eq!(coercions.at(0), None);
assert_eq!(coercions.at(1), Some(DataType::Float));
}
#[test]
fn function_exact_match_emits_identity() {
let (entry, coercions) =
resolve_function("LENGTH", &[DataType::Text]).expect("LENGTH(text) must resolve");
assert_eq!(entry.name, "LENGTH");
assert!(coercions.is_identity());
}
#[test]
fn function_int_to_text_widening_resolves_with_explicit_cast() {
let (entry, coercions) = resolve_function("LENGTH", &[DataType::Integer])
.expect("LENGTH(int) currently resolves via Integer->Text widening");
assert_eq!(entry.arg_types, &[DataType::Text]);
assert_eq!(coercions.at(0), Some(DataType::Text));
}
#[test]
fn function_picks_exact_overload_over_cast_overload() {
let (entry, coercions) =
resolve_function("ABS", &[DataType::Integer]).expect("ABS(int) must resolve");
assert_eq!(entry.return_type, DataType::Integer);
assert!(coercions.is_identity());
}
#[test]
fn cast_int_to_float_is_implicit() {
let entry = resolve_cast(DataType::Integer, DataType::Float)
.expect("int -> float must be implicit");
assert_eq!(entry.src, DataType::Integer);
assert_eq!(entry.target, DataType::Float);
assert!(!entry.lossy);
}
#[test]
fn cast_float_to_int_currently_resolves_via_assignment_entry() {
let entry = resolve_cast(DataType::Float, DataType::Integer)
.expect("Float -> Integer resolves under current allows() rule");
assert!(entry.lossy);
}
#[test]
fn numeric_promotion_ladder_all_implicit_edges() {
let pairs = [
(DataType::Integer, DataType::BigInt),
(DataType::Integer, DataType::Float),
(DataType::Integer, DataType::Decimal),
(DataType::BigInt, DataType::Float),
(DataType::UnsignedInteger, DataType::Integer),
(DataType::UnsignedInteger, DataType::Float),
];
for (src, tgt) in pairs {
let entry = resolve_cast(src, tgt)
.unwrap_or_else(|| panic!("{:?} → {:?} must be implicit", src, tgt));
assert!(!entry.lossy, "{:?} → {:?} should be lossless", src, tgt);
}
}
#[test]
fn integer_to_text_implicit_cast_rejected() {
assert!(
resolve_cast(DataType::Integer, DataType::Text).is_none(),
"Integer→Text must not be implicit; it is Explicit-only"
);
}
#[test]
fn text_to_integer_cast_rejected_by_spine() {
assert!(
resolve_cast(DataType::Text, DataType::Integer).is_none(),
"Text→Integer has no catalog entry"
);
}
#[test]
fn operator_with_unknown_null_type_returns_none() {
assert!(
resolve_binop(BinOp::Add, DataType::Unknown, DataType::Integer).is_none(),
"Unknown+Integer must not resolve"
);
assert!(
resolve_binop(BinOp::Eq, DataType::Integer, DataType::Unknown).is_none(),
"Integer=Unknown must not resolve"
);
}
#[test]
fn text_arithmetic_not_resolvable() {
assert!(
resolve_binop(BinOp::Add, DataType::Text, DataType::Text).is_none(),
"Text+Text must not resolve"
);
assert!(
resolve_binop(BinOp::Add, DataType::Text, DataType::Integer).is_none(),
"Text+Integer must not resolve"
);
}
#[test]
fn function_overload_selects_exact_over_coercion() {
let (int_entry, int_coercions) =
resolve_function("ABS", &[DataType::Integer]).expect("ABS(int) must resolve");
assert_eq!(int_entry.return_type, DataType::Integer);
assert!(int_coercions.is_identity());
let (float_entry, float_coercions) =
resolve_function("ABS", &[DataType::Float]).expect("ABS(float) must resolve");
assert_eq!(float_entry.return_type, DataType::Float);
assert!(float_coercions.is_identity());
assert_ne!(int_entry.return_type, float_entry.return_type);
}
}