hanzo_ml/
dtype.rs

1//! Types for elements that can be stored and manipulated using tensors.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::BackendStorage;
4use crate::{CpuStorage, CpuStorageRef, Error, Result};
5
6/// The different types of elements allowed in tensors.
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DType {
9    // Unsigned 8 bits integer.
10    U8,
11    // Unsigned 32 bits integer.
12    U32,
13    // Signed 16 bits integer.
14    I16,
15    // Signed 32 bits integer.
16    I32,
17    // Signed 64 bits integer.
18    I64,
19    // Brain floating-point using half precision (16 bits).
20    BF16,
21    // Floating-point using half precision (16 bits).
22    F16,
23    // Floating-point using single precision (32 bits).
24    F32,
25    // Floating-point using double precision (64 bits).
26    F64,
27    // 8-bit floating point with 4-bit exponent and 3-bit mantissa.
28    F8E4M3,
29    /// 6-bit float with 2 exponent bits and 3 mantissa bits (MX6 format)
30    F6E2M3,
31    /// 6-bit float with 3 exponent bits and 2 mantissa bits (MX6 format)
32    F6E3M2,
33    /// 4-bit float (MX4 format)
34    F4,
35    /// 8-bit float with 8 exponent bits and 0 mantissa bits
36    F8E8M0,
37}
38
39#[derive(Debug, PartialEq, Eq)]
40pub struct DTypeParseError(String);
41
42impl std::fmt::Display for DTypeParseError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "cannot parse '{}' as a dtype", self.0)
45    }
46}
47
48impl std::error::Error for DTypeParseError {}
49
50impl std::str::FromStr for DType {
51    type Err = DTypeParseError;
52    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
53        match s {
54            "u8" => Ok(Self::U8),
55            "u32" => Ok(Self::U32),
56            "i16" => Ok(Self::I16),
57            "i32" => Ok(Self::I32),
58            "i64" => Ok(Self::I64),
59            "bf16" => Ok(Self::BF16),
60            "f16" => Ok(Self::F16),
61            "f32" => Ok(Self::F32),
62            "f64" => Ok(Self::F64),
63            "f8e4m3" => Ok(Self::F8E4M3),
64            "f6e2m3" => Ok(Self::F6E2M3),
65            "f6e3m2" => Ok(Self::F6E3M2),
66            "f4" => Ok(Self::F4),
67            "f8e8m0" => Ok(Self::F8E8M0),
68            _ => Err(DTypeParseError(s.to_string())),
69        }
70    }
71}
72
73impl DType {
74    /// String representation for dtypes.
75    pub fn as_str(&self) -> &'static str {
76        match self {
77            Self::U8 => "u8",
78            Self::U32 => "u32",
79            Self::I16 => "i16",
80            Self::I32 => "i32",
81            Self::I64 => "i64",
82            Self::BF16 => "bf16",
83            Self::F16 => "f16",
84            Self::F32 => "f32",
85            Self::F64 => "f64",
86            Self::F8E4M3 => "f8e4m3",
87            Self::F6E2M3 => "f6e2m3",
88            Self::F6E3M2 => "f6e3m2",
89            Self::F4 => "f4",
90            Self::F8E8M0 => "f8e8m0",
91        }
92    }
93
94    /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
95    pub fn size_in_bytes(&self) -> usize {
96        match self {
97            Self::U8 => 1,
98            Self::U32 => 4,
99            Self::I16 => 2,
100            Self::I32 => 4,
101            Self::I64 => 8,
102            Self::BF16 => 2,
103            Self::F16 => 2,
104            Self::F32 => 4,
105            Self::F64 => 8,
106            Self::F8E4M3 => 1,
107            Self::F6E2M3 => 0, // 6 bits
108            Self::F6E3M2 => 0, // 6 bits
109            Self::F4 => 0,     // 4 bits
110            Self::F8E8M0 => 1,
111        }
112    }
113
114    pub fn is_int(&self) -> bool {
115        match self {
116            Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true,
117            Self::BF16
118            | Self::F16
119            | Self::F32
120            | Self::F64
121            | Self::F8E4M3
122            | Self::F6E2M3
123            | Self::F6E3M2
124            | Self::F4
125            | Self::F8E8M0 => false,
126        }
127    }
128
129    pub fn is_float(&self) -> bool {
130        match self {
131            Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false,
132            Self::BF16
133            | Self::F16
134            | Self::F32
135            | Self::F64
136            | Self::F8E4M3
137            | Self::F6E2M3
138            | Self::F6E3M2
139            | Self::F4
140            | Self::F8E8M0 => true,
141        }
142    }
143}
144
145pub trait WithDType:
146    Sized
147    + Copy
148    + num_traits::NumAssign
149    + std::cmp::PartialOrd
150    + std::fmt::Display
151    + 'static
152    + Send
153    + Sync
154    + std::any::Any
155    + crate::cpu::kernels::VecOps
156{
157    const DTYPE: DType;
158
159    fn from_f64(v: f64) -> Self;
160    fn to_f64(self) -> f64;
161    fn to_scalar(self) -> crate::scalar::Scalar;
162    fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
163    fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
164
165    fn to_cpu_storage(data: &[Self]) -> CpuStorage {
166        Self::to_cpu_storage_owned(data.to_vec())
167    }
168
169    fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
170    fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
171}
172
173macro_rules! with_dtype {
174    ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
175        impl WithDType for $ty {
176            const DTYPE: DType = DType::$dtype;
177
178            fn from_f64(v: f64) -> Self {
179                $from_f64(v)
180            }
181
182            fn to_f64(self) -> f64 {
183                $to_f64(self)
184            }
185
186            fn to_scalar(self) -> crate::scalar::Scalar {
187                crate::scalar::Scalar::$dtype(self)
188            }
189
190            fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
191                CpuStorageRef::$dtype(data)
192            }
193
194            fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
195                CpuStorage::$dtype(data)
196            }
197
198            fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
199                match s {
200                    CpuStorage::$dtype(data) => Ok(data),
201                    _ => Err(Error::UnexpectedDType {
202                        expected: DType::$dtype,
203                        got: s.dtype(),
204                        msg: "unexpected dtype",
205                    }
206                    .bt()),
207                }
208            }
209
210            fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
211                match s {
212                    CpuStorage::$dtype(data) => Ok(data),
213                    _ => Err(Error::UnexpectedDType {
214                        expected: DType::$dtype,
215                        got: s.dtype(),
216                        msg: "unexpected dtype",
217                    }
218                    .bt()),
219                }
220            }
221        }
222    };
223}
224use float8::F8E4M3 as f8e4m3;
225use half::{bf16, f16};
226
227with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
228with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
229with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64);
230with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64);
231with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
232with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
233with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
234with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
235with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
236with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64());
237
238pub trait IntDType: WithDType + num_traits::Bounded {
239    fn is_true(&self) -> bool;
240    fn as_usize(&self) -> usize;
241}
242
243impl IntDType for i64 {
244    fn is_true(&self) -> bool {
245        *self != 0
246    }
247    fn as_usize(&self) -> usize {
248        *self as usize
249    }
250}
251
252impl IntDType for u32 {
253    fn is_true(&self) -> bool {
254        *self != 0
255    }
256    fn as_usize(&self) -> usize {
257        *self as usize
258    }
259}
260
261impl IntDType for u8 {
262    fn is_true(&self) -> bool {
263        *self != 0
264    }
265    fn as_usize(&self) -> usize {
266        *self as usize
267    }
268}
269
270impl IntDType for i16 {
271    fn is_true(&self) -> bool {
272        *self != 0
273    }
274    fn as_usize(&self) -> usize {
275        *self as usize
276    }
277}
278
279impl IntDType for i32 {
280    fn is_true(&self) -> bool {
281        *self != 0
282    }
283    fn as_usize(&self) -> usize {
284        *self as usize
285    }
286}
287
288pub trait FloatDType: WithDType {}
289
290impl FloatDType for f16 {}
291impl FloatDType for bf16 {}
292impl FloatDType for f32 {}
293impl FloatDType for f64 {}
294impl FloatDType for f8e4m3 {}