1use core::marker::PhantomData;
9
10use crate::dtype::{CutlassDtype, GemmSupported, SmArch};
11use crate::kernels;
12use crate::plan_cache::PlanKey;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum GemmLayout {
20 RowMajor,
21 ColMajor,
22}
23
24impl GemmLayout {
25 pub fn cutlass_layout(self) -> &'static str {
26 match self {
27 GemmLayout::RowMajor => "cutlass::layout::RowMajor",
28 GemmLayout::ColMajor => "cutlass::layout::ColumnMajor",
29 }
30 }
31
32 pub fn short_name(self) -> &'static str {
33 match self {
34 GemmLayout::RowMajor => "rm",
35 GemmLayout::ColMajor => "cm",
36 }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct GemmShape {
43 pub m: u32,
44 pub n: u32,
45 pub k: u32,
46}
47
48impl GemmShape {
49 pub fn new(m: u32, n: u32, k: u32) -> Self {
50 Self { m, n, k }
51 }
52}
53
54#[derive(Debug, Clone, Copy, PartialEq)]
60pub enum GemmEpilogue {
61 Linear { alpha: f32, beta: f32 },
62 LinearReLU { alpha: f32, beta: f32 },
63 LinearGelu { alpha: f32, beta: f32 },
64}
65
66impl Default for GemmEpilogue {
67 fn default() -> Self {
68 GemmEpilogue::Linear {
69 alpha: 1.0,
70 beta: 0.0,
71 }
72 }
73}
74
75impl Eq for GemmEpilogue {}
76
77impl core::hash::Hash for GemmEpilogue {
78 fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
79 match *self {
84 GemmEpilogue::Linear { alpha, beta } => {
85 0u8.hash(state);
86 alpha.to_bits().hash(state);
87 beta.to_bits().hash(state);
88 }
89 GemmEpilogue::LinearReLU { alpha, beta } => {
90 1u8.hash(state);
91 alpha.to_bits().hash(state);
92 beta.to_bits().hash(state);
93 }
94 GemmEpilogue::LinearGelu { alpha, beta } => {
95 2u8.hash(state);
96 alpha.to_bits().hash(state);
97 beta.to_bits().hash(state);
98 }
99 }
100 }
101}
102
103impl GemmEpilogue {
104 pub fn short_name(self) -> &'static str {
106 match self {
107 GemmEpilogue::Linear { .. } => "linear",
108 GemmEpilogue::LinearReLU { .. } => "linear_relu",
109 GemmEpilogue::LinearGelu { .. } => "linear_gelu",
110 }
111 }
112}
113
114#[derive(Debug, Clone)]
118pub struct GemmRequest<T: GemmSupported> {
119 pub shape: GemmShape,
120 pub layout_a: GemmLayout,
121 pub layout_b: GemmLayout,
122 pub layout_c: GemmLayout,
123 pub epilogue: GemmEpilogue,
124 pub accum_dtype: CutlassDtype,
126 pub output_dtype: CutlassDtype,
128 pub arch: SmArch,
130 pub persistent: bool,
132 _t: PhantomData<fn() -> T>,
133}
134
135impl<T: GemmSupported> GemmRequest<T> {
136 pub fn new(shape: GemmShape, arch: SmArch) -> Self {
138 Self {
139 shape,
140 layout_a: GemmLayout::RowMajor,
141 layout_b: GemmLayout::RowMajor,
142 layout_c: GemmLayout::RowMajor,
143 epilogue: GemmEpilogue::default(),
144 accum_dtype: CutlassDtype::F32,
145 output_dtype: T::DTYPE,
146 arch,
147 persistent: arch.supports_persistent_kernels(),
148 _t: PhantomData,
149 }
150 }
151
152 #[deprecated(note = "use `GemmRequest::new(shape, arch)` plus the builder methods instead")]
156 pub fn legacy(m: u32, n: u32, k: u32, layout: GemmLayout, alpha: f32) -> Self {
157 let mut req = Self::new(GemmShape::new(m, n, k), SmArch::Sm80);
158 req.layout_a = layout;
159 req.layout_b = layout;
160 req.layout_c = layout;
161 req.epilogue = GemmEpilogue::Linear { alpha, beta: 0.0 };
162 req
163 }
164
165 pub fn with_layouts(mut self, a: GemmLayout, b: GemmLayout, c: GemmLayout) -> Self {
166 self.layout_a = a;
167 self.layout_b = b;
168 self.layout_c = c;
169 self
170 }
171
172 pub fn with_epilogue(mut self, ep: GemmEpilogue) -> Self {
173 self.epilogue = ep;
174 self
175 }
176
177 pub fn with_accum_dtype(mut self, dt: CutlassDtype) -> Self {
178 self.accum_dtype = dt;
179 self
180 }
181
182 pub fn with_output_dtype(mut self, dt: CutlassDtype) -> Self {
183 self.output_dtype = dt;
184 self
185 }
186
187 pub fn with_persistent(mut self, persistent: bool) -> Self {
188 self.persistent = persistent;
189 self
190 }
191
192 pub fn plan_key(&self) -> PlanKey {
195 PlanKey::gemm::<T>(
196 self.shape,
197 self.layout_a,
198 self.layout_b,
199 self.layout_c,
200 self.epilogue,
201 self.accum_dtype,
202 self.output_dtype,
203 self.arch,
204 self.persistent,
205 )
206 }
207
208 pub fn render_cu(&self) -> (String, String) {
211 kernels::render_gemm::<T>(self)
212 }
213}
214
215pub trait CutlassGemmDispatch: Send + 'static {
218 fn plan_key(&self) -> PlanKey;
219 fn render_cu(&self) -> (String, String);
220 fn dtype(&self) -> CutlassDtype;
221 fn arch(&self) -> SmArch;
222 fn shape(&self) -> GemmShape;
223}
224
225impl<T: GemmSupported> CutlassGemmDispatch for GemmRequest<T> {
226 fn plan_key(&self) -> PlanKey {
227 GemmRequest::plan_key(self)
228 }
229
230 fn render_cu(&self) -> (String, String) {
231 GemmRequest::render_cu(self)
232 }
233
234 fn dtype(&self) -> CutlassDtype {
235 T::DTYPE
236 }
237
238 fn arch(&self) -> SmArch {
239 self.arch
240 }
241
242 fn shape(&self) -> GemmShape {
243 self.shape
244 }
245}
246
247#[derive(Debug)]
251pub struct RefitMsg {
252 pub plan_key: PlanKey,
253 pub weights: Vec<u8>,
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::dtype::{Bf16, F4E2m1, F8E4m3, F8E5m2, F16};
262
263 #[test]
264 fn gemm_request_round_trip_for_every_dtype() {
265 let req = GemmRequest::<f32>::new(GemmShape::new(128, 256, 64), SmArch::Sm80);
267 assert_eq!(req.dtype(), CutlassDtype::F32);
268 assert_eq!(req.shape().m, 128);
269 let (src, name) = req.render_cu();
270 assert!(src.contains("cutlass::gemm::device::GemmUniversal"));
271 assert!(name.starts_with("atomr_cutlass_gemm_"));
272
273 let req = GemmRequest::<f64>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
275 assert_eq!(req.dtype(), CutlassDtype::F64);
276
277 let req = GemmRequest::<F16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80).with_layouts(
279 GemmLayout::ColMajor,
280 GemmLayout::RowMajor,
281 GemmLayout::RowMajor,
282 );
283 let key1 = req.plan_key();
284 let req2 = GemmRequest::<F16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
285 assert_ne!(key1, req2.plan_key());
286
287 let _ = GemmRequest::<Bf16>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
288
289 let req = GemmRequest::<F8E4m3>::new(GemmShape::new(128, 128, 128), SmArch::Sm90a)
291 .with_epilogue(GemmEpilogue::LinearReLU {
292 alpha: 1.0,
293 beta: 0.0,
294 });
295 assert_eq!(req.dtype(), CutlassDtype::F8E4m3);
296 assert!(req.persistent);
297 let _ = GemmRequest::<F8E5m2>::new(GemmShape::new(64, 64, 64), SmArch::Sm90a);
298
299 let req = GemmRequest::<F4E2m1>::new(GemmShape::new(64, 64, 64), SmArch::Sm100);
301 assert_eq!(req.dtype(), CutlassDtype::F4E2m1);
302
303 let _ = GemmRequest::<i8>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
305 let _ = GemmRequest::<i32>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
306 let _ = GemmRequest::<u8>::new(GemmShape::new(64, 64, 64), SmArch::Sm80);
307 }
308
309 #[test]
310 fn deprecated_constructor_paths_compile() {
311 #[allow(deprecated)]
315 let req = GemmRequest::<f32>::legacy(64, 64, 64, GemmLayout::RowMajor, 1.0);
316 assert_eq!(req.shape, GemmShape::new(64, 64, 64));
317 match req.epilogue {
318 GemmEpilogue::Linear { alpha, beta } => {
319 assert_eq!(alpha, 1.0);
320 assert_eq!(beta, 0.0);
321 }
322 _ => panic!("legacy constructor should produce Linear epilogue"),
323 }
324 }
325
326 #[test]
327 fn persistent_default_tracks_arch() {
328 assert!(!GemmRequest::<f32>::new(GemmShape::new(1, 1, 1), SmArch::Sm80).persistent);
329 assert!(GemmRequest::<f32>::new(GemmShape::new(1, 1, 1), SmArch::Sm90a).persistent);
330 assert!(GemmRequest::<f32>::new(GemmShape::new(1, 1, 1), SmArch::Sm100).persistent);
331 }
332}