Skip to main content

candle_core/
dtype.rs

1//! Types for elements that can be stored and manipulated using tensors.
2#![allow(clippy::redundant_closure_call)]
3use crate::backend::BackendStorage;
4use crate::{CpuStorage, CpuStorageRef, Error, Result};
5
6/// The different types of elements allowed in tensors.
7#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8#[non_exhaustive]
9pub enum DType {
10    // Unsigned 8 bits integer.
11    U8,
12    // Unsigned 32 bits integer.
13    U32,
14    // Signed 16 bits integer.
15    I16,
16    // Signed 32 bits integer.
17    I32,
18    // Signed 64 bits integer.
19    I64,
20    // Brain floating-point using half precision (16 bits).
21    BF16,
22    // Floating-point using half precision (16 bits).
23    F16,
24    // Floating-point using single precision (32 bits).
25    F32,
26    // Floating-point using double precision (64 bits).
27    F64,
28    // 8-bit floating point with 4-bit exponent and 3-bit mantissa.
29    F8E4M3,
30    /// 6-bit float with 2 exponent bits and 3 mantissa bits (MX6 format)
31    F6E2M3,
32    /// 6-bit float with 3 exponent bits and 2 mantissa bits (MX6 format)
33    F6E3M2,
34    /// 4-bit float (MX4 format)
35    F4,
36    /// 8-bit float with 8 exponent bits and 0 mantissa bits
37    F8E8M0,
38}
39
40#[derive(Debug, PartialEq, Eq)]
41pub struct DTypeParseError(String);
42
43impl std::fmt::Display for DTypeParseError {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(f, "cannot parse '{}' as a dtype", self.0)
46    }
47}
48
49impl std::error::Error for DTypeParseError {}
50
51impl std::str::FromStr for DType {
52    type Err = DTypeParseError;
53    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
54        match s {
55            "u8" => Ok(Self::U8),
56            "u32" => Ok(Self::U32),
57            "i16" => Ok(Self::I16),
58            "i32" => Ok(Self::I32),
59            "i64" => Ok(Self::I64),
60            "bf16" => Ok(Self::BF16),
61            "f16" => Ok(Self::F16),
62            "f32" => Ok(Self::F32),
63            "f64" => Ok(Self::F64),
64            "f8e4m3" => Ok(Self::F8E4M3),
65            "f6e2m3" => Ok(Self::F6E2M3),
66            "f6e3m2" => Ok(Self::F6E3M2),
67            "f4" => Ok(Self::F4),
68            "f8e8m0" => Ok(Self::F8E8M0),
69            _ => Err(DTypeParseError(s.to_string())),
70        }
71    }
72}
73
74impl DType {
75    /// String representation for dtypes.
76    pub fn as_str(&self) -> &'static str {
77        match self {
78            Self::U8 => "u8",
79            Self::U32 => "u32",
80            Self::I16 => "i16",
81            Self::I32 => "i32",
82            Self::I64 => "i64",
83            Self::BF16 => "bf16",
84            Self::F16 => "f16",
85            Self::F32 => "f32",
86            Self::F64 => "f64",
87            Self::F8E4M3 => "f8e4m3",
88            Self::F6E2M3 => "f6e2m3",
89            Self::F6E3M2 => "f6e3m2",
90            Self::F4 => "f4",
91            Self::F8E8M0 => "f8e8m0",
92        }
93    }
94
95    /// The size used by each element in bytes, i.e. 1 for `U8`, 4 for `F32`.
96    pub fn size_in_bytes(&self) -> usize {
97        match self {
98            Self::U8 => 1,
99            Self::U32 => 4,
100            Self::I16 => 2,
101            Self::I32 => 4,
102            Self::I64 => 8,
103            Self::BF16 => 2,
104            Self::F16 => 2,
105            Self::F32 => 4,
106            Self::F64 => 8,
107            Self::F8E4M3 => 1,
108            Self::F6E2M3 => 0, // 6 bits
109            Self::F6E3M2 => 0, // 6 bits
110            Self::F4 => 0,     // 4 bits
111            Self::F8E8M0 => 1,
112        }
113    }
114
115    pub fn is_int(&self) -> bool {
116        match self {
117            Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true,
118            Self::BF16
119            | Self::F16
120            | Self::F32
121            | Self::F64
122            | Self::F8E4M3
123            | Self::F6E2M3
124            | Self::F6E3M2
125            | Self::F4
126            | Self::F8E8M0 => false,
127        }
128    }
129
130    pub fn is_float(&self) -> bool {
131        match self {
132            Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false,
133            Self::BF16
134            | Self::F16
135            | Self::F32
136            | Self::F64
137            | Self::F8E4M3
138            | Self::F6E2M3
139            | Self::F6E3M2
140            | Self::F4
141            | Self::F8E8M0 => true,
142        }
143    }
144}
145
146pub trait WithDType:
147    Sized
148    + Copy
149    + num_traits::NumAssign
150    + std::cmp::PartialOrd
151    + std::fmt::Display
152    + 'static
153    + Send
154    + Sync
155    + std::any::Any
156    + crate::cpu::kernels::VecOps
157{
158    const DTYPE: DType;
159
160    fn from_f64(v: f64) -> Self;
161    fn to_f64(self) -> f64;
162    fn to_scalar(self) -> crate::scalar::Scalar;
163    fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
164    fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
165
166    fn to_cpu_storage(data: &[Self]) -> CpuStorage {
167        Self::to_cpu_storage_owned(data.to_vec())
168    }
169
170    fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
171    fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
172}
173
174macro_rules! with_dtype {
175    ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
176        impl WithDType for $ty {
177            const DTYPE: DType = DType::$dtype;
178
179            fn from_f64(v: f64) -> Self {
180                $from_f64(v)
181            }
182
183            fn to_f64(self) -> f64 {
184                $to_f64(self)
185            }
186
187            fn to_scalar(self) -> crate::scalar::Scalar {
188                crate::scalar::Scalar::$dtype(self)
189            }
190
191            fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
192                CpuStorageRef::$dtype(data)
193            }
194
195            fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
196                CpuStorage::$dtype(data)
197            }
198
199            fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
200                match s {
201                    CpuStorage::$dtype(data) => Ok(data),
202                    _ => Err(Error::UnexpectedDType {
203                        expected: DType::$dtype,
204                        got: s.dtype(),
205                        msg: "unexpected dtype",
206                    }
207                    .bt()),
208                }
209            }
210
211            fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
212                match s {
213                    CpuStorage::$dtype(data) => Ok(data),
214                    _ => Err(Error::UnexpectedDType {
215                        expected: DType::$dtype,
216                        got: s.dtype(),
217                        msg: "unexpected dtype",
218                    }
219                    .bt()),
220                }
221            }
222        }
223    };
224}
225use float8::F8E4M3 as f8e4m3;
226use half::{bf16, f16};
227
228with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
229with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
230with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64);
231with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64);
232with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
233with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
234with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
235with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
236with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
237with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64());
238
239pub trait IntDType: WithDType + num_traits::Bounded {
240    fn is_true(&self) -> bool;
241    fn as_usize(&self) -> usize;
242}
243
244impl IntDType for i64 {
245    fn is_true(&self) -> bool {
246        *self != 0
247    }
248    fn as_usize(&self) -> usize {
249        *self as usize
250    }
251}
252
253impl IntDType for u32 {
254    fn is_true(&self) -> bool {
255        *self != 0
256    }
257    fn as_usize(&self) -> usize {
258        *self as usize
259    }
260}
261
262impl IntDType for u8 {
263    fn is_true(&self) -> bool {
264        *self != 0
265    }
266    fn as_usize(&self) -> usize {
267        *self as usize
268    }
269}
270
271impl IntDType for i16 {
272    fn is_true(&self) -> bool {
273        *self != 0
274    }
275    fn as_usize(&self) -> usize {
276        *self as usize
277    }
278}
279
280impl IntDType for i32 {
281    fn is_true(&self) -> bool {
282        *self != 0
283    }
284    fn as_usize(&self) -> usize {
285        *self as usize
286    }
287}
288
289pub trait FloatDType: WithDType {}
290
291impl FloatDType for f16 {}
292impl FloatDType for bf16 {}
293impl FloatDType for f32 {}
294impl FloatDType for f64 {}
295impl FloatDType for f8e4m3 {}