use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum DType {
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
BFloat16,
#[default]
Float32,
Float64,
Bool,
Complex64,
Complex128,
}
impl DType {
pub fn size(&self) -> usize {
match self {
DType::Int8 | DType::UInt8 | DType::Bool => 1,
DType::Int16 | DType::UInt16 | DType::Float16 | DType::BFloat16 => 2,
DType::Int32 | DType::UInt32 | DType::Float32 => 4,
DType::Int64 | DType::UInt64 | DType::Float64 | DType::Complex64 => 8,
DType::Complex128 => 16,
}
}
pub fn is_float(&self) -> bool {
matches!(
self,
DType::Float16 | DType::BFloat16 | DType::Float32 | DType::Float64
)
}
pub fn is_int(&self) -> bool {
matches!(
self,
DType::Int8
| DType::UInt8
| DType::Int16
| DType::UInt16
| DType::Int32
| DType::UInt32
| DType::Int64
| DType::UInt64
)
}
pub fn is_signed_int(&self) -> bool {
matches!(
self,
DType::Int8 | DType::Int16 | DType::Int32 | DType::Int64
)
}
pub fn is_unsigned_int(&self) -> bool {
matches!(
self,
DType::UInt8 | DType::UInt16 | DType::UInt32 | DType::UInt64
)
}
pub fn is_complex(&self) -> bool {
matches!(self, DType::Complex64 | DType::Complex128)
}
pub fn is_bool(&self) -> bool {
matches!(self, DType::Bool)
}
pub fn to_float(&self) -> DType {
match self {
DType::Int8 | DType::UInt8 | DType::Bool => DType::Float32,
DType::Int16 | DType::UInt16 | DType::Float16 | DType::BFloat16 => DType::Float32,
DType::Int32 | DType::UInt32 | DType::Float32 => DType::Float32,
DType::Int64 | DType::UInt64 | DType::Float64 => DType::Float64,
DType::Complex64 => DType::Complex64,
DType::Complex128 => DType::Complex128,
}
}
pub fn to_int(&self) -> DType {
match self {
DType::Int8 | DType::UInt8 | DType::Bool => DType::Int32,
DType::Int16 | DType::UInt16 | DType::Float16 | DType::BFloat16 => DType::Int32,
DType::Int32 | DType::UInt32 | DType::Float32 => DType::Int32,
DType::Int64 | DType::UInt64 | DType::Float64 => DType::Int64,
DType::Complex64 | DType::Complex128 => DType::Int64,
}
}
pub fn min_value(&self) -> Option<f64> {
match self {
DType::Int8 => Some(i8::MIN as f64),
DType::UInt8 => Some(0.0),
DType::Int16 => Some(i16::MIN as f64),
DType::UInt16 => Some(0.0),
DType::Int32 => Some(i32::MIN as f64),
DType::UInt32 => Some(0.0),
DType::Int64 => Some(i64::MIN as f64),
DType::UInt64 => Some(0.0),
DType::Float16 => Some(-65504.0), DType::BFloat16 => Some(-3.38e38), DType::Float32 => Some(f32::MIN as f64),
DType::Float64 => Some(f64::MIN),
DType::Bool => Some(0.0),
DType::Complex64 | DType::Complex128 => None,
}
}
pub fn max_value(&self) -> Option<f64> {
match self {
DType::Int8 => Some(i8::MAX as f64),
DType::UInt8 => Some(u8::MAX as f64),
DType::Int16 => Some(i16::MAX as f64),
DType::UInt16 => Some(u16::MAX as f64),
DType::Int32 => Some(i32::MAX as f64),
DType::UInt32 => Some(u32::MAX as f64),
DType::Int64 => Some(i64::MAX as f64),
DType::UInt64 => Some(u64::MAX as f64),
DType::Float16 => Some(65504.0), DType::BFloat16 => Some(3.38e38), DType::Float32 => Some(f32::MAX as f64),
DType::Float64 => Some(f64::MAX),
DType::Bool => Some(1.0),
DType::Complex64 | DType::Complex128 => None,
}
}
pub fn is_compatible_with(&self, other: &DType) -> bool {
if self == other {
return true;
}
match (self, other) {
(a, b) if a.is_int() && b.is_int() => true,
(a, b) if a.is_float() && b.is_float() => true,
(a, b) if a.is_int() && b.is_float() => true,
(a, b) if a.is_float() && b.is_int() => true,
(DType::Bool, b) if b.is_int() || b.is_float() => true,
(a, DType::Bool) if a.is_int() || a.is_float() => true,
(a, b) if a.is_complex() && b.is_complex() => true,
_ => false,
}
}
pub fn common_dtype(&self, other: &DType) -> DType {
if self == other {
return *self;
}
match (self, other) {
(DType::Complex128, _) | (_, DType::Complex128) => DType::Complex128,
(DType::Complex64, _) | (_, DType::Complex64) => DType::Complex64,
(DType::Float64, _) | (_, DType::Float64) => DType::Float64,
(DType::Float32, _) | (_, DType::Float32) => DType::Float32,
(DType::BFloat16, _) | (_, DType::BFloat16) => DType::BFloat16,
(DType::Float16, _) | (_, DType::Float16) => DType::Float16,
(DType::Int64, _) | (_, DType::Int64) => DType::Int64,
(DType::UInt64, _) | (_, DType::UInt64) => DType::UInt64,
(DType::Int32, _) | (_, DType::Int32) => DType::Int32,
(DType::UInt32, _) | (_, DType::UInt32) => DType::UInt32,
(DType::Int16, _) | (_, DType::Int16) => DType::Int16,
(DType::UInt16, _) | (_, DType::UInt16) => DType::UInt16,
(DType::Int8, _) | (_, DType::Int8) => DType::Int8,
(DType::UInt8, _) | (_, DType::UInt8) => DType::UInt8,
(first, _) => *first,
}
}
pub fn parse_dtype(s: &str) -> Result<DType, String> {
match s.to_lowercase().as_str() {
"int8" | "i8" => Ok(DType::Int8),
"uint8" | "u8" => Ok(DType::UInt8),
"int16" | "i16" => Ok(DType::Int16),
"uint16" | "u16" => Ok(DType::UInt16),
"int32" | "i32" | "int" => Ok(DType::Int32),
"uint32" | "u32" | "uint" => Ok(DType::UInt32),
"int64" | "i64" | "long" => Ok(DType::Int64),
"uint64" | "u64" | "ulong" => Ok(DType::UInt64),
"float16" | "f16" | "half" => Ok(DType::Float16),
"bfloat16" | "bf16" => Ok(DType::BFloat16),
"float32" | "f32" | "float" => Ok(DType::Float32),
"float64" | "f64" | "double" => Ok(DType::Float64),
"bool" | "boolean" => Ok(DType::Bool),
"complex64" | "c64" => Ok(DType::Complex64),
"complex128" | "c128" => Ok(DType::Complex128),
_ => Err(format!("Unknown dtype: {}", s)),
}
}
}
impl fmt::Display for DType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
DType::Int8 => "int8",
DType::UInt8 => "uint8",
DType::Int16 => "int16",
DType::UInt16 => "uint16",
DType::Int32 => "int32",
DType::UInt32 => "uint32",
DType::Int64 => "int64",
DType::UInt64 => "uint64",
DType::Float16 => "float16",
DType::BFloat16 => "bfloat16",
DType::Float32 => "float32",
DType::Float64 => "float64",
DType::Bool => "bool",
DType::Complex64 => "complex64",
DType::Complex128 => "complex128",
};
write!(f, "{}", name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtype_size() {
assert_eq!(DType::Int8.size(), 1);
assert_eq!(DType::UInt8.size(), 1);
assert_eq!(DType::Bool.size(), 1);
assert_eq!(DType::Int16.size(), 2);
assert_eq!(DType::Float16.size(), 2);
assert_eq!(DType::Int32.size(), 4);
assert_eq!(DType::Float32.size(), 4);
assert_eq!(DType::Int64.size(), 8);
assert_eq!(DType::Float64.size(), 8);
assert_eq!(DType::Complex64.size(), 8);
assert_eq!(DType::Complex128.size(), 16);
}
#[test]
fn test_dtype_properties() {
assert!(DType::Float32.is_float());
assert!(DType::Float64.is_float());
assert!(!DType::Int32.is_float());
assert!(DType::Int32.is_int());
assert!(DType::UInt32.is_int());
assert!(!DType::Float32.is_int());
assert!(DType::Int32.is_signed_int());
assert!(!DType::UInt32.is_signed_int());
assert!(DType::UInt32.is_unsigned_int());
assert!(!DType::Int32.is_unsigned_int());
assert!(DType::Complex64.is_complex());
assert!(!DType::Float32.is_complex());
assert!(DType::Bool.is_bool());
assert!(!DType::Int32.is_bool());
}
#[test]
fn test_dtype_conversion() {
assert_eq!(DType::Int32.to_float(), DType::Float32);
assert_eq!(DType::Int64.to_float(), DType::Float64);
assert_eq!(DType::Float32.to_float(), DType::Float32);
assert_eq!(DType::Float32.to_int(), DType::Int32);
assert_eq!(DType::Float64.to_int(), DType::Int64);
assert_eq!(DType::Int32.to_int(), DType::Int32);
}
#[test]
fn test_dtype_min_max() {
assert_eq!(DType::Int8.min_value(), Some(i8::MIN as f64));
assert_eq!(DType::Int8.max_value(), Some(i8::MAX as f64));
assert_eq!(DType::UInt8.min_value(), Some(0.0));
assert_eq!(DType::UInt8.max_value(), Some(u8::MAX as f64));
assert_eq!(DType::Bool.min_value(), Some(0.0));
assert_eq!(DType::Bool.max_value(), Some(1.0));
assert_eq!(DType::Complex64.min_value(), None);
assert_eq!(DType::Complex64.max_value(), None);
}
#[test]
fn test_dtype_compatibility() {
assert!(DType::Int32.is_compatible_with(&DType::Int32));
assert!(DType::Int32.is_compatible_with(&DType::Float32));
assert!(DType::Float32.is_compatible_with(&DType::Int32));
assert!(DType::Bool.is_compatible_with(&DType::Int32));
assert!(DType::Complex64.is_compatible_with(&DType::Complex128));
assert!(!DType::Complex64.is_compatible_with(&DType::Float32));
}
#[test]
fn test_dtype_promotion() {
assert_eq!(DType::Int32.common_dtype(&DType::Int32), DType::Int32);
assert_eq!(DType::Int32.common_dtype(&DType::Float32), DType::Float32);
assert_eq!(DType::Float32.common_dtype(&DType::Float64), DType::Float64);
assert_eq!(
DType::Int32.common_dtype(&DType::Complex64),
DType::Complex64
);
assert_eq!(DType::Bool.common_dtype(&DType::Int32), DType::Int32);
}
#[test]
fn test_dtype_from_string() {
assert_eq!(DType::parse_dtype("int32").unwrap(), DType::Int32);
assert_eq!(DType::parse_dtype("float32").unwrap(), DType::Float32);
assert_eq!(DType::parse_dtype("bool").unwrap(), DType::Bool);
assert_eq!(DType::parse_dtype("f32").unwrap(), DType::Float32);
assert_eq!(DType::parse_dtype("i32").unwrap(), DType::Int32);
assert!(DType::parse_dtype("invalid").is_err());
}
#[test]
fn test_dtype_display() {
assert_eq!(DType::Int32.to_string(), "int32");
assert_eq!(DType::Float32.to_string(), "float32");
assert_eq!(DType::Bool.to_string(), "bool");
assert_eq!(DType::Complex64.to_string(), "complex64");
}
#[test]
fn test_dtype_default() {
assert_eq!(DType::default(), DType::Float32);
}
}