1use std::fmt::Debug;
14
15pub trait AccelDtype: Copy + Send + Sync + 'static + Debug {
23 type Scalar: Copy + Send + Sync + 'static + Debug;
28
29 const KIND: DType;
31
32 const SIZE: usize;
34
35 const NAME: &'static str;
37
38 fn zero() -> Self;
39 fn one() -> Self;
40
41 fn nan() -> Option<Self>;
43}
44
45#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
47#[non_exhaustive]
48pub enum DType {
49 F32,
50 F64,
51 F16,
52 Bf16,
53 I8,
54 I16,
55 I32,
56 I64,
57 U8,
58 U16,
59 U32,
60 U64,
61 F8E4m3,
63 F8E5m2,
65 F4E2m1,
67}
68
69impl DType {
70 pub const fn size_bytes(self) -> usize {
71 match self {
72 DType::F32 | DType::I32 | DType::U32 => 4,
73 DType::F64 | DType::I64 | DType::U64 => 8,
74 DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
75 DType::I8 | DType::U8 | DType::F8E4m3 | DType::F8E5m2 | DType::F4E2m1 => 1,
76 }
77 }
78
79 pub const fn name(self) -> &'static str {
80 match self {
81 DType::F32 => "f32",
82 DType::F64 => "f64",
83 DType::F16 => "f16",
84 DType::Bf16 => "bf16",
85 DType::I8 => "i8",
86 DType::I16 => "i16",
87 DType::I32 => "i32",
88 DType::I64 => "i64",
89 DType::U8 => "u8",
90 DType::U16 => "u16",
91 DType::U32 => "u32",
92 DType::U64 => "u64",
93 DType::F8E4m3 => "f8_e4m3",
94 DType::F8E5m2 => "f8_e5m2",
95 DType::F4E2m1 => "f4_e2m1",
96 }
97 }
98
99 pub const fn is_float(self) -> bool {
100 matches!(
101 self,
102 DType::F32
103 | DType::F64
104 | DType::F16
105 | DType::Bf16
106 | DType::F8E4m3
107 | DType::F8E5m2
108 | DType::F4E2m1
109 )
110 }
111
112 pub const fn is_integer(self) -> bool {
113 matches!(
114 self,
115 DType::I8
116 | DType::I16
117 | DType::I32
118 | DType::I64
119 | DType::U8
120 | DType::U16
121 | DType::U32
122 | DType::U64
123 )
124 }
125
126 pub const fn is_signed(self) -> bool {
127 matches!(
128 self,
129 DType::I8
130 | DType::I16
131 | DType::I32
132 | DType::I64
133 | DType::F32
134 | DType::F64
135 | DType::F16
136 | DType::Bf16
137 | DType::F8E4m3
138 | DType::F8E5m2
139 | DType::F4E2m1
140 )
141 }
142}
143
144macro_rules! impl_accel_dtype_int {
145 ($t:ty, $kind:expr, $name:literal) => {
146 impl AccelDtype for $t {
147 type Scalar = Self;
148 const KIND: DType = $kind;
149 const SIZE: usize = std::mem::size_of::<Self>();
150 const NAME: &'static str = $name;
151
152 #[inline]
153 fn zero() -> Self {
154 0
155 }
156 #[inline]
157 fn one() -> Self {
158 1
159 }
160 #[inline]
161 fn nan() -> Option<Self> {
162 None
163 }
164 }
165 };
166}
167
168macro_rules! impl_accel_dtype_float {
169 ($t:ty, $kind:expr, $name:literal) => {
170 impl AccelDtype for $t {
171 type Scalar = Self;
172 const KIND: DType = $kind;
173 const SIZE: usize = std::mem::size_of::<Self>();
174 const NAME: &'static str = $name;
175
176 #[inline]
177 fn zero() -> Self {
178 0.0
179 }
180 #[inline]
181 fn one() -> Self {
182 1.0
183 }
184 #[inline]
185 fn nan() -> Option<Self> {
186 Some(<$t>::NAN)
187 }
188 }
189 };
190}
191
192impl_accel_dtype_float!(f32, DType::F32, "f32");
193impl_accel_dtype_float!(f64, DType::F64, "f64");
194impl_accel_dtype_int!(i8, DType::I8, "i8");
195impl_accel_dtype_int!(i16, DType::I16, "i16");
196impl_accel_dtype_int!(i32, DType::I32, "i32");
197impl_accel_dtype_int!(i64, DType::I64, "i64");
198impl_accel_dtype_int!(u8, DType::U8, "u8");
199impl_accel_dtype_int!(u16, DType::U16, "u16");
200impl_accel_dtype_int!(u32, DType::U32, "u32");
201impl_accel_dtype_int!(u64, DType::U64, "u64");
202
203#[cfg(feature = "f16")]
204impl AccelDtype for half::f16 {
205 type Scalar = Self;
206 const KIND: DType = DType::F16;
207 const SIZE: usize = std::mem::size_of::<Self>();
208 const NAME: &'static str = "f16";
209 #[inline]
210 fn zero() -> Self {
211 half::f16::ZERO
212 }
213 #[inline]
214 fn one() -> Self {
215 half::f16::ONE
216 }
217 #[inline]
218 fn nan() -> Option<Self> {
219 Some(half::f16::NAN)
220 }
221}
222
223#[cfg(feature = "f16")]
224impl AccelDtype for half::bf16 {
225 type Scalar = Self;
226 const KIND: DType = DType::Bf16;
227 const SIZE: usize = std::mem::size_of::<Self>();
228 const NAME: &'static str = "bf16";
229 #[inline]
230 fn zero() -> Self {
231 half::bf16::ZERO
232 }
233 #[inline]
234 fn one() -> Self {
235 half::bf16::ONE
236 }
237 #[inline]
238 fn nan() -> Option<Self> {
239 Some(half::bf16::NAN)
240 }
241}
242
243#[cfg(feature = "f8")]
246#[repr(transparent)]
247#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
248pub struct F8E4m3(pub u8);
249
250#[cfg(feature = "f8")]
251impl F8E4m3 {
252 pub const ZERO: Self = F8E4m3(0x00);
253 pub const ONE: Self = F8E4m3(0x38);
254 pub const NAN: Self = F8E4m3(0x7f);
255
256 pub fn from_f32(x: f32) -> Self {
258 if x.is_nan() {
259 return Self::NAN;
260 }
261 let max = 448.0_f32;
262 let clamped = x.clamp(-max, max);
263 let bits = clamped.to_bits();
264 let sign = ((bits >> 31) as u8) << 7;
265 let abs = clamped.abs();
266 if abs == 0.0 {
267 return F8E4m3(sign);
268 }
269 let f32_exp = ((bits >> 23) & 0xff) as i32 - 127;
270 let f32_mant = bits & 0x007f_ffff;
271 let e4_exp = f32_exp + 7;
272 if e4_exp <= 0 {
273 let shift = 21 + (1 - e4_exp) as u32;
274 let m = ((f32_mant | 0x0080_0000) >> shift) as u8;
275 return F8E4m3(sign | (m & 0x07));
276 }
277 let mant = (f32_mant >> 20) as u8;
278 let round_bit = ((f32_mant >> 19) & 1) as u8;
279 let sticky = ((f32_mant & 0x0007_ffff) != 0) as u8;
280 let mut e = e4_exp as u8;
281 let mut m = mant & 0x07;
282 if round_bit == 1 && (sticky == 1 || (m & 1) == 1) {
283 m = m.wrapping_add(1);
284 if m == 0x08 {
285 m = 0;
286 e = e.wrapping_add(1);
287 }
288 }
289 if e >= 0x0f {
290 return F8E4m3(sign | 0x7e);
291 }
292 F8E4m3(sign | (e << 3) | m)
293 }
294
295 pub fn to_f32(self) -> f32 {
296 let sign = (self.0 >> 7) & 1;
297 let exp = (self.0 >> 3) & 0x0f;
298 let mant = self.0 & 0x07;
299 if exp == 0 && mant == 0 {
300 return if sign == 1 { -0.0 } else { 0.0 };
301 }
302 if exp == 0x0f && mant == 0x07 {
303 return f32::NAN;
304 }
305 let (e, m) = if exp == 0 {
306 let lz = (mant.leading_zeros() as i32) - 5;
307 (1 - 7 - lz, ((mant as u32) << (lz + 1)) & 0x07)
308 } else {
309 (exp as i32 - 7, mant as u32)
310 };
311 let bits = ((sign as u32) << 31) | (((e + 127) as u32) << 23) | (m << 20);
312 f32::from_bits(bits)
313 }
314}
315
316#[cfg(feature = "f8")]
317impl AccelDtype for F8E4m3 {
318 type Scalar = f32;
319 const KIND: DType = DType::F8E4m3;
320 const SIZE: usize = 1;
321 const NAME: &'static str = "f8_e4m3";
322 #[inline]
323 fn zero() -> Self {
324 F8E4m3::ZERO
325 }
326 #[inline]
327 fn one() -> Self {
328 F8E4m3::ONE
329 }
330 #[inline]
331 fn nan() -> Option<Self> {
332 Some(F8E4m3::NAN)
333 }
334}
335
336#[cfg(feature = "f8")]
339#[repr(transparent)]
340#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
341pub struct F8E5m2(pub u8);
342
343#[cfg(feature = "f8")]
344impl F8E5m2 {
345 pub const ZERO: Self = F8E5m2(0x00);
346 pub const ONE: Self = F8E5m2(0x3c);
347 pub const NAN: Self = F8E5m2(0x7e);
348 pub const INFINITY: Self = F8E5m2(0x7c);
349
350 pub fn from_f32(x: f32) -> Self {
351 if x.is_nan() {
352 return Self::NAN;
353 }
354 let bits = x.to_bits();
355 let sign = ((bits >> 31) as u8) << 7;
356 let f32_exp = ((bits >> 23) & 0xff) as i32 - 127;
357 let f32_mant = bits & 0x007f_ffff;
358 if x == 0.0 {
359 return F8E5m2(sign);
360 }
361 let e5_exp = f32_exp + 15;
362 if e5_exp >= 0x1f {
363 return F8E5m2(sign | 0x7c);
364 }
365 if e5_exp <= 0 {
366 let shift = 22 + (1 - e5_exp) as u32;
367 let m = ((f32_mant | 0x0080_0000) >> shift) as u8;
368 return F8E5m2(sign | (m & 0x03));
369 }
370 let mant = (f32_mant >> 21) as u8;
371 let round_bit = ((f32_mant >> 20) & 1) as u8;
372 let sticky = ((f32_mant & 0x000f_ffff) != 0) as u8;
373 let mut e = e5_exp as u8;
374 let mut m = mant & 0x03;
375 if round_bit == 1 && (sticky == 1 || (m & 1) == 1) {
376 m = m.wrapping_add(1);
377 if m == 0x04 {
378 m = 0;
379 e = e.wrapping_add(1);
380 }
381 }
382 if e >= 0x1f {
383 return F8E5m2(sign | 0x7c);
384 }
385 F8E5m2(sign | (e << 2) | m)
386 }
387
388 pub fn to_f32(self) -> f32 {
389 let sign = (self.0 >> 7) & 1;
390 let exp = (self.0 >> 2) & 0x1f;
391 let mant = self.0 & 0x03;
392 if exp == 0 && mant == 0 {
393 return if sign == 1 { -0.0 } else { 0.0 };
394 }
395 if exp == 0x1f {
396 return if mant == 0 {
397 if sign == 1 {
398 f32::NEG_INFINITY
399 } else {
400 f32::INFINITY
401 }
402 } else {
403 f32::NAN
404 };
405 }
406 let (e, m) = if exp == 0 {
407 let lz = (mant.leading_zeros() as i32) - 6;
408 (1 - 15 - lz, ((mant as u32) << (lz + 1)) & 0x03)
409 } else {
410 (exp as i32 - 15, mant as u32)
411 };
412 let bits = ((sign as u32) << 31) | (((e + 127) as u32) << 23) | (m << 21);
413 f32::from_bits(bits)
414 }
415}
416
417#[cfg(feature = "f8")]
418impl AccelDtype for F8E5m2 {
419 type Scalar = f32;
420 const KIND: DType = DType::F8E5m2;
421 const SIZE: usize = 1;
422 const NAME: &'static str = "f8_e5m2";
423 #[inline]
424 fn zero() -> Self {
425 F8E5m2::ZERO
426 }
427 #[inline]
428 fn one() -> Self {
429 F8E5m2::ONE
430 }
431 #[inline]
432 fn nan() -> Option<Self> {
433 Some(F8E5m2::NAN)
434 }
435}
436
437#[cfg(feature = "f4")]
442#[repr(transparent)]
443#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
444pub struct F4E2m1(pub u8);
445
446#[cfg(feature = "f4")]
447impl F4E2m1 {
448 pub const ZERO: Self = F4E2m1(0x0);
449 pub const ONE: Self = F4E2m1(0x4);
450
451 pub fn to_f32(self) -> f32 {
452 let n = self.0 & 0x0f;
453 let sign = if (n >> 3) & 1 == 1 { -1.0 } else { 1.0 };
454 let exp = (n >> 1) & 0x03;
455 let mant = n & 0x01;
456 let value = match (exp, mant) {
457 (0, 0) => 0.0,
458 (0, 1) => 0.5,
459 (e, m) => {
460 let mantissa = 1.0 + (m as f32) * 0.5;
461 mantissa * 2.0_f32.powi(e as i32 - 1)
462 }
463 };
464 sign * value
465 }
466}
467
468#[cfg(feature = "f4")]
469impl AccelDtype for F4E2m1 {
470 type Scalar = f32;
471 const KIND: DType = DType::F4E2m1;
472 const SIZE: usize = 1;
473 const NAME: &'static str = "f4_e2m1";
474 #[inline]
475 fn zero() -> Self {
476 F4E2m1::ZERO
477 }
478 #[inline]
479 fn one() -> Self {
480 F4E2m1::ONE
481 }
482 #[inline]
483 fn nan() -> Option<Self> {
484 None
485 }
486}
487
488#[cfg(test)]
489mod tests {
490 use super::*;
491
492 #[test]
493 fn dtype_size_matches_trait() {
494 assert_eq!(<f32 as AccelDtype>::SIZE, DType::F32.size_bytes());
495 assert_eq!(<f64 as AccelDtype>::SIZE, DType::F64.size_bytes());
496 assert_eq!(<i8 as AccelDtype>::SIZE, DType::I8.size_bytes());
497 assert_eq!(<i32 as AccelDtype>::SIZE, DType::I32.size_bytes());
498 assert_eq!(<u32 as AccelDtype>::SIZE, DType::U32.size_bytes());
499 assert_eq!(<u64 as AccelDtype>::SIZE, DType::U64.size_bytes());
500 }
501
502 #[test]
503 fn dtype_classifiers() {
504 assert!(DType::F32.is_float());
505 assert!(!DType::I32.is_float());
506 assert!(DType::I32.is_integer());
507 assert!(DType::I32.is_signed());
508 assert!(!DType::U32.is_signed());
509 assert!(DType::F32.is_signed());
510 }
511
512 #[test]
513 fn dtype_names_match() {
514 assert_eq!(DType::F32.name(), <f32 as AccelDtype>::NAME);
515 assert_eq!(DType::F64.name(), <f64 as AccelDtype>::NAME);
516 assert_eq!(DType::U8.name(), <u8 as AccelDtype>::NAME);
517 }
518
519 #[test]
520 fn float_nan_is_some() {
521 assert!(<f32 as AccelDtype>::nan().is_some());
522 assert!(<f64 as AccelDtype>::nan().is_some());
523 }
524
525 #[test]
526 fn integer_nan_is_none() {
527 assert!(<i32 as AccelDtype>::nan().is_none());
528 assert!(<u64 as AccelDtype>::nan().is_none());
529 }
530
531 #[test]
532 fn zero_one_round_trip() {
533 assert_eq!(<f32 as AccelDtype>::zero(), 0.0);
534 assert_eq!(<f32 as AccelDtype>::one(), 1.0);
535 assert_eq!(<i32 as AccelDtype>::zero(), 0);
536 assert_eq!(<i32 as AccelDtype>::one(), 1);
537 }
538
539 #[cfg(feature = "f8")]
540 #[test]
541 fn f8e4m3_round_trip_simple() {
542 assert_eq!(F8E4m3::from_f32(0.0).to_f32(), 0.0);
543 assert_eq!(F8E4m3::from_f32(1.0).to_f32(), 1.0);
544 assert_eq!(F8E4m3::from_f32(2.0).to_f32(), 2.0);
545 assert_eq!(F8E4m3::from_f32(-1.0).to_f32(), -1.0);
546 }
547
548 #[cfg(feature = "f8")]
549 #[test]
550 fn f8e5m2_round_trip_simple() {
551 assert_eq!(F8E5m2::from_f32(0.0).to_f32(), 0.0);
552 assert_eq!(F8E5m2::from_f32(1.0).to_f32(), 1.0);
553 assert_eq!(F8E5m2::from_f32(2.0).to_f32(), 2.0);
554 assert_eq!(F8E5m2::from_f32(-1.0).to_f32(), -1.0);
555 }
556}