Skip to main content

atomr_accel_cutlass/
dtype.rs

1//! Local dtype enum used by CUTLASS template messages.
2//!
3//! `atomr-accel-cuda` is expected to expose a richer `CudaDtype` (with
4//! 16 capability markers and fp8 / fp4 wrappers); on this branch the
5//! cutlass crate ships a minimal mirror that covers the surface needed
6//! for template instantiation. Once the upstream `CudaDtype` lands, the
7//! re-export here can be replaced with a `pub use
8//! atomr_accel_cuda::dtype::CudaDtype as CutlassDtype` alias without
9//! changing the public API of this crate.
10
11use core::fmt;
12
13/// CUTLASS-side dtype tag. Each variant maps 1-to-1 onto a concrete
14/// CUTLASS C++ scalar type via [`CutlassDtype::as_cutlass_type`].
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
16pub enum CutlassDtype {
17    /// `float` / `cutlass::float32_t`.
18    F32,
19    /// `double` / `cutlass::float64_t`.
20    F64,
21    /// `cutlass::half_t`.
22    F16,
23    /// `cutlass::bfloat16_t`.
24    Bf16,
25    /// `cutlass::float_e4m3_t` (Hopper / Blackwell fp8).
26    F8E4m3,
27    /// `cutlass::float_e5m2_t` (Hopper / Blackwell fp8).
28    F8E5m2,
29    /// `cutlass::float_e2m1_t` (Blackwell fp4).
30    F4E2m1,
31    /// `int8_t` (CUTLASS quantized GEMM lane).
32    I8,
33    /// `int32_t` (accumulator).
34    I32,
35    /// `uint8_t`.
36    U8,
37}
38
39impl CutlassDtype {
40    /// CUTLASS C++ type spelling for use inside a generated template
41    /// instantiation. Used by the `.cu` source emitter.
42    pub fn as_cutlass_type(self) -> &'static str {
43        match self {
44            CutlassDtype::F32 => "float",
45            CutlassDtype::F64 => "double",
46            CutlassDtype::F16 => "cutlass::half_t",
47            CutlassDtype::Bf16 => "cutlass::bfloat16_t",
48            CutlassDtype::F8E4m3 => "cutlass::float_e4m3_t",
49            CutlassDtype::F8E5m2 => "cutlass::float_e5m2_t",
50            CutlassDtype::F4E2m1 => "cutlass::float_e2m1_t",
51            CutlassDtype::I8 => "int8_t",
52            CutlassDtype::I32 => "int32_t",
53            CutlassDtype::U8 => "uint8_t",
54        }
55    }
56
57    /// Stable short name used in plan-cache keys and log output.
58    pub fn short_name(self) -> &'static str {
59        match self {
60            CutlassDtype::F32 => "f32",
61            CutlassDtype::F64 => "f64",
62            CutlassDtype::F16 => "f16",
63            CutlassDtype::Bf16 => "bf16",
64            CutlassDtype::F8E4m3 => "f8e4m3",
65            CutlassDtype::F8E5m2 => "f8e5m2",
66            CutlassDtype::F4E2m1 => "f4e2m1",
67            CutlassDtype::I8 => "i8",
68            CutlassDtype::I32 => "i32",
69            CutlassDtype::U8 => "u8",
70        }
71    }
72
73    /// Element size in bits. fp4 is the only sub-byte dtype.
74    pub fn size_bits(self) -> u32 {
75        match self {
76            CutlassDtype::F64 => 64,
77            CutlassDtype::F32 | CutlassDtype::I32 => 32,
78            CutlassDtype::F16 | CutlassDtype::Bf16 => 16,
79            CutlassDtype::F8E4m3 | CutlassDtype::F8E5m2 | CutlassDtype::I8 | CutlassDtype::U8 => 8,
80            CutlassDtype::F4E2m1 => 4,
81        }
82    }
83}
84
85impl fmt::Display for CutlassDtype {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.write_str(self.short_name())
88    }
89}
90
91/// Compute architectures the GEMM template emitter knows how to target.
92///
93/// Mirrors the per-arch toolchain keys used by `NvrtcActor`. Adding a
94/// variant here is a non-breaking change — downstream code matches via
95/// the helper predicates below.
96#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
97pub enum SmArch {
98    Sm80,
99    Sm86,
100    Sm89,
101    Sm90,
102    Sm90a,
103    Sm100,
104    Sm120,
105}
106
107impl SmArch {
108    pub fn nvrtc_flag(self) -> &'static str {
109        match self {
110            SmArch::Sm80 => "--gpu-architecture=compute_80",
111            SmArch::Sm86 => "--gpu-architecture=compute_86",
112            SmArch::Sm89 => "--gpu-architecture=compute_89",
113            SmArch::Sm90 => "--gpu-architecture=compute_90",
114            SmArch::Sm90a => "--gpu-architecture=compute_90a",
115            SmArch::Sm100 => "--gpu-architecture=compute_100",
116            SmArch::Sm120 => "--gpu-architecture=compute_120",
117        }
118    }
119
120    pub fn short_name(self) -> &'static str {
121        match self {
122            SmArch::Sm80 => "sm_80",
123            SmArch::Sm86 => "sm_86",
124            SmArch::Sm89 => "sm_89",
125            SmArch::Sm90 => "sm_90",
126            SmArch::Sm90a => "sm_90a",
127            SmArch::Sm100 => "sm_100",
128            SmArch::Sm120 => "sm_120",
129        }
130    }
131
132    pub fn supports_fp8(self) -> bool {
133        matches!(
134            self,
135            SmArch::Sm89 | SmArch::Sm90 | SmArch::Sm90a | SmArch::Sm100 | SmArch::Sm120
136        )
137    }
138
139    pub fn supports_fp4(self) -> bool {
140        matches!(self, SmArch::Sm100 | SmArch::Sm120)
141    }
142
143    pub fn supports_persistent_kernels(self) -> bool {
144        matches!(
145            self,
146            SmArch::Sm90 | SmArch::Sm90a | SmArch::Sm100 | SmArch::Sm120
147        )
148    }
149}
150
151impl fmt::Display for SmArch {
152    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
153        f.write_str(self.short_name())
154    }
155}
156
157/// Marker trait: types that a CUTLASS GEMM template can instantiate
158/// against. The trait body is empty; the dtype tag returned by
159/// [`GemmSupported::DTYPE`] drives the template emitter.
160///
161/// Implemented for `f32`, `f64`, `i8`, `i32`, `u8`, plus the fp8 / fp4 /
162/// f16 / bf16 wrapper structs in this module.
163pub trait GemmSupported: Copy + Send + Sync + 'static {
164    const DTYPE: CutlassDtype;
165}
166
167impl GemmSupported for f32 {
168    const DTYPE: CutlassDtype = CutlassDtype::F32;
169}
170impl GemmSupported for f64 {
171    const DTYPE: CutlassDtype = CutlassDtype::F64;
172}
173impl GemmSupported for i8 {
174    const DTYPE: CutlassDtype = CutlassDtype::I8;
175}
176impl GemmSupported for i32 {
177    const DTYPE: CutlassDtype = CutlassDtype::I32;
178}
179impl GemmSupported for u8 {
180    const DTYPE: CutlassDtype = CutlassDtype::U8;
181}
182
183/// Local f16 marker. Wraps a `u16` to avoid pulling `half` into the
184/// crate's dependency tree. Matches the wrapper layout used by the
185/// upstream `atomr-accel-cuda::dtype` module so users that pass through
186/// our actor surface don't observe any divergence.
187#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
188#[repr(transparent)]
189pub struct F16(pub u16);
190impl GemmSupported for F16 {
191    const DTYPE: CutlassDtype = CutlassDtype::F16;
192}
193
194/// Local bf16 marker.
195#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
196#[repr(transparent)]
197pub struct Bf16(pub u16);
198impl GemmSupported for Bf16 {
199    const DTYPE: CutlassDtype = CutlassDtype::Bf16;
200}
201
202/// Local fp8 e4m3 marker.
203#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
204#[repr(transparent)]
205pub struct F8E4m3(pub u8);
206impl GemmSupported for F8E4m3 {
207    const DTYPE: CutlassDtype = CutlassDtype::F8E4m3;
208}
209
210/// Local fp8 e5m2 marker.
211#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
212#[repr(transparent)]
213pub struct F8E5m2(pub u8);
214impl GemmSupported for F8E5m2 {
215    const DTYPE: CutlassDtype = CutlassDtype::F8E5m2;
216}
217
218/// Local fp4 e2m1 marker. Stored as `u8` because Rust has no native
219/// `u4`; the lower nibble is the value, the upper nibble is unused.
220#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
221#[repr(transparent)]
222pub struct F4E2m1(pub u8);
223impl GemmSupported for F4E2m1 {
224    const DTYPE: CutlassDtype = CutlassDtype::F4E2m1;
225}
226
227/// Returns true if `dtype` can be instantiated on `arch` by a CUTLASS
228/// GEMM template. fp8 is sm_89+; fp4 is sm_100+; everything else is
229/// supported on every modern Tensor-Core arch.
230pub fn is_supported_for(dtype: CutlassDtype, arch: SmArch) -> bool {
231    match dtype {
232        CutlassDtype::F8E4m3 | CutlassDtype::F8E5m2 => arch.supports_fp8(),
233        CutlassDtype::F4E2m1 => arch.supports_fp4(),
234        _ => true,
235    }
236}
237
238/// Convenience predicate for fp8-only callers.
239pub fn is_fp8_supported(arch: SmArch) -> bool {
240    arch.supports_fp8()
241}
242
243/// Convenience predicate for fp4-only callers.
244pub fn is_fp4_supported(arch: SmArch) -> bool {
245    arch.supports_fp4()
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn arch_capability_predicates() {
254        assert!(!SmArch::Sm80.supports_fp8());
255        assert!(SmArch::Sm89.supports_fp8());
256        assert!(SmArch::Sm90a.supports_fp8());
257        assert!(SmArch::Sm100.supports_fp4());
258        assert!(!SmArch::Sm89.supports_fp4());
259        assert!(SmArch::Sm90a.supports_persistent_kernels());
260        assert!(!SmArch::Sm80.supports_persistent_kernels());
261    }
262
263    #[test]
264    fn dtype_short_names_unique() {
265        let all = [
266            CutlassDtype::F32,
267            CutlassDtype::F64,
268            CutlassDtype::F16,
269            CutlassDtype::Bf16,
270            CutlassDtype::F8E4m3,
271            CutlassDtype::F8E5m2,
272            CutlassDtype::F4E2m1,
273            CutlassDtype::I8,
274            CutlassDtype::I32,
275            CutlassDtype::U8,
276        ];
277        let mut seen: Vec<&'static str> = Vec::new();
278        for dt in all {
279            assert!(!seen.contains(&dt.short_name()));
280            seen.push(dt.short_name());
281        }
282    }
283
284    #[test]
285    fn is_supported_for_matrix() {
286        assert!(is_supported_for(CutlassDtype::F32, SmArch::Sm80));
287        assert!(!is_supported_for(CutlassDtype::F8E4m3, SmArch::Sm80));
288        assert!(is_supported_for(CutlassDtype::F8E4m3, SmArch::Sm90a));
289        assert!(!is_supported_for(CutlassDtype::F4E2m1, SmArch::Sm89));
290        assert!(is_supported_for(CutlassDtype::F4E2m1, SmArch::Sm100));
291    }
292}