1#![allow(clippy::redundant_closure_call)]
3use crate::backend::BackendStorage;
4use crate::{CpuStorage, CpuStorageRef, Error, Result};
5
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8pub enum DType {
9 U8,
11 U32,
13 I16,
15 I32,
17 I64,
19 BF16,
21 F16,
23 F32,
25 F64,
27 F8E4M3,
29 F6E2M3,
31 F6E3M2,
33 F4,
35 F8E8M0,
37}
38
39#[derive(Debug, PartialEq, Eq)]
40pub struct DTypeParseError(String);
41
42impl std::fmt::Display for DTypeParseError {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "cannot parse '{}' as a dtype", self.0)
45 }
46}
47
48impl std::error::Error for DTypeParseError {}
49
50impl std::str::FromStr for DType {
51 type Err = DTypeParseError;
52 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
53 match s {
54 "u8" => Ok(Self::U8),
55 "u32" => Ok(Self::U32),
56 "i16" => Ok(Self::I16),
57 "i32" => Ok(Self::I32),
58 "i64" => Ok(Self::I64),
59 "bf16" => Ok(Self::BF16),
60 "f16" => Ok(Self::F16),
61 "f32" => Ok(Self::F32),
62 "f64" => Ok(Self::F64),
63 "f8e4m3" => Ok(Self::F8E4M3),
64 "f6e2m3" => Ok(Self::F6E2M3),
65 "f6e3m2" => Ok(Self::F6E3M2),
66 "f4" => Ok(Self::F4),
67 "f8e8m0" => Ok(Self::F8E8M0),
68 _ => Err(DTypeParseError(s.to_string())),
69 }
70 }
71}
72
73impl DType {
74 pub fn as_str(&self) -> &'static str {
76 match self {
77 Self::U8 => "u8",
78 Self::U32 => "u32",
79 Self::I16 => "i16",
80 Self::I32 => "i32",
81 Self::I64 => "i64",
82 Self::BF16 => "bf16",
83 Self::F16 => "f16",
84 Self::F32 => "f32",
85 Self::F64 => "f64",
86 Self::F8E4M3 => "f8e4m3",
87 Self::F6E2M3 => "f6e2m3",
88 Self::F6E3M2 => "f6e3m2",
89 Self::F4 => "f4",
90 Self::F8E8M0 => "f8e8m0",
91 }
92 }
93
94 pub fn size_in_bytes(&self) -> usize {
96 match self {
97 Self::U8 => 1,
98 Self::U32 => 4,
99 Self::I16 => 2,
100 Self::I32 => 4,
101 Self::I64 => 8,
102 Self::BF16 => 2,
103 Self::F16 => 2,
104 Self::F32 => 4,
105 Self::F64 => 8,
106 Self::F8E4M3 => 1,
107 Self::F6E2M3 => 0, Self::F6E3M2 => 0, Self::F4 => 0, Self::F8E8M0 => 1,
111 }
112 }
113
114 pub fn is_int(&self) -> bool {
115 match self {
116 Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true,
117 Self::BF16
118 | Self::F16
119 | Self::F32
120 | Self::F64
121 | Self::F8E4M3
122 | Self::F6E2M3
123 | Self::F6E3M2
124 | Self::F4
125 | Self::F8E8M0 => false,
126 }
127 }
128
129 pub fn is_float(&self) -> bool {
130 match self {
131 Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false,
132 Self::BF16
133 | Self::F16
134 | Self::F32
135 | Self::F64
136 | Self::F8E4M3
137 | Self::F6E2M3
138 | Self::F6E3M2
139 | Self::F4
140 | Self::F8E8M0 => true,
141 }
142 }
143}
144
145pub trait WithDType:
146 Sized
147 + Copy
148 + num_traits::NumAssign
149 + std::cmp::PartialOrd
150 + std::fmt::Display
151 + 'static
152 + Send
153 + Sync
154 + std::any::Any
155 + crate::cpu::kernels::VecOps
156{
157 const DTYPE: DType;
158
159 fn from_f64(v: f64) -> Self;
160 fn to_f64(self) -> f64;
161 fn to_scalar(self) -> crate::scalar::Scalar;
162 fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
163 fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
164
165 fn to_cpu_storage(data: &[Self]) -> CpuStorage {
166 Self::to_cpu_storage_owned(data.to_vec())
167 }
168
169 fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
170 fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
171}
172
173macro_rules! with_dtype {
174 ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
175 impl WithDType for $ty {
176 const DTYPE: DType = DType::$dtype;
177
178 fn from_f64(v: f64) -> Self {
179 $from_f64(v)
180 }
181
182 fn to_f64(self) -> f64 {
183 $to_f64(self)
184 }
185
186 fn to_scalar(self) -> crate::scalar::Scalar {
187 crate::scalar::Scalar::$dtype(self)
188 }
189
190 fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
191 CpuStorageRef::$dtype(data)
192 }
193
194 fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
195 CpuStorage::$dtype(data)
196 }
197
198 fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
199 match s {
200 CpuStorage::$dtype(data) => Ok(data),
201 _ => Err(Error::UnexpectedDType {
202 expected: DType::$dtype,
203 got: s.dtype(),
204 msg: "unexpected dtype",
205 }
206 .bt()),
207 }
208 }
209
210 fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
211 match s {
212 CpuStorage::$dtype(data) => Ok(data),
213 _ => Err(Error::UnexpectedDType {
214 expected: DType::$dtype,
215 got: s.dtype(),
216 msg: "unexpected dtype",
217 }
218 .bt()),
219 }
220 }
221 }
222 };
223}
224use float8::F8E4M3 as f8e4m3;
225use half::{bf16, f16};
226
227with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
228with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
229with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64);
230with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64);
231with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
232with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
233with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
234with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
235with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
236with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64());
237
238pub trait IntDType: WithDType + num_traits::Bounded {
239 fn is_true(&self) -> bool;
240 fn as_usize(&self) -> usize;
241}
242
243impl IntDType for i64 {
244 fn is_true(&self) -> bool {
245 *self != 0
246 }
247 fn as_usize(&self) -> usize {
248 *self as usize
249 }
250}
251
252impl IntDType for u32 {
253 fn is_true(&self) -> bool {
254 *self != 0
255 }
256 fn as_usize(&self) -> usize {
257 *self as usize
258 }
259}
260
261impl IntDType for u8 {
262 fn is_true(&self) -> bool {
263 *self != 0
264 }
265 fn as_usize(&self) -> usize {
266 *self as usize
267 }
268}
269
270impl IntDType for i16 {
271 fn is_true(&self) -> bool {
272 *self != 0
273 }
274 fn as_usize(&self) -> usize {
275 *self as usize
276 }
277}
278
279impl IntDType for i32 {
280 fn is_true(&self) -> bool {
281 *self != 0
282 }
283 fn as_usize(&self) -> usize {
284 *self as usize
285 }
286}
287
288pub trait FloatDType: WithDType {}
289
290impl FloatDType for f16 {}
291impl FloatDType for bf16 {}
292impl FloatDType for f32 {}
293impl FloatDType for f64 {}
294impl FloatDType for f8e4m3 {}