1#![allow(clippy::redundant_closure_call)]
3use crate::backend::BackendStorage;
4use crate::{CpuStorage, CpuStorageRef, Error, Result};
5
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
8#[non_exhaustive]
9pub enum DType {
10 U8,
12 U32,
14 I16,
16 I32,
18 I64,
20 BF16,
22 F16,
24 F32,
26 F64,
28 F8E4M3,
30 F6E2M3,
32 F6E3M2,
34 F4,
36 F8E8M0,
38}
39
40#[derive(Debug, PartialEq, Eq)]
41pub struct DTypeParseError(String);
42
43impl std::fmt::Display for DTypeParseError {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "cannot parse '{}' as a dtype", self.0)
46 }
47}
48
49impl std::error::Error for DTypeParseError {}
50
51impl std::str::FromStr for DType {
52 type Err = DTypeParseError;
53 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
54 match s {
55 "u8" => Ok(Self::U8),
56 "u32" => Ok(Self::U32),
57 "i16" => Ok(Self::I16),
58 "i32" => Ok(Self::I32),
59 "i64" => Ok(Self::I64),
60 "bf16" => Ok(Self::BF16),
61 "f16" => Ok(Self::F16),
62 "f32" => Ok(Self::F32),
63 "f64" => Ok(Self::F64),
64 "f8e4m3" => Ok(Self::F8E4M3),
65 "f6e2m3" => Ok(Self::F6E2M3),
66 "f6e3m2" => Ok(Self::F6E3M2),
67 "f4" => Ok(Self::F4),
68 "f8e8m0" => Ok(Self::F8E8M0),
69 _ => Err(DTypeParseError(s.to_string())),
70 }
71 }
72}
73
74impl DType {
75 pub fn as_str(&self) -> &'static str {
77 match self {
78 Self::U8 => "u8",
79 Self::U32 => "u32",
80 Self::I16 => "i16",
81 Self::I32 => "i32",
82 Self::I64 => "i64",
83 Self::BF16 => "bf16",
84 Self::F16 => "f16",
85 Self::F32 => "f32",
86 Self::F64 => "f64",
87 Self::F8E4M3 => "f8e4m3",
88 Self::F6E2M3 => "f6e2m3",
89 Self::F6E3M2 => "f6e3m2",
90 Self::F4 => "f4",
91 Self::F8E8M0 => "f8e8m0",
92 }
93 }
94
95 pub fn size_in_bytes(&self) -> usize {
97 match self {
98 Self::U8 => 1,
99 Self::U32 => 4,
100 Self::I16 => 2,
101 Self::I32 => 4,
102 Self::I64 => 8,
103 Self::BF16 => 2,
104 Self::F16 => 2,
105 Self::F32 => 4,
106 Self::F64 => 8,
107 Self::F8E4M3 => 1,
108 Self::F6E2M3 => 0, Self::F6E3M2 => 0, Self::F4 => 0, Self::F8E8M0 => 1,
112 }
113 }
114
115 pub fn is_int(&self) -> bool {
116 match self {
117 Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => true,
118 Self::BF16
119 | Self::F16
120 | Self::F32
121 | Self::F64
122 | Self::F8E4M3
123 | Self::F6E2M3
124 | Self::F6E3M2
125 | Self::F4
126 | Self::F8E8M0 => false,
127 }
128 }
129
130 pub fn is_float(&self) -> bool {
131 match self {
132 Self::U8 | Self::U32 | Self::I16 | Self::I32 | Self::I64 => false,
133 Self::BF16
134 | Self::F16
135 | Self::F32
136 | Self::F64
137 | Self::F8E4M3
138 | Self::F6E2M3
139 | Self::F6E3M2
140 | Self::F4
141 | Self::F8E8M0 => true,
142 }
143 }
144}
145
146pub trait WithDType:
147 Sized
148 + Copy
149 + num_traits::NumAssign
150 + std::cmp::PartialOrd
151 + std::fmt::Display
152 + 'static
153 + Send
154 + Sync
155 + std::any::Any
156 + crate::cpu::kernels::VecOps
157{
158 const DTYPE: DType;
159
160 fn from_f64(v: f64) -> Self;
161 fn to_f64(self) -> f64;
162 fn to_scalar(self) -> crate::scalar::Scalar;
163 fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
164 fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
165
166 fn to_cpu_storage(data: &[Self]) -> CpuStorage {
167 Self::to_cpu_storage_owned(data.to_vec())
168 }
169
170 fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>;
171 fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>>;
172}
173
174macro_rules! with_dtype {
175 ($ty:ty, $dtype:ident, $from_f64:expr, $to_f64:expr) => {
176 impl WithDType for $ty {
177 const DTYPE: DType = DType::$dtype;
178
179 fn from_f64(v: f64) -> Self {
180 $from_f64(v)
181 }
182
183 fn to_f64(self) -> f64 {
184 $to_f64(self)
185 }
186
187 fn to_scalar(self) -> crate::scalar::Scalar {
188 crate::scalar::Scalar::$dtype(self)
189 }
190
191 fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
192 CpuStorageRef::$dtype(data)
193 }
194
195 fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage {
196 CpuStorage::$dtype(data)
197 }
198
199 fn cpu_storage_data(s: CpuStorage) -> Result<Vec<Self>> {
200 match s {
201 CpuStorage::$dtype(data) => Ok(data),
202 _ => Err(Error::UnexpectedDType {
203 expected: DType::$dtype,
204 got: s.dtype(),
205 msg: "unexpected dtype",
206 }
207 .bt()),
208 }
209 }
210
211 fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]> {
212 match s {
213 CpuStorage::$dtype(data) => Ok(data),
214 _ => Err(Error::UnexpectedDType {
215 expected: DType::$dtype,
216 got: s.dtype(),
217 msg: "unexpected dtype",
218 }
219 .bt()),
220 }
221 }
222 }
223 };
224}
225use float8::F8E4M3 as f8e4m3;
226use half::{bf16, f16};
227
228with_dtype!(u8, U8, |v: f64| v as u8, |v: u8| v as f64);
229with_dtype!(u32, U32, |v: f64| v as u32, |v: u32| v as f64);
230with_dtype!(i16, I16, |v: f64| v as i16, |v: i16| v as f64);
231with_dtype!(i32, I32, |v: f64| v as i32, |v: i32| v as f64);
232with_dtype!(i64, I64, |v: f64| v as i64, |v: i64| v as f64);
233with_dtype!(f16, F16, f16::from_f64, f16::to_f64);
234with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
235with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
236with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
237with_dtype!(f8e4m3, F8E4M3, f8e4m3::from_f64, |v: f8e4m3| v.to_f64());
238
239pub trait IntDType: WithDType + num_traits::Bounded {
240 fn is_true(&self) -> bool;
241 fn as_usize(&self) -> usize;
242}
243
244impl IntDType for i64 {
245 fn is_true(&self) -> bool {
246 *self != 0
247 }
248 fn as_usize(&self) -> usize {
249 *self as usize
250 }
251}
252
253impl IntDType for u32 {
254 fn is_true(&self) -> bool {
255 *self != 0
256 }
257 fn as_usize(&self) -> usize {
258 *self as usize
259 }
260}
261
262impl IntDType for u8 {
263 fn is_true(&self) -> bool {
264 *self != 0
265 }
266 fn as_usize(&self) -> usize {
267 *self as usize
268 }
269}
270
271impl IntDType for i16 {
272 fn is_true(&self) -> bool {
273 *self != 0
274 }
275 fn as_usize(&self) -> usize {
276 *self as usize
277 }
278}
279
280impl IntDType for i32 {
281 fn is_true(&self) -> bool {
282 *self != 0
283 }
284 fn as_usize(&self) -> usize {
285 *self as usize
286 }
287}
288
289pub trait FloatDType: WithDType {}
290
291impl FloatDType for f16 {}
292impl FloatDType for bf16 {}
293impl FloatDType for f32 {}
294impl FloatDType for f64 {}
295impl FloatDType for f8e4m3 {}