use crate::dtype::DType;
#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
#[inline]
pub fn promote(a: DType, b: DType) -> DType {
if a == b {
return a;
}
match (a, b) {
(DType::Str(x), DType::Str(y)) => return DType::Str(x.max(y)),
(DType::Bytes(x), DType::Bytes(y)) => return DType::Bytes(x.max(y)),
_ => {}
}
if matches!(a, DType::Str(_) | DType::Bytes(_) | DType::Object)
|| matches!(b, DType::Str(_) | DType::Bytes(_) | DType::Object)
{
return DType::Object;
}
if a.is_complex() || b.is_complex() {
let fw = float_width(a).max(float_width(b)).max(complex_real_width(a)).max(complex_real_width(b));
return match fw {
2 | 4 => DType::C64, _ => DType::C128,
};
}
if a.is_float() || b.is_float() {
let lhs_w = effective_float_width(a);
let rhs_w = effective_float_width(b);
let w = lhs_w.max(rhs_w);
return match w {
2 => DType::F16,
4 => DType::F32,
_ => DType::F64,
};
}
let (a_sign, a_bits) = int_class(a);
let (b_sign, b_bits) = int_class(b);
if a_sign == b_sign {
let bits = a_bits.max(b_bits);
return int_dtype(a_sign, bits);
}
let (signed_bits, unsigned_bits) = if a_sign {
(a_bits, b_bits)
} else {
(b_bits, a_bits)
};
let needed = signed_bits.max(unsigned_bits + 1).min(64);
if unsigned_bits >= 64 {
return DType::F64;
}
int_dtype(true, needed)
}
#[inline]
pub fn promote_many(types: &[DType]) -> DType {
types
.iter()
.copied()
.reduce(promote)
.unwrap_or(DType::F64)
}
#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
#[inline]
fn int_class(d: DType) -> (bool, u32) {
match d {
DType::Bool => (false, 1),
DType::I8 => (true, 8),
DType::I16 => (true, 16),
DType::I32 => (true, 32),
DType::I64 => (true, 64),
DType::U8 => (false, 8),
DType::U16 => (false, 16),
DType::U32 => (false, 32),
DType::U64 => (false, 64),
_ => (false, 0),
}
}
#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
#[inline]
fn int_dtype(signed: bool, bits: u32) -> DType {
let bits = bits.max(8);
if signed {
match bits {
b if b <= 8 => DType::I8,
b if b <= 16 => DType::I16,
b if b <= 32 => DType::I32,
_ => DType::I64,
}
} else {
match bits {
b if b <= 8 => DType::U8,
b if b <= 16 => DType::U16,
b if b <= 32 => DType::U32,
_ => DType::U64,
}
}
}
#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
#[inline]
fn float_width(d: DType) -> u32 {
match d {
DType::F16 => 2,
DType::F32 => 4,
DType::F64 => 8,
_ => 0,
}
}
#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
#[inline]
fn complex_real_width(d: DType) -> u32 {
match d {
DType::C64 => 4,
DType::C128 => 8,
_ => 0,
}
}
#[cfg_attr(feature = "no-panic", no_panic::no_panic)]
#[inline]
fn effective_float_width(d: DType) -> u32 {
match d {
DType::Bool | DType::U8 | DType::I8 | DType::U16 | DType::I16 => 2,
DType::U32 | DType::I32 | DType::U64 | DType::I64 => 8,
DType::F16 => 2,
DType::F32 => 4,
DType::F64 => 8,
_ => 8,
}
}