use crate::cast_catalog::can_implicit_cast;
use crate::types::{DataType, TypeCategory};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PseudoType {
AnyElement,
AnyArray,
AnyNonArray,
AnyCompatible,
}
#[derive(Debug, Clone, Copy)]
pub enum ArgSlot {
Concrete(DataType),
Poly(PseudoType),
}
#[derive(Debug, Clone, Default)]
pub struct Substitution {
pub any_element: Option<DataType>,
pub any_array: Option<DataType>,
pub any_nonarray: Option<DataType>,
pub any_compatible: Option<DataType>,
}
impl Substitution {
pub fn apply(&self, slot: ArgSlot) -> Option<DataType> {
match slot {
ArgSlot::Concrete(dt) => Some(dt),
ArgSlot::Poly(PseudoType::AnyElement) => self.any_element,
ArgSlot::Poly(PseudoType::AnyArray) => self.any_array,
ArgSlot::Poly(PseudoType::AnyNonArray) => self.any_nonarray,
ArgSlot::Poly(PseudoType::AnyCompatible) => self.any_compatible,
}
}
}
#[derive(Debug, Clone)]
pub enum ResolveError {
Conflict {
pseudo: PseudoType,
first: DataType,
other: DataType,
},
NonArrayGotArray,
ArrayGotScalar,
ArityMismatch { expected: usize, got: usize },
}
impl std::fmt::Display for ResolveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Conflict {
pseudo,
first,
other,
} => {
write!(
f,
"polymorphic `{pseudo:?}` bound to `{first:?}` but later seen as `{other:?}`"
)
}
Self::NonArrayGotArray => write!(f, "AnyNonArray position got an array argument"),
Self::ArrayGotScalar => write!(f, "AnyArray position got a non-array argument"),
Self::ArityMismatch { expected, got } => {
write!(
f,
"polymorphic signature expects {expected} args, got {got}"
)
}
}
}
}
impl std::error::Error for ResolveError {}
pub fn resolve(
signature: &[ArgSlot],
call_args: &[DataType],
) -> Result<Substitution, ResolveError> {
if signature.len() != call_args.len() {
return Err(ResolveError::ArityMismatch {
expected: signature.len(),
got: call_args.len(),
});
}
let mut sub = Substitution::default();
for (slot, &arg_ty) in signature.iter().zip(call_args.iter()) {
match slot {
ArgSlot::Concrete(expected) => {
if *expected != arg_ty && !can_implicit_cast(arg_ty, *expected) {
return Err(ResolveError::Conflict {
pseudo: PseudoType::AnyElement, first: *expected,
other: arg_ty,
});
}
}
ArgSlot::Poly(PseudoType::AnyElement) => {
bind(&mut sub.any_element, arg_ty, PseudoType::AnyElement)?;
}
ArgSlot::Poly(PseudoType::AnyArray) => {
if arg_ty.category() != TypeCategory::Array {
return Err(ResolveError::ArrayGotScalar);
}
bind(&mut sub.any_array, arg_ty, PseudoType::AnyArray)?;
}
ArgSlot::Poly(PseudoType::AnyNonArray) => {
if arg_ty.category() == TypeCategory::Array {
return Err(ResolveError::NonArrayGotArray);
}
bind(&mut sub.any_nonarray, arg_ty, PseudoType::AnyNonArray)?;
}
ArgSlot::Poly(PseudoType::AnyCompatible) => {
match sub.any_compatible {
None => sub.any_compatible = Some(arg_ty),
Some(prev) if prev == arg_ty => {}
Some(prev) => {
if can_implicit_cast(arg_ty, prev) {
} else if can_implicit_cast(prev, arg_ty) {
sub.any_compatible = Some(arg_ty);
} else {
return Err(ResolveError::Conflict {
pseudo: PseudoType::AnyCompatible,
first: prev,
other: arg_ty,
});
}
}
}
}
}
}
Ok(sub)
}
fn bind(
slot: &mut Option<DataType>,
arg: DataType,
pseudo: PseudoType,
) -> Result<(), ResolveError> {
match *slot {
None => {
*slot = Some(arg);
Ok(())
}
Some(prev) if prev == arg => Ok(()),
Some(prev) => Err(ResolveError::Conflict {
pseudo,
first: prev,
other: arg,
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn substitution_apply_returns_bound_or_concrete_types() {
let sub = Substitution {
any_element: Some(DataType::Integer),
any_array: Some(DataType::Array),
any_nonarray: Some(DataType::Text),
any_compatible: Some(DataType::Float),
};
assert_eq!(
sub.apply(ArgSlot::Concrete(DataType::Boolean)),
Some(DataType::Boolean)
);
assert_eq!(
sub.apply(ArgSlot::Poly(PseudoType::AnyElement)),
Some(DataType::Integer)
);
assert_eq!(
sub.apply(ArgSlot::Poly(PseudoType::AnyArray)),
Some(DataType::Array)
);
assert_eq!(
sub.apply(ArgSlot::Poly(PseudoType::AnyNonArray)),
Some(DataType::Text)
);
assert_eq!(
sub.apply(ArgSlot::Poly(PseudoType::AnyCompatible)),
Some(DataType::Float)
);
assert_eq!(
Substitution::default().apply(ArgSlot::Poly(PseudoType::AnyElement)),
None
);
}
#[test]
fn resolve_accepts_concrete_and_poly_slots() {
let sub = resolve(
&[
ArgSlot::Concrete(DataType::Float),
ArgSlot::Poly(PseudoType::AnyElement),
ArgSlot::Poly(PseudoType::AnyArray),
ArgSlot::Poly(PseudoType::AnyNonArray),
],
&[
DataType::Integer,
DataType::Text,
DataType::Array,
DataType::Boolean,
],
)
.unwrap();
assert_eq!(sub.any_element, Some(DataType::Text));
assert_eq!(sub.any_array, Some(DataType::Array));
assert_eq!(sub.any_nonarray, Some(DataType::Boolean));
}
#[test]
fn resolve_reports_arity_and_kind_errors() {
assert!(matches!(
resolve(&[ArgSlot::Poly(PseudoType::AnyElement)], &[]),
Err(ResolveError::ArityMismatch {
expected: 1,
got: 0
})
));
assert!(matches!(
resolve(&[ArgSlot::Poly(PseudoType::AnyArray)], &[DataType::Text]),
Err(ResolveError::ArrayGotScalar)
));
assert!(matches!(
resolve(
&[ArgSlot::Poly(PseudoType::AnyNonArray)],
&[DataType::Array]
),
Err(ResolveError::NonArrayGotArray)
));
assert!(matches!(
resolve(&[ArgSlot::Concrete(DataType::Boolean)], &[DataType::Text]),
Err(ResolveError::Conflict { .. })
));
}
#[test]
fn repeated_pseudo_slots_must_be_consistent() {
let ok = resolve(
&[
ArgSlot::Poly(PseudoType::AnyElement),
ArgSlot::Poly(PseudoType::AnyElement),
],
&[DataType::Integer, DataType::Integer],
)
.unwrap();
assert_eq!(ok.any_element, Some(DataType::Integer));
let err = resolve(
&[
ArgSlot::Poly(PseudoType::AnyElement),
ArgSlot::Poly(PseudoType::AnyElement),
],
&[DataType::Integer, DataType::Text],
)
.unwrap_err();
assert!(matches!(
err,
ResolveError::Conflict {
pseudo: PseudoType::AnyElement,
first: DataType::Integer,
other: DataType::Text,
}
));
assert!(err.to_string().contains("AnyElement"));
}
#[test]
fn anycompatible_uses_cast_catalog_to_resolve_binding() {
let int_then_float = resolve(
&[
ArgSlot::Poly(PseudoType::AnyCompatible),
ArgSlot::Poly(PseudoType::AnyCompatible),
],
&[DataType::Integer, DataType::Float],
)
.unwrap();
assert_eq!(int_then_float.any_compatible, Some(DataType::Integer));
let float_then_int = resolve(
&[
ArgSlot::Poly(PseudoType::AnyCompatible),
ArgSlot::Poly(PseudoType::AnyCompatible),
],
&[DataType::Float, DataType::Integer],
)
.unwrap();
assert_eq!(float_then_int.any_compatible, Some(DataType::Float));
assert!(matches!(
resolve(
&[
ArgSlot::Poly(PseudoType::AnyCompatible),
ArgSlot::Poly(PseudoType::AnyCompatible),
],
&[DataType::Boolean, DataType::Json],
),
Err(ResolveError::Conflict {
pseudo: PseudoType::AnyCompatible,
..
})
));
}
}