Skip to main content

atomr_accel_cutlass/
gemm.rs

1//! GEMM template request and dispatch surface.
2//!
3//! A [`GemmRequest<T>`] is a typed, host-side description of a CUTLASS
4//! `gemm_universal<...>` instantiation. The actor lifts it to a
5//! `Box<dyn CutlassGemmDispatch>` so the per-dtype monomorphisations
6//! can share a single mailbox.
7
8use core::marker::PhantomData;
9
10use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
11use crate::kernels;
12use crate::plan_cache::PlanKey;
13
14/// Row- vs column-major layout tags used at the API surface.
15///
16/// CUTLASS uses `cutlass::layout::RowMajor` / `cutlass::layout::ColumnMajor`
17/// internally; the emitter maps these directly.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum GemmLayout {
20    RowMajor,
21    ColMajor,
22}
23
24impl GemmLayout {
25    pub fn cutlass_layout(self) -> &'static str {
26        match self {
27            GemmLayout::RowMajor => "cutlass::layout::RowMajor",
28            GemmLayout::ColMajor => "cutlass::layout::ColumnMajor",
29        }
30    }
31
32    pub fn short_name(self) -> &'static str {
33        match self {
34            GemmLayout::RowMajor => "rm",
35            GemmLayout::ColMajor => "cm",
36        }
37    }
38}
39
40/// `(M, N, K)` problem shape for a single GEMM.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct GemmShape {
43    pub m: u32,
44    pub n: u32,
45    pub k: u32,
46}
47
48impl GemmShape {
49    pub fn new(m: u32, n: u32, k: u32) -> Self {
50        Self { m, n, k }
51    }
52}
53
54/// Epilogue selector. The `Linear { alpha, beta }` arm is the default
55/// `D = alpha * A @ B + beta * C` epilogue. `LinearReLU` and
56/// `LinearGelu` are the most common fused activations. The richer
57/// epilogue surface (multi-output, quantize, reduce) lives in
58/// [`crate::evt`] behind the `evt` cargo feature.
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum GemmEpilogue {
61    Linear { alpha: f32, beta: f32 },
62    LinearReLU { alpha: f32, beta: f32 },
63    LinearGelu { alpha: f32, beta: f32 },
64}
65
66impl Default for GemmEpilogue {
67    fn default() -> Self {
68        GemmEpilogue::Linear {
69            alpha: 1.0,
70            beta: 0.0,
71        }
72    }
73}
74
75impl Eq for GemmEpilogue {}
76
77impl core::hash::Hash for GemmEpilogue {
78    fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
79        // Bit-cast the f32 fields so the hash matches `Eq` on
80        // bit-pattern equality. CUTLASS-side template id only depends
81        // on the discriminant, but the plan cache also wants to
82        // dedupe at runtime parameter values.
83        match *self {
84            GemmEpilogue::Linear { alpha, beta } => {
85                0u8.hash(state);
86                alpha.to_bits().hash(state);
87                beta.to_bits().hash(state);
88            }
89            GemmEpilogue::LinearReLU { alpha, beta } => {
90                1u8.hash(state);
91                alpha.to_bits().hash(state);
92                beta.to_bits().hash(state);
93            }
94            GemmEpilogue::LinearGelu { alpha, beta } => {
95                2u8.hash(state);
96                alpha.to_bits().hash(state);
97                beta.to_bits().hash(state);
98            }
99        }
100    }
101}
102
103impl GemmEpilogue {
104    /// Stable short name used in plan-cache keys and template ids.
105    pub fn short_name(self) -> &'static str {
106        match self {
107            GemmEpilogue::Linear { .. } => "linear",
108            GemmEpilogue::LinearReLU { .. } => "linear_relu",
109            GemmEpilogue::LinearGelu { .. } => "linear_gelu",
110        }
111    }
112}
113
114/// Typed GEMM request. `T` is the element type for `A` and `B`; the
115/// accumulator and output dtypes are derived from `T` by CUTLASS
116/// (configurable via `accum_dtype` / `output_dtype`).
117#[derive(Debug, Clone)]
118pub struct GemmRequest<T: GemmSupported> {
119    pub shape: GemmShape,
120    pub layout_a: GemmLayout,
121    pub layout_b: GemmLayout,
122    pub layout_c: GemmLayout,
123    pub epilogue: GemmEpilogue,
124    /// Override the accumulator dtype. Defaults to fp32.
125    pub accum_dtype: CutlassDtype,
126    /// Override the output dtype. Defaults to `T::DTYPE`.
127    pub output_dtype: CutlassDtype,
128    /// Target compute architecture.
129    pub arch: SmArch,
130    /// Use a CUTLASS persistent (Hopper / Blackwell) kernel.
131    pub persistent: bool,
132    _t: PhantomData<fn() -> T>,
133}
134
135impl<T: GemmSupported> GemmRequest<T> {
136    /// Canonical constructor.
137    pub fn new(shape: GemmShape, arch: SmArch) -> Self {
138        Self {
139            shape,
140            layout_a: GemmLayout::RowMajor,
141            layout_b: GemmLayout::RowMajor,
142            layout_c: GemmLayout::RowMajor,
143            epilogue: GemmEpilogue::default(),
144            accum_dtype: CutlassDtype::F32,
145            output_dtype: T::DTYPE,
146            arch,
147            persistent: arch.supports_persistent_kernels(),
148            _t: PhantomData,
149        }
150    }
151
152    /// Deprecated 5-argument constructor. Pre-Phase-6 callers passed
153    /// `(m, n, k, layout, alpha)`; we keep this path so out-of-tree
154    /// downstreams compile against the 0.3.0 API surface.
155    #[deprecated(note = "use `GemmRequest::new(shape, arch)` plus the builder methods instead")]
156    pub fn legacy(m: u32, n: u32, k: u32, layout: GemmLayout, alpha: f32) -> Self {
157        let mut req = Self::new(GemmShape::new(m, n, k), SmArch::Sm80);
158        req.layout_a = layout;
159        req.layout_b = layout;
160        req.layout_c = layout;
161        req.epilogue = GemmEpilogue::Linear { alpha, beta: 0.0 };
162        req
163    }
164
165    pub fn with_layouts(mut self, a: GemmLayout, b: GemmLayout, c: GemmLayout) -> Self {
166        self.layout_a = a;
167        self.layout_b = b;
168        self.layout_c = c;
169        self
170    }
171
172    pub fn with_epilogue(mut self, ep: GemmEpilogue) -> Self {
173        self.epilogue = ep;
174        self
175    }
176
177    pub fn with_accum_dtype(mut self, dt: CutlassDtype) -> Self {
178        self.accum_dtype = dt;
179        self
180    }
181
182    pub fn with_output_dtype(mut self, dt: CutlassDtype) -> Self {
183        self.output_dtype = dt;
184        self
185    }
186
187    pub fn with_persistent(mut self, persistent: bool) -> Self {
188        self.persistent = persistent;
189        self
190    }
191
192    /// Stable plan-cache key for this request. Used by the actor to
193    /// dedupe NVRTC compilations.
194    pub fn plan_key(&self) -> PlanKey {
195        PlanKey::gemm::<T>(
196            self.shape,
197            self.layout_a,
198            self.layout_b,
199            self.layout_c,
200            self.epilogue,
201            self.accum_dtype,
202            self.output_dtype,
203            self.arch,
204            self.persistent,
205        )
206    }
207
208    /// Render the `.cu` source for this template. Returns the source
209    /// plus the lowered kernel name to look up after NVRTC compile.
210    pub fn render_cu(&self) -> (String, String) {
211        kernels::render_gemm::<T>(self)
212    }
213}
214
215/// Erased dispatch surface so the actor mailbox is `Sized`. Each
216/// `GemmRequest<T>` boxes itself as `Box<dyn CutlassGemmDispatch>`.
217pub trait CutlassGemmDispatch: Send + 'static {
218    fn plan_key(&self) -> PlanKey;
219    fn render_cu(&self) -> (String, String);
220    fn dtype(&self) -> CutlassDtype;
221    fn arch(&self) -> SmArch;
222    fn shape(&self) -> GemmShape;
223}
224
225impl<T: GemmSupported> CutlassGemmDispatch for GemmRequest<T> {
226    fn plan_key(&self) -> PlanKey {
227        GemmRequest::plan_key(self)
228    }
229
230    fn render_cu(&self) -> (String, String) {
231        GemmRequest::render_cu(self)
232    }
233
234    fn dtype(&self) -> CutlassDtype {
235        T::DTYPE
236    }
237
238    fn arch(&self) -> SmArch {
239        self.arch
240    }
241
242    fn shape(&self) -> GemmShape {
243        self.shape
244    }
245}
246
247/// Reply payload for [`crate::actor::CutlassMsg::Refit`]. Replaces a
248/// previously compiled plan's weight buffer in place without
249/// recompiling the kernel.
250#[derive(Debug)]
251pub struct RefitMsg {
252    pub plan_key: PlanKey,
253    /// Opaque weight bytes; the actor forwards them to the kernel's
254    /// allocated workspace via the existing memory actors.
255    pub weights: Vec<u8>,
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::dtype::{Bf16, F4E2m1, F8E4m3, F8E5m2, F16};
262
263    #[test]
264    fn gemm_request_round_trip_for_every_dtype() {
265        // f32 — every arch
266        let req = GemmRequest::<f32>::new(GemmShape::new(128, 256, 64), SmArch::Sm80);
267        assert_eq!(req.dtype(), CutlassDtype::F32);
268        assert_eq!(req.shape().m, 128);
269        let (src, name) = req.render_cu();
270        assert!(src.contains("cutlass::gemm::device::GemmUniversal"));
271        assert!(name.starts_with("atomr_cutlass_gemm_"));
272
273        // f64
274        let req = GemmRequest::<f64>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
275        assert_eq!(req.dtype(), CutlassDtype::F64);
276
277        // f16 / bf16
278        let req = GemmRequest::<F16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80).with_layouts(
279            GemmLayout::ColMajor,
280            GemmLayout::RowMajor,
281            GemmLayout::RowMajor,
282        );
283        let key1 = req.plan_key();
284        let req2 = GemmRequest::<F16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
285        assert_ne!(key1, req2.plan_key());
286
287        let _ = GemmRequest::<Bf16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
288
289        // fp8 e4m3 / e5m2 — Hopper
290        let req = GemmRequest::<F8E4m3>::new(GemmShape::new(128, 128, 128), SmArch::Sm90a)
291            .with_epilogue(GemmEpilogue::LinearReLU {
292                alpha: 1.0,
293                beta: 0.0,
294            });
295        assert_eq!(req.dtype(), CutlassDtype::F8E4m3);
296        assert!(req.persistent);
297        let _ = GemmRequest::<F8E5m2>::new(GemmShape::new(64, 64, 64), SmArch::Sm90a);
298
299        // fp4 — Blackwell
300        let req = GemmRequest::<F4E2m1>::new(GemmShape::new(64, 64, 64), SmArch::Sm100);
301        assert_eq!(req.dtype(), CutlassDtype::F4E2m1);
302
303        // i8 / i32 / u8
304        let _ = GemmRequest::<i8>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
305        let _ = GemmRequest::<i32>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
306        let _ = GemmRequest::<u8>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
307    }
308
309    #[test]
310    fn deprecated_constructor_paths_compile() {
311        // The legacy 5-arg form must still build so 0.2.x callers are
312        // not broken. We allow the deprecation warning here on
313        // purpose.
314        #[allow(deprecated)]
315        let req = GemmRequest::<f32>::legacy(64, 64, 64, GemmLayout::RowMajor, 1.0);
316        assert_eq!(req.shape, GemmShape::new(64, 64, 64));
317        match req.epilogue {
318            GemmEpilogue::Linear { alpha, beta } => {
319                assert_eq!(alpha, 1.0);
320                assert_eq!(beta, 0.0);
321            }
322            _ => panic!("legacy constructor should produce Linear epilogue"),
323        }
324    }
325
326    #[test]
327    fn persistent_default_tracks_arch() {
328        assert!(!GemmRequest::<f32>::new(GemmShape::new(1, 1, 1), SmArch::Sm80).persistent);
329        assert!(GemmRequest::<f32>::new(GemmShape::new(1, 1, 1), SmArch::Sm90a).persistent);
330        assert!(GemmRequest::<f32>::new(GemmShape::new(1, 1, 1), SmArch::Sm100).persistent);
331    }
332}