use crate::dtype::DType;
use crate::dtype::PType;
use crate::dtype::decimal::DecimalDType;
impl PType {
pub fn least_supertype(self, other: PType) -> Option<PType> {
if self == other {
return Some(self);
}
if self.is_unsigned_int() && other.is_unsigned_int() {
return Some(self.max_unsigned_ptype(other));
}
if self.is_signed_int() && other.is_signed_int() {
return Some(self.max_signed_ptype(other));
}
if self.is_float() && other.is_float() {
return if self.byte_width() >= other.byte_width() {
Some(self)
} else {
Some(other)
};
}
if self.is_unsigned_int() && other.is_signed_int() {
return Self::unsigned_signed_supertype(self, other);
}
if self.is_signed_int() && other.is_unsigned_int() {
return Self::unsigned_signed_supertype(other, self);
}
let (int, float) = if self.is_float() {
(other, self)
} else {
(self, other)
};
Self::int_float_supertype(int, float)
}
fn unsigned_signed_supertype(unsigned: PType, signed: PType) -> Option<PType> {
use PType::*;
match unsigned.byte_width().max(signed.byte_width()) {
1 => Some(I16),
2 => Some(I32),
4 => Some(I64),
_ => None, }
}
fn int_float_supertype(int: PType, float: PType) -> Option<PType> {
use PType::*;
let min_float = match int.byte_width() {
1 => F16, 2 => F32, 4 => F64, _ => return None, };
if float.byte_width() >= min_float.byte_width() {
Some(float)
} else {
Some(min_float)
}
}
}
impl DType {
pub fn least_supertype(&self, other: &DType) -> Option<DType> {
if self.eq_ignore_nullability(other) {
return Some(self.with_nullability(self.nullability() | other.nullability()));
}
let union_null = self.nullability() | other.nullability();
if matches!(self, DType::Null) {
return Some(other.as_nullable());
}
if matches!(other, DType::Null) {
return Some(self.as_nullable());
}
if self.is_boolean() && other.is_numeric() {
return Some(other.with_nullability(union_null));
}
if other.is_boolean() && self.is_numeric() {
return Some(self.with_nullability(union_null));
}
if let (DType::Primitive(lhs_p, _), DType::Primitive(rhs_p, _)) = (self, other) {
return lhs_p
.least_supertype(*rhs_p)
.map(|p| DType::Primitive(p, union_null));
}
if let (DType::Decimal(lhs_d, _), DType::Decimal(rhs_d, _)) = (self, other) {
return decimal_least_supertype(*lhs_d, *rhs_d).map(|d| DType::Decimal(d, union_null));
}
if let (DType::Decimal(dec, _), DType::Primitive(p, _)) = (self, other)
&& p.is_int()
{
let int_dec = DecimalDType::new(integer_decimal_precision(*p), 0);
return decimal_least_supertype(*dec, int_dec).map(|d| DType::Decimal(d, union_null));
}
if let (DType::Primitive(p, _), DType::Decimal(dec, _)) = (self, other)
&& p.is_int()
{
let int_dec = DecimalDType::new(integer_decimal_precision(*p), 0);
return decimal_least_supertype(int_dec, *dec).map(|d| DType::Decimal(d, union_null));
}
if let DType::Extension(ext) = self {
return ext
.least_supertype(other)
.map(|dt| dt.with_nullability(union_null));
}
if let DType::Extension(ext) = other {
return ext
.least_supertype(self)
.map(|dt| dt.with_nullability(union_null));
}
None
}
pub fn least_supertype_of(types: &[DType]) -> Option<DType> {
types
.iter()
.skip(1)
.try_fold(types[0].clone(), |acc, t| acc.least_supertype(t))
}
pub fn can_coerce_from(&self, other: &DType) -> bool {
if self.eq_ignore_nullability(other) {
return self.is_nullable() || !other.is_nullable();
}
if matches!(other, DType::Null) {
return self.is_nullable();
}
if other.is_boolean() && self.is_numeric() {
return self.is_nullable() || !other.is_nullable();
}
if let (DType::Primitive(..), DType::Primitive(..)) = (self, other) {
return other
.least_supertype(self)
.is_some_and(|st| st.eq_ignore_nullability(self))
&& (self.is_nullable() || !other.is_nullable());
}
if let (DType::Decimal(target, _), DType::Decimal(source, _)) = (self, other) {
let target_integral = target.precision() as i16 - target.scale() as i16;
let source_integral = source.precision() as i16 - source.scale() as i16;
return target_integral >= source_integral
&& target.scale() >= source.scale()
&& (self.is_nullable() || !other.is_nullable());
}
if let (DType::Decimal(dec, _), DType::Primitive(p, _)) = (self, other)
&& p.is_int()
{
let needed = integer_decimal_precision(*p);
let integral_digits = dec.precision() as i16 - dec.scale() as i16;
return integral_digits >= needed as i16
&& (self.is_nullable() || !other.is_nullable());
}
if let DType::Extension(ext) = self {
return ext.can_coerce_from(other);
}
false
}
pub fn can_coerce_to(&self, other: &DType) -> bool {
other.can_coerce_from(self)
}
pub fn are_coercible(types: &[DType]) -> bool {
DType::least_supertype_of(types).is_some()
}
pub fn all_coercible_to(types: &[DType], target: &DType) -> bool {
types.iter().all(|t| target.can_coerce_from(t))
}
pub fn coerce_all_to(types: &[DType], target: &DType) -> Option<Vec<DType>> {
types
.iter()
.all(|t| target.can_coerce_from(t))
.then(|| vec![target.clone(); types.len()])
}
pub fn coerce_to_supertype(types: &[DType]) -> Option<Vec<DType>> {
let supertype = DType::least_supertype_of(types)?;
Some(vec![supertype; types.len()])
}
pub fn is_numeric(&self) -> bool {
matches!(self, DType::Primitive(..) | DType::Decimal(..))
}
pub fn is_temporal(&self) -> bool {
match self {
DType::Extension(ext) => {
use crate::dtype::extension::Matcher;
use crate::extension::datetime::AnyTemporal;
AnyTemporal::matches(ext)
}
_ => false,
}
}
}
fn integer_decimal_precision(ptype: PType) -> u8 {
match ptype {
PType::U8 | PType::I8 => 3,
PType::U16 | PType::I16 => 5,
PType::U32 | PType::I32 => 10,
PType::U64 | PType::I64 => 19,
_ => 19,
}
}
fn decimal_least_supertype(a: DecimalDType, b: DecimalDType) -> Option<DecimalDType> {
let a_integral = a.precision() as i16 - a.scale() as i16;
let b_integral = b.precision() as i16 - b.scale() as i16;
let max_integral = a_integral.max(b_integral);
let max_scale = a.scale().max(b.scale());
let precision = u8::try_from(max_integral + max_scale as i16).ok()?;
DecimalDType::try_new(precision, max_scale).ok()
}
#[cfg(test)]
mod tests {
use crate::dtype::DType;
use crate::dtype::PType;
use crate::dtype::decimal::DecimalDType;
use crate::dtype::nullability::Nullability::NonNullable;
use crate::dtype::nullability::Nullability::Nullable;
#[test]
fn is_numeric() {
assert!(DType::Primitive(PType::I32, NonNullable).is_numeric());
assert!(DType::Primitive(PType::F64, NonNullable).is_numeric());
assert!(DType::Decimal(DecimalDType::new(10, 2), NonNullable).is_numeric());
assert!(!DType::Bool(NonNullable).is_numeric());
assert!(!DType::Utf8(NonNullable).is_numeric());
assert!(!DType::Null.is_numeric());
}
#[test]
fn least_supertype_identity() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
assert_eq!(i32_nn.least_supertype(&i32_nn).unwrap(), i32_nn);
}
#[test]
fn least_supertype_nullability_union() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
let i32_n = DType::Primitive(PType::I32, Nullable);
assert_eq!(i32_nn.least_supertype(&i32_n).unwrap(), i32_n);
assert_eq!(i32_n.least_supertype(&i32_nn).unwrap(), i32_n);
}
#[test]
fn least_supertype_null_absorption() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
assert_eq!(
DType::Null.least_supertype(&i32_nn).unwrap(),
DType::Primitive(PType::I32, Nullable)
);
assert_eq!(
i32_nn.least_supertype(&DType::Null).unwrap(),
DType::Primitive(PType::I32, Nullable)
);
}
#[test]
fn least_supertype_unsigned_widening() {
let u8_nn = DType::Primitive(PType::U8, NonNullable);
let u32_nn = DType::Primitive(PType::U32, NonNullable);
assert_eq!(u8_nn.least_supertype(&u32_nn).unwrap(), u32_nn);
}
#[test]
fn least_supertype_signed_widening() {
let i16_nn = DType::Primitive(PType::I16, NonNullable);
let i64_nn = DType::Primitive(PType::I64, NonNullable);
assert_eq!(i16_nn.least_supertype(&i64_nn).unwrap(), i64_nn);
}
#[test]
fn least_supertype_cross_family() {
let u8_nn = DType::Primitive(PType::U8, NonNullable);
let i8_nn = DType::Primitive(PType::I8, NonNullable);
assert_eq!(
u8_nn.least_supertype(&i8_nn).unwrap(),
DType::Primitive(PType::I16, NonNullable)
);
}
#[test]
fn least_supertype_u64_i64_none() {
let u64_nn = DType::Primitive(PType::U64, NonNullable);
let i64_nn = DType::Primitive(PType::I64, NonNullable);
assert!(u64_nn.least_supertype(&i64_nn).is_none());
}
#[test]
fn least_supertype_int_float_promotion() {
let u8_nn = DType::Primitive(PType::U8, NonNullable);
let f32_nn = DType::Primitive(PType::F32, NonNullable);
assert_eq!(u8_nn.least_supertype(&f32_nn).unwrap(), f32_nn);
}
#[test]
fn least_supertype_i32_f32_to_f64() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
let f32_nn = DType::Primitive(PType::F32, NonNullable);
assert_eq!(
i32_nn.least_supertype(&f32_nn).unwrap(),
DType::Primitive(PType::F64, NonNullable)
);
}
#[test]
fn least_supertype_bool_numeric() {
let bool_nn = DType::Bool(NonNullable);
let i32_nn = DType::Primitive(PType::I32, NonNullable);
assert_eq!(bool_nn.least_supertype(&i32_nn).unwrap(), i32_nn);
assert_eq!(i32_nn.least_supertype(&bool_nn).unwrap(), i32_nn);
}
#[test]
fn least_supertype_decimal_widening() {
let d1 = DType::Decimal(DecimalDType::new(10, 2), NonNullable);
let d2 = DType::Decimal(DecimalDType::new(15, 5), NonNullable);
let result = d1.least_supertype(&d2).unwrap();
assert_eq!(
result,
DType::Decimal(DecimalDType::new(15, 5), NonNullable)
);
}
#[test]
fn least_supertype_incompatible_none() {
let utf8 = DType::Utf8(NonNullable);
let i32_nn = DType::Primitive(PType::I32, NonNullable);
assert!(utf8.least_supertype(&i32_nn).is_none());
}
#[test]
fn can_coerce_from_widening() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
let i64_nn = DType::Primitive(PType::I64, NonNullable);
assert!(i64_nn.can_coerce_from(&i32_nn));
}
#[test]
fn can_coerce_from_narrowing_rejected() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
let i64_nn = DType::Primitive(PType::I64, NonNullable);
assert!(!i32_nn.can_coerce_from(&i64_nn));
}
#[test]
fn can_coerce_from_nullability_constraints() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
let i32_n = DType::Primitive(PType::I32, Nullable);
assert!(i32_n.can_coerce_from(&i32_nn));
assert!(!i32_nn.can_coerce_from(&i32_n));
}
#[test]
fn can_coerce_from_null() {
let i32_n = DType::Primitive(PType::I32, Nullable);
let i32_nn = DType::Primitive(PType::I32, NonNullable);
assert!(i32_n.can_coerce_from(&DType::Null));
assert!(!i32_nn.can_coerce_from(&DType::Null));
}
#[test]
fn are_coercible_mixed() {
let types = [
DType::Primitive(PType::I32, NonNullable),
DType::Primitive(PType::I64, NonNullable),
];
assert!(DType::are_coercible(&types));
}
#[test]
fn all_coercible_to_target() {
let types = [
DType::Primitive(PType::I32, NonNullable),
DType::Primitive(PType::I16, NonNullable),
];
let target = DType::Primitive(PType::I64, NonNullable);
assert!(DType::all_coercible_to(&types, &target));
}
#[test]
fn coerce_to_supertype_works() {
let types = [
DType::Primitive(PType::U8, NonNullable),
DType::Primitive(PType::I16, NonNullable),
];
let result = DType::coerce_to_supertype(&types).unwrap();
assert_eq!(result, vec![DType::Primitive(PType::I32, NonNullable); 2]);
}
#[test]
fn least_supertype_integer_decimal() {
let i32_nn = DType::Primitive(PType::I32, NonNullable);
let dec = DType::Decimal(DecimalDType::new(15, 5), NonNullable);
let result = i32_nn.least_supertype(&dec).unwrap();
assert_eq!(
result,
DType::Decimal(DecimalDType::new(15, 5), NonNullable)
);
}
}