use super::cast_catalog::can_implicit_cast;
use super::types::{DataType, TypeCategory};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PseudoType {
AnyElement,
AnyArray,
AnyNonArray,
AnyCompatible,
}
#[derive(Debug, Clone, Copy)]
pub enum ArgSlot {
Concrete(DataType),
Poly(PseudoType),
}
#[derive(Debug, Clone, Default)]
pub struct Substitution {
pub any_element: Option<DataType>,
pub any_array: Option<DataType>,
pub any_nonarray: Option<DataType>,
pub any_compatible: Option<DataType>,
}
impl Substitution {
pub fn apply(&self, slot: ArgSlot) -> Option<DataType> {
match slot {
ArgSlot::Concrete(dt) => Some(dt),
ArgSlot::Poly(PseudoType::AnyElement) => self.any_element,
ArgSlot::Poly(PseudoType::AnyArray) => self.any_array,
ArgSlot::Poly(PseudoType::AnyNonArray) => self.any_nonarray,
ArgSlot::Poly(PseudoType::AnyCompatible) => self.any_compatible,
}
}
}
#[derive(Debug, Clone)]
pub enum ResolveError {
Conflict {
pseudo: PseudoType,
first: DataType,
other: DataType,
},
NonArrayGotArray,
ArrayGotScalar,
ArityMismatch { expected: usize, got: usize },
}
impl std::fmt::Display for ResolveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Conflict {
pseudo,
first,
other,
} => {
write!(
f,
"polymorphic `{pseudo:?}` bound to `{first:?}` but later seen as `{other:?}`"
)
}
Self::NonArrayGotArray => write!(f, "AnyNonArray position got an array argument"),
Self::ArrayGotScalar => write!(f, "AnyArray position got a non-array argument"),
Self::ArityMismatch { expected, got } => {
write!(
f,
"polymorphic signature expects {expected} args, got {got}"
)
}
}
}
}
impl std::error::Error for ResolveError {}
pub fn resolve(
signature: &[ArgSlot],
call_args: &[DataType],
) -> Result<Substitution, ResolveError> {
if signature.len() != call_args.len() {
return Err(ResolveError::ArityMismatch {
expected: signature.len(),
got: call_args.len(),
});
}
let mut sub = Substitution::default();
for (slot, &arg_ty) in signature.iter().zip(call_args.iter()) {
match slot {
ArgSlot::Concrete(expected) => {
if *expected != arg_ty && !can_implicit_cast(arg_ty, *expected) {
return Err(ResolveError::Conflict {
pseudo: PseudoType::AnyElement, first: *expected,
other: arg_ty,
});
}
}
ArgSlot::Poly(PseudoType::AnyElement) => {
bind(&mut sub.any_element, arg_ty, PseudoType::AnyElement)?;
}
ArgSlot::Poly(PseudoType::AnyArray) => {
if arg_ty.category() != TypeCategory::Array {
return Err(ResolveError::ArrayGotScalar);
}
bind(&mut sub.any_array, arg_ty, PseudoType::AnyArray)?;
}
ArgSlot::Poly(PseudoType::AnyNonArray) => {
if arg_ty.category() == TypeCategory::Array {
return Err(ResolveError::NonArrayGotArray);
}
bind(&mut sub.any_nonarray, arg_ty, PseudoType::AnyNonArray)?;
}
ArgSlot::Poly(PseudoType::AnyCompatible) => {
match sub.any_compatible {
None => sub.any_compatible = Some(arg_ty),
Some(prev) if prev == arg_ty => {}
Some(prev) => {
if can_implicit_cast(arg_ty, prev) {
} else if can_implicit_cast(prev, arg_ty) {
sub.any_compatible = Some(arg_ty);
} else {
return Err(ResolveError::Conflict {
pseudo: PseudoType::AnyCompatible,
first: prev,
other: arg_ty,
});
}
}
}
}
}
}
Ok(sub)
}
fn bind(
slot: &mut Option<DataType>,
arg: DataType,
pseudo: PseudoType,
) -> Result<(), ResolveError> {
match *slot {
None => {
*slot = Some(arg);
Ok(())
}
Some(prev) if prev == arg => Ok(()),
Some(prev) => Err(ResolveError::Conflict {
pseudo,
first: prev,
other: arg,
}),
}
}