Skip to main content

atomr_accel/
dtype.rs

1//! `AccelDtype` — backend-agnostic numeric data-type trait.
2//!
3//! Every numeric kernel actor in atomr-accel works over `T: AccelDtype`
4//! so allocation, copy, and op messages can be dtype-generic without
5//! exploding into one variant per (op, dtype) pair.
6//!
7//! The trait captures only what every backend agrees on (size, identity
8//! values, NaN, a discriminant for runtime dispatch). Backend-specific
9//! traits like `CudaDtype` layer on top with the binding-specific
10//! mappings (cudarc enums for CUDA, Metal `MTLDataType` for Metal,
11//! `hipblasDatatype_t` for ROCm, …).
12
13use std::fmt::Debug;
14
15/// Marker for any numeric type that can be a typed device buffer
16/// element across atomr-accel backends.
17///
18/// `AccelDtype` is intentionally *narrower* than what individual
19/// backends can support. cuBLAS f64 GEMM is gated by a `GemmSupported`
20/// marker on the CUDA side; this trait says only that the type is a
21/// recognised dtype.
22pub trait AccelDtype: Copy + Send + Sync + 'static + Debug {
23    /// Companion scalar type for host-side parameters (alpha/beta,
24    /// mean/std, scaling factors). `Self` for full-precision dtypes;
25    /// `f32` for fp8 / fp4 wrappers because the upstream APIs accept
26    /// f32 scales.
27    type Scalar: Copy + Send + Sync + 'static + Debug;
28
29    /// Runtime discriminant.
30    const KIND: DType;
31
32    /// Bytes per element including representation padding.
33    const SIZE: usize;
34
35    /// Human-readable name used in tracing and error messages.
36    const NAME: &'static str;
37
38    fn zero() -> Self;
39    fn one() -> Self;
40
41    /// `Some(NaN)` for floats, `None` for integers.
42    fn nan() -> Option<Self>;
43}
44
45/// Compact discriminant for [`AccelDtype::KIND`].
46#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
47#[non_exhaustive]
48pub enum DType {
49    F32,
50    F64,
51    F16,
52    Bf16,
53    I8,
54    I16,
55    I32,
56    I64,
57    U8,
58    U16,
59    U32,
60    U64,
61    /// 8-bit float, E4M3 (sign 1, exp 4, mant 3). Hopper+ fp8 GEMM, FlashAttention v3.
62    F8E4m3,
63    /// 8-bit float, E5M2 (sign 1, exp 5, mant 2).
64    F8E5m2,
65    /// 4-bit float, E2M1. Blackwell fp4 inference.
66    F4E2m1,
67}
68
69impl DType {
70    pub const fn size_bytes(self) -> usize {
71        match self {
72            DType::F32 | DType::I32 | DType::U32 => 4,
73            DType::F64 | DType::I64 | DType::U64 => 8,
74            DType::F16 | DType::Bf16 | DType::I16 | DType::U16 => 2,
75            DType::I8 | DType::U8 | DType::F8E4m3 | DType::F8E5m2 | DType::F4E2m1 => 1,
76        }
77    }
78
79    pub const fn name(self) -> &'static str {
80        match self {
81            DType::F32 => "f32",
82            DType::F64 => "f64",
83            DType::F16 => "f16",
84            DType::Bf16 => "bf16",
85            DType::I8 => "i8",
86            DType::I16 => "i16",
87            DType::I32 => "i32",
88            DType::I64 => "i64",
89            DType::U8 => "u8",
90            DType::U16 => "u16",
91            DType::U32 => "u32",
92            DType::U64 => "u64",
93            DType::F8E4m3 => "f8_e4m3",
94            DType::F8E5m2 => "f8_e5m2",
95            DType::F4E2m1 => "f4_e2m1",
96        }
97    }
98
99    pub const fn is_float(self) -> bool {
100        matches!(
101            self,
102            DType::F32
103                | DType::F64
104                | DType::F16
105                | DType::Bf16
106                | DType::F8E4m3
107                | DType::F8E5m2
108                | DType::F4E2m1
109        )
110    }
111
112    pub const fn is_integer(self) -> bool {
113        matches!(
114            self,
115            DType::I8
116                | DType::I16
117                | DType::I32
118                | DType::I64
119                | DType::U8
120                | DType::U16
121                | DType::U32
122                | DType::U64
123        )
124    }
125
126    pub const fn is_signed(self) -> bool {
127        matches!(
128            self,
129            DType::I8
130                | DType::I16
131                | DType::I32
132                | DType::I64
133                | DType::F32
134                | DType::F64
135                | DType::F16
136                | DType::Bf16
137                | DType::F8E4m3
138                | DType::F8E5m2
139                | DType::F4E2m1
140        )
141    }
142}
143
144macro_rules! impl_accel_dtype_int {
145    ($t:ty, $kind:expr, $name:literal) => {
146        impl AccelDtype for $t {
147            type Scalar = Self;
148            const KIND: DType = $kind;
149            const SIZE: usize = std::mem::size_of::<Self>();
150            const NAME: &'static str = $name;
151
152            #[inline]
153            fn zero() -> Self {
154                0
155            }
156            #[inline]
157            fn one() -> Self {
158                1
159            }
160            #[inline]
161            fn nan() -> Option<Self> {
162                None
163            }
164        }
165    };
166}
167
168macro_rules! impl_accel_dtype_float {
169    ($t:ty, $kind:expr, $name:literal) => {
170        impl AccelDtype for $t {
171            type Scalar = Self;
172            const KIND: DType = $kind;
173            const SIZE: usize = std::mem::size_of::<Self>();
174            const NAME: &'static str = $name;
175
176            #[inline]
177            fn zero() -> Self {
178                0.0
179            }
180            #[inline]
181            fn one() -> Self {
182                1.0
183            }
184            #[inline]
185            fn nan() -> Option<Self> {
186                Some(<$t>::NAN)
187            }
188        }
189    };
190}
191
192impl_accel_dtype_float!(f32, DType::F32, "f32");
193impl_accel_dtype_float!(f64, DType::F64, "f64");
194impl_accel_dtype_int!(i8, DType::I8, "i8");
195impl_accel_dtype_int!(i16, DType::I16, "i16");
196impl_accel_dtype_int!(i32, DType::I32, "i32");
197impl_accel_dtype_int!(i64, DType::I64, "i64");
198impl_accel_dtype_int!(u8, DType::U8, "u8");
199impl_accel_dtype_int!(u16, DType::U16, "u16");
200impl_accel_dtype_int!(u32, DType::U32, "u32");
201impl_accel_dtype_int!(u64, DType::U64, "u64");
202
203#[cfg(feature = "f16")]
204impl AccelDtype for half::f16 {
205    type Scalar = Self;
206    const KIND: DType = DType::F16;
207    const SIZE: usize = std::mem::size_of::<Self>();
208    const NAME: &'static str = "f16";
209    #[inline]
210    fn zero() -> Self {
211        half::f16::ZERO
212    }
213    #[inline]
214    fn one() -> Self {
215        half::f16::ONE
216    }
217    #[inline]
218    fn nan() -> Option<Self> {
219        Some(half::f16::NAN)
220    }
221}
222
223#[cfg(feature = "f16")]
224impl AccelDtype for half::bf16 {
225    type Scalar = Self;
226    const KIND: DType = DType::Bf16;
227    const SIZE: usize = std::mem::size_of::<Self>();
228    const NAME: &'static str = "bf16";
229    #[inline]
230    fn zero() -> Self {
231        half::bf16::ZERO
232    }
233    #[inline]
234    fn one() -> Self {
235        half::bf16::ONE
236    }
237    #[inline]
238    fn nan() -> Option<Self> {
239        Some(half::bf16::NAN)
240    }
241}
242
243/// 8-bit float, E4M3 layout. Storage is one byte; conversions to/from
244/// f32 are saturating.
245#[cfg(feature = "f8")]
246#[repr(transparent)]
247#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
248pub struct F8E4m3(pub u8);
249
250#[cfg(feature = "f8")]
251impl F8E4m3 {
252    pub const ZERO: Self = F8E4m3(0x00);
253    pub const ONE: Self = F8E4m3(0x38);
254    pub const NAN: Self = F8E4m3(0x7f);
255
256    /// Saturating round-to-nearest-even conversion from f32.
257    pub fn from_f32(x: f32) -> Self {
258        if x.is_nan() {
259            return Self::NAN;
260        }
261        let max = 448.0_f32;
262        let clamped = x.clamp(-max, max);
263        let bits = clamped.to_bits();
264        let sign = ((bits >> 31) as u8) << 7;
265        let abs = clamped.abs();
266        if abs == 0.0 {
267            return F8E4m3(sign);
268        }
269        let f32_exp = ((bits >> 23) & 0xff) as i32 - 127;
270        let f32_mant = bits & 0x007f_ffff;
271        let e4_exp = f32_exp + 7;
272        if e4_exp <= 0 {
273            let shift = 21 + (1 - e4_exp) as u32;
274            let m = ((f32_mant | 0x0080_0000) >> shift) as u8;
275            return F8E4m3(sign | (m & 0x07));
276        }
277        let mant = (f32_mant >> 20) as u8;
278        let round_bit = ((f32_mant >> 19) & 1) as u8;
279        let sticky = ((f32_mant & 0x0007_ffff) != 0) as u8;
280        let mut e = e4_exp as u8;
281        let mut m = mant & 0x07;
282        if round_bit == 1 && (sticky == 1 || (m & 1) == 1) {
283            m = m.wrapping_add(1);
284            if m == 0x08 {
285                m = 0;
286                e = e.wrapping_add(1);
287            }
288        }
289        if e >= 0x0f {
290            return F8E4m3(sign | 0x7e);
291        }
292        F8E4m3(sign | (e << 3) | m)
293    }
294
295    pub fn to_f32(self) -> f32 {
296        let sign = (self.0 >> 7) & 1;
297        let exp = (self.0 >> 3) & 0x0f;
298        let mant = self.0 & 0x07;
299        if exp == 0 && mant == 0 {
300            return if sign == 1 { -0.0 } else { 0.0 };
301        }
302        if exp == 0x0f && mant == 0x07 {
303            return f32::NAN;
304        }
305        let (e, m) = if exp == 0 {
306            let lz = (mant.leading_zeros() as i32) - 5;
307            (1 - 7 - lz, ((mant as u32) << (lz + 1)) & 0x07)
308        } else {
309            (exp as i32 - 7, mant as u32)
310        };
311        let bits = ((sign as u32) << 31) | (((e + 127) as u32) << 23) | (m << 20);
312        f32::from_bits(bits)
313    }
314}
315
316#[cfg(feature = "f8")]
317impl AccelDtype for F8E4m3 {
318    type Scalar = f32;
319    const KIND: DType = DType::F8E4m3;
320    const SIZE: usize = 1;
321    const NAME: &'static str = "f8_e4m3";
322    #[inline]
323    fn zero() -> Self {
324        F8E4m3::ZERO
325    }
326    #[inline]
327    fn one() -> Self {
328        F8E4m3::ONE
329    }
330    #[inline]
331    fn nan() -> Option<Self> {
332        Some(F8E4m3::NAN)
333    }
334}
335
336/// 8-bit float, E5M2 layout. Storage is one byte; conversions to/from
337/// f32 are saturating.
338#[cfg(feature = "f8")]
339#[repr(transparent)]
340#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
341pub struct F8E5m2(pub u8);
342
343#[cfg(feature = "f8")]
344impl F8E5m2 {
345    pub const ZERO: Self = F8E5m2(0x00);
346    pub const ONE: Self = F8E5m2(0x3c);
347    pub const NAN: Self = F8E5m2(0x7e);
348    pub const INFINITY: Self = F8E5m2(0x7c);
349
350    pub fn from_f32(x: f32) -> Self {
351        if x.is_nan() {
352            return Self::NAN;
353        }
354        let bits = x.to_bits();
355        let sign = ((bits >> 31) as u8) << 7;
356        let f32_exp = ((bits >> 23) & 0xff) as i32 - 127;
357        let f32_mant = bits & 0x007f_ffff;
358        if x == 0.0 {
359            return F8E5m2(sign);
360        }
361        let e5_exp = f32_exp + 15;
362        if e5_exp >= 0x1f {
363            return F8E5m2(sign | 0x7c);
364        }
365        if e5_exp <= 0 {
366            let shift = 22 + (1 - e5_exp) as u32;
367            let m = ((f32_mant | 0x0080_0000) >> shift) as u8;
368            return F8E5m2(sign | (m & 0x03));
369        }
370        let mant = (f32_mant >> 21) as u8;
371        let round_bit = ((f32_mant >> 20) & 1) as u8;
372        let sticky = ((f32_mant & 0x000f_ffff) != 0) as u8;
373        let mut e = e5_exp as u8;
374        let mut m = mant & 0x03;
375        if round_bit == 1 && (sticky == 1 || (m & 1) == 1) {
376            m = m.wrapping_add(1);
377            if m == 0x04 {
378                m = 0;
379                e = e.wrapping_add(1);
380            }
381        }
382        if e >= 0x1f {
383            return F8E5m2(sign | 0x7c);
384        }
385        F8E5m2(sign | (e << 2) | m)
386    }
387
388    pub fn to_f32(self) -> f32 {
389        let sign = (self.0 >> 7) & 1;
390        let exp = (self.0 >> 2) & 0x1f;
391        let mant = self.0 & 0x03;
392        if exp == 0 && mant == 0 {
393            return if sign == 1 { -0.0 } else { 0.0 };
394        }
395        if exp == 0x1f {
396            return if mant == 0 {
397                if sign == 1 {
398                    f32::NEG_INFINITY
399                } else {
400                    f32::INFINITY
401                }
402            } else {
403                f32::NAN
404            };
405        }
406        let (e, m) = if exp == 0 {
407            let lz = (mant.leading_zeros() as i32) - 6;
408            (1 - 15 - lz, ((mant as u32) << (lz + 1)) & 0x03)
409        } else {
410            (exp as i32 - 15, mant as u32)
411        };
412        let bits = ((sign as u32) << 31) | (((e + 127) as u32) << 23) | (m << 21);
413        f32::from_bits(bits)
414    }
415}
416
417#[cfg(feature = "f8")]
418impl AccelDtype for F8E5m2 {
419    type Scalar = f32;
420    const KIND: DType = DType::F8E5m2;
421    const SIZE: usize = 1;
422    const NAME: &'static str = "f8_e5m2";
423    #[inline]
424    fn zero() -> Self {
425        F8E5m2::ZERO
426    }
427    #[inline]
428    fn one() -> Self {
429        F8E5m2::ONE
430    }
431    #[inline]
432    fn nan() -> Option<Self> {
433        Some(F8E5m2::NAN)
434    }
435}
436
437/// 4-bit float, E2M1 layout. The byte stores one element in the low
438/// nibble; the upper nibble is zero. Two F4E2m1 values are commonly
439/// packed into one byte by the kernel layer — that packing is not the
440/// concern of this newtype.
441#[cfg(feature = "f4")]
442#[repr(transparent)]
443#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
444pub struct F4E2m1(pub u8);
445
446#[cfg(feature = "f4")]
447impl F4E2m1 {
448    pub const ZERO: Self = F4E2m1(0x0);
449    pub const ONE: Self = F4E2m1(0x4);
450
451    pub fn to_f32(self) -> f32 {
452        let n = self.0 & 0x0f;
453        let sign = if (n >> 3) & 1 == 1 { -1.0 } else { 1.0 };
454        let exp = (n >> 1) & 0x03;
455        let mant = n & 0x01;
456        let value = match (exp, mant) {
457            (0, 0) => 0.0,
458            (0, 1) => 0.5,
459            (e, m) => {
460                let mantissa = 1.0 + (m as f32) * 0.5;
461                mantissa * 2.0_f32.powi(e as i32 - 1)
462            }
463        };
464        sign * value
465    }
466}
467
468#[cfg(feature = "f4")]
469impl AccelDtype for F4E2m1 {
470    type Scalar = f32;
471    const KIND: DType = DType::F4E2m1;
472    const SIZE: usize = 1;
473    const NAME: &'static str = "f4_e2m1";
474    #[inline]
475    fn zero() -> Self {
476        F4E2m1::ZERO
477    }
478    #[inline]
479    fn one() -> Self {
480        F4E2m1::ONE
481    }
482    #[inline]
483    fn nan() -> Option<Self> {
484        None
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn dtype_size_matches_trait() {
494        assert_eq!(<f32 as AccelDtype>::SIZE, DType::F32.size_bytes());
495        assert_eq!(<f64 as AccelDtype>::SIZE, DType::F64.size_bytes());
496        assert_eq!(<i8 as AccelDtype>::SIZE, DType::I8.size_bytes());
497        assert_eq!(<i32 as AccelDtype>::SIZE, DType::I32.size_bytes());
498        assert_eq!(<u32 as AccelDtype>::SIZE, DType::U32.size_bytes());
499        assert_eq!(<u64 as AccelDtype>::SIZE, DType::U64.size_bytes());
500    }
501
502    #[test]
503    fn dtype_classifiers() {
504        assert!(DType::F32.is_float());
505        assert!(!DType::I32.is_float());
506        assert!(DType::I32.is_integer());
507        assert!(DType::I32.is_signed());
508        assert!(!DType::U32.is_signed());
509        assert!(DType::F32.is_signed());
510    }
511
512    #[test]
513    fn dtype_names_match() {
514        assert_eq!(DType::F32.name(), <f32 as AccelDtype>::NAME);
515        assert_eq!(DType::F64.name(), <f64 as AccelDtype>::NAME);
516        assert_eq!(DType::U8.name(), <u8 as AccelDtype>::NAME);
517    }
518
519    #[test]
520    fn float_nan_is_some() {
521        assert!(<f32 as AccelDtype>::nan().is_some());
522        assert!(<f64 as AccelDtype>::nan().is_some());
523    }
524
525    #[test]
526    fn integer_nan_is_none() {
527        assert!(<i32 as AccelDtype>::nan().is_none());
528        assert!(<u64 as AccelDtype>::nan().is_none());
529    }
530
531    #[test]
532    fn zero_one_round_trip() {
533        assert_eq!(<f32 as AccelDtype>::zero(), 0.0);
534        assert_eq!(<f32 as AccelDtype>::one(), 1.0);
535        assert_eq!(<i32 as AccelDtype>::zero(), 0);
536        assert_eq!(<i32 as AccelDtype>::one(), 1);
537    }
538
539    #[cfg(feature = "f8")]
540    #[test]
541    fn f8e4m3_round_trip_simple() {
542        assert_eq!(F8E4m3::from_f32(0.0).to_f32(), 0.0);
543        assert_eq!(F8E4m3::from_f32(1.0).to_f32(), 1.0);
544        assert_eq!(F8E4m3::from_f32(2.0).to_f32(), 2.0);
545        assert_eq!(F8E4m3::from_f32(-1.0).to_f32(), -1.0);
546    }
547
548    #[cfg(feature = "f8")]
549    #[test]
550    fn f8e5m2_round_trip_simple() {
551        assert_eq!(F8E5m2::from_f32(0.0).to_f32(), 0.0);
552        assert_eq!(F8E5m2::from_f32(1.0).to_f32(), 1.0);
553        assert_eq!(F8E5m2::from_f32(2.0).to_f32(), 2.0);
554        assert_eq!(F8E5m2::from_f32(-1.0).to_f32(), -1.0);
555    }
556}