Skip to main content

baracuda_cutlass/
plan.rs

1//! Plan-based GEMM and grouped-GEMM API.
2
3use core::ffi::c_void;
4use core::marker::PhantomData;
5
6use baracuda_driver::{Context, PinnedBuffer, Stream};
7use baracuda_kernels_types::BackendKind;
8
9use crate::error::{status_to_result, Error, Result};
10use crate::types::{
11    ArchSku, BatchedGemmArgs, BatchedGemmDescriptor, BiasElement, CutlassElement, ElementKind,
12    EpilogueKind, GemmArgs, GemmDescriptor, GemmSku, GroupedPlanPreference, GroupedProblem,
13    GroupedScheduleMode, IntElement, IntGemmArgs, IntGemmDescriptor, LayoutSku, PlanPreference,
14    PrecisionGuarantee, ScalarType, Workspace,
15};
16
17// ============================================================================
18// Internal dispatch — generic-T → element-specific extern "C" entry point
19// ============================================================================
20
21mod dispatch {
22    use super::{ElementKind, LayoutSku};
23    use core::ffi::c_void;
24
25    use super::EpilogueKind;
26
27    /// Single-GEMM dispatch on sm_80 with a bias-family epilogue (Bias,
28    /// BiasRelu, BiasGelu, BiasSilu).
29    ///
30    /// SKU coverage today: `{Rcr, Rrr} × {F16, Bf16, F32 (TF32),
31    /// F32Strict (SIMT)}` for every bias-family epilogue. F64 routes
32    /// through [`gemm_bias_sm80_run_f64`] because the FFI takes `f64`
33    /// alpha/beta; that fork is selected at the call site based on
34    /// `T::Scalar::IS_F64`. The non-bias path lives in
35    /// [`gemm_sm80_run`].
36    #[cfg(feature = "sm80")]
37    #[allow(clippy::too_many_arguments)]
38    pub(super) unsafe fn gemm_bias_sm80_run(
39        layout: LayoutSku,
40        kind: ElementKind,
41        epilogue: EpilogueKind,
42        m: i32,
43        n: i32,
44        k: i32,
45        a: *const c_void,
46        lda: i64,
47        b: *const c_void,
48        ldb: i64,
49        c: *const c_void,
50        ldc: i64,
51        d: *mut c_void,
52        ldd: i64,
53        bias: *const c_void,
54        alpha: f32,
55        beta: f32,
56        workspace: *mut c_void,
57        workspace_bytes: usize,
58        stream: *mut c_void,
59    ) -> i32 {
60        use baracuda_cutlass_kernels_sys as k_sys;
61        match (layout, kind, epilogue) {
62            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
63                k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_run(
64                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
65                    bias, alpha, beta, workspace, workspace_bytes, stream,
66                )
67            },
68            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
69                k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_run(
70                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
71                    bias, alpha, beta, workspace, workspace_bytes, stream,
72                )
73            },
74            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
75                k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_run(
76                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
77                    bias, alpha, beta, workspace, workspace_bytes, stream,
78                )
79            },
80            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
81                k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_run(
82                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
83                    bias, alpha, beta, workspace, workspace_bytes, stream,
84                )
85            },
86            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
87                k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_run(
88                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
89                    bias, alpha, beta, workspace, workspace_bytes, stream,
90                )
91            },
92            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
93                k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_run(
94                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
95                    bias, alpha, beta, workspace, workspace_bytes, stream,
96                )
97            },
98            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
99                k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_run(
100                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
101                    bias, alpha, beta, workspace, workspace_bytes, stream,
102                )
103            },
104            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
105                k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_run(
106                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
107                    bias, alpha, beta, workspace, workspace_bytes, stream,
108                )
109            },
110            // ---- Rrr layout ----
111            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
112                k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_run(
113                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
114                    bias, alpha, beta, workspace, workspace_bytes, stream,
115                )
116            },
117            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
118                k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_run(
119                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
120                    bias, alpha, beta, workspace, workspace_bytes, stream,
121                )
122            },
123            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
124                k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_run(
125                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
126                    bias, alpha, beta, workspace, workspace_bytes, stream,
127                )
128            },
129            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
130                k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_run(
131                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
132                    bias, alpha, beta, workspace, workspace_bytes, stream,
133                )
134            },
135            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
136                k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_run(
137                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
138                    bias, alpha, beta, workspace, workspace_bytes, stream,
139                )
140            },
141            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
142                k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_run(
143                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
144                    bias, alpha, beta, workspace, workspace_bytes, stream,
145                )
146            },
147            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
148                k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_run(
149                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
150                    bias, alpha, beta, workspace, workspace_bytes, stream,
151                )
152            },
153            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
154                k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_run(
155                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
156                    bias, alpha, beta, workspace, workspace_bytes, stream,
157                )
158            },
159            // ---- TF32 path (Rcr × F32) ----
160            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
161                k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_run(
162                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
163                    bias, alpha, beta, workspace, workspace_bytes, stream,
164                )
165            },
166            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
167                k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_run(
168                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
169                    bias, alpha, beta, workspace, workspace_bytes, stream,
170                )
171            },
172            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
173                k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_run(
174                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
175                    bias, alpha, beta, workspace, workspace_bytes, stream,
176                )
177            },
178            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
179                k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_run(
180                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
181                    bias, alpha, beta, workspace, workspace_bytes, stream,
182                )
183            },
184            // ---- TF32 path (Rrr × F32) ----
185            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
186                k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_run(
187                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
188                    bias, alpha, beta, workspace, workspace_bytes, stream,
189                )
190            },
191            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
192                k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_run(
193                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
194                    bias, alpha, beta, workspace, workspace_bytes, stream,
195                )
196            },
197            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
198                k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_run(
199                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
200                    bias, alpha, beta, workspace, workspace_bytes, stream,
201                )
202            },
203            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
204                k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_run(
205                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
206                    bias, alpha, beta, workspace, workspace_bytes, stream,
207                )
208            },
209            // ---- f32-SIMT path (Rcr × F32Strict) ----
210            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
211                k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_run(
212                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
213                    bias, alpha, beta, workspace, workspace_bytes, stream,
214                )
215            },
216            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
217                k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_run(
218                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
219                    bias, alpha, beta, workspace, workspace_bytes, stream,
220                )
221            },
222            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
223                k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_run(
224                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
225                    bias, alpha, beta, workspace, workspace_bytes, stream,
226                )
227            },
228            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
229                k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_run(
230                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
231                    bias, alpha, beta, workspace, workspace_bytes, stream,
232                )
233            },
234            // ---- f32-SIMT path (Rrr × F32Strict) ----
235            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
236                k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_run(
237                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
238                    bias, alpha, beta, workspace, workspace_bytes, stream,
239                )
240            },
241            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
242                k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_run(
243                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
244                    bias, alpha, beta, workspace, workspace_bytes, stream,
245                )
246            },
247            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
248                k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_run(
249                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
250                    bias, alpha, beta, workspace, workspace_bytes, stream,
251                )
252            },
253            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
254                k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_run(
255                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
256                    bias, alpha, beta, workspace, workspace_bytes, stream,
257                )
258            },
259            _ => 3,
260        }
261    }
262
263    #[cfg(feature = "sm80")]
264    pub(super) fn gemm_bias_sm80_workspace_size(
265        layout: LayoutSku,
266        kind: ElementKind,
267        epilogue: EpilogueKind,
268        m: i32,
269        n: i32,
270        k: i32,
271    ) -> usize {
272        use baracuda_cutlass_kernels_sys as k_sys;
273        match (layout, kind, epilogue) {
274            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
275                k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_workspace_size(m, n, k)
276            },
277            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
278                k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_workspace_size(m, n, k)
279            },
280            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
281                k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_workspace_size(m, n, k)
282            },
283            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
284                k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_workspace_size(m, n, k)
285            },
286            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
287                k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_workspace_size(m, n, k)
288            },
289            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
290                k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_workspace_size(m, n, k)
291            },
292            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
293                k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_workspace_size(m, n, k)
294            },
295            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
296                k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_workspace_size(m, n, k)
297            },
298            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
299                k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_workspace_size(m, n, k)
300            },
301            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
302                k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_workspace_size(m, n, k)
303            },
304            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
305                k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_workspace_size(m, n, k)
306            },
307            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
308                k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_workspace_size(m, n, k)
309            },
310            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
311                k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_workspace_size(m, n, k)
312            },
313            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
314                k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_workspace_size(m, n, k)
315            },
316            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
317                k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_workspace_size(m, n, k)
318            },
319            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
320                k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_workspace_size(m, n, k)
321            },
322            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
323                k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_workspace_size(m, n, k)
324            },
325            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
326                k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_workspace_size(m, n, k)
327            },
328            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
329                k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_workspace_size(m, n, k)
330            },
331            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
332                k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_workspace_size(m, n, k)
333            },
334            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
335                k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_workspace_size(m, n, k)
336            },
337            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
338                k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_workspace_size(m, n, k)
339            },
340            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
341                k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_workspace_size(m, n, k)
342            },
343            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
344                k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_workspace_size(m, n, k)
345            },
346            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
347                k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_workspace_size(m, n, k)
348            },
349            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
350                k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_workspace_size(m, n, k)
351            },
352            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
353                k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_workspace_size(m, n, k)
354            },
355            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
356                k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_workspace_size(m, n, k)
357            },
358            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
359                k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_workspace_size(m, n, k)
360            },
361            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
362                k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_workspace_size(m, n, k)
363            },
364            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
365                k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_workspace_size(m, n, k)
366            },
367            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
368                k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_workspace_size(m, n, k)
369            },
370            _ => 0,
371        }
372    }
373
374    #[cfg(feature = "sm80")]
375    #[allow(clippy::too_many_arguments)]
376    pub(super) unsafe fn gemm_bias_sm80_can_implement(
377        layout: LayoutSku,
378        kind: ElementKind,
379        epilogue: EpilogueKind,
380        m: i32,
381        n: i32,
382        k: i32,
383        a: *const c_void,
384        lda: i64,
385        b: *const c_void,
386        ldb: i64,
387        c: *const c_void,
388        ldc: i64,
389        d: *mut c_void,
390        ldd: i64,
391        bias: *const c_void,
392    ) -> i32 {
393        use baracuda_cutlass_kernels_sys as k_sys;
394        match (layout, kind, epilogue) {
395            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
396                k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_can_implement(
397                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
398                )
399            },
400            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
401                k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_can_implement(
402                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
403                )
404            },
405            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
406                k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_can_implement(
407                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
408                )
409            },
410            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
411                k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_can_implement(
412                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
413                )
414            },
415            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
416                k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_can_implement(
417                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
418                )
419            },
420            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
421                k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_can_implement(
422                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
423                )
424            },
425            (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
426                k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_can_implement(
427                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
428                )
429            },
430            (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
431                k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_can_implement(
432                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
433                )
434            },
435            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
436                k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_can_implement(
437                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
438                )
439            },
440            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
441                k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_can_implement(
442                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
443                )
444            },
445            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
446                k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_can_implement(
447                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
448                )
449            },
450            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
451                k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_can_implement(
452                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
453                )
454            },
455            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
456                k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_can_implement(
457                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
458                )
459            },
460            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
461                k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_can_implement(
462                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
463                )
464            },
465            (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
466                k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_can_implement(
467                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
468                )
469            },
470            (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
471                k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_can_implement(
472                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
473                )
474            },
475            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
476                k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_can_implement(
477                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
478                )
479            },
480            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
481                k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_can_implement(
482                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
483                )
484            },
485            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
486                k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_can_implement(
487                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
488                )
489            },
490            (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
491                k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_can_implement(
492                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
493                )
494            },
495            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
496                k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_can_implement(
497                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
498                )
499            },
500            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
501                k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_can_implement(
502                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
503                )
504            },
505            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
506                k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_can_implement(
507                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
508                )
509            },
510            (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
511                k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_can_implement(
512                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
513                )
514            },
515            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
516                k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_can_implement(
517                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
518                )
519            },
520            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
521                k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_can_implement(
522                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
523                )
524            },
525            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
526                k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_can_implement(
527                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
528                )
529            },
530            (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
531                k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_can_implement(
532                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
533                )
534            },
535            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
536                k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_can_implement(
537                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
538                )
539            },
540            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
541                k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_can_implement(
542                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
543                )
544            },
545            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
546                k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_can_implement(
547                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
548                )
549            },
550            (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
551                k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_can_implement(
552                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
553                )
554            },
555            _ => 3,
556        }
557    }
558
559    /// Single-GEMM dispatch on sm_80 — selects per `(layout, kind)`.
560    ///
561    /// SKU coverage: all `(Rcr|Rrr) × (F16|Bf16|F32-via-TF32)` are
562    /// implemented. F32 routes through Ampere TF32 tensor cores. Any
563    /// unknown pair returns status 3 (not implemented).
564    #[cfg(feature = "sm80")]
565    #[allow(clippy::too_many_arguments)]
566    pub(super) unsafe fn gemm_sm80_run(
567        layout: LayoutSku,
568        kind: ElementKind,
569        m: i32,
570        n: i32,
571        k: i32,
572        a: *const c_void,
573        lda: i64,
574        b: *const c_void,
575        ldb: i64,
576        c: *const c_void,
577        ldc: i64,
578        d: *mut c_void,
579        ldd: i64,
580        alpha: f32,
581        beta: f32,
582        workspace: *mut c_void,
583        workspace_bytes: usize,
584        stream: *mut c_void,
585    ) -> i32 {
586        use baracuda_cutlass_kernels_sys as k_sys;
587        match (layout, kind) {
588            (LayoutSku::Rcr, ElementKind::F16) => unsafe {
589                k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_run(
590                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
591                    alpha, beta, workspace, workspace_bytes, stream,
592                )
593            },
594            (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
595                k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_run(
596                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
597                    alpha, beta, workspace, workspace_bytes, stream,
598                )
599            },
600            (LayoutSku::Rcr, ElementKind::F32) => unsafe {
601                k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_run(
602                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
603                    alpha, beta, workspace, workspace_bytes, stream,
604                )
605            },
606            (LayoutSku::Rrr, ElementKind::F16) => unsafe {
607                k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_run(
608                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
609                    alpha, beta, workspace, workspace_bytes, stream,
610                )
611            },
612            (LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
613                k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_run(
614                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
615                    alpha, beta, workspace, workspace_bytes, stream,
616                )
617            },
618            (LayoutSku::Rrr, ElementKind::F32) => unsafe {
619                k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_run(
620                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
621                    alpha, beta, workspace, workspace_bytes, stream,
622                )
623            },
624            (LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
625                k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_run(
626                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
627                    alpha, beta, workspace, workspace_bytes, stream,
628                )
629            },
630            (LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
631                k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_run(
632                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
633                    alpha, beta, workspace, workspace_bytes, stream,
634                )
635            },
636            // F64 has its own dispatcher (`gemm_sm80_run_f64`) because the
637            // FFI takes `f64` alpha/beta. The plan layer routes F64 through
638            // that path before reaching this function; this arm is
639            // defensive.
640            (LayoutSku::Rcr, ElementKind::F64)
641            | (LayoutSku::Rrr, ElementKind::F64) => 3,
642            // Integer element kinds (`S8`, `U8`) route through
643            // [`int_gemm_sm80_run`] instead — this float-family
644            // dispatcher should never see them when the plan layer is
645            // wired correctly. `I32` is an accumulator-only kind and
646            // is never a kernel input element.
647            (_, ElementKind::S8) | (_, ElementKind::U8) | (_, ElementKind::I32)
648            | (_, ElementKind::I64)
649            | (_, ElementKind::Bool)
650            | (_, ElementKind::Fp8E4M3)
651            | (_, ElementKind::Fp8E5M2)
652            | (_, ElementKind::S4)
653            | (_, ElementKind::U4)
654            | (_, ElementKind::Bin)
655            | (_, ElementKind::Complex32)
656            | (_, ElementKind::Complex64) => 3,
657        }
658    }
659
660    #[cfg(feature = "sm80")]
661    pub(super) fn gemm_sm80_workspace_size(
662        layout: LayoutSku,
663        kind: ElementKind,
664        m: i32,
665        n: i32,
666        k: i32,
667    ) -> usize {
668        use baracuda_cutlass_kernels_sys as k_sys;
669        match (layout, kind) {
670            (LayoutSku::Rcr, ElementKind::F16) => unsafe {
671                k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_workspace_size(m, n, k)
672            },
673            (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
674                k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_workspace_size(m, n, k)
675            },
676            (LayoutSku::Rcr, ElementKind::F32) => unsafe {
677                k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_workspace_size(m, n, k)
678            },
679            (LayoutSku::Rrr, ElementKind::F16) => unsafe {
680                k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_workspace_size(m, n, k)
681            },
682            (LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
683                k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_workspace_size(m, n, k)
684            },
685            (LayoutSku::Rrr, ElementKind::F32) => unsafe {
686                k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_workspace_size(m, n, k)
687            },
688            (LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
689                k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_workspace_size(m, n, k)
690            },
691            (LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
692                k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_workspace_size(m, n, k)
693            },
694            // F64 routes through its own dispatcher.
695            (LayoutSku::Rcr, ElementKind::F64)
696            | (LayoutSku::Rrr, ElementKind::F64) => 0,
697            // Integer kinds route through `int_gemm_sm80_workspace_size`.
698            // FP8 kinds route through baracuda-kernels-sys. Defensive arms;
699            // never expected to fire here.
700            (_, ElementKind::S8)
701            | (_, ElementKind::U8)
702            | (_, ElementKind::I32)
703            | (_, ElementKind::I64)
704            | (_, ElementKind::Bool)
705            | (_, ElementKind::Fp8E4M3)
706            | (_, ElementKind::Fp8E5M2)
707            | (_, ElementKind::S4)
708            | (_, ElementKind::U4)
709            | (_, ElementKind::Bin)
710            | (_, ElementKind::Complex32)
711            | (_, ElementKind::Complex64) => 0,
712        }
713    }
714
715    #[cfg(feature = "sm80")]
716    #[allow(clippy::too_many_arguments)]
717    pub(super) unsafe fn gemm_sm80_can_implement(
718        layout: LayoutSku,
719        kind: ElementKind,
720        m: i32,
721        n: i32,
722        k: i32,
723        a: *const c_void,
724        lda: i64,
725        b: *const c_void,
726        ldb: i64,
727        c: *const c_void,
728        ldc: i64,
729        d: *mut c_void,
730        ldd: i64,
731    ) -> i32 {
732        use baracuda_cutlass_kernels_sys as k_sys;
733        match (layout, kind) {
734            (LayoutSku::Rcr, ElementKind::F16) => unsafe {
735                k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_can_implement(
736                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
737                )
738            },
739            (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
740                k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_can_implement(
741                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
742                )
743            },
744            (LayoutSku::Rcr, ElementKind::F32) => unsafe {
745                k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_can_implement(
746                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
747                )
748            },
749            (LayoutSku::Rrr, ElementKind::F16) => unsafe {
750                k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_can_implement(
751                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
752                )
753            },
754            (LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
755                k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_can_implement(
756                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
757                )
758            },
759            (LayoutSku::Rrr, ElementKind::F32) => unsafe {
760                k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_can_implement(
761                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
762                )
763            },
764            (LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
765                k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_can_implement(
766                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
767                )
768            },
769            (LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
770                k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_can_implement(
771                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
772                )
773            },
774            // F64 routes through its own dispatcher.
775            (LayoutSku::Rcr, ElementKind::F64)
776            | (LayoutSku::Rrr, ElementKind::F64) => 3,
777            // Integer kinds route through `int_gemm_sm80_can_implement`.
778            (_, ElementKind::S8) | (_, ElementKind::U8) | (_, ElementKind::I32)
779            | (_, ElementKind::I64)
780            | (_, ElementKind::Bool)
781            | (_, ElementKind::Fp8E4M3)
782            | (_, ElementKind::Fp8E5M2)
783            | (_, ElementKind::S4)
784            | (_, ElementKind::U4)
785            | (_, ElementKind::Bin)
786            | (_, ElementKind::Complex32)
787            | (_, ElementKind::Complex64) => 3,
788        }
789    }
790
791    // ---------- f64 single-GEMM dispatch, sm_80 ----------
792    //
793    // Distinct from `gemm_sm80_run` because the FFI signature takes
794    // `f64` alpha/beta. The plan layer routes through these when
795    // `T::Scalar::IS_F64` is true.
796
797    #[cfg(feature = "sm80")]
798    #[allow(clippy::too_many_arguments)]
799    pub(super) unsafe fn gemm_sm80_run_f64(
800        layout: LayoutSku,
801        m: i32,
802        n: i32,
803        k: i32,
804        a: *const c_void,
805        lda: i64,
806        b: *const c_void,
807        ldb: i64,
808        c: *const c_void,
809        ldc: i64,
810        d: *mut c_void,
811        ldd: i64,
812        alpha: f64,
813        beta: f64,
814        workspace: *mut c_void,
815        workspace_bytes: usize,
816        stream: *mut c_void,
817    ) -> i32 {
818        use baracuda_cutlass_kernels_sys as k_sys;
819        match layout {
820            LayoutSku::Rcr => unsafe {
821                k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_run(
822                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
823                    alpha, beta, workspace, workspace_bytes, stream,
824                )
825            },
826            LayoutSku::Rrr => unsafe {
827                k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_run(
828                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
829                    alpha, beta, workspace, workspace_bytes, stream,
830                )
831            },
832        }
833    }
834
835    #[cfg(feature = "sm80")]
836    pub(super) fn gemm_sm80_workspace_size_f64(layout: LayoutSku, m: i32, n: i32, k: i32) -> usize {
837        use baracuda_cutlass_kernels_sys as k_sys;
838        match layout {
839            LayoutSku::Rcr => unsafe {
840                k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_workspace_size(m, n, k)
841            },
842            LayoutSku::Rrr => unsafe {
843                k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_workspace_size(m, n, k)
844            },
845        }
846    }
847
848    #[cfg(feature = "sm80")]
849    #[allow(clippy::too_many_arguments)]
850    pub(super) unsafe fn gemm_sm80_can_implement_f64(
851        layout: LayoutSku,
852        m: i32,
853        n: i32,
854        k: i32,
855        a: *const c_void,
856        lda: i64,
857        b: *const c_void,
858        ldb: i64,
859        c: *const c_void,
860        ldc: i64,
861        d: *mut c_void,
862        ldd: i64,
863    ) -> i32 {
864        use baracuda_cutlass_kernels_sys as k_sys;
865        match layout {
866            LayoutSku::Rcr => unsafe {
867                k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_can_implement(
868                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
869                )
870            },
871            LayoutSku::Rrr => unsafe {
872                k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_can_implement(
873                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
874                )
875            },
876        }
877    }
878
879    // ---------- f64 bias-fused GEMM dispatch, sm_80 ----------
880
881    #[cfg(feature = "sm80")]
882    #[allow(clippy::too_many_arguments)]
883    pub(super) unsafe fn gemm_bias_sm80_run_f64(
884        layout: LayoutSku,
885        epilogue: EpilogueKind,
886        m: i32,
887        n: i32,
888        k: i32,
889        a: *const c_void,
890        lda: i64,
891        b: *const c_void,
892        ldb: i64,
893        c: *const c_void,
894        ldc: i64,
895        d: *mut c_void,
896        ldd: i64,
897        bias: *const c_void,
898        alpha: f64,
899        beta: f64,
900        workspace: *mut c_void,
901        workspace_bytes: usize,
902        stream: *mut c_void,
903    ) -> i32 {
904        use baracuda_cutlass_kernels_sys as k_sys;
905        match (layout, epilogue) {
906            (LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
907                k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_run(
908                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
909                    bias, alpha, beta, workspace, workspace_bytes, stream,
910                )
911            },
912            (LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
913                k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_run(
914                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
915                    bias, alpha, beta, workspace, workspace_bytes, stream,
916                )
917            },
918            (LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
919                k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_run(
920                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
921                    bias, alpha, beta, workspace, workspace_bytes, stream,
922                )
923            },
924            (LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
925                k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_run(
926                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
927                    bias, alpha, beta, workspace, workspace_bytes, stream,
928                )
929            },
930            (LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
931                k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_run(
932                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
933                    bias, alpha, beta, workspace, workspace_bytes, stream,
934                )
935            },
936            (LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
937                k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_run(
938                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
939                    bias, alpha, beta, workspace, workspace_bytes, stream,
940                )
941            },
942            (LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
943                k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_run(
944                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
945                    bias, alpha, beta, workspace, workspace_bytes, stream,
946                )
947            },
948            (LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
949                k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_run(
950                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
951                    bias, alpha, beta, workspace, workspace_bytes, stream,
952                )
953            },
954            // Identity not a bias-family epilogue; never reached when
955            // `epilogue.requires_bias()` gates the call site.
956            (_, EpilogueKind::Identity) => 3,
957        }
958    }
959
960    #[cfg(feature = "sm80")]
961    pub(super) fn gemm_bias_sm80_workspace_size_f64(
962        layout: LayoutSku,
963        epilogue: EpilogueKind,
964        m: i32,
965        n: i32,
966        k: i32,
967    ) -> usize {
968        use baracuda_cutlass_kernels_sys as k_sys;
969        match (layout, epilogue) {
970            (LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
971                k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_workspace_size(m, n, k)
972            },
973            (LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
974                k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_workspace_size(m, n, k)
975            },
976            (LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
977                k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_workspace_size(m, n, k)
978            },
979            (LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
980                k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_workspace_size(m, n, k)
981            },
982            (LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
983                k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_workspace_size(m, n, k)
984            },
985            (LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
986                k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_workspace_size(m, n, k)
987            },
988            (LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
989                k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_workspace_size(m, n, k)
990            },
991            (LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
992                k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_workspace_size(m, n, k)
993            },
994            (_, EpilogueKind::Identity) => 0,
995        }
996    }
997
998    #[cfg(feature = "sm80")]
999    #[allow(clippy::too_many_arguments)]
1000    pub(super) unsafe fn gemm_bias_sm80_can_implement_f64(
1001        layout: LayoutSku,
1002        epilogue: EpilogueKind,
1003        m: i32,
1004        n: i32,
1005        k: i32,
1006        a: *const c_void,
1007        lda: i64,
1008        b: *const c_void,
1009        ldb: i64,
1010        c: *const c_void,
1011        ldc: i64,
1012        d: *mut c_void,
1013        ldd: i64,
1014        bias: *const c_void,
1015    ) -> i32 {
1016        use baracuda_cutlass_kernels_sys as k_sys;
1017        match (layout, epilogue) {
1018            (LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
1019                k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_can_implement(
1020                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1021                )
1022            },
1023            (LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
1024                k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_can_implement(
1025                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1026                )
1027            },
1028            (LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
1029                k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_can_implement(
1030                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1031                )
1032            },
1033            (LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
1034                k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_can_implement(
1035                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1036                )
1037            },
1038            (LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
1039                k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_can_implement(
1040                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1041                )
1042            },
1043            (LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
1044                k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_can_implement(
1045                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1046                )
1047            },
1048            (LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
1049                k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_can_implement(
1050                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1051                )
1052            },
1053            (LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
1054                k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_can_implement(
1055                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1056                )
1057            },
1058            (_, EpilogueKind::Identity) => 3,
1059        }
1060    }
1061
1062    // ---------- batched GEMM, sm_80 ----------
1063    //
1064    // SKU coverage (status 3 = not implemented):
1065    //   - Rcr × {F16, Bf16}            ✓
1066    //   - Rcr × F32, all Rrr           ✗
1067
1068    #[cfg(feature = "sm80")]
1069    #[allow(clippy::too_many_arguments)]
1070    pub(super) unsafe fn batched_gemm_sm80_run(
1071        layout: LayoutSku,
1072        kind: ElementKind,
1073        m: i32,
1074        n: i32,
1075        k: i32,
1076        a: *const c_void,
1077        lda: i64,
1078        stride_a: i64,
1079        b: *const c_void,
1080        ldb: i64,
1081        stride_b: i64,
1082        c: *const c_void,
1083        ldc: i64,
1084        stride_c: i64,
1085        d: *mut c_void,
1086        ldd: i64,
1087        stride_d: i64,
1088        alpha: f32,
1089        beta: f32,
1090        batch_count: i32,
1091        workspace: *mut c_void,
1092        workspace_bytes: usize,
1093        stream: *mut c_void,
1094    ) -> i32 {
1095        use baracuda_cutlass_kernels_sys as k_sys;
1096        match (layout, kind) {
1097            (LayoutSku::Rcr, ElementKind::F16) => unsafe {
1098                k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_run(
1099                    m, n, k,
1100                    a, lda, stride_a,
1101                    b, ldb, stride_b,
1102                    c, ldc, stride_c,
1103                    d, ldd, stride_d,
1104                    alpha, beta,
1105                    batch_count,
1106                    workspace, workspace_bytes,
1107                    stream,
1108                )
1109            },
1110            (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
1111                k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_run(
1112                    m, n, k,
1113                    a, lda, stride_a,
1114                    b, ldb, stride_b,
1115                    c, ldc, stride_c,
1116                    d, ldd, stride_d,
1117                    alpha, beta,
1118                    batch_count,
1119                    workspace, workspace_bytes,
1120                    stream,
1121                )
1122            },
1123            _ => 3,
1124        }
1125    }
1126
1127    #[cfg(feature = "sm80")]
1128    pub(super) fn batched_gemm_sm80_workspace_size(
1129        layout: LayoutSku,
1130        kind: ElementKind,
1131        m: i32,
1132        n: i32,
1133        k: i32,
1134        batch_count: i32,
1135    ) -> usize {
1136        use baracuda_cutlass_kernels_sys as k_sys;
1137        match (layout, kind) {
1138            (LayoutSku::Rcr, ElementKind::F16) => unsafe {
1139                k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_workspace_size(
1140                    m, n, k, batch_count,
1141                )
1142            },
1143            (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
1144                k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_workspace_size(
1145                    m, n, k, batch_count,
1146                )
1147            },
1148            _ => 0,
1149        }
1150    }
1151
1152    #[cfg(feature = "sm80")]
1153    #[allow(clippy::too_many_arguments)]
1154    pub(super) unsafe fn batched_gemm_sm80_can_implement(
1155        layout: LayoutSku,
1156        kind: ElementKind,
1157        m: i32,
1158        n: i32,
1159        k: i32,
1160        a: *const c_void,
1161        lda: i64,
1162        stride_a: i64,
1163        b: *const c_void,
1164        ldb: i64,
1165        stride_b: i64,
1166        c: *const c_void,
1167        ldc: i64,
1168        stride_c: i64,
1169        d: *mut c_void,
1170        ldd: i64,
1171        stride_d: i64,
1172        batch_count: i32,
1173    ) -> i32 {
1174        use baracuda_cutlass_kernels_sys as k_sys;
1175        match (layout, kind) {
1176            (LayoutSku::Rcr, ElementKind::F16) => unsafe {
1177                k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_can_implement(
1178                    m, n, k,
1179                    a, lda, stride_a,
1180                    b, ldb, stride_b,
1181                    c, ldc, stride_c,
1182                    d, ldd, stride_d,
1183                    batch_count,
1184                )
1185            },
1186            (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
1187                k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_can_implement(
1188                    m, n, k,
1189                    a, lda, stride_a,
1190                    b, ldb, stride_b,
1191                    c, ldc, stride_c,
1192                    d, ldd, stride_d,
1193                    batch_count,
1194                )
1195            },
1196            _ => 3,
1197        }
1198    }
1199
1200    // ---------- grouped GEMM, RCR sm_80 ----------
1201
1202    #[cfg(feature = "sm80")]
1203    pub(super) unsafe fn grouped_gemm_rcr_sm80_sufficient(
1204        kind: ElementKind,
1205        h_m: *const i32,
1206        h_n: *const i32,
1207        h_k: *const i32,
1208        group_count: i32,
1209    ) -> i32 {
1210        use baracuda_cutlass_kernels_sys as k_sys;
1211        match kind {
1212            ElementKind::F16 => unsafe {
1213                k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient(h_m, h_n, h_k, group_count)
1214            },
1215            ElementKind::Bf16 => unsafe {
1216                k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_sufficient(h_m, h_n, h_k, group_count)
1217            },
1218            ElementKind::F32
1219            | ElementKind::F32Strict
1220            | ElementKind::F64
1221            | ElementKind::S8
1222            | ElementKind::U8
1223            | ElementKind::I32
1224            | ElementKind::I64
1225            | ElementKind::Bool
1226            | ElementKind::Fp8E4M3
1227            | ElementKind::Fp8E5M2
1228            | ElementKind::S4
1229            | ElementKind::U4
1230            | ElementKind::Bin
1231            | ElementKind::Complex32
1232            | ElementKind::Complex64 => 0,
1233        }
1234    }
1235
1236    #[cfg(feature = "sm80")]
1237    pub(super) unsafe fn grouped_gemm_rcr_sm80_scratch_bytes(
1238        kind: ElementKind,
1239        h_m: *const i32,
1240        h_n: *const i32,
1241        h_k: *const i32,
1242        group_count: i32,
1243        threadblock_count: i32,
1244    ) -> usize {
1245        use baracuda_cutlass_kernels_sys as k_sys;
1246        match kind {
1247            ElementKind::F16 => unsafe {
1248                k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_scratch_bytes(
1249                    h_m, h_n, h_k, group_count, threadblock_count,
1250                )
1251            },
1252            ElementKind::Bf16 => unsafe {
1253                k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_scratch_bytes(
1254                    h_m, h_n, h_k, group_count, threadblock_count,
1255                )
1256            },
1257            ElementKind::F32
1258            | ElementKind::F32Strict
1259            | ElementKind::F64
1260            | ElementKind::S8
1261            | ElementKind::U8
1262            | ElementKind::I32
1263            | ElementKind::I64
1264            | ElementKind::Bool
1265            | ElementKind::Fp8E4M3
1266            | ElementKind::Fp8E5M2
1267            | ElementKind::S4
1268            | ElementKind::U4
1269            | ElementKind::Bin
1270            | ElementKind::Complex32
1271            | ElementKind::Complex64 => 0,
1272        }
1273    }
1274
1275    #[cfg(feature = "sm80")]
1276    pub(super) unsafe fn grouped_gemm_rcr_sm80_can_implement(
1277        kind: ElementKind,
1278        h_m: *const i32,
1279        h_n: *const i32,
1280        h_k: *const i32,
1281        group_count: i32,
1282    ) -> i32 {
1283        use baracuda_cutlass_kernels_sys as k_sys;
1284        match kind {
1285            ElementKind::F16 => unsafe {
1286                k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_can_implement(h_m, h_n, h_k, group_count)
1287            },
1288            ElementKind::Bf16 => unsafe {
1289                k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_can_implement(h_m, h_n, h_k, group_count)
1290            },
1291            ElementKind::F32
1292            | ElementKind::F32Strict
1293            | ElementKind::F64
1294            | ElementKind::S8
1295            | ElementKind::U8
1296            | ElementKind::I32
1297            | ElementKind::I64
1298            | ElementKind::Bool
1299            | ElementKind::Fp8E4M3
1300            | ElementKind::Fp8E5M2
1301            | ElementKind::S4
1302            | ElementKind::U4
1303            | ElementKind::Bin
1304            | ElementKind::Complex32
1305            | ElementKind::Complex64 => 3,
1306        }
1307    }
1308
1309    #[cfg(feature = "sm80")]
1310    #[allow(clippy::too_many_arguments)]
1311    pub(super) unsafe fn grouped_gemm_rcr_sm80_run(
1312        kind: ElementKind,
1313        group_count: i32,
1314        threadblock_count: i32,
1315        d_problem_sizes: *const c_void,
1316        d_ptr_a: *const c_void,
1317        d_ptr_b: *const c_void,
1318        d_ptr_c: *const c_void,
1319        d_ptr_d: *mut c_void,
1320        d_lda: *const c_void,
1321        d_ldb: *const c_void,
1322        d_ldc: *const c_void,
1323        d_ldd: *const c_void,
1324        h_problem_sizes: *const c_void,
1325        alpha: f32,
1326        beta: f32,
1327        scratch: *mut c_void,
1328        scratch_bytes: usize,
1329        stream: *mut c_void,
1330    ) -> i32 {
1331        use baracuda_cutlass_kernels_sys as k_sys;
1332        match kind {
1333            ElementKind::F16 => unsafe {
1334                k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_run(
1335                    group_count, threadblock_count,
1336                    d_problem_sizes,
1337                    d_ptr_a, d_ptr_b, d_ptr_c, d_ptr_d,
1338                    d_lda, d_ldb, d_ldc, d_ldd,
1339                    h_problem_sizes,
1340                    alpha, beta,
1341                    scratch, scratch_bytes,
1342                    stream,
1343                )
1344            },
1345            ElementKind::Bf16 => unsafe {
1346                k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_run(
1347                    group_count, threadblock_count,
1348                    d_problem_sizes,
1349                    d_ptr_a, d_ptr_b, d_ptr_c, d_ptr_d,
1350                    d_lda, d_ldb, d_ldc, d_ldd,
1351                    h_problem_sizes,
1352                    alpha, beta,
1353                    scratch, scratch_bytes,
1354                    stream,
1355                )
1356            },
1357            ElementKind::F32
1358            | ElementKind::F32Strict
1359            | ElementKind::F64
1360            | ElementKind::S8
1361            | ElementKind::U8
1362            | ElementKind::I32
1363            | ElementKind::I64
1364            | ElementKind::Bool
1365            | ElementKind::Fp8E4M3
1366            | ElementKind::Fp8E5M2
1367            | ElementKind::S4
1368            | ElementKind::U4
1369            | ElementKind::Bin
1370            | ElementKind::Complex32
1371            | ElementKind::Complex64 => 3,
1372        }
1373    }
1374
1375    // ============================================================================
1376    // int8 GEMM dispatch — Identity, RCR layout, sm_80
1377    // ============================================================================
1378    //
1379    // Identity int kernels: `LinearCombinationClamp` epilogue. Layout is
1380    // restricted to `Rcr` (`Rrr` is deferred — see lib.rs section header).
1381    // Element kind is restricted to S8 / U8. Alpha/beta are `f32` (CUTLASS
1382    // int8 dequant-in-epilogue convention).
1383
1384    #[cfg(feature = "sm80")]
1385    #[allow(clippy::too_many_arguments)]
1386    pub(super) unsafe fn int_gemm_rcr_sm80_run(
1387        layout: LayoutSku,
1388        kind: ElementKind,
1389        m: i32,
1390        n: i32,
1391        k: i32,
1392        a: *const c_void,
1393        lda: i64,
1394        b: *const c_void,
1395        ldb: i64,
1396        c: *const c_void,
1397        ldc: i64,
1398        d: *mut c_void,
1399        ldd: i64,
1400        alpha: f32,
1401        beta: f32,
1402        workspace: *mut c_void,
1403        workspace_bytes: usize,
1404        stream: *mut c_void,
1405    ) -> i32 {
1406        use baracuda_cutlass_kernels_sys as k_sys;
1407        match (layout, kind) {
1408            (LayoutSku::Rcr, ElementKind::S8) => unsafe {
1409                k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_run(
1410                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1411                    alpha, beta, workspace, workspace_bytes, stream,
1412                )
1413            },
1414            (LayoutSku::Rcr, ElementKind::U8) => unsafe {
1415                k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_run(
1416                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1417                    alpha, beta, workspace, workspace_bytes, stream,
1418                )
1419            },
1420            // RRR int8 is deferred — CUTLASS 4.2.0 lacks the 8-bit
1421            // `TensorOpMultiplicandCongruous` warp iterator needed for
1422            // `RowMajor × RowMajor × OpClassTensorOp` instantiation.
1423            (LayoutSku::Rrr, ElementKind::S8) | (LayoutSku::Rrr, ElementKind::U8) => 3,
1424            // Defensive: float and I32 kinds should never reach this dispatcher.
1425            _ => 3,
1426        }
1427    }
1428
1429    #[cfg(feature = "sm80")]
1430    pub(super) fn int_gemm_rcr_sm80_workspace_size(
1431        layout: LayoutSku,
1432        kind: ElementKind,
1433        m: i32,
1434        n: i32,
1435        k: i32,
1436    ) -> usize {
1437        use baracuda_cutlass_kernels_sys as k_sys;
1438        match (layout, kind) {
1439            (LayoutSku::Rcr, ElementKind::S8) => unsafe {
1440                k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_workspace_size(m, n, k)
1441            },
1442            (LayoutSku::Rcr, ElementKind::U8) => unsafe {
1443                k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_workspace_size(m, n, k)
1444            },
1445            _ => 0,
1446        }
1447    }
1448
1449    #[cfg(feature = "sm80")]
1450    #[allow(clippy::too_many_arguments)]
1451    pub(super) unsafe fn int_gemm_rcr_sm80_can_implement(
1452        layout: LayoutSku,
1453        kind: ElementKind,
1454        m: i32,
1455        n: i32,
1456        k: i32,
1457        a: *const c_void,
1458        lda: i64,
1459        b: *const c_void,
1460        ldb: i64,
1461        c: *const c_void,
1462        ldc: i64,
1463        d: *mut c_void,
1464        ldd: i64,
1465    ) -> i32 {
1466        use baracuda_cutlass_kernels_sys as k_sys;
1467        match (layout, kind) {
1468            (LayoutSku::Rcr, ElementKind::S8) => unsafe {
1469                k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_can_implement(
1470                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1471                )
1472            },
1473            (LayoutSku::Rcr, ElementKind::U8) => unsafe {
1474                k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_can_implement(
1475                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1476                )
1477            },
1478            (LayoutSku::Rrr, ElementKind::S8) | (LayoutSku::Rrr, ElementKind::U8) => 3,
1479            _ => 3,
1480        }
1481    }
1482
1483    // ============================================================================
1484    // int8 bias-fused GEMM dispatch — RCR layout, sm_80
1485    // ============================================================================
1486    //
1487    // Bias kernels: `LinearCombinationBiasElementwise<int8, int32, float,
1488    // int8, int8, EPA=16, ActOp, plus<float>, false, ElementBias>`. The
1489    // four `Bias*` epilogues differ only in `ActOp`; the bias element
1490    // is independent of the matrix element and dispatched on
1491    // `bias_kind`. The 16 reachable arms = 2 sgn × 4 epi × 2 bias-type.
1492
1493    use crate::types::BiasElementKind;
1494
1495    #[cfg(feature = "sm80")]
1496    #[allow(clippy::too_many_arguments)]
1497    pub(super) unsafe fn int_gemm_bias_rcr_sm80_run(
1498        layout: LayoutSku,
1499        kind: ElementKind,
1500        epilogue: EpilogueKind,
1501        bias_kind: BiasElementKind,
1502        m: i32,
1503        n: i32,
1504        k: i32,
1505        a: *const c_void,
1506        lda: i64,
1507        b: *const c_void,
1508        ldb: i64,
1509        c: *const c_void,
1510        ldc: i64,
1511        d: *mut c_void,
1512        ldd: i64,
1513        bias: *const c_void,
1514        alpha: f32,
1515        beta: f32,
1516        workspace: *mut c_void,
1517        workspace_bytes: usize,
1518        stream: *mut c_void,
1519    ) -> i32 {
1520        use baracuda_cutlass_kernels_sys as k_sys;
1521        // RRR int8 deferred (see Identity dispatcher).
1522        if !matches!(layout, LayoutSku::Rcr) {
1523            return 3;
1524        }
1525        match (kind, epilogue, bias_kind) {
1526            // ---- s8 × f32 bias ----
1527            (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1528                k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_run(
1529                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1530                    bias, alpha, beta, workspace, workspace_bytes, stream,
1531                )
1532            },
1533            (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1534                k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_run(
1535                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1536                    bias, alpha, beta, workspace, workspace_bytes, stream,
1537                )
1538            },
1539            (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1540                k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_run(
1541                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1542                    bias, alpha, beta, workspace, workspace_bytes, stream,
1543                )
1544            },
1545            (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1546                k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_run(
1547                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1548                    bias, alpha, beta, workspace, workspace_bytes, stream,
1549                )
1550            },
1551            // ---- s8 × i32 bias ----
1552            (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1553                k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_run(
1554                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1555                    bias, alpha, beta, workspace, workspace_bytes, stream,
1556                )
1557            },
1558            (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1559                k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_run(
1560                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1561                    bias, alpha, beta, workspace, workspace_bytes, stream,
1562                )
1563            },
1564            (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1565                k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_run(
1566                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1567                    bias, alpha, beta, workspace, workspace_bytes, stream,
1568                )
1569            },
1570            (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1571                k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_run(
1572                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1573                    bias, alpha, beta, workspace, workspace_bytes, stream,
1574                )
1575            },
1576            // ---- u8 × f32 bias ----
1577            (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1578                k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_run(
1579                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1580                    bias, alpha, beta, workspace, workspace_bytes, stream,
1581                )
1582            },
1583            (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1584                k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_run(
1585                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1586                    bias, alpha, beta, workspace, workspace_bytes, stream,
1587                )
1588            },
1589            (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1590                k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_run(
1591                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1592                    bias, alpha, beta, workspace, workspace_bytes, stream,
1593                )
1594            },
1595            (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1596                k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_run(
1597                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1598                    bias, alpha, beta, workspace, workspace_bytes, stream,
1599                )
1600            },
1601            // ---- u8 × i32 bias ----
1602            (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1603                k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_run(
1604                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1605                    bias, alpha, beta, workspace, workspace_bytes, stream,
1606                )
1607            },
1608            (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1609                k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_run(
1610                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1611                    bias, alpha, beta, workspace, workspace_bytes, stream,
1612                )
1613            },
1614            (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1615                k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_run(
1616                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1617                    bias, alpha, beta, workspace, workspace_bytes, stream,
1618                )
1619            },
1620            (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1621                k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_run(
1622                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1623                    bias, alpha, beta, workspace, workspace_bytes, stream,
1624                )
1625            },
1626            // Identity reaches here when called accidentally — bias is
1627            // required for the bias-family dispatchers. Identity routes
1628            // through `int_gemm_rcr_sm80_run` instead.
1629            (_, EpilogueKind::Identity, _) => 3,
1630            // Defensive: float / I32 kinds should never reach this dispatcher.
1631            _ => 3,
1632        }
1633    }
1634
1635    #[cfg(feature = "sm80")]
1636    pub(super) fn int_gemm_bias_rcr_sm80_workspace_size(
1637        layout: LayoutSku,
1638        kind: ElementKind,
1639        epilogue: EpilogueKind,
1640        bias_kind: BiasElementKind,
1641        m: i32,
1642        n: i32,
1643        k: i32,
1644    ) -> usize {
1645        use baracuda_cutlass_kernels_sys as k_sys;
1646        if !matches!(layout, LayoutSku::Rcr) {
1647            return 0;
1648        }
1649        match (kind, epilogue, bias_kind) {
1650            (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1651                k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1652            },
1653            (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1654                k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1655            },
1656            (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1657                k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1658            },
1659            (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1660                k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1661            },
1662            (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1663                k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1664            },
1665            (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1666                k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1667            },
1668            (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1669                k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1670            },
1671            (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1672                k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1673            },
1674            (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1675                k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1676            },
1677            (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1678                k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1679            },
1680            (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1681                k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1682            },
1683            (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1684                k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1685            },
1686            (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1687                k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1688            },
1689            (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1690                k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1691            },
1692            (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1693                k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1694            },
1695            (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1696                k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1697            },
1698            _ => 0,
1699        }
1700    }
1701
1702    #[cfg(feature = "sm80")]
1703    #[allow(clippy::too_many_arguments)]
1704    pub(super) unsafe fn int_gemm_bias_rcr_sm80_can_implement(
1705        layout: LayoutSku,
1706        kind: ElementKind,
1707        epilogue: EpilogueKind,
1708        bias_kind: BiasElementKind,
1709        m: i32,
1710        n: i32,
1711        k: i32,
1712        a: *const c_void,
1713        lda: i64,
1714        b: *const c_void,
1715        ldb: i64,
1716        c: *const c_void,
1717        ldc: i64,
1718        d: *mut c_void,
1719        ldd: i64,
1720        bias: *const c_void,
1721    ) -> i32 {
1722        use baracuda_cutlass_kernels_sys as k_sys;
1723        if !matches!(layout, LayoutSku::Rcr) {
1724            return 3;
1725        }
1726        match (kind, epilogue, bias_kind) {
1727            (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1728                k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_can_implement(
1729                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1730                )
1731            },
1732            (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1733                k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_can_implement(
1734                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1735                )
1736            },
1737            (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1738                k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_can_implement(
1739                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1740                )
1741            },
1742            (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1743                k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_can_implement(
1744                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1745                )
1746            },
1747            (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1748                k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_can_implement(
1749                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1750                )
1751            },
1752            (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1753                k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_can_implement(
1754                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1755                )
1756            },
1757            (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1758                k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_can_implement(
1759                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1760                )
1761            },
1762            (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1763                k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_can_implement(
1764                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1765                )
1766            },
1767            (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1768                k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_can_implement(
1769                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1770                )
1771            },
1772            (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1773                k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_can_implement(
1774                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1775                )
1776            },
1777            (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1778                k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_can_implement(
1779                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1780                )
1781            },
1782            (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1783                k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_can_implement(
1784                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1785                )
1786            },
1787            (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1788                k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_can_implement(
1789                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1790                )
1791            },
1792            (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1793                k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_can_implement(
1794                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1795                )
1796            },
1797            (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1798                k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_can_implement(
1799                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1800                )
1801            },
1802            (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1803                k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_can_implement(
1804                    m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1805                )
1806            },
1807            _ => 3,
1808        }
1809    }
1810}
1811
1812// ============================================================================
1813// Host-side validation helpers
1814// ============================================================================
1815
1816/// Minimum element count required to back a `(rows, cols, ld)` matrix at
1817/// the given layout. Returns `None` on overflow.
1818///
1819/// Assumes `rows >= 1` and `cols >= 1` (callers go through
1820/// [`check_descriptor`] which rejects non-positive dimensions first).
1821fn min_elements_row_major(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1822    // [rows, cols] row-major: element [i, j] at offset i*ld + j.
1823    // Maximum addressable index = (rows - 1) * ld + (cols - 1), so the
1824    // buffer must hold (rows - 1) * ld + cols elements. Accepts padded
1825    // leading dimensions (ld > cols) without rejecting valid slabs.
1826    let r = (rows - 1) as i64;
1827    let needed = r.checked_mul(ld)?.checked_add(cols as i64)?;
1828    usize::try_from(needed).ok()
1829}
1830
1831fn min_elements_col_major(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1832    // [rows, cols] column-major: element [i, j] at offset j*ld + i.
1833    // Maximum addressable index = (cols - 1) * ld + (rows - 1).
1834    let c = (cols - 1) as i64;
1835    let needed = c.checked_mul(ld)?.checked_add(rows as i64)?;
1836    usize::try_from(needed).ok()
1837}
1838
1839// Compatibility shims for the buffer-size tests below — they document
1840// the per-operand layout (A row-major, B column-major for Rcr, C/D
1841// row-major) so we keep separate names for clarity.
1842#[cfg(test)]
1843fn min_elements_rcr_a(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1844    min_elements_row_major(rows, cols, ld)
1845}
1846#[cfg(test)]
1847fn min_elements_rcr_b(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1848    min_elements_col_major(rows, cols, ld)
1849}
1850#[cfg(test)]
1851fn min_elements_rcr_cd(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1852    min_elements_row_major(rows, cols, ld)
1853}
1854
1855fn check_descriptor(desc: &GemmDescriptor) -> Result<()> {
1856    if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
1857        return Err(Error::InvalidProblem("M, N, K must all be positive"));
1858    }
1859    // All shipped layouts (Rcr, Rrr) and epilogues (Identity) are
1860    // accepted by select; per-(layout, kind) implementability is
1861    // dispatched at run time to the kernel's own can_implement.
1862    Ok(())
1863}
1864
1865fn check_args<T: CutlassElement>(desc: &GemmDescriptor, args: &GemmArgs<'_, T>) -> Result<()> {
1866    // Epilogue / bias must agree: any Bias* variant needs bias = Some,
1867    // Identity needs bias = None.
1868    match (desc.epilogue.requires_bias(), &args.bias) {
1869        (false, Some(_)) => {
1870            return Err(Error::InvalidProblem(
1871                "args.bias must be None when descriptor.epilogue is Identity",
1872            ));
1873        }
1874        (true, None) => {
1875            return Err(Error::InvalidProblem(
1876                "args.bias is required when descriptor.epilogue is in the Bias family \
1877                 (Bias / BiasRelu / BiasGelu / BiasSilu)",
1878            ));
1879        }
1880        (false, None) | (true, Some(_)) => {}
1881    }
1882    if let Some(bias) = &args.bias {
1883        if bias.len != desc.n {
1884            return Err(Error::InvalidProblem(
1885                "bias vector length must equal N",
1886            ));
1887        }
1888        if bias.stride != 1 {
1889            return Err(Error::Unsupported(
1890                "bias vector must be contiguous (stride 1) — strided bias not supported",
1891            ));
1892        }
1893        if bias.data.len() < desc.n as usize {
1894            return Err(Error::BufferTooSmall {
1895                needed: desc.n as usize,
1896                got: bias.data.len(),
1897            });
1898        }
1899    }
1900    if args.a.rows != desc.m || args.a.cols != desc.k {
1901        return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
1902    }
1903    if args.b.rows != desc.k || args.b.cols != desc.n {
1904        return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
1905    }
1906    if args.d.rows != desc.m || args.d.cols != desc.n {
1907        return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
1908    }
1909    if let Some(c) = &args.c {
1910        if c.rows != desc.m || c.cols != desc.n {
1911            return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
1912        }
1913    }
1914    // A is row-major in both Rcr and Rrr — leading dim along the K axis.
1915    if args.a.ld < desc.k as i64 {
1916        return Err(Error::InvalidProblem("A leading dimension must be >= K"));
1917    }
1918    // B's leading-dim minor depends on layout:
1919    //   Rcr: column-major [K, N], ld is along the K axis (rows).
1920    //   Rrr: row-major    [K, N], ld is along the N axis (cols).
1921    let b_min_ld = match desc.layout {
1922        LayoutSku::Rcr => desc.k as i64,
1923        LayoutSku::Rrr => desc.n as i64,
1924    };
1925    if args.b.ld < b_min_ld {
1926        return Err(Error::InvalidProblem(match desc.layout {
1927            LayoutSku::Rcr => "B leading dimension must be >= K (column-major Rcr layout)",
1928            LayoutSku::Rrr => "B leading dimension must be >= N (row-major Rrr layout)",
1929        }));
1930    }
1931    if args.d.ld < desc.n as i64 {
1932        return Err(Error::InvalidProblem("D leading dimension must be >= N"));
1933    }
1934    if let Some(c) = &args.c {
1935        if c.ld < desc.n as i64 {
1936            return Err(Error::InvalidProblem("C leading dimension must be >= N"));
1937        }
1938    }
1939
1940    let need_a = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
1941        .ok_or(Error::InvalidProblem("A storage size overflow"))?;
1942    if args.a.data.len() < need_a {
1943        return Err(Error::BufferTooSmall {
1944            needed: need_a,
1945            got: args.a.data.len(),
1946        });
1947    }
1948    let need_b = match desc.layout {
1949        LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
1950        LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
1951    }
1952    .ok_or(Error::InvalidProblem("B storage size overflow"))?;
1953    if args.b.data.len() < need_b {
1954        return Err(Error::BufferTooSmall {
1955            needed: need_b,
1956            got: args.b.data.len(),
1957        });
1958    }
1959    let need_d = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
1960        .ok_or(Error::InvalidProblem("D storage size overflow"))?;
1961    if args.d.data.len() < need_d {
1962        return Err(Error::BufferTooSmall {
1963            needed: need_d,
1964            got: args.d.data.len(),
1965        });
1966    }
1967    if let Some(c) = &args.c {
1968        let need_c = min_elements_row_major(c.rows, c.cols, c.ld)
1969            .ok_or(Error::InvalidProblem("C storage size overflow"))?;
1970        if c.data.len() < need_c {
1971            return Err(Error::BufferTooSmall {
1972                needed: need_c,
1973                got: c.data.len(),
1974            });
1975        }
1976    }
1977    Ok(())
1978}
1979
1980// ============================================================================
1981// cuBLAS backend — Phase 30 fast-path for f16/bf16 decode-regime FP GEMM
1982// ============================================================================
1983//
1984// Background (from `crates/baracuda-kernels-bench/BENCHMARKS.md`, Phase 29):
1985// baracuda's CUTLASS RCR plan emits `_sm80_` kernels (Ampere-tuned) for
1986// f16/bf16. On Ada (sm_89, RTX 4070) cuBLAS auto-dispatches to the
1987// sm_89 tensor-core path with an f32 accumulator and wins 2–4× at
1988// low-M (M=1 / M=32) decode shapes, narrowing to ~parity at M=128.
1989// This module hosts the cuBLAS Gemm fallback path that the
1990// [`GemmPlan::select`] heuristic dispatches to for those shapes.
1991//
1992// Coverage:
1993//   - f16, bf16, f32: `cublasGemmEx` with `CUBLAS_COMPUTE_32F`
1994//     (CUBLAS_GEMM_DEFAULT_TENSOR_OP algo, value 99).
1995//   - f64: `cublasDgemm` (no GemmEx needed — DGEMM is plenty fast).
1996//   - F32Strict: not supported (cuBLAS doesn't expose a strict-IEEE
1997//     SIMT path the way CUTLASS does; F32Strict callers stay on the
1998//     CUTLASS SIMT kernels for bit-stability).
1999//   - Bias / Bias* epilogues: not supported (cuBLAS-classic has no
2000//     fused-bias-activation; cuBLASLt does but is a separate API
2001//     surface). Forcing Cublas backend on a Bias* epilogue returns
2002//     `Error::Unsupported` from `select`.
2003//
2004// Layout mapping (baracuda Row/Col/Row → cuBLAS column-major):
2005// baracuda's RCR means A row-major [M, K], B column-major [K, N], D
2006// row-major [M, N]. cuBLAS is column-major, so we use the standard
2007// "C^T = B^T · A^T" trick: pass B as the first operand, A as the
2008// second, with `transa = Op::T` on both, and the resulting D in
2009// col-major IS our row-major D. For RRR the mapping changes the
2010// second `transa` to `Op::N`.
2011
2012mod cublas_backend {
2013    use core::cell::RefCell;
2014
2015    use baracuda_cublas::Handle as CublasHandle;
2016    use baracuda_driver::Stream;
2017
2018    // Per-thread cache of cuBLAS handles, keyed by the raw context
2019    // pointer the handle was created against.
2020    //
2021    // cuBLAS handles are tied to the CUDA context current at the time
2022    // of `cublasCreate`. Caching by (thread, ctx-raw-ptr) means the
2023    // first call into the cuBLAS backend on a given thread pays the
2024    // (~100µs–ms) `cublasCreate` cost; subsequent calls reuse the
2025    // handle. Per-thread storage sidesteps the cuBLAS-handle `!Sync`
2026    // constraint (NVIDIA documents that a handle should only be
2027    // touched from one host thread at a time) and keeps
2028    // `super::GemmPlan` `Send + Sync` (the plan stores no handle —
2029    // it looks one up out of the thread-local on each `run`).
2030    //
2031    // A `Vec` keyed by raw context pointer is fine here: production
2032    // callers have one context per process, and even
2033    // multi-context callers rarely exceed a handful.
2034    thread_local! {
2035        static HANDLE_CACHE: RefCell<Vec<(usize, CublasHandle)>> =
2036            const { RefCell::new(Vec::new()) };
2037    }
2038
2039    /// Fetch (or lazily create) the cuBLAS handle bound to `stream`'s
2040    /// context, with the handle's stream binding set to `stream`.
2041    ///
2042    /// The returned handle is `Clone`able (the inner state is `Arc`'d
2043    /// inside [`baracuda_cublas::Handle`]) so the caller can move it
2044    /// into the closure that calls `gemm_ex` / `gemm` without holding
2045    /// the thread-local `RefCell` borrow across the launch.
2046    pub(super) fn handle_for(stream: &Stream) -> crate::Result<CublasHandle> {
2047        // Use the context's raw pointer as the cache key. baracuda's
2048        // `Context` doesn't expose a stable id, but the inner CUDA
2049        // context pointer is unique per process and stable for the
2050        // context's lifetime.
2051        let ctx_key = stream.context().as_raw() as usize;
2052        let handle = HANDLE_CACHE.with(|cache| -> crate::Result<CublasHandle> {
2053            let mut cache = cache.borrow_mut();
2054            if let Some((_, h)) = cache.iter().find(|(k, _)| *k == ctx_key) {
2055                return Ok(h.clone());
2056            }
2057            // First touch on this thread for this context — create.
2058            // The current context must be the stream's context for
2059            // `cublasCreate` to bind correctly. baracuda's Context API
2060            // is push-based; we call `set_current` (idempotent if
2061            // already current) before the cuBLAS create.
2062            stream
2063                .context()
2064                .set_current()
2065                .map_err(crate::Error::Driver)?;
2066            // Retry handle creation under transient failure. Observed
2067            // empirically: when many test PROCESSES start in parallel
2068            // (the `cargo test --workspace` harness launches dozens of
2069            // binaries concurrently), the cuBLAS library's first-call
2070            // initialization races on a shared driver resource (DLL
2071            // loader / cuBLAS-side global lock) and `cublasCreate_v2`
2072            // intermittently returns CUBLAS_STATUS_ALLOC_FAILED or
2073            // CUBLAS_STATUS_NOT_INITIALIZED. The error is genuinely
2074            // transient — a 50ms backoff + retry clears it 100% of the
2075            // time on RTX 4070 / cuBLAS 12. Five retries with linear
2076            // backoff (50ms, 100ms, ..., 250ms) bounds the worst case
2077            // at ~750ms per first-touch, which is fine for what is
2078            // already a one-time-per-thread operation.
2079            let h = {
2080                let mut last_err = None;
2081                let mut handle: Option<CublasHandle> = None;
2082                for attempt in 0..5 {
2083                    match CublasHandle::new() {
2084                        Ok(h) => { handle = Some(h); break }
2085                        Err(e) => {
2086                            last_err = Some(e);
2087                            // Linear backoff: 50ms * (attempt + 1).
2088                            std::thread::sleep(std::time::Duration::from_millis(
2089                                50 * (attempt as u64 + 1),
2090                            ));
2091                        }
2092                    }
2093                }
2094                match handle {
2095                    Some(h) => h,
2096                    None => {
2097                        // Surface the underlying error code in the
2098                        // message for diagnostics. The retry loop is
2099                        // transparent on success — only the final
2100                        // failure path mentions exhaustion.
2101                        let _ = last_err; // dropped; cublas Error not in scope here
2102                        return Err(crate::Error::Unsupported(
2103                            "cuBLAS handle creation failed after 5 retries \
2104                             (library missing, device unavailable, or \
2105                             persistent driver-init contention)",
2106                        ));
2107                    }
2108                }
2109            };
2110            cache.push((ctx_key, h.clone()));
2111            Ok(h)
2112        })?;
2113        // Bind to the launch stream so the GEMM enqueues on the
2114        // caller-supplied stream rather than the default stream. The
2115        // bind is per-handle, sticky across calls — but a previously
2116        // cached handle could have been bound to a *different* stream
2117        // last time. Always re-bind.
2118        handle
2119            .set_stream(stream)
2120            .map_err(|_| crate::Error::Unsupported(
2121                "cuBLAS set_stream failed",
2122            ))?;
2123        Ok(handle)
2124    }
2125}
2126
2127/// Algo selector for `cublasGemmEx`.
2128///
2129/// `-1` corresponds to `CUBLAS_GEMM_DEFAULT` / `CUBLAS_GEMM_DFALT` — the
2130/// "let cuBLAS pick" heuristic. In modern cuBLAS (CUDA 12+) when
2131/// `compute_type = Compute32F` the algo selector is largely advisory:
2132/// cuBLAS picks the kernel from its internal heuristic based on
2133/// (M, N, K, dtype) regardless of the value here. The Phase-29 bench's
2134/// reference cuBLAS-direct path passes `0` (CUBLAS_GEMM_ALGO0), which
2135/// is also routed through the same heuristic; both produce the same
2136/// kernel selection on the f16/bf16/f32 GemmEx paths.
2137///
2138/// `CUBLAS_GEMM_DEFAULT_TENSOR_OP` (= 99) would force a tensor-op
2139/// kernel even at very low M where cuBLAS's heuristic prefers a CUDA-
2140/// core kernel; benchmarking showed it was ~3× slower than `DEFAULT`
2141/// at M=1 on RTX 4070. Stick with `DEFAULT` (`-1`).
2142const CUBLAS_GEMM_ALGO: i32 = -1;
2143
2144/// Backend the plan picked for this descriptor. Stored privately on
2145/// the plan; surfaced through [`GemmPlan::backend`].
2146///
2147/// Differs from `GemmSku::arch` in that it answers a *higher-level*
2148/// question: "which kernel library does this plan call into?" The arch
2149/// SKU only meaningfully describes the CUTLASS variant — CUTLASS itself
2150/// is one of several backends now that the Phase-30 cuBLAS fast-path
2151/// joined the dispatch surface.
2152#[derive(Copy, Clone, Debug, Eq, PartialEq)]
2153enum BackendChoice {
2154    /// Compiled CUTLASS template instantiation in
2155    /// `baracuda-cutlass-kernels-sys`. `arch` records which CUTLASS
2156    /// SKU (sm_80 today).
2157    Cutlass { arch: ArchSku },
2158    /// `cublasGemmEx` (f16/bf16/f32) or `cublasDgemm` (f64) via the
2159    /// `baracuda-cublas` wrapper. Used as the Phase-30 fast path for
2160    /// f16/bf16 low-M decode shapes on sm_89 hardware.
2161    Cublas,
2162    /// Phase 44: vendored ozIMMU (Ozaki-scheme FP64 GEMM that
2163    /// synthesizes a DGEMM from `S²` int8 tensor-core matmuls). Only
2164    /// valid for f64 / RCR / RRR / Identity epilogue. The `slices`
2165    /// discriminant follows the public `BackendKind::Ozaki { slices }`
2166    /// convention: 0 = auto, 3..=18 = fixed slice count.
2167    Ozaki { slices: u8 },
2168}
2169
2170impl BackendChoice {
2171    fn as_public(self) -> BackendKind {
2172        match self {
2173            BackendChoice::Cutlass { .. } => BackendKind::Cutlass,
2174            BackendChoice::Cublas => BackendKind::Cublas,
2175            BackendChoice::Ozaki { slices } => BackendKind::Ozaki { slices },
2176        }
2177    }
2178}
2179
2180/// Heuristic that decides whether to route an FP-family GEMM through
2181/// cuBLAS or CUTLASS. Caller `pref.prefer_backend` overrides whatever
2182/// the heuristic would have picked (subject to availability checks).
2183///
2184/// Today's heuristic (Phase 30, validated on RTX 4070):
2185///
2186/// - f16/bf16 with `2 ≤ M < 128`: cuBLAS. The decode-batch regime
2187///   where cuBLAS's sm_89 tensor-core kernels beat baracuda's
2188///   CUTLASS sm_80 RCR plan by 2–3×.
2189/// - f16/bf16 with `M == 1` (pure GEMV-shape): CUTLASS. Counter-
2190///   intuitive given Phase-29's M=1 perf gap, but the layout
2191///   `transa=T` we must use to bridge baracuda's row-major contract
2192///   to cuBLAS's col-major API forces cuBLAS to materialize a B^T
2193///   transpose, which is slower than the unmaterialized
2194///   CUTLASS-RCR-sm_80 GEMV-tile at the K=N=4096 shape (185µs vs
2195///   108µs measured). For pure single-token decode, callers wanting
2196///   the bench's "all-1's mathematically-wrong" speed can force
2197///   `prefer_backend = Some(Cublas)` and accept the routing.
2198/// - f16/bf16 with `M >= 128`: CUTLASS. Tie zone; the existing
2199///   sm_80 RCR is known-stable.
2200/// - f32: CUTLASS. Phase-29 bench shows CUTLASS f32 RCR beats cuBLAS
2201///   `sgemm` at M=128; ~parity at low-M; not worth the routing churn.
2202/// - f64: CUTLASS (back-compat). Callers can force cuBLAS via
2203///   `prefer_backend` for decode-shaped f64 workloads.
2204/// - F32Strict: CUTLASS (cuBLAS has no strict-IEEE SIMT path).
2205///
2206/// Bias* epilogues always route through CUTLASS regardless of
2207/// (M, dtype) — cuBLAS-classic has no fused-bias-activation.
2208fn should_use_cublas_for_fp(
2209    desc: &GemmDescriptor,
2210    element: ElementKind,
2211) -> bool {
2212    // Bias / Bias* epilogues — cuBLAS-classic has no fused path.
2213    if desc.epilogue.requires_bias() {
2214        return false;
2215    }
2216    match element {
2217        // The 2 ≤ M < 128 window is where cuBLAS wins. M=1 is excluded
2218        // because the `transa=T` mapping costs more than cuBLAS-direct
2219        // saves over CUTLASS at K=N=4096-ish shapes. M=128+ is excluded
2220        // because CUTLASS sm_80 catches up at prefill scale.
2221        ElementKind::F16 | ElementKind::Bf16 => desc.m >= 2 && desc.m < 128,
2222        // f32 stays on CUTLASS by default — Phase 29 bench shows
2223        // CUTLASS f32 RCR beats cuBLAS sgemm at the M=128 prefill
2224        // shape. Callers can force via PlanPreference if needed.
2225        ElementKind::F32 => false,
2226        // F32Strict: CUTLASS SIMT path is the only bit-stable option.
2227        ElementKind::F32Strict => false,
2228        // f64: keep CUTLASS by default (back-compat). Callers can
2229        // force-prefer cuBLAS if their f64 workload is decode-shaped.
2230        ElementKind::F64 => false,
2231        // Any other element type isn't reachable through `GemmPlan<T>`
2232        // (the `Element` bound rejects non-FP types), but the match
2233        // arm needs to be exhaustive.
2234        _ => false,
2235    }
2236}
2237
2238/// Phase 44 — validate a `BackendKind::Ozaki { slices }` request
2239/// against this descriptor.
2240///
2241/// ozIMMU only supports FP64 GEMM with the Identity epilogue (no
2242/// fused bias / activation chain) and the `RCR` / `RRR` layouts that
2243/// baracuda's `GemmPlan` already exposes. The slice count must be
2244/// `0` (= auto) or `3..=18` (= fixed). Anything else returns
2245/// `Error::Unsupported` so callers see the rejection at plan-select
2246/// time rather than as a deep status code at launch.
2247///
2248/// When the `ozimmu` cargo feature is off, every request is rejected
2249/// with a message pointing at the gate.
2250#[cfg_attr(not(feature = "ozimmu"), allow(unused_variables))]
2251fn validate_ozaki_request(
2252    desc: &GemmDescriptor,
2253    element: ElementKind,
2254    slices: u8,
2255) -> Result<()> {
2256    #[cfg(not(feature = "ozimmu"))]
2257    {
2258        return Err(Error::Unsupported(
2259            "PlanPreference::prefer_backend = Some(Ozaki {..}) requires the \
2260             `ozimmu` cargo feature on baracuda-cutlass (off by default — \
2261             enable on baracuda-kernels too if going through the kernels facade)",
2262        ));
2263    }
2264    #[cfg(feature = "ozimmu")]
2265    {
2266        if element != ElementKind::F64 {
2267            return Err(Error::Unsupported(
2268                "BackendKind::Ozaki is FP64-only (Ozaki-scheme synthesizes \
2269                 DGEMM from int8; f16/bf16/f32/F32Strict have no Ozaki path)",
2270            ));
2271        }
2272        if desc.epilogue != EpilogueKind::Identity {
2273            return Err(Error::Unsupported(
2274                "BackendKind::Ozaki only supports the Identity epilogue \
2275                 (no fused bias / activation chain on the Ozaki path)",
2276            ));
2277        }
2278        // Layout is already constrained to Rcr / Rrr by the descriptor;
2279        // both are supported by `mtk::ozimmu::gemm`.
2280        //
2281        // Phase 44c extends the encoding to include variant flags.
2282        // The low 5 bits are the slice count (0 = auto, 3..=18 = fixed);
2283        // the high 3 bits are the variant (0 = Base, 1 = EF, 2 = RN,
2284        // 3 = H). Reject anything else.
2285        let s = slices & 0x1F; // low 5 bits = slice count
2286        let v = slices >> 5;   // high 3 bits = variant
2287        if s != 0 && !(3..=18).contains(&s) {
2288            return Err(Error::Unsupported(
2289                "BackendKind::Ozaki slice count (low 5 bits) must be 0 \
2290                 (auto) or 3..=18",
2291            ));
2292        }
2293        if v > 3 {
2294            return Err(Error::Unsupported(
2295                "BackendKind::Ozaki variant (high 3 bits) must be 0 (Base), \
2296                 1 (EF), 2 (RN), or 3 (H)",
2297            ));
2298        }
2299        Ok(())
2300    }
2301}
2302
2303/// Map a baracuda [`ElementKind`] to the cuBLAS [`cudaDataType_t`].
2304/// Returns `None` for element kinds that don't have a cuBLAS GemmEx
2305/// dtype (F32Strict, integer, FP8, …). Callers should fall back to
2306/// CUTLASS in the `None` case.
2307fn cublas_dtype_for(kind: ElementKind) -> Option<baracuda_cublas_sys::functions::cudaDataType_t> {
2308    use baracuda_cublas_sys::functions::cudaDataType_t;
2309    match kind {
2310        ElementKind::F16 => Some(cudaDataType_t::R_16F),
2311        ElementKind::Bf16 => Some(cudaDataType_t::R_16BF),
2312        ElementKind::F32 => Some(cudaDataType_t::R_32F),
2313        ElementKind::F64 => Some(cudaDataType_t::R_64F),
2314        _ => None,
2315    }
2316}
2317
2318// ============================================================================
2319// GemmPlan
2320// ============================================================================
2321
2322/// Selected GEMM kernel and the host-side metadata needed to launch it.
2323///
2324/// Plans are cheap to construct, hold no device memory, and are
2325/// `Send + Sync` for the same reason — they're pure host data. The
2326/// Phase-30 cuBLAS fast-path adds no per-plan state: cuBLAS handles
2327/// live in a thread-local cache so the plan itself stays trivially
2328/// thread-safe.
2329///
2330/// See the crate root for usage; key methods:
2331/// - [`select`](Self::select) — pick a kernel for a problem shape.
2332/// - [`can_implement`](Self::can_implement) — host-side validation.
2333/// - [`workspace_size`](Self::workspace_size) — bytes of scratch needed.
2334/// - [`run`](Self::run) — launch on a stream.
2335/// - [`sku`](Self::sku) — identity of the chosen kernel.
2336/// - [`backend`](Self::backend) — which backend (CUTLASS / cuBLAS) was
2337///   picked. Phase 30 added the cuBLAS fast-path for f16/bf16 low-M
2338///   decode shapes; the heuristic is documented on
2339///   [`should_use_cublas_for_fp`](self::should_use_cublas_for_fp).
2340#[derive(Debug)]
2341pub struct GemmPlan<T: CutlassElement> {
2342    desc: GemmDescriptor,
2343    sku: GemmSku,
2344    backend: BackendChoice,
2345    _element: PhantomData<T>,
2346}
2347
2348impl<T: CutlassElement> GemmPlan<T> {
2349    /// Pick a kernel for `desc`.
2350    ///
2351    /// Queries the stream's device for its compute capability and selects
2352    /// between the CUTLASS-sm_80 (forward-compatible across Ampere /
2353    /// Ada / Hopper), CUTLASS-sm_90a (Hopper-specialized, when feature-
2354    /// enabled and the device actually is Hopper), and the Phase-30
2355    /// cuBLAS fast-path. Build features filter what kernels are
2356    /// *available*; the device cap and the f16/bf16-low-M heuristic
2357    /// decide what to *use*. See [`should_use_cublas_for_fp`] for the
2358    /// dispatch rules. Override the heuristic via
2359    /// [`PlanPreference::prefer_backend`].
2360    pub fn select(stream: &Stream, desc: &GemmDescriptor, pref: PlanPreference) -> Result<Self> {
2361        check_descriptor(desc)?;
2362        // Bias-family kernels currently cover every `(Layout, ElementKind)`
2363        // pair: F16/Bf16 → 8 elements per epilogue access; F32 (TF32)
2364        // and F64 → 4 elements per access; F32Strict (SIMT) → 1. The
2365        // F32Strict bias path uses a vendored SIMT broadcast-epilogue
2366        // partial specialization (see
2367        // `crates/baracuda-cutlass-kernels-sys/kernels/include/
2368        // baracuda_simt_broadcast_epilogue.h`) because CUTLASS doesn't
2369        // wire one by default. F64 routes through the f64-scalar
2370        // dispatcher (`gemm_bias_sm80_run_f64`) because alpha/beta are
2371        // `f64` at the FFI. When new element kinds land that don't yet
2372        // have bias instantiations, reinstate a gate here that returns
2373        // `Error::Unsupported` so callers don't get a runtime status 3
2374        // deep inside the launch path.
2375
2376        // Phase 30: pick the backend (CUTLASS vs cuBLAS) before the
2377        // arch. Caller may force a backend via `pref.prefer_backend`;
2378        // otherwise the f16/bf16-low-M heuristic chooses.
2379        // Phase 44: `BackendKind::Ozaki { slices }` joins the dispatch
2380        // surface for FP64-only / Identity-only / RCR-or-RRR shapes
2381        // (gated behind the `ozimmu` feature on this crate).
2382        let element = T::KIND;
2383
2384        // Phase 44 ozIMMU dispatch — handled before the cuBLAS gate so
2385        // its preconditions (f64 / Identity / RCR|RRR / feature gate)
2386        // are validated up front.
2387        if let Some(BackendKind::Ozaki { slices }) = pref.prefer_backend {
2388            validate_ozaki_request(desc, element, slices)?;
2389            let arch_for_sku = pick_arch(stream, desc, pref)?;
2390            let backend = BackendChoice::Ozaki { slices };
2391            let sku = GemmSku {
2392                arch: arch_for_sku,
2393                layout: desc.layout,
2394                epilogue: desc.epilogue,
2395                element,
2396                bias_element: None,
2397            };
2398            return Ok(Self {
2399                desc: *desc,
2400                sku,
2401                backend,
2402                _element: PhantomData,
2403            });
2404        }
2405
2406        let use_cublas = match pref.prefer_backend {
2407            Some(BackendKind::Cublas) => {
2408                // Force-Cublas: validate that cuBLAS actually supports
2409                // this (layout, epilogue, element) triple. Bias*
2410                // epilogues are not supported (cuBLAS-classic has no
2411                // fused-bias-activation). F32Strict has no cuBLAS dtype.
2412                if desc.epilogue.requires_bias() {
2413                    return Err(Error::Unsupported(
2414                        "cuBLAS backend doesn't fuse bias activations \
2415                         (use Cutlass backend for Bias* epilogues)",
2416                    ));
2417                }
2418                if cublas_dtype_for(element).is_none() {
2419                    return Err(Error::Unsupported(
2420                        "cuBLAS backend has no GemmEx dtype for this element \
2421                         (F32Strict / integer / FP8 stay on Cutlass)",
2422                    ));
2423                }
2424                true
2425            }
2426            Some(BackendKind::Cutlass) => false,
2427            Some(BackendKind::Ozaki { .. }) => {
2428                // Unreachable — handled by the early return above.
2429                false
2430            }
2431            Some(_) => {
2432                // Other backend hints aren't meaningful for GEMM; treat
2433                // as "let the heuristic decide".
2434                should_use_cublas_for_fp(desc, element)
2435                    && cublas_dtype_for(element).is_some()
2436            }
2437            None => {
2438                should_use_cublas_for_fp(desc, element)
2439                    && cublas_dtype_for(element).is_some()
2440            }
2441        };
2442
2443        let (backend, sku_arch) = if use_cublas {
2444            // Capture device arch for the SKU record (so telemetry /
2445            // autotuner caches still see the real silicon), but the
2446            // backend choice routes through cuBLAS.
2447            let arch_for_sku = pick_arch(stream, desc, pref)?;
2448            (BackendChoice::Cublas, arch_for_sku)
2449        } else {
2450            let arch = pick_arch(stream, desc, pref)?;
2451            (BackendChoice::Cutlass { arch }, arch)
2452        };
2453
2454        let sku = GemmSku {
2455            arch: sku_arch,
2456            layout: desc.layout,
2457            epilogue: desc.epilogue,
2458            element,
2459            // Float-family bias kernels imply bias element = element type;
2460            // `None` distinguishes them from the int-family bias kernels
2461            // (which encode the bias element explicitly because it's
2462            // independent of the matrix dtype).
2463            bias_element: None,
2464        };
2465        Ok(Self {
2466            desc: *desc,
2467            sku,
2468            backend,
2469            _element: PhantomData,
2470        })
2471    }
2472
2473    /// Which backend this plan picked.
2474    ///
2475    /// CUTLASS (the default path) or cuBLAS (the Phase-30 fast-path
2476    /// for f16/bf16 low-M decode shapes). The dispatch heuristic is
2477    /// documented on [`should_use_cublas_for_fp`].
2478    pub fn backend(&self) -> BackendKind {
2479        self.backend.as_public()
2480    }
2481
2482    /// Validate that this plan can actually launch with `args`.
2483    ///
2484    /// Two-stage check:
2485    /// 1. **Host-side**: shape/stride/buffer-size validation in pure Rust.
2486    /// 2. **Kernel-side**: calls CUTLASS's `Gemm::can_implement` host
2487    ///    adapter via a no-launch FFI symbol to catch alignment and
2488    ///    kernel-support issues that the host can't see (e.g., the
2489    ///    selected tile's element-per-access requirement on `lda`/`ldb`).
2490    ///
2491    /// Returns without launching a kernel and without touching the device.
2492    /// Use this as a clean prelaunch branch point: if it returns `Ok`, the
2493    /// `run` call will succeed barring runtime CUDA errors.
2494    pub fn can_implement(&self, args: &GemmArgs<'_, T>) -> Result<()> {
2495        check_args(&self.desc, args)?;
2496
2497        // Phase 30: even when the plan picked the cuBLAS backend, we
2498        // still run the CUTLASS-side `can_implement` FFI check.
2499        // Rationale: the cuBLAS path falls back to CUTLASS under graph
2500        // capture (cuBLAS-classic calls aren't capture-safe), so the
2501        // CUTLASS launch path may execute regardless of which backend
2502        // the plan nominally targets. Better to validate the CUTLASS
2503        // path up front than discover an alignment problem deep in the
2504        // capture replay.
2505
2506        let a_ptr = args.a.data.as_raw().0 as *const c_void;
2507        let b_ptr = args.b.data.as_raw().0 as *const c_void;
2508        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
2509        let (c_ptr, ldc) = match &args.c {
2510            Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
2511            None => (core::ptr::null(), 0i64),
2512        };
2513        let bias_ptr = args
2514            .bias
2515            .as_ref()
2516            .map(|b| b.data.as_raw().0 as *const c_void)
2517            .unwrap_or(core::ptr::null());
2518
2519        let bias_family = self.sku.epilogue.requires_bias();
2520        let status = match (self.sku.arch, bias_family) {
2521            #[cfg(feature = "sm80")]
2522            (ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2523                dispatch::gemm_sm80_can_implement_f64(
2524                    self.sku.layout,
2525                    self.desc.m, self.desc.n, self.desc.k,
2526                    a_ptr, args.a.ld,
2527                    b_ptr, args.b.ld,
2528                    c_ptr, ldc,
2529                    d_ptr, args.d.ld,
2530                )
2531            },
2532            #[cfg(feature = "sm80")]
2533            (ArchSku::Sm80, false) => unsafe {
2534                dispatch::gemm_sm80_can_implement(
2535                    self.sku.layout,
2536                    T::KIND,
2537                    self.desc.m, self.desc.n, self.desc.k,
2538                    a_ptr, args.a.ld,
2539                    b_ptr, args.b.ld,
2540                    c_ptr, ldc,
2541                    d_ptr, args.d.ld,
2542                )
2543            },
2544            #[cfg(feature = "sm80")]
2545            (ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2546                dispatch::gemm_bias_sm80_can_implement_f64(
2547                    self.sku.layout,
2548                    self.sku.epilogue,
2549                    self.desc.m, self.desc.n, self.desc.k,
2550                    a_ptr, args.a.ld,
2551                    b_ptr, args.b.ld,
2552                    c_ptr, ldc,
2553                    d_ptr, args.d.ld,
2554                    bias_ptr,
2555                )
2556            },
2557            #[cfg(feature = "sm80")]
2558            (ArchSku::Sm80, true) => unsafe {
2559                dispatch::gemm_bias_sm80_can_implement(
2560                    self.sku.layout,
2561                    T::KIND,
2562                    self.sku.epilogue,
2563                    self.desc.m, self.desc.n, self.desc.k,
2564                    a_ptr, args.a.ld,
2565                    b_ptr, args.b.ld,
2566                    c_ptr, ldc,
2567                    d_ptr, args.d.ld,
2568                    bias_ptr,
2569                )
2570            },
2571            #[cfg(not(feature = "sm80"))]
2572            (ArchSku::Sm80, _) => {
2573                return Err(Error::Unsupported(
2574                    "sm80 selected but the `sm80` feature isn't enabled",
2575                ));
2576            }
2577            (ArchSku::Sm90a, _) => {
2578                return Err(Error::Unsupported(
2579                    "sm90a kernels not yet shipped (deferred until Hopper hardware available for validation)",
2580                ));
2581            }
2582            (ArchSku::Sm89, _) => {
2583                return Err(Error::Unsupported(
2584                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
2585                ));
2586            }
2587        };
2588
2589        status_to_result(status)
2590    }
2591
2592    /// Bytes of device scratch this plan needs at `run` time.
2593    ///
2594    /// Returns 0 when the kernel's launch is workspace-free; pass
2595    /// [`Workspace::None`] in that case.
2596    ///
2597    /// Phase 30 note: even when the plan picked the cuBLAS backend
2598    /// (which manages its own scratch internally), this method
2599    /// reports the **CUTLASS-side** workspace requirement. The cuBLAS
2600    /// path falls back to CUTLASS under graph capture (cuBLAS-classic
2601    /// calls aren't capture-safe), so the caller must size the
2602    /// workspace for the CUTLASS path in case the fallback triggers.
2603    /// In practice CUTLASS Identity-epilogue GEMM on sm_80 reports
2604    /// 0 bytes for most (M, N, K), so this conservative reporting
2605    /// rarely costs anything.
2606    pub fn workspace_size(&self) -> usize {
2607        let bias_family = self.sku.epilogue.requires_bias();
2608        match (self.sku.arch, bias_family) {
2609            #[cfg(feature = "sm80")]
2610            (ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => {
2611                dispatch::gemm_sm80_workspace_size_f64(
2612                    self.sku.layout,
2613                    self.desc.m, self.desc.n, self.desc.k,
2614                )
2615            }
2616            #[cfg(feature = "sm80")]
2617            (ArchSku::Sm80, false) => dispatch::gemm_sm80_workspace_size(
2618                self.sku.layout,
2619                T::KIND,
2620                self.desc.m, self.desc.n, self.desc.k,
2621            ),
2622            #[cfg(feature = "sm80")]
2623            (ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => {
2624                dispatch::gemm_bias_sm80_workspace_size_f64(
2625                    self.sku.layout,
2626                    self.sku.epilogue,
2627                    self.desc.m, self.desc.n, self.desc.k,
2628                )
2629            }
2630            #[cfg(feature = "sm80")]
2631            (ArchSku::Sm80, true) => dispatch::gemm_bias_sm80_workspace_size(
2632                self.sku.layout,
2633                T::KIND,
2634                self.sku.epilogue,
2635                self.desc.m, self.desc.n, self.desc.k,
2636            ),
2637            #[cfg(not(feature = "sm80"))]
2638            (ArchSku::Sm80, _) => 0,
2639            (ArchSku::Sm90a, _) => 0,
2640            (ArchSku::Sm89, _) => 0,
2641        }
2642    }
2643
2644    /// Identity of the kernel this plan chose.
2645    pub fn sku(&self) -> GemmSku {
2646        self.sku
2647    }
2648
2649    /// Numerical guarantees this plan's kernel provides.
2650    ///
2651    /// Convenience for [`GemmSku::precision_guarantee`] applied to this
2652    /// plan's SKU. Useful for callers that maintain a per-decision-point
2653    /// alternatives table (e.g. picking between cuBLAS and CUTLASS for a
2654    /// given precision contract) without having to re-derive the
2655    /// guarantees from per-kernel documentation.
2656    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
2657        self.sku.precision_guarantee()
2658    }
2659
2660    /// Launch the kernel.
2661    ///
2662    /// `workspace` must be at least [`workspace_size`](Self::workspace_size)
2663    /// bytes when non-zero, or [`Workspace::None`] when zero. The stream
2664    /// must be in the same context as the device buffers in `args`.
2665    pub fn run(
2666        &self,
2667        stream: &Stream,
2668        workspace: Workspace<'_>,
2669        args: GemmArgs<'_, T>,
2670    ) -> Result<()> {
2671        self.can_implement(&args)?;
2672
2673        let needed = self.workspace_size();
2674        let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
2675            Workspace::None => {
2676                if needed != 0 {
2677                    return Err(Error::WorkspaceTooSmall {
2678                        needed,
2679                        got: 0,
2680                    });
2681                }
2682                (core::ptr::null_mut(), 0)
2683            }
2684            Workspace::Borrowed(slice) => {
2685                if slice.len() < needed {
2686                    return Err(Error::WorkspaceTooSmall {
2687                        needed,
2688                        got: slice.len(),
2689                    });
2690                }
2691                (slice.as_raw().0 as *mut c_void, slice.len())
2692            }
2693        };
2694
2695        let a_ptr = args.a.data.as_raw().0 as *const c_void;
2696        let b_ptr = args.b.data.as_raw().0 as *const c_void;
2697        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
2698        let (c_ptr, ldc) = match &args.c {
2699            Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
2700            None => (core::ptr::null(), 0i64),
2701        };
2702        let bias_ptr = args
2703            .bias
2704            .as_ref()
2705            .map(|b| b.data.as_raw().0 as *const c_void)
2706            .unwrap_or(core::ptr::null());
2707        // When the caller passes c = None, force beta = 0 at the safe
2708        // layer. The kernel internally substitutes D for the C operand to
2709        // satisfy CUTLASS's non-null pointer contract, so a non-zero beta
2710        // here would silently fold the previous D contents into the
2711        // result (D += alpha*AB instead of D = alpha*AB).
2712        let beta_eff = if args.c.is_some() { args.beta } else { <T::Scalar as Default>::default() };
2713        let stream_raw = stream.as_raw();
2714
2715        // Phase 30: cuBLAS fast-path. Dispatches to `cublasGemmEx`
2716        // (f16/bf16/f32) or `cublasDgemm` (f64). cuBLAS is column-major;
2717        // we apply the row-major-from-col-major trick: compute
2718        // `D^T = (op_b B)^T · (op_a A)^T` in cuBLAS terms, which
2719        // lets us pass A/B straight through and read D row-major.
2720        //
2721        // For RCR (A row-major, B col-major, D row-major):
2722        //   cuBLAS sees A as col-major A^T [K, M] (lda = K),
2723        //   cuBLAS sees B as col-major B [K, N] (ldb = K),
2724        //   D^T col-major [N, M] (ldd = N) IS D row-major [M, N].
2725        //   So we call gemmEx(transa = T, transb = N, m=N, n=M, k=K,
2726        //                     B, ldb=K, A, lda=K, D, ldd=N).
2727        // For RRR (A row-major, B row-major, D row-major):
2728        //   B viewed as col-major is B^T [N, K] (ldb = N),
2729        //   so transb stays N (no transpose) but the operands are
2730        //   B^T·A^T = (A·B)^T → identical call shape with transa=N.
2731        //
2732        // Capture-mode guard: cuBLAS-classic calls aren't capture-safe
2733        // (they perform host-side handle-state mutations and may issue
2734        // host allocations / branches that break graph capture). Fall
2735        // back to CUTLASS when the stream is in capture mode. The
2736        // `is_capturing` query itself is graph-safe (it just reads
2737        // driver state).
2738        if matches!(self.backend, BackendChoice::Cublas) {
2739            let capturing = stream.is_capturing().unwrap_or(false);
2740            if !capturing {
2741                return self.run_cublas(stream, args, beta_eff);
2742            }
2743            // Fall through to the CUTLASS dispatch below. The
2744            // BackendChoice::Cublas flag stays set on the plan for
2745            // telemetry / sku() consistency; only this launch falls
2746            // back. Subsequent launches outside capture will route
2747            // back to cuBLAS.
2748        }
2749
2750        // Phase 44: ozIMMU dispatch. Same capture-safety story as the
2751        // cuBLAS path — ozIMMU calls into cuBLAS internally to drive
2752        // the int8 tensor-core matmuls + accumulate stage, so it
2753        // inherits cuBLAS's "not capture-safe" property. Fall back to
2754        // the CUTLASS DGEMM path under capture; the SKU stays Ozaki
2755        // for telemetry consistency.
2756        #[cfg(feature = "ozimmu")]
2757        if let BackendChoice::Ozaki { slices } = self.backend {
2758            let capturing = stream.is_capturing().unwrap_or(false);
2759            if !capturing {
2760                return self.run_ozaki(stream, args, beta_eff, slices);
2761            }
2762        }
2763        #[cfg(not(feature = "ozimmu"))]
2764        if matches!(self.backend, BackendChoice::Ozaki { .. }) {
2765            // Should be unreachable — `select` rejects Ozaki when the
2766            // feature is off. Belt-and-suspenders.
2767            return Err(Error::Unsupported(
2768                "BackendChoice::Ozaki selected without `ozimmu` cargo feature",
2769            ));
2770        }
2771
2772        let bias_family = self.sku.epilogue.requires_bias();
2773        let status = match (self.sku.arch, bias_family) {
2774            // Fork on T::Scalar::IS_F64 to pick the matching FFI dispatcher.
2775            // The two paths only differ in the alpha/beta types passed to
2776            // CUTLASS; the `to_f32` / `to_f64` calls are identity on the
2777            // matching impl and never narrow at runtime because the gate
2778            // statically picks the impl that matches `T::Scalar`.
2779            #[cfg(feature = "sm80")]
2780            (ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2781                dispatch::gemm_sm80_run_f64(
2782                    self.sku.layout,
2783                    self.desc.m, self.desc.n, self.desc.k,
2784                    a_ptr, args.a.ld,
2785                    b_ptr, args.b.ld,
2786                    c_ptr, ldc,
2787                    d_ptr, args.d.ld,
2788                    args.alpha.to_f64(),
2789                    beta_eff.to_f64(),
2790                    ws_ptr, ws_bytes, stream_raw,
2791                )
2792            },
2793            #[cfg(feature = "sm80")]
2794            (ArchSku::Sm80, false) => unsafe {
2795                dispatch::gemm_sm80_run(
2796                    self.sku.layout,
2797                    T::KIND,
2798                    self.desc.m, self.desc.n, self.desc.k,
2799                    a_ptr, args.a.ld,
2800                    b_ptr, args.b.ld,
2801                    c_ptr, ldc,
2802                    d_ptr, args.d.ld,
2803                    args.alpha.to_f32(),
2804                    beta_eff.to_f32(),
2805                    ws_ptr, ws_bytes, stream_raw,
2806                )
2807            },
2808            #[cfg(feature = "sm80")]
2809            (ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2810                dispatch::gemm_bias_sm80_run_f64(
2811                    self.sku.layout,
2812                    self.sku.epilogue,
2813                    self.desc.m, self.desc.n, self.desc.k,
2814                    a_ptr, args.a.ld,
2815                    b_ptr, args.b.ld,
2816                    c_ptr, ldc,
2817                    d_ptr, args.d.ld,
2818                    bias_ptr,
2819                    args.alpha.to_f64(),
2820                    beta_eff.to_f64(),
2821                    ws_ptr, ws_bytes, stream_raw,
2822                )
2823            },
2824            #[cfg(feature = "sm80")]
2825            (ArchSku::Sm80, true) => unsafe {
2826                dispatch::gemm_bias_sm80_run(
2827                    self.sku.layout,
2828                    T::KIND,
2829                    self.sku.epilogue,
2830                    self.desc.m, self.desc.n, self.desc.k,
2831                    a_ptr, args.a.ld,
2832                    b_ptr, args.b.ld,
2833                    c_ptr, ldc,
2834                    d_ptr, args.d.ld,
2835                    bias_ptr,
2836                    args.alpha.to_f32(),
2837                    beta_eff.to_f32(),
2838                    ws_ptr, ws_bytes, stream_raw,
2839                )
2840            },
2841            #[cfg(not(feature = "sm80"))]
2842            (ArchSku::Sm80, _) => {
2843                return Err(Error::Unsupported(
2844                    "sm80 selected but the `sm80` feature isn't enabled",
2845                ));
2846            }
2847            (ArchSku::Sm90a, _) => {
2848                return Err(Error::Unsupported(
2849                    "sm90a kernels not yet implemented (Phase 4c)",
2850                ));
2851            }
2852            (ArchSku::Sm89, _) => {
2853                return Err(Error::Unsupported(
2854                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
2855                ));
2856            }
2857        };
2858
2859        status_to_result(status)
2860    }
2861
2862    /// Phase-30 cuBLAS fast-path launch.
2863    ///
2864    /// Caller already validated host-side shape / strides via
2865    /// [`Self::can_implement`] and confirmed the chosen backend is
2866    /// [`BackendChoice::Cublas`]. This method translates baracuda's
2867    /// row-major operands into cuBLAS's column-major convention via
2868    /// the standard "compute D^T in col-major" trick, then dispatches
2869    /// to `cublasGemmEx` (f16/bf16/f32) or `cublasDgemm` (f64).
2870    fn run_cublas(
2871        &self,
2872        stream: &Stream,
2873        args: GemmArgs<'_, T>,
2874        beta_eff: T::Scalar,
2875    ) -> Result<()> {
2876        use baracuda_cublas::Op as CublasOp;
2877        use baracuda_cublas_sys::functions::cublasComputeType_t;
2878
2879        // Bias / Bias* epilogues aren't supported on the cuBLAS path —
2880        // `select` already rejected them. Belt-and-suspenders here.
2881        if self.sku.epilogue.requires_bias() {
2882            return Err(Error::Unsupported(
2883                "cuBLAS backend doesn't fuse bias activations \
2884                 (caller forced a Bias* epilogue onto the cuBLAS path)",
2885            ));
2886        }
2887
2888        let handle = cublas_backend::handle_for(stream)?;
2889
2890        let m = self.desc.m;
2891        let n = self.desc.n;
2892        let k = self.desc.k;
2893        let a_ptr = args.a.data.as_raw().0 as *const c_void;
2894        let b_ptr = args.b.data.as_raw().0 as *const c_void;
2895        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
2896        // C must equal D for cuBLAS's `C = α·op(A)·op(B) + β·C` shape
2897        // when the safe-layer caller passed C=None (treated as β=0).
2898        // When C is Some, we route it through. cuBLAS reads C with the
2899        // same layout/ldc as it writes D, so the same row-major-from-
2900        // col-major mapping applies.
2901        let (c_ptr, ldc_arg) = match &args.c {
2902            // cuBLAS expects β*C added in; pass the caller's C directly.
2903            Some(c) => (c.data.as_raw().0 as *mut c_void, c.ld as i32),
2904            None => (d_ptr, args.d.ld as i32),
2905        };
2906
2907        // Row-major-from-col-major mapping for baracuda's Rcr/Rrr:
2908        //
2909        // baracuda Rcr: A row-major [M, K] (lda = K),
2910        //               B col-major [K, N] (ldb = K),
2911        //               D row-major [M, N] (ldd = N).
2912        //
2913        // cuBLAS sees the same memory column-major:
2914        //   - A_raw[i*K + kk] col-major (lda=K) → A_blas[kk, i] = A[i, kk]
2915        //     ⇒ cuBLAS view of A is A^T, shape [K, M], lda=K.
2916        //   - B_raw[j*K + kk] col-major (ldb=K) → B_blas[kk, j] = B[kk, j]
2917        //     ⇒ cuBLAS view of B IS B, shape [K, N], ldb=K.
2918        //   - D_raw[i*N + j] col-major (ldd=N) → D_blas[j, i] = D[i, j]
2919        //     ⇒ cuBLAS view of D is D^T, shape [N, M], ldd=N.
2920        //
2921        // The target identity is D[i,j] = Σ_kk A[i,kk]·B[kk,j], which in
2922        // cuBLAS view is D_blas[j,i] = Σ_kk B_blas[kk,j] · A_blas[kk,i].
2923        // Expressed as a cuBLAS GEMM `D_blas = op(X)·op(Y)` with
2924        // D_blas shape (M_blas, N_blas) = (N, M):
2925        //   - op(X) shape (M_blas, K) = (N, K) ⇐ B_blas with Op::T
2926        //     (B_blas is [K, N], B_blas^T = [N, K]).
2927        //   - op(Y) shape (K, N_blas) = (K, M) ⇐ A_blas with Op::N
2928        //     (A_blas is already [K, M]).
2929        // So pass B as the first operand to gemmEx with Op::T and A as
2930        // the second operand with Op::N.
2931        //
2932        // For Rrr (A row-major [M,K], B row-major [K,N], D row-major
2933        // [M,N]):
2934        //   - B_raw[kk*N + j] col-major (ldb=N) → B_blas[j, kk] = B[kk, j]
2935        //     ⇒ cuBLAS view of B is B^T, shape [N, K], ldb=N.
2936        //   - First operand needs (M_blas, K) = (N, K) ⇐ B^T-as-stored
2937        //     with Op::N (no further transpose).
2938        //   - Second operand same as Rcr: A with Op::N, lda=K.
2939        //
2940        // `cublas_lda` and `cublas_ldb` below name the ld values *as
2941        // gemmEx expects them*: cublas_lda is the leading dim of the
2942        // first operand (which in our call is baracuda's B), and
2943        // cublas_ldb is the leading dim of the second operand (which
2944        // is baracuda's A). For both Rcr and Rrr, cublas_lda = args.b.ld
2945        // and cublas_ldb = args.a.ld; the only thing that changes is
2946        // the op on the first operand.
2947        let (transa, transb) = match self.desc.layout {
2948            LayoutSku::Rcr => (CublasOp::T, CublasOp::N),
2949            LayoutSku::Rrr => (CublasOp::N, CublasOp::N),
2950        };
2951        let cublas_lda = args.b.ld as i32; // first operand (B) ld
2952        let cublas_ldb = args.a.ld as i32; // second operand (A) ld
2953        let ldd_arg = args.d.ld as i32;
2954
2955        // Pick compute path by `T::Scalar` (the FP family bound on
2956        // CutlassElement guarantees f32 or f64). f64 routes through
2957        // `cublasDgemm`; everything else uses `cublasGemmEx` with
2958        // `CUBLAS_COMPUTE_32F` (f32 accumulator regardless of operand
2959        // dtype, matching cuBLAS's tensor-core path).
2960        if <T::Scalar as ScalarType>::IS_F64 {
2961            // f64 path: cublasDgemm. No GemmEx needed; α/β are f64.
2962            use baracuda_cublas_sys::cublasOperation_t;
2963
2964            // Translate cublasOperation_t for the raw helper.
2965            let to_raw = |op: CublasOp| match op {
2966                CublasOp::N => cublasOperation_t::N,
2967                CublasOp::T => cublasOperation_t::T,
2968                CublasOp::C => cublasOperation_t::C,
2969            };
2970            // cublasDgemm signature: (handle, transa, transb, m, n, k,
2971            //   alpha, A, lda, B, ldb, beta, C, ldc).
2972            // Layout-mapped operand order: B first, A second, output D.
2973            let alpha_f64 = args.alpha.to_f64();
2974            let beta_f64 = beta_eff.to_f64();
2975            let c_api = baracuda_cublas_sys::cublas()
2976                .map_err(|_| Error::Unsupported("cuBLAS library unavailable"))?;
2977            let dgemm = c_api
2978                .cublas_dgemm()
2979                .map_err(|_| Error::Unsupported("cublasDgemm symbol unavailable"))?;
2980            // SAFETY: every pointer is from a live DeviceBuffer with
2981            // the correct dtype; sizes were validated via check_args;
2982            // the handle is bound to `stream`'s context.
2983            let status = unsafe {
2984                dgemm(
2985                    handle.as_raw(),
2986                    to_raw(transa),
2987                    to_raw(transb),
2988                    n,
2989                    m,
2990                    k,
2991                    &alpha_f64,
2992                    b_ptr as *const f64,
2993                    cublas_lda,
2994                    a_ptr as *const f64,
2995                    cublas_ldb,
2996                    &beta_f64,
2997                    // cublasDgemm uses C as both input and output. Pass
2998                    // C-or-D depending on whether caller supplied C.
2999                    if args.c.is_some() {
3000                        // Caller passed a separate C. cuBLAS writes
3001                        // back into the same buffer it reads from, so
3002                        // we'd need a copy step to materialize into D.
3003                        // For now, restrict the cuBLAS f64 path to
3004                        // C=None (the common decode case).
3005                        return Err(Error::Unsupported(
3006                            "cuBLAS f64 GEMM with explicit C operand is not yet wired \
3007                             (D and C alias differently than cuBLAS expects); \
3008                             use Cutlass backend or set c = None",
3009                        ));
3010                    } else {
3011                        d_ptr as *mut f64
3012                    },
3013                    ldd_arg,
3014                )
3015            };
3016            return match status {
3017                baracuda_cublas_sys::cublasStatus_t::SUCCESS => Ok(()),
3018                _ => Err(Error::CutlassInternal(status.0)),
3019            };
3020        }
3021
3022        // f16/bf16/f32 path: cublasGemmEx with Compute32F accumulator
3023        // and the tensor-op default algo selector.
3024        let dtype = cublas_dtype_for(self.sku.element).ok_or(Error::Unsupported(
3025            "cuBLAS backend selected for element kind without a cuBLAS dtype mapping",
3026        ))?;
3027        // For the col-major-mapped problem the operands swap order; the
3028        // dtype tag is the same for A, B, and D since all three carry
3029        // the same `T`.
3030        let a_type = dtype;
3031        let b_type = dtype;
3032        let c_type = dtype;
3033        // Alpha/beta as f32 (the host scalars cuBLAS expects when
3034        // compute_type = Compute32F).
3035        let alpha_f32 = args.alpha.to_f32();
3036        let beta_f32 = beta_eff.to_f32();
3037
3038        // Caller passed an explicit C: cuBLAS gemmEx writes C in-place
3039        // (treats C as the output too). When `args.c == Some` and the
3040        // caller wants D != C, we'd need a memcpy step. For simplicity
3041        // (and matching the bench's no-C decode case) reject the
3042        // explicit-C path on the cuBLAS branch and tell the caller to
3043        // use Cutlass backend. The Identity / no-bias path is the
3044        // common decode shape we're optimizing.
3045        if args.c.is_some() {
3046            return Err(Error::Unsupported(
3047                "cuBLAS GemmPlan path requires c = None \
3048                 (cublasGemmEx writes the output in-place into the C operand; \
3049                 explicit-C with D ≠ C requires an extra copy step — \
3050                 force Cutlass backend if you need it)",
3051            ));
3052        }
3053        let _ = (c_ptr, ldc_arg); // suppress unused-let warning
3054
3055        // SAFETY: every pointer is from a live DeviceBuffer with the
3056        // correct dtype; sizes were validated via check_args; the
3057        // handle is bound to `stream`'s context.
3058        unsafe {
3059            baracuda_cublas::gemm_ex(
3060                &handle,
3061                transa,
3062                transb,
3063                n,
3064                m,
3065                k,
3066                &alpha_f32 as *const f32 as *const c_void,
3067                b_ptr,
3068                b_type,
3069                cublas_lda,
3070                a_ptr,
3071                a_type,
3072                cublas_ldb,
3073                &beta_f32 as *const f32 as *const c_void,
3074                d_ptr,
3075                c_type,
3076                ldd_arg,
3077                cublasComputeType_t::Compute32F,
3078                CUBLAS_GEMM_ALGO,
3079            )
3080            .map_err(|_| Error::CutlassInternal(-1))
3081        }
3082    }
3083
3084    /// Phase-44 ozIMMU dispatch launch.
3085    ///
3086    /// Caller already validated host-side shape / strides via
3087    /// [`Self::can_implement`], confirmed the chosen backend is
3088    /// [`BackendChoice::Ozaki`], and verified the stream is not in
3089    /// graph-capture mode (Ozaki is not capture-safe — ozIMMU runs
3090    /// cuBLAS internally on the int8 accumulate stage). This method
3091    /// translates baracuda's row-major operands into ozIMMU's
3092    /// cuBLAS-compatible column-major convention (same trick as the
3093    /// cuBLAS f64 path) and dispatches to `mtk::ozimmu::gemm`.
3094    ///
3095    /// Restrictions for the Phase 44 alpha:
3096    /// - `args.c` must be `None` (caller passes a C operand only when
3097    ///   they want `D = alpha*AB + beta*C` with `D != C`; ozIMMU's
3098    ///   GEMM uses C as the output operand, same as cublasDgemm, so
3099    ///   we'd need an extra copy to materialize into D — defer).
3100    /// - Element must be F64 — guarded at `select` already; this
3101    ///   method takes the safe-typed `T` so the compiler can't see
3102    ///   that, but the `Scalar::IS_F64` check below makes it
3103    ///   defense-in-depth.
3104    #[cfg(feature = "ozimmu")]
3105    fn run_ozaki(
3106        &self,
3107        stream: &Stream,
3108        args: GemmArgs<'_, T>,
3109        beta_eff: T::Scalar,
3110        slices: u8,
3111    ) -> Result<()> {
3112        use baracuda_ozimmu::{Op as OzakiOp, OzakiSlices, OzakiVariant};
3113
3114        if !<T::Scalar as ScalarType>::IS_F64 {
3115            return Err(Error::Unsupported(
3116                "BackendChoice::Ozaki reached on non-f64 element \
3117                 (select() guard should have rejected this)",
3118            ));
3119        }
3120        if args.c.is_some() {
3121            return Err(Error::Unsupported(
3122                "ozIMMU GemmPlan path requires c = None \
3123                 (the Ozaki path writes its output in-place into the C \
3124                 operand of the underlying cuBLAS GEMM — explicit-C with \
3125                 D ≠ C requires an extra copy step that the Phase 44 \
3126                 alpha does not yet wire; force Cutlass backend if needed)",
3127            ));
3128        }
3129
3130        // Phase 44c — decode the discriminant. Low 5 bits = slice
3131        // count (0 = auto, 3..=18 = fixed); high 3 bits = variant
3132        // (0 = Base, 1 = EF, 2 = RN, 3 = H). The pure-slices-only
3133        // values 0..=18 used by Phase 44b decode as Base, preserving
3134        // source-compat for existing callers.
3135        let s = slices & 0x1F;
3136        let v = slices >> 5;
3137        let slice_choice = match s {
3138            0 => OzakiSlices::Auto,
3139            3 => OzakiSlices::S3,
3140            4 => OzakiSlices::S4,
3141            5 => OzakiSlices::S5,
3142            6 => OzakiSlices::S6,
3143            7 => OzakiSlices::S7,
3144            8 => OzakiSlices::S8,
3145            9 => OzakiSlices::S9,
3146            10 => OzakiSlices::S10,
3147            11 => OzakiSlices::S11,
3148            12 => OzakiSlices::S12,
3149            13 => OzakiSlices::S13,
3150            14 => OzakiSlices::S14,
3151            15 => OzakiSlices::S15,
3152            16 => OzakiSlices::S16,
3153            17 => OzakiSlices::S17,
3154            18 => OzakiSlices::S18,
3155            _ => {
3156                return Err(Error::Unsupported(
3157                    "ozIMMU slice count out of range (validated at select; \
3158                     this is unreachable)",
3159                ));
3160            }
3161        };
3162        let variant_choice = match v {
3163            0 => OzakiVariant::Base,
3164            1 => OzakiVariant::EF,
3165            2 => OzakiVariant::RN,
3166            3 => OzakiVariant::H,
3167            _ => {
3168                return Err(Error::Unsupported(
3169                    "ozIMMU variant out of range (validated at select; \
3170                     this is unreachable)",
3171                ));
3172            }
3173        };
3174
3175        let handle = ozimmu_backend::handle_for(stream)?;
3176
3177        // Row-major → ozIMMU (cuBLAS-compatible col-major) mapping.
3178        // Same algebra as the cuBLAS f64 path: compute `D^T` in
3179        // col-major terms, swap the operand order, set transa = T for
3180        // RCR (because B in col-major IS B^T-transposed-from-row-major)
3181        // or transa = N for RRR (because B in col-major IS B from row
3182        // major).
3183        //
3184        // Pass B as the first operand, A as the second.
3185        let (transa, transb) = match self.desc.layout {
3186            LayoutSku::Rcr => (OzakiOp::T, OzakiOp::N),
3187            LayoutSku::Rrr => (OzakiOp::N, OzakiOp::N),
3188        };
3189        let m = self.desc.m as usize;
3190        let n = self.desc.n as usize;
3191        let k = self.desc.k as usize;
3192        let lda = args.b.ld as usize; // first operand (B) ld
3193        let ldb = args.a.ld as usize; // second operand (A) ld
3194        let ldc = args.d.ld as usize;
3195
3196        let a_ptr = args.a.data.as_raw().0 as *const f64;
3197        let b_ptr = args.b.data.as_raw().0 as *const f64;
3198        let d_ptr = args.d.data.as_raw().0 as *mut f64;
3199        let alpha = args.alpha.to_f64();
3200        let beta = beta_eff.to_f64();
3201
3202        // SAFETY: the descriptor / args were validated by
3203        // can_implement(); pointers are live DeviceBuffer<f64> views.
3204        //
3205        // Phase 44c — always go through dgemm_with_variant. Base
3206        // variant is bit-identical to the pre-44c dgemm path; the
3207        // variant dispatch happens inside the C++ shim with no
3208        // measurable per-call host overhead.
3209        unsafe {
3210            handle.dgemm_with_variant(
3211                transa, transb,
3212                // ozIMMU sees us computing D^T col-major: shape (n, m).
3213                n, m, k,
3214                alpha,
3215                b_ptr, lda,
3216                a_ptr, ldb,
3217                beta,
3218                d_ptr, ldc,
3219                slice_choice,
3220                variant_choice,
3221            )
3222            .map_err(|e| {
3223                use baracuda_ozimmu::Error as OzErr;
3224                match e {
3225                    OzErr::DgemmFailed(s) => Error::CutlassInternal(s),
3226                    _ => Error::Unsupported(
3227                        "ozIMMU dgemm rejected the request (see logs)",
3228                    ),
3229                }
3230            })
3231        }
3232    }
3233}
3234
3235// ============================================================================
3236// Phase 44 ozIMMU backend — handle cache + dispatch helpers
3237// ============================================================================
3238//
3239// Mirror of the Phase-30 `cublas_backend` module: thread-local cache of
3240// ozIMMU handles keyed by raw context pointer. ozIMMU handles are
3241// expensive to construct (they spin up a cuBLAS handle internally + do
3242// some env-var probing); cache + re-bind keeps the steady-state launch
3243// cost down to one `set_cuda_stream` call. Returns an `Arc`-cloned
3244// safe wrapper so the cache can hold the canonical handle while the
3245// caller drives a launch without holding the thread-local borrow.
3246
3247#[cfg(feature = "ozimmu")]
3248mod ozimmu_backend {
3249    use core::cell::RefCell;
3250    use std::rc::Rc;
3251
3252    use baracuda_driver::Stream;
3253    use baracuda_ozimmu::Handle as OzimmuHandle;
3254
3255    thread_local! {
3256        static HANDLE_CACHE: RefCell<Vec<(usize, Rc<OzimmuHandle>)>> =
3257            const { RefCell::new(Vec::new()) };
3258    }
3259
3260    /// Fetch (or lazily create) an ozIMMU handle bound to `stream`'s
3261    /// context, with the handle's stream binding set to `stream`.
3262    pub(super) fn handle_for(stream: &Stream) -> crate::Result<Rc<OzimmuHandle>> {
3263        let ctx_key = stream.context().as_raw() as usize;
3264        let handle = HANDLE_CACHE.with(|cache| -> crate::Result<Rc<OzimmuHandle>> {
3265            let mut cache = cache.borrow_mut();
3266            if let Some((_, h)) = cache.iter().find(|(k, _)| *k == ctx_key) {
3267                return Ok(h.clone());
3268            }
3269            // Make sure the stream's context is current — ozIMMU's
3270            // create() calls into cuBLAS create() which needs the
3271            // current context to bind correctly.
3272            stream
3273                .context()
3274                .set_current()
3275                .map_err(crate::Error::Driver)?;
3276            // Retry with linear backoff under transient init failures
3277            // (mirrors the cuBLAS Phase-35 retry pattern; ozIMMU
3278            // wraps cuBLAS so it inherits the same parallel-init
3279            // race window).
3280            let mut last_status: Option<i32> = None;
3281            let mut handle: Option<OzimmuHandle> = None;
3282            for attempt in 0..5 {
3283                match OzimmuHandle::new() {
3284                    Ok(h) => { handle = Some(h); break }
3285                    Err(e) => {
3286                        if let baracuda_ozimmu::Error::CreateFailed(s) = e {
3287                            last_status = Some(s);
3288                        }
3289                        std::thread::sleep(std::time::Duration::from_millis(
3290                            50 * (attempt as u64 + 1),
3291                        ));
3292                    }
3293                }
3294            }
3295            let h = match handle {
3296                Some(h) => h,
3297                None => {
3298                    let _ = last_status;
3299                    return Err(crate::Error::Unsupported(
3300                        "ozIMMU handle creation failed after 5 retries \
3301                         (library missing, device unavailable, or persistent \
3302                         init contention)",
3303                    ));
3304                }
3305            };
3306            let rc = Rc::new(h);
3307            cache.push((ctx_key, rc.clone()));
3308            Ok(rc)
3309        })?;
3310        handle.set_stream(stream);
3311        Ok(handle)
3312    }
3313}
3314
3315// ============================================================================
3316// BatchedGemmPlan — uniform-shape batched GEMM
3317// ============================================================================
3318//
3319// All batches share `(M, N, K)`; per-batch operands are addressed by
3320// adding `i * stride_*` (in elements) to the base pointer. For
3321// variable-shape grouped problems use `GroupedGemmPlan` instead.
3322//
3323// v1 coverage: Rcr layout, F16 / Bf16 elements, sm_80, Identity epilogue.
3324
3325fn check_batched_descriptor(desc: &BatchedGemmDescriptor) -> Result<()> {
3326    if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
3327        return Err(Error::InvalidProblem("M, N, K must all be positive"));
3328    }
3329    if desc.batch_count <= 0 {
3330        return Err(Error::InvalidProblem("batch_count must be positive"));
3331    }
3332    if desc.epilogue != EpilogueKind::Identity {
3333        return Err(Error::Unsupported(
3334            "BatchedGemmPlan v1 supports only EpilogueKind::Identity",
3335        ));
3336    }
3337    Ok(())
3338}
3339
3340fn check_batched_args<T: CutlassElement>(
3341    desc: &BatchedGemmDescriptor,
3342    args: &BatchedGemmArgs<'_, T>,
3343) -> Result<()> {
3344    // Per-batch shape validation matches the single-GEMM path; strides
3345    // are validated by checking that the last batch's max-addressable
3346    // element fits within each base buffer.
3347    if args.a.rows != desc.m || args.a.cols != desc.k {
3348        return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
3349    }
3350    if args.b.rows != desc.k || args.b.cols != desc.n {
3351        return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
3352    }
3353    if args.d.rows != desc.m || args.d.cols != desc.n {
3354        return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
3355    }
3356    if let Some(c) = &args.c {
3357        if c.rows != desc.m || c.cols != desc.n {
3358            return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
3359        }
3360    }
3361    if args.a.ld < desc.k as i64 {
3362        return Err(Error::InvalidProblem("A leading dimension must be >= K"));
3363    }
3364    let b_min_ld = match desc.layout {
3365        LayoutSku::Rcr => desc.k as i64,
3366        LayoutSku::Rrr => desc.n as i64,
3367    };
3368    if args.b.ld < b_min_ld {
3369        return Err(Error::InvalidProblem("B leading dimension too small for layout"));
3370    }
3371    if args.d.ld < desc.n as i64 {
3372        return Err(Error::InvalidProblem("D leading dimension must be >= N"));
3373    }
3374    if let Some(c) = &args.c {
3375        if c.ld < desc.n as i64 {
3376            return Err(Error::InvalidProblem("C leading dimension must be >= N"));
3377        }
3378    }
3379
3380    // Per-batch element footprint = single-batch min + (batch - 1) * stride.
3381    // Stride 0 means "broadcast same matrix across all batches" — the
3382    // single-batch min is the only constraint there.
3383    fn need_for_batches(
3384        per_batch_min: usize,
3385        stride: i64,
3386        batch_count: i32,
3387    ) -> Option<usize> {
3388        if batch_count <= 1 || stride == 0 {
3389            return Some(per_batch_min);
3390        }
3391        let extra = stride.checked_mul((batch_count - 1) as i64)?;
3392        let extra = usize::try_from(extra).ok()?;
3393        per_batch_min.checked_add(extra)
3394    }
3395
3396    let a_per = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
3397        .ok_or(Error::InvalidProblem("A storage size overflow"))?;
3398    let need_a = need_for_batches(a_per, args.stride_a, desc.batch_count)
3399        .ok_or(Error::InvalidProblem("A batched storage size overflow"))?;
3400    if args.a.data.len() < need_a {
3401        return Err(Error::BufferTooSmall {
3402            needed: need_a,
3403            got: args.a.data.len(),
3404        });
3405    }
3406
3407    let b_per = match desc.layout {
3408        LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
3409        LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
3410    }
3411    .ok_or(Error::InvalidProblem("B storage size overflow"))?;
3412    let need_b = need_for_batches(b_per, args.stride_b, desc.batch_count)
3413        .ok_or(Error::InvalidProblem("B batched storage size overflow"))?;
3414    if args.b.data.len() < need_b {
3415        return Err(Error::BufferTooSmall {
3416            needed: need_b,
3417            got: args.b.data.len(),
3418        });
3419    }
3420
3421    let d_per = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
3422        .ok_or(Error::InvalidProblem("D storage size overflow"))?;
3423    let need_d = need_for_batches(d_per, args.stride_d, desc.batch_count)
3424        .ok_or(Error::InvalidProblem("D batched storage size overflow"))?;
3425    if args.d.data.len() < need_d {
3426        return Err(Error::BufferTooSmall {
3427            needed: need_d,
3428            got: args.d.data.len(),
3429        });
3430    }
3431
3432    if let Some(c) = &args.c {
3433        let c_per = min_elements_row_major(c.rows, c.cols, c.ld)
3434            .ok_or(Error::InvalidProblem("C storage size overflow"))?;
3435        let need_c = need_for_batches(c_per, args.stride_c, desc.batch_count)
3436            .ok_or(Error::InvalidProblem("C batched storage size overflow"))?;
3437        if c.data.len() < need_c {
3438            return Err(Error::BufferTooSmall {
3439                needed: need_c,
3440                got: c.data.len(),
3441            });
3442        }
3443    }
3444    Ok(())
3445}
3446
3447/// Plan for a uniform-shape batched GEMM launch.
3448///
3449/// All batches share `(M, N, K)`. Plans hold host-side selection
3450/// metadata only — no device memory, cheap to clone, `Send + Sync`.
3451///
3452/// See [`GroupedGemmPlan`] for the variable-shape (per-group) case.
3453#[derive(Debug)]
3454pub struct BatchedGemmPlan<T: CutlassElement> {
3455    desc: BatchedGemmDescriptor,
3456    sku: GemmSku,
3457    _element: PhantomData<T>,
3458}
3459
3460impl<T: CutlassElement> BatchedGemmPlan<T> {
3461    /// Pick a kernel for `desc`.
3462    ///
3463    /// Returns [`Error::Unsupported`] when the requested
3464    /// `(layout, element)` combination has no shipped batched kernel
3465    /// (today: anything other than `Rcr × {F16, Bf16}`).
3466    pub fn select(
3467        stream: &Stream,
3468        desc: &BatchedGemmDescriptor,
3469        pref: PlanPreference,
3470    ) -> Result<Self> {
3471        check_batched_descriptor(desc)?;
3472        let one_off_desc = GemmDescriptor {
3473            m: desc.m,
3474            n: desc.n,
3475            k: desc.k,
3476            layout: desc.layout,
3477            epilogue: desc.epilogue,
3478        };
3479        let arch = pick_arch(stream, &one_off_desc, pref)?;
3480        // v1 coverage gate: only Rcr × {F16, Bf16} have batched kernels.
3481        match (desc.layout, T::KIND) {
3482            (LayoutSku::Rcr, ElementKind::F16) | (LayoutSku::Rcr, ElementKind::Bf16) => {}
3483            _ => {
3484                return Err(Error::Unsupported(
3485                    "BatchedGemmPlan v1 only ships Rcr × {F16, Bf16} on sm_80",
3486                ));
3487            }
3488        }
3489        let sku = GemmSku {
3490            arch,
3491            layout: desc.layout,
3492            epilogue: desc.epilogue,
3493            element: T::KIND,
3494            // Float-family bias kernels imply bias element = element type;
3495            // `None` distinguishes them from the int-family bias kernels
3496            // (which encode the bias element explicitly because it's
3497            // independent of the matrix dtype).
3498            bias_element: None,
3499        };
3500        Ok(Self {
3501            desc: *desc,
3502            sku,
3503            _element: PhantomData,
3504        })
3505    }
3506
3507    /// Validate that this plan can launch with `args`. See
3508    /// [`GemmPlan::can_implement`] for the two-stage host + kernel check.
3509    pub fn can_implement(&self, args: &BatchedGemmArgs<'_, T>) -> Result<()> {
3510        check_batched_args(&self.desc, args)?;
3511
3512        let a_ptr = args.a.data.as_raw().0 as *const c_void;
3513        let b_ptr = args.b.data.as_raw().0 as *const c_void;
3514        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
3515        let (c_ptr, ldc, stride_c) = match &args.c {
3516            Some(c) => (c.data.as_raw().0 as *const c_void, c.ld, args.stride_c),
3517            None => (core::ptr::null(), 0i64, 0i64),
3518        };
3519
3520        let status = match self.sku.arch {
3521            #[cfg(feature = "sm80")]
3522            ArchSku::Sm80 => unsafe {
3523                dispatch::batched_gemm_sm80_can_implement(
3524                    self.sku.layout,
3525                    T::KIND,
3526                    self.desc.m,
3527                    self.desc.n,
3528                    self.desc.k,
3529                    a_ptr,
3530                    args.a.ld,
3531                    args.stride_a,
3532                    b_ptr,
3533                    args.b.ld,
3534                    args.stride_b,
3535                    c_ptr,
3536                    ldc,
3537                    stride_c,
3538                    d_ptr,
3539                    args.d.ld,
3540                    args.stride_d,
3541                    self.desc.batch_count,
3542                )
3543            },
3544            #[cfg(not(feature = "sm80"))]
3545            ArchSku::Sm80 => {
3546                return Err(Error::Unsupported(
3547                    "sm80 selected but the `sm80` feature isn't enabled",
3548                ));
3549            }
3550            ArchSku::Sm90a => {
3551                return Err(Error::Unsupported(
3552                    "sm90a batched kernels not yet shipped",
3553                ));
3554            }
3555            ArchSku::Sm89 => {
3556                return Err(Error::Unsupported(
3557                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
3558                ));
3559            }
3560        };
3561
3562        status_to_result(status)
3563    }
3564
3565    /// Bytes of device scratch this plan needs at `run` time.
3566    pub fn workspace_size(&self) -> usize {
3567        match self.sku.arch {
3568            #[cfg(feature = "sm80")]
3569            ArchSku::Sm80 => dispatch::batched_gemm_sm80_workspace_size(
3570                self.sku.layout,
3571                T::KIND,
3572                self.desc.m,
3573                self.desc.n,
3574                self.desc.k,
3575                self.desc.batch_count,
3576            ),
3577            #[cfg(not(feature = "sm80"))]
3578            ArchSku::Sm80 => 0,
3579            ArchSku::Sm90a => 0,
3580            ArchSku::Sm89 => 0,
3581        }
3582    }
3583
3584    /// Identity of the kernel this plan chose.
3585    pub fn sku(&self) -> GemmSku {
3586        self.sku
3587    }
3588
3589    /// Numerical guarantees this plan's kernel provides.
3590    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
3591        self.sku.precision_guarantee()
3592    }
3593
3594    /// Launch the batched kernel.
3595    pub fn run(
3596        &self,
3597        stream: &Stream,
3598        workspace: Workspace<'_>,
3599        args: BatchedGemmArgs<'_, T>,
3600    ) -> Result<()> {
3601        self.can_implement(&args)?;
3602
3603        let needed = self.workspace_size();
3604        let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
3605            Workspace::None => {
3606                if needed != 0 {
3607                    return Err(Error::WorkspaceTooSmall { needed, got: 0 });
3608                }
3609                (core::ptr::null_mut(), 0)
3610            }
3611            Workspace::Borrowed(slice) => {
3612                if slice.len() < needed {
3613                    return Err(Error::WorkspaceTooSmall {
3614                        needed,
3615                        got: slice.len(),
3616                    });
3617                }
3618                (slice.as_raw().0 as *mut c_void, slice.len())
3619            }
3620        };
3621
3622        let a_ptr = args.a.data.as_raw().0 as *const c_void;
3623        let b_ptr = args.b.data.as_raw().0 as *const c_void;
3624        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
3625        let (c_ptr, ldc, stride_c) = match &args.c {
3626            Some(c) => (c.data.as_raw().0 as *const c_void, c.ld, args.stride_c),
3627            None => (core::ptr::null(), 0i64, 0i64),
3628        };
3629        let beta_eff = if args.c.is_some() { args.beta } else { <T::Scalar as Default>::default() };
3630        let stream_raw = stream.as_raw();
3631
3632        let status = match self.sku.arch {
3633            #[cfg(feature = "sm80")]
3634            ArchSku::Sm80 => unsafe {
3635                // Batched GEMM v1 ships only Rcr × {F16, Bf16} — both
3636                // f32-scalar — so the `to_f32` conversion is identity.
3637                // The select gate rejects F64/F32Strict before we reach here.
3638                dispatch::batched_gemm_sm80_run(
3639                    self.sku.layout,
3640                    T::KIND,
3641                    self.desc.m,
3642                    self.desc.n,
3643                    self.desc.k,
3644                    a_ptr,
3645                    args.a.ld,
3646                    args.stride_a,
3647                    b_ptr,
3648                    args.b.ld,
3649                    args.stride_b,
3650                    c_ptr,
3651                    ldc,
3652                    stride_c,
3653                    d_ptr,
3654                    args.d.ld,
3655                    args.stride_d,
3656                    args.alpha.to_f32(),
3657                    beta_eff.to_f32(),
3658                    self.desc.batch_count,
3659                    ws_ptr,
3660                    ws_bytes,
3661                    stream_raw,
3662                )
3663            },
3664            #[cfg(not(feature = "sm80"))]
3665            ArchSku::Sm80 => {
3666                return Err(Error::Unsupported(
3667                    "sm80 selected but the `sm80` feature isn't enabled",
3668                ));
3669            }
3670            ArchSku::Sm90a => {
3671                return Err(Error::Unsupported(
3672                    "sm90a batched kernels not yet shipped",
3673                ));
3674            }
3675            ArchSku::Sm89 => {
3676                return Err(Error::Unsupported(
3677                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
3678                ));
3679            }
3680        };
3681
3682        status_to_result(status)
3683    }
3684}
3685
3686fn pick_arch(
3687    stream: &Stream,
3688    _desc: &GemmDescriptor,
3689    pref: PlanPreference,
3690) -> Result<ArchSku> {
3691    // Selection policy:
3692    //   1. Query the stream's device for its compute capability.
3693    //   2. Prefer sm90a when (a) the caller didn't disable it via
3694    //      `pref.allow_sm90a == false`, (b) the `sm90a` feature is on,
3695    //      and (c) the device is actually Hopper (cap >= 9.0).
3696    //   3. Otherwise fall back to sm80, which runs forward-compatibly on
3697    //      Ampere (sm_80), Ada (sm_89), and Hopper (sm_90+) at lower peak
3698    //      perf than sm90a.
3699    //
3700    // Build features control what kernels are *available*; this function
3701    // controls what's *picked* given the actual device.
3702    let (major, _minor) = stream.context().device().compute_capability()?;
3703
3704    if pref.allow_sm90a && cfg!(feature = "sm90a") && major >= 9 {
3705        return Ok(ArchSku::Sm90a);
3706    }
3707
3708    if cfg!(feature = "sm80") {
3709        // sm80 kernels are PTX-forward-compatible to anything sm_80+.
3710        if major >= 8 {
3711            return Ok(ArchSku::Sm80);
3712        }
3713        return Err(Error::Unsupported(
3714            "device compute capability < 8.0; sm_80 kernels won't run here",
3715        ));
3716    }
3717
3718    Err(Error::Unsupported(
3719        "no arch features enabled — build with --features sm80",
3720    ))
3721}
3722
3723// ============================================================================
3724// Grouped GEMM
3725// ============================================================================
3726//
3727// Architecture (per Fuel team's design review):
3728//   - `GroupedGemmPlan` holds host-only selection metadata (SKU, schedule
3729//     mode, epilogue kind). Cheap to clone, no device allocations.
3730//   - `prepare()` packs per-group host arrays (problem_sizes, ptr arrays,
3731//     ld arrays) and computes the threadblock count + scratch size for the
3732//     specific problem set. Returns a `PreparedGroupedGemm`.
3733//   - `PreparedGroupedGemm::workspace_size()` reports total bytes needed
3734//     (metadata layout + CUTLASS internal scratch, with alignment padding).
3735//   - `PreparedGroupedGemm::run(stream, workspace)` uploads the host
3736//     metadata to the start of the workspace via async H2D, computes
3737//     workspace pointer offsets, and launches the kernel.
3738//
3739// Workspace layout (caller-supplied; baracuda-cutlass owns this):
3740//
3741//     0                              metadata_end                  total
3742//     |  problem_sizes               |    pad   | CUTLASS scratch  |
3743//     |  ptr_a, ptr_b, ptr_c, ptr_d  |          |                  |
3744//     |  lda,   ldb,   ldc,   ldd    |          |                  |
3745//
3746// All v0 limitations:
3747//   - All groups must share the same `(alpha, beta)` epilogue params.
3748//   - All groups must consistently have `c = None` or `c = Some(_)`.
3749//   - Identity epilogue only (Bias deferred per Fuel team roadmap).
3750
3751const COORD_BYTES: usize = 12; // [i32; 3]
3752const PTR_BYTES: usize = 8; // u64
3753const LD_BYTES: usize = 8; // i64
3754const SCRATCH_ALIGN: usize = 256; // CUTLASS internal scratch wants ≥128B; 256 is safe
3755
3756#[inline]
3757fn align_up(x: usize, align: usize) -> usize {
3758    (x + align - 1) & !(align - 1)
3759}
3760
3761/// Byte offsets for each metadata array within the caller-supplied workspace.
3762#[derive(Copy, Clone, Debug)]
3763struct MetadataLayout {
3764    problem_sizes_offset: usize,
3765    ptr_a_offset: usize,
3766    ptr_b_offset: usize,
3767    ptr_c_offset: usize,
3768    ptr_d_offset: usize,
3769    lda_offset: usize,
3770    ldb_offset: usize,
3771    ldc_offset: usize,
3772    ldd_offset: usize,
3773    /// First byte past the packed metadata, before scratch alignment.
3774    metadata_end: usize,
3775    /// Aligned start of CUTLASS internal scratch.
3776    scratch_offset: usize,
3777    /// Total workspace bytes needed.
3778    total_workspace_bytes: usize,
3779}
3780
3781impl MetadataLayout {
3782    fn compute(group_count: usize, scratch_bytes: usize) -> Self {
3783        let mut off = 0usize;
3784        let problem_sizes_offset = off;
3785        off += COORD_BYTES * group_count;
3786        off = align_up(off, 8);
3787
3788        let ptr_a_offset = off;
3789        off += PTR_BYTES * group_count;
3790        let ptr_b_offset = off;
3791        off += PTR_BYTES * group_count;
3792        let ptr_c_offset = off;
3793        off += PTR_BYTES * group_count;
3794        let ptr_d_offset = off;
3795        off += PTR_BYTES * group_count;
3796        let lda_offset = off;
3797        off += LD_BYTES * group_count;
3798        let ldb_offset = off;
3799        off += LD_BYTES * group_count;
3800        let ldc_offset = off;
3801        off += LD_BYTES * group_count;
3802        let ldd_offset = off;
3803        off += LD_BYTES * group_count;
3804        let metadata_end = off;
3805
3806        let scratch_offset = align_up(metadata_end, SCRATCH_ALIGN);
3807        let total_workspace_bytes = scratch_offset + scratch_bytes;
3808
3809        Self {
3810            problem_sizes_offset,
3811            ptr_a_offset,
3812            ptr_b_offset,
3813            ptr_c_offset,
3814            ptr_d_offset,
3815            lda_offset,
3816            ldb_offset,
3817            ldc_offset,
3818            ldd_offset,
3819            metadata_end,
3820            scratch_offset,
3821            total_workspace_bytes,
3822        }
3823    }
3824}
3825
3826/// Plan for a grouped (per-problem variable shape) GEMM launch.
3827///
3828/// Cheap host-only struct — selection metadata + a cloned [`Context`]
3829/// handle so [`prepare`](Self::prepare) can allocate pinned host memory
3830/// without a stream argument. The cloned context is `Arc`-backed in
3831/// baracuda-driver, so cloning is cheap and the plan stays `Send + Sync`.
3832///
3833/// Use [`prepare`](Self::prepare) to bind a concrete slice of
3834/// [`GroupedProblem`]s to this plan and produce a [`PreparedGroupedGemm`]
3835/// that owns pinned host scratch and can launch capture-safely.
3836#[derive(Debug)]
3837pub struct GroupedGemmPlan<T: CutlassElement> {
3838    sku: GemmSku,
3839    schedule: GroupedScheduleMode,
3840    context: Context,
3841    _element: PhantomData<T>,
3842}
3843
3844impl<T: CutlassElement> GroupedGemmPlan<T> {
3845    /// Pick a grouped-GEMM kernel for the given epilogue and preferences.
3846    ///
3847    /// v0 supports only [`EpilogueKind::Identity`]. Selection arch follows
3848    /// the same device-cap-aware logic as [`GemmPlan::select`].
3849    pub fn select(
3850        stream: &Stream,
3851        epilogue: EpilogueKind,
3852        pref: GroupedPlanPreference,
3853    ) -> Result<Self> {
3854        if epilogue != EpilogueKind::Identity {
3855            return Err(Error::Unsupported(
3856                "v0 grouped GEMM supports only EpilogueKind::Identity",
3857            ));
3858        }
3859
3860        let dummy_desc = GemmDescriptor {
3861            m: 1,
3862            n: 1,
3863            k: 1,
3864            layout: LayoutSku::Rcr,
3865            epilogue,
3866        };
3867        let arch = pick_arch(stream, &dummy_desc, pref.base)?;
3868        let sku = GemmSku {
3869            arch,
3870            layout: LayoutSku::Rcr,
3871            epilogue,
3872            element: T::KIND,
3873            bias_element: None,
3874        };
3875        Ok(Self {
3876            sku,
3877            schedule: pref.schedule,
3878            context: stream.context().clone(),
3879            _element: PhantomData,
3880        })
3881    }
3882
3883    /// Identity of the kernel this plan chose.
3884    pub fn sku(&self) -> GemmSku {
3885        self.sku
3886    }
3887
3888    /// Numerical guarantees this plan's kernel provides.
3889    ///
3890    /// See [`GemmPlan::precision_guarantee`] for usage.
3891    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
3892        self.sku.precision_guarantee()
3893    }
3894
3895    /// Schedule mode this plan was selected with.
3896    pub fn schedule(&self) -> GroupedScheduleMode {
3897        self.schedule
3898    }
3899
3900    /// Bind a concrete set of per-group problems to this plan.
3901    ///
3902    /// Performs host-side validation, queries CUTLASS for the threadblock
3903    /// count and scratch-bytes requirement, and packs the per-group host
3904    /// arrays. The returned [`PreparedGroupedGemm`] holds host-side
3905    /// metadata only — no device allocations, and crucially **no Rust
3906    /// borrow on the input `groups` slice**: device pointers are extracted
3907    /// into pinned memory during this call. The caller is free to drop
3908    /// `groups` immediately after; the underlying device buffers must be
3909    /// kept alive for as long as the prepared plan (or any captured graph
3910    /// referencing it) is in use.
3911    pub fn prepare<'a, 'g>(
3912        &'a self,
3913        groups: &'g [GroupedProblem<'g, T>],
3914    ) -> Result<PreparedGroupedGemm<'a, T>> {
3915        if groups.is_empty() {
3916            return Err(Error::InvalidProblem("grouped GEMM requires at least one group"));
3917        }
3918
3919        // v0 invariants enforced here, before we touch CUTLASS:
3920        //   - All groups share the same (alpha, beta).
3921        //   - C presence (`Some` vs `None`) is consistent across groups.
3922        //   - Each group's shapes / strides are individually valid.
3923        let first_alpha = groups[0].alpha;
3924        let first_beta = groups[0].beta;
3925        let first_has_c = groups[0].c.is_some();
3926        for g in groups {
3927            if g.m <= 0 || g.n <= 0 || g.k <= 0 {
3928                return Err(Error::InvalidProblem("group M, N, K must all be positive"));
3929            }
3930            if g.a.rows != g.m || g.a.cols != g.k {
3931                return Err(Error::InvalidProblem("group A shape doesn't match (M, K)"));
3932            }
3933            if g.b.rows != g.k || g.b.cols != g.n {
3934                return Err(Error::InvalidProblem("group B shape doesn't match (K, N)"));
3935            }
3936            if g.d.rows != g.m || g.d.cols != g.n {
3937                return Err(Error::InvalidProblem("group D shape doesn't match (M, N)"));
3938            }
3939            if let Some(c) = &g.c {
3940                if c.rows != g.m || c.cols != g.n {
3941                    return Err(Error::InvalidProblem("group C shape doesn't match (M, N)"));
3942                }
3943            }
3944            if g.a.ld < g.k as i64 || g.b.ld < g.k as i64 || g.d.ld < g.n as i64 {
3945                return Err(Error::InvalidProblem("group leading dimension too small"));
3946            }
3947            if g.alpha != first_alpha {
3948                return Err(Error::Unsupported(
3949                    "v0 grouped GEMM requires all groups to share alpha",
3950                ));
3951            }
3952            if g.beta != first_beta {
3953                return Err(Error::Unsupported(
3954                    "v0 grouped GEMM requires all groups to share beta",
3955                ));
3956            }
3957            if g.c.is_some() != first_has_c {
3958                return Err(Error::Unsupported(
3959                    "v0 grouped GEMM requires all groups to consistently have c=None or c=Some",
3960                ));
3961            }
3962        }
3963
3964        // Pack host problem_sizes as [m,n,k, m,n,k, ...] for the C ABI's
3965        // sufficient / scratch_bytes / can_implement queries.
3966        let group_count = groups.len();
3967        let mut h_m: Vec<i32> = Vec::with_capacity(group_count);
3968        let mut h_n: Vec<i32> = Vec::with_capacity(group_count);
3969        let mut h_k: Vec<i32> = Vec::with_capacity(group_count);
3970        for g in groups {
3971            h_m.push(g.m);
3972            h_n.push(g.n);
3973            h_k.push(g.k);
3974        }
3975
3976        let kind = T::KIND;
3977        let group_count_i32 = group_count as i32;
3978
3979        // CUTLASS-level can_implement (per-group alignment / shape).
3980        let ci_status = match self.sku.arch {
3981            #[cfg(feature = "sm80")]
3982            ArchSku::Sm80 => unsafe {
3983                dispatch::grouped_gemm_rcr_sm80_can_implement(
3984                    kind,
3985                    h_m.as_ptr(),
3986                    h_n.as_ptr(),
3987                    h_k.as_ptr(),
3988                    group_count_i32,
3989                )
3990            },
3991            #[cfg(not(feature = "sm80"))]
3992            ArchSku::Sm80 => {
3993                return Err(Error::Unsupported(
3994                    "sm80 selected but the `sm80` feature isn't enabled",
3995                ));
3996            }
3997            ArchSku::Sm90a => {
3998                return Err(Error::Unsupported(
3999                    "sm90a grouped kernels not yet shipped (deferred until Hopper hardware available)",
4000                ));
4001            }
4002            ArchSku::Sm89 => {
4003                return Err(Error::Unsupported(
4004                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4005                ));
4006            }
4007        };
4008        status_to_result(ci_status)?;
4009
4010        // Threadblock count + CUTLASS scratch bytes.
4011        let threadblock_count = match self.sku.arch {
4012            #[cfg(feature = "sm80")]
4013            ArchSku::Sm80 => unsafe {
4014                dispatch::grouped_gemm_rcr_sm80_sufficient(
4015                    kind,
4016                    h_m.as_ptr(),
4017                    h_n.as_ptr(),
4018                    h_k.as_ptr(),
4019                    group_count_i32,
4020                )
4021            },
4022            #[cfg(not(feature = "sm80"))]
4023            ArchSku::Sm80 => 0,
4024            ArchSku::Sm90a => 0,
4025            ArchSku::Sm89 => 0,
4026        };
4027        if threadblock_count <= 0 {
4028            return Err(Error::CutlassInternal(threadblock_count));
4029        }
4030
4031        let scratch_bytes = match self.sku.arch {
4032            #[cfg(feature = "sm80")]
4033            ArchSku::Sm80 => unsafe {
4034                dispatch::grouped_gemm_rcr_sm80_scratch_bytes(
4035                    kind,
4036                    h_m.as_ptr(),
4037                    h_n.as_ptr(),
4038                    h_k.as_ptr(),
4039                    group_count_i32,
4040                    threadblock_count,
4041                )
4042            },
4043            #[cfg(not(feature = "sm80"))]
4044            ArchSku::Sm80 => 0,
4045            ArchSku::Sm90a => 0,
4046            ArchSku::Sm89 => 0,
4047        };
4048
4049        let layout = MetadataLayout::compute(group_count, scratch_bytes);
4050
4051        // Pack host metadata into a PINNED host buffer so the H2D in
4052        // `run()` is truly async (and therefore stream-capture-safe).
4053        // From pageable host memory, cuMemcpyHtoDAsync is implicitly
4054        // synchronizing and not capturable; pinned memory is required.
4055        let mut pinned: PinnedBuffer<u8> = PinnedBuffer::new(&self.context, layout.metadata_end)?;
4056
4057        // Collect device pointers and leading dimensions before borrowing
4058        // `pinned` mutably — keeps the mutable-borrow scope tight so we
4059        // can move `pinned` into the returned struct below.
4060        let ptr_a: Vec<u64> = groups.iter().map(|g| g.a.data.as_raw().0).collect();
4061        let ptr_b: Vec<u64> = groups.iter().map(|g| g.b.data.as_raw().0).collect();
4062        let ptr_d: Vec<u64> = groups.iter().map(|g| g.d.data.as_raw().0).collect();
4063        // For c = None, point ptr_c at the group's D buffer (kernel reads
4064        // a valid pointer but multiplies by beta = 0; same trick as
4065        // single-GEMM null-C handling).
4066        let ptr_c: Vec<u64> = groups
4067            .iter()
4068            .map(|g| {
4069                g.c.as_ref()
4070                    .map(|c| c.data.as_raw().0)
4071                    .unwrap_or_else(|| g.d.data.as_raw().0)
4072            })
4073            .collect();
4074        let lda: Vec<i64> = groups.iter().map(|g| g.a.ld).collect();
4075        let ldb: Vec<i64> = groups.iter().map(|g| g.b.ld).collect();
4076        let ldd: Vec<i64> = groups.iter().map(|g| g.d.ld).collect();
4077        let ldc: Vec<i64> = groups
4078            .iter()
4079            .map(|g| g.c.as_ref().map(|c| c.ld).unwrap_or(g.d.ld))
4080            .collect();
4081
4082        // Now write all metadata into the pinned slab. The borrow ends
4083        // when `host_packed` falls out of scope at the end of the block.
4084        {
4085            let host_packed: &mut [u8] = &mut pinned;
4086
4087            let mut p = layout.problem_sizes_offset;
4088            for g in groups {
4089                host_packed[p..p + 4].copy_from_slice(&g.m.to_ne_bytes());
4090                host_packed[p + 4..p + 8].copy_from_slice(&g.n.to_ne_bytes());
4091                host_packed[p + 8..p + 12].copy_from_slice(&g.k.to_ne_bytes());
4092                p += COORD_BYTES;
4093            }
4094
4095            let pack_ptrs = |dst: &mut [u8], offset: usize, ptrs: &[u64]| {
4096                let mut p = offset;
4097                for &val in ptrs {
4098                    dst[p..p + 8].copy_from_slice(&val.to_ne_bytes());
4099                    p += PTR_BYTES;
4100                }
4101            };
4102            pack_ptrs(host_packed, layout.ptr_a_offset, &ptr_a);
4103            pack_ptrs(host_packed, layout.ptr_b_offset, &ptr_b);
4104            pack_ptrs(host_packed, layout.ptr_c_offset, &ptr_c);
4105            pack_ptrs(host_packed, layout.ptr_d_offset, &ptr_d);
4106
4107            let pack_lds = |dst: &mut [u8], offset: usize, lds: &[i64]| {
4108                let mut p = offset;
4109                for &val in lds {
4110                    dst[p..p + 8].copy_from_slice(&val.to_ne_bytes());
4111                    p += LD_BYTES;
4112                }
4113            };
4114            pack_lds(host_packed, layout.lda_offset, &lda);
4115            pack_lds(host_packed, layout.ldb_offset, &ldb);
4116            pack_lds(host_packed, layout.ldc_offset, &ldc);
4117            pack_lds(host_packed, layout.ldd_offset, &ldd);
4118        }
4119
4120        // Host-side problem_sizes copy CUTLASS reads at run time (the
4121        // device copy from `pinned` is what the kernel actually
4122        // dereferences; this Vec just gives CUTLASS a stable host pointer
4123        // for its own internal tile-schedule math).
4124        let mut host_problem_sizes: Vec<i32> = Vec::with_capacity(group_count * 3);
4125        for g in groups {
4126            host_problem_sizes.push(g.m);
4127            host_problem_sizes.push(g.n);
4128            host_problem_sizes.push(g.k);
4129        }
4130
4131        let beta_eff = if first_has_c { first_beta } else { <T::Scalar as Default>::default() };
4132
4133        // Grouped GEMM v0 ships only Rcr × {F16, Bf16} (both f32-scalar);
4134        // the to_f32 conversion is identity. The select gate rejects
4135        // F64/F32Strict before we reach here.
4136        Ok(PreparedGroupedGemm {
4137            plan: self,
4138            pinned,
4139            host_problem_sizes,
4140            layout,
4141            threadblock_count,
4142            alpha: first_alpha.to_f32(),
4143            beta: beta_eff.to_f32(),
4144            _element: PhantomData,
4145        })
4146    }
4147}
4148
4149/// A [`GroupedGemmPlan`] bound to a concrete set of per-group problems.
4150///
4151/// Owns a [`PinnedBuffer<u8>`] holding the packed metadata (problem
4152/// sizes, pointer arrays, leading dimensions). Pinned host memory is
4153/// what makes the H2D inside [`run`](Self::run) truly async — and
4154/// therefore safely capturable into a CUDA graph. Owns no device memory;
4155/// the caller supplies that via [`Workspace::Borrowed`] at run time.
4156///
4157/// # Lifetime contract
4158///
4159/// `PreparedGroupedGemm` extracts raw device pointers from the input
4160/// [`GroupedProblem`] slice during [`prepare`](GroupedGemmPlan::prepare)
4161/// and stores them in pinned memory — it does **not** hold a Rust borrow
4162/// on the input buffers afterwards. This is required for stream capture:
4163/// the captured graph references the pinned buffer (for the metadata
4164/// H2D) and the device buffers (via the pointer arrays) by raw address,
4165/// not by Rust lifetime. The caller must therefore keep both this
4166/// `PreparedGroupedGemm` and the underlying device buffers alive for as
4167/// long as any captured graph that references them is in use.
4168///
4169/// In practice the pattern is: build groups, call `prepare`, capture
4170/// into a graph, then keep `PreparedGroupedGemm` plus the input/output
4171/// device buffers alive for the lifetime of the captured graph.
4172#[derive(Debug)]
4173pub struct PreparedGroupedGemm<'a, T: CutlassElement> {
4174    plan: &'a GroupedGemmPlan<T>,
4175    /// Pinned host scratch holding all packed metadata. The H2D in `run`
4176    /// reads from this buffer; pinned memory means the copy is truly
4177    /// async + capturable on the user's stream.
4178    pinned: PinnedBuffer<u8>,
4179    host_problem_sizes: Vec<i32>,
4180    layout: MetadataLayout,
4181    threadblock_count: i32,
4182    alpha: f32,
4183    beta: f32,
4184    _element: PhantomData<T>,
4185}
4186
4187impl<'a, T: CutlassElement> PreparedGroupedGemm<'a, T> {
4188    /// Total bytes of device workspace this plan needs at `run` time.
4189    ///
4190    /// Includes both the packed metadata layout and CUTLASS's internal
4191    /// scratch tail with alignment padding between them.
4192    pub fn workspace_size(&self) -> usize {
4193        self.layout.total_workspace_bytes
4194    }
4195
4196    /// Identity of the kernel this plan chose (forwarded from the parent
4197    /// [`GroupedGemmPlan`]).
4198    pub fn sku(&self) -> GemmSku {
4199        self.plan.sku
4200    }
4201
4202    /// Group count this plan was prepared for.
4203    pub fn group_count(&self) -> usize {
4204        self.host_problem_sizes.len() / 3
4205    }
4206
4207    /// Launch the grouped GEMM.
4208    ///
4209    /// Uploads the packed metadata to the start of `workspace` via async
4210    /// H2D on `stream`, then enqueues the grouped kernel using the
4211    /// remainder of the workspace as CUTLASS internal scratch.
4212    pub fn run(&self, stream: &Stream, workspace: Workspace<'_>) -> Result<()> {
4213        let needed = self.workspace_size();
4214        let workspace_slice = match workspace {
4215            Workspace::None => {
4216                return Err(Error::WorkspaceTooSmall { needed, got: 0 });
4217            }
4218            Workspace::Borrowed(slice) => {
4219                if slice.len() < needed {
4220                    return Err(Error::WorkspaceTooSmall {
4221                        needed,
4222                        got: slice.len(),
4223                    });
4224                }
4225                slice
4226            }
4227        };
4228
4229        let workspace_base = workspace_slice.as_raw().0;
4230
4231        // Single async H2D from pinned host memory into the workspace
4232        // prefix. Because the source is pinned (allocated in `prepare`),
4233        // this copy is truly async and capture-safe — wrapping the
4234        // surrounding launch in `cuStreamBeginCapture` / `cuStreamEndCapture`
4235        // produces a graph that replays correctly.
4236        {
4237            let mut workspace_for_meta = workspace_slice;
4238            let metadata_dst = workspace_for_meta.slice_mut(0..self.layout.metadata_end);
4239            metadata_dst.copy_from_host_async(&self.pinned, stream)?;
4240        }
4241
4242        // Compute device pointers via base + offset arithmetic. The C
4243        // function dereferences these as pointer arrays.
4244        let off = |o: usize| (workspace_base + o as u64) as *const c_void;
4245        let off_mut = |o: usize| (workspace_base + o as u64) as *mut c_void;
4246        let d_problem_sizes = off(self.layout.problem_sizes_offset);
4247        let d_ptr_a = off(self.layout.ptr_a_offset);
4248        let d_ptr_b = off(self.layout.ptr_b_offset);
4249        let d_ptr_c = off(self.layout.ptr_c_offset);
4250        let d_ptr_d = off_mut(self.layout.ptr_d_offset);
4251        let d_lda = off(self.layout.lda_offset);
4252        let d_ldb = off(self.layout.ldb_offset);
4253        let d_ldc = off(self.layout.ldc_offset);
4254        let d_ldd = off(self.layout.ldd_offset);
4255        let scratch_ptr = off_mut(self.layout.scratch_offset);
4256        let scratch_bytes = self.layout.total_workspace_bytes - self.layout.scratch_offset;
4257
4258        let h_problem_sizes = self.host_problem_sizes.as_ptr() as *const c_void;
4259        let stream_raw = stream.as_raw();
4260        let group_count = self.group_count() as i32;
4261
4262        let status = match self.plan.sku.arch {
4263            #[cfg(feature = "sm80")]
4264            ArchSku::Sm80 => unsafe {
4265                dispatch::grouped_gemm_rcr_sm80_run(
4266                    T::KIND,
4267                    group_count,
4268                    self.threadblock_count,
4269                    d_problem_sizes,
4270                    d_ptr_a,
4271                    d_ptr_b,
4272                    d_ptr_c,
4273                    d_ptr_d,
4274                    d_lda,
4275                    d_ldb,
4276                    d_ldc,
4277                    d_ldd,
4278                    h_problem_sizes,
4279                    self.alpha,
4280                    self.beta,
4281                    scratch_ptr,
4282                    scratch_bytes,
4283                    stream_raw,
4284                )
4285            },
4286            #[cfg(not(feature = "sm80"))]
4287            ArchSku::Sm80 => {
4288                return Err(Error::Unsupported(
4289                    "sm80 selected but the `sm80` feature isn't enabled",
4290                ));
4291            }
4292            ArchSku::Sm90a => {
4293                return Err(Error::Unsupported(
4294                    "sm90a grouped kernels not yet shipped",
4295                ));
4296            }
4297            ArchSku::Sm89 => {
4298                return Err(Error::Unsupported(
4299                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4300                ));
4301            }
4302        };
4303
4304        status_to_result(status)
4305    }
4306}
4307
4308// ============================================================================
4309// IntGemmPlan — int8 GEMM (Phase 2)
4310// ============================================================================
4311//
4312// Sibling to [`GemmPlan`] for the integer GEMM family. The split exists
4313// because the kernel-level dispatch, accumulator (int32), and epilogue
4314// templates (`LinearCombinationClamp` / `LinearCombinationBiasElementwise`
4315// with `ElementCompute = float`) differ enough from the float family
4316// that mixing them through a single generic plan would smear the
4317// argument types of every helper.
4318//
4319// API surface mirrors [`GemmPlan`]: `select`, `can_implement`,
4320// `workspace_size`, `run`, `sku`, `precision_guarantee`. Trait bounds:
4321// - `T: IntElement` constrains the matrix element to a kernel-supported
4322//   int dtype (today [`S8`](crate::S8) or [`U8`](crate::U8)).
4323// - `BT: BiasElement` constrains the bias element to a kernel-supported
4324//   bias dtype (today `f32` or `i32`). Defaults to `f32` since most
4325//   callers using a `Bias*` epilogue start there; `IntGemmPlan::<S8>`
4326//   shorthand picks the f32-bias path. Override explicitly with
4327//   `IntGemmPlan::<S8, i32>` for the int32-bias path.
4328//
4329// Layout coverage: today's int8 kernels are RCR-only. Selecting
4330// [`LayoutSku::Rrr`] returns [`Error::Unsupported`] at `select` time —
4331// CUTLASS 4.2.0 has no warp-level iterator for the 8-bit `Congruous`
4332// shared-memory layout the row-major B operand would need. A follow-up
4333// release will vendor the missing specialization.
4334
4335/// Plan handle for an int8 GEMM kernel selection.
4336///
4337/// See module-level docs above and [`IntGemmDescriptor`] /
4338/// [`IntGemmArgs`] for the full API contract.
4339#[derive(Debug)]
4340pub struct IntGemmPlan<T: IntElement, BT: BiasElement = f32> {
4341    desc: IntGemmDescriptor,
4342    sku: GemmSku,
4343    _element: PhantomData<T>,
4344    _bias_element: PhantomData<BT>,
4345}
4346
4347impl<T: IntElement, BT: BiasElement> IntGemmPlan<T, BT> {
4348    /// Pick an int-GEMM kernel for `desc`.
4349    ///
4350    /// Returns [`Error::Unsupported`] for `LayoutSku::Rrr` — see the
4351    /// module-level docs above for the CUTLASS upstream limitation.
4352    /// `BT` is only meaningful for the `Bias*` epilogue variants;
4353    /// `EpilogueKind::Identity` ignores it.
4354    pub fn select(
4355        stream: &Stream,
4356        desc: &IntGemmDescriptor,
4357        pref: PlanPreference,
4358    ) -> Result<Self> {
4359        check_int_descriptor(desc)?;
4360        let arch = pick_int_arch(stream, pref)?;
4361        // RCR-only on int8 today. The descriptor-level check rejects
4362        // any non-RCR layout up front so callers see the error at
4363        // plan-creation time rather than at launch.
4364        if !matches!(desc.layout, LayoutSku::Rcr) {
4365            return Err(Error::Unsupported(
4366                "int8 GEMM kernels are RCR-only in this release \
4367                 (CUTLASS 4.2.0 lacks 8-bit `TensorOpMultiplicandCongruous` \
4368                 warp iterators for RRR / row-major-B layout)",
4369            ));
4370        }
4371        // `bias_element` is meaningful only for the bias-family
4372        // epilogues; Identity int kernels carry `None` to match the
4373        // float-family convention.
4374        let bias_element = if desc.epilogue.requires_bias() {
4375            Some(BT::KIND)
4376        } else {
4377            None
4378        };
4379        let sku = GemmSku {
4380            arch,
4381            layout: desc.layout,
4382            epilogue: desc.epilogue,
4383            element: T::KIND,
4384            bias_element,
4385        };
4386        Ok(Self {
4387            desc: *desc,
4388            sku,
4389            _element: PhantomData,
4390            _bias_element: PhantomData,
4391        })
4392    }
4393
4394    /// Validate that this plan can launch with `args`.
4395    ///
4396    /// Two-stage check identical in shape to [`GemmPlan::can_implement`].
4397    pub fn can_implement(&self, args: &IntGemmArgs<'_, T, BT>) -> Result<()> {
4398        check_int_args(&self.desc, args)?;
4399
4400        let a_ptr = args.a.data.as_raw().0 as *const c_void;
4401        let b_ptr = args.b.data.as_raw().0 as *const c_void;
4402        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
4403        let (c_ptr, ldc) = match &args.c {
4404            Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
4405            None => (core::ptr::null(), 0i64),
4406        };
4407        let bias_ptr = args
4408            .bias
4409            .as_ref()
4410            .map(|b| b.data.as_raw().0 as *const c_void)
4411            .unwrap_or(core::ptr::null());
4412
4413        let bias_family = self.sku.epilogue.requires_bias();
4414        let status = match (self.sku.arch, bias_family) {
4415            #[cfg(feature = "sm80")]
4416            (ArchSku::Sm80, false) => unsafe {
4417                dispatch::int_gemm_rcr_sm80_can_implement(
4418                    self.sku.layout,
4419                    T::KIND,
4420                    self.desc.m, self.desc.n, self.desc.k,
4421                    a_ptr, args.a.ld,
4422                    b_ptr, args.b.ld,
4423                    c_ptr, ldc,
4424                    d_ptr, args.d.ld,
4425                )
4426            },
4427            #[cfg(feature = "sm80")]
4428            (ArchSku::Sm80, true) => unsafe {
4429                dispatch::int_gemm_bias_rcr_sm80_can_implement(
4430                    self.sku.layout,
4431                    T::KIND,
4432                    self.sku.epilogue,
4433                    BT::KIND,
4434                    self.desc.m, self.desc.n, self.desc.k,
4435                    a_ptr, args.a.ld,
4436                    b_ptr, args.b.ld,
4437                    c_ptr, ldc,
4438                    d_ptr, args.d.ld,
4439                    bias_ptr,
4440                )
4441            },
4442            #[cfg(not(feature = "sm80"))]
4443            (ArchSku::Sm80, _) => {
4444                return Err(Error::Unsupported(
4445                    "sm80 selected but the `sm80` feature isn't enabled",
4446                ));
4447            }
4448            (ArchSku::Sm90a, _) => {
4449                return Err(Error::Unsupported(
4450                    "sm90a int8 kernels not yet shipped",
4451                ));
4452            }
4453            (ArchSku::Sm89, _) => {
4454                return Err(Error::Unsupported(
4455                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4456                ));
4457            }
4458        };
4459
4460        status_to_result(status)
4461    }
4462
4463    /// Bytes of device scratch this plan needs at `run` time.
4464    pub fn workspace_size(&self) -> usize {
4465        let bias_family = self.sku.epilogue.requires_bias();
4466        match (self.sku.arch, bias_family) {
4467            #[cfg(feature = "sm80")]
4468            (ArchSku::Sm80, false) => dispatch::int_gemm_rcr_sm80_workspace_size(
4469                self.sku.layout,
4470                T::KIND,
4471                self.desc.m, self.desc.n, self.desc.k,
4472            ),
4473            #[cfg(feature = "sm80")]
4474            (ArchSku::Sm80, true) => dispatch::int_gemm_bias_rcr_sm80_workspace_size(
4475                self.sku.layout,
4476                T::KIND,
4477                self.sku.epilogue,
4478                BT::KIND,
4479                self.desc.m, self.desc.n, self.desc.k,
4480            ),
4481            #[cfg(not(feature = "sm80"))]
4482            (ArchSku::Sm80, _) => 0,
4483            (ArchSku::Sm90a, _) => 0,
4484            (ArchSku::Sm89, _) => 0,
4485        }
4486    }
4487
4488    /// Identity of the kernel this plan chose.
4489    pub fn sku(&self) -> GemmSku {
4490        self.sku
4491    }
4492
4493    /// Numerical guarantees this plan's kernel provides.
4494    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
4495        self.sku.precision_guarantee()
4496    }
4497
4498    /// Launch the kernel.
4499    pub fn run(
4500        &self,
4501        stream: &Stream,
4502        workspace: Workspace<'_>,
4503        args: IntGemmArgs<'_, T, BT>,
4504    ) -> Result<()> {
4505        self.can_implement(&args)?;
4506
4507        let needed = self.workspace_size();
4508        let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
4509            Workspace::None => {
4510                if needed != 0 {
4511                    return Err(Error::WorkspaceTooSmall { needed, got: 0 });
4512                }
4513                (core::ptr::null_mut(), 0)
4514            }
4515            Workspace::Borrowed(slice) => {
4516                if slice.len() < needed {
4517                    return Err(Error::WorkspaceTooSmall {
4518                        needed,
4519                        got: slice.len(),
4520                    });
4521                }
4522                (slice.as_raw().0 as *mut c_void, slice.len())
4523            }
4524        };
4525
4526        let a_ptr = args.a.data.as_raw().0 as *const c_void;
4527        let b_ptr = args.b.data.as_raw().0 as *const c_void;
4528        let d_ptr = args.d.data.as_raw().0 as *mut c_void;
4529        let (c_ptr, ldc) = match &args.c {
4530            Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
4531            None => (core::ptr::null(), 0i64),
4532        };
4533        let bias_ptr = args
4534            .bias
4535            .as_ref()
4536            .map(|b| b.data.as_raw().0 as *const c_void)
4537            .unwrap_or(core::ptr::null());
4538        // Match GemmPlan: when c = None, force beta = 0 at the safe layer
4539        // so the C-substituted-as-D fallback inside the kernel doesn't
4540        // fold the previous D contents into the result.
4541        let beta_eff: f32 = if args.c.is_some() { args.beta } else { 0.0 };
4542        let stream_raw = stream.as_raw();
4543
4544        let bias_family = self.sku.epilogue.requires_bias();
4545        let status = match (self.sku.arch, bias_family) {
4546            #[cfg(feature = "sm80")]
4547            (ArchSku::Sm80, false) => unsafe {
4548                dispatch::int_gemm_rcr_sm80_run(
4549                    self.sku.layout,
4550                    T::KIND,
4551                    self.desc.m, self.desc.n, self.desc.k,
4552                    a_ptr, args.a.ld,
4553                    b_ptr, args.b.ld,
4554                    c_ptr, ldc,
4555                    d_ptr, args.d.ld,
4556                    args.alpha,
4557                    beta_eff,
4558                    ws_ptr, ws_bytes, stream_raw,
4559                )
4560            },
4561            #[cfg(feature = "sm80")]
4562            (ArchSku::Sm80, true) => unsafe {
4563                dispatch::int_gemm_bias_rcr_sm80_run(
4564                    self.sku.layout,
4565                    T::KIND,
4566                    self.sku.epilogue,
4567                    BT::KIND,
4568                    self.desc.m, self.desc.n, self.desc.k,
4569                    a_ptr, args.a.ld,
4570                    b_ptr, args.b.ld,
4571                    c_ptr, ldc,
4572                    d_ptr, args.d.ld,
4573                    bias_ptr,
4574                    args.alpha,
4575                    beta_eff,
4576                    ws_ptr, ws_bytes, stream_raw,
4577                )
4578            },
4579            #[cfg(not(feature = "sm80"))]
4580            (ArchSku::Sm80, _) => {
4581                return Err(Error::Unsupported(
4582                    "sm80 selected but the `sm80` feature isn't enabled",
4583                ));
4584            }
4585            (ArchSku::Sm90a, _) => {
4586                return Err(Error::Unsupported("sm90a int8 kernels not yet shipped"));
4587            }
4588            (ArchSku::Sm89, _) => {
4589                return Err(Error::Unsupported(
4590                    "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4591                ));
4592            }
4593        };
4594
4595        status_to_result(status)
4596    }
4597}
4598
4599// Internal helpers for int-GEMM validation. Structure mirrors
4600// `check_descriptor` / `check_args` for the float family.
4601
4602fn check_int_descriptor(desc: &IntGemmDescriptor) -> Result<()> {
4603    if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
4604        return Err(Error::InvalidProblem("M, N, K must all be positive"));
4605    }
4606    Ok(())
4607}
4608
4609fn check_int_args<T: IntElement, BT: BiasElement>(
4610    desc: &IntGemmDescriptor,
4611    args: &IntGemmArgs<'_, T, BT>,
4612) -> Result<()> {
4613    // Epilogue / bias must agree.
4614    match (desc.epilogue.requires_bias(), &args.bias) {
4615        (false, Some(_)) => {
4616            return Err(Error::InvalidProblem(
4617                "args.bias must be None when descriptor.epilogue is Identity",
4618            ));
4619        }
4620        (true, None) => {
4621            return Err(Error::InvalidProblem(
4622                "args.bias is required when descriptor.epilogue is in the Bias family \
4623                 (Bias / BiasRelu / BiasGelu / BiasSilu)",
4624            ));
4625        }
4626        (false, None) | (true, Some(_)) => {}
4627    }
4628    if let Some(bias) = &args.bias {
4629        if bias.len != desc.n {
4630            return Err(Error::InvalidProblem("bias vector length must equal N"));
4631        }
4632        if bias.stride != 1 {
4633            return Err(Error::Unsupported(
4634                "bias vector must be contiguous (stride 1) — strided bias not supported",
4635            ));
4636        }
4637        if bias.data.len() < desc.n as usize {
4638            return Err(Error::BufferTooSmall {
4639                needed: desc.n as usize,
4640                got: bias.data.len(),
4641            });
4642        }
4643    }
4644    if args.a.rows != desc.m || args.a.cols != desc.k {
4645        return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
4646    }
4647    if args.b.rows != desc.k || args.b.cols != desc.n {
4648        return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
4649    }
4650    if args.d.rows != desc.m || args.d.cols != desc.n {
4651        return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
4652    }
4653    if let Some(c) = &args.c {
4654        if c.rows != desc.m || c.cols != desc.n {
4655            return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
4656        }
4657    }
4658    if args.a.ld < desc.k as i64 {
4659        return Err(Error::InvalidProblem("A leading dimension must be >= K"));
4660    }
4661    let b_min_ld = match desc.layout {
4662        LayoutSku::Rcr => desc.k as i64,
4663        LayoutSku::Rrr => desc.n as i64,
4664    };
4665    if args.b.ld < b_min_ld {
4666        return Err(Error::InvalidProblem(match desc.layout {
4667            LayoutSku::Rcr => "B leading dimension must be >= K (column-major Rcr layout)",
4668            LayoutSku::Rrr => "B leading dimension must be >= N (row-major Rrr layout)",
4669        }));
4670    }
4671    if args.d.ld < desc.n as i64 {
4672        return Err(Error::InvalidProblem("D leading dimension must be >= N"));
4673    }
4674    if let Some(c) = &args.c {
4675        if c.ld < desc.n as i64 {
4676            return Err(Error::InvalidProblem("C leading dimension must be >= N"));
4677        }
4678    }
4679    let need_a = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
4680        .ok_or(Error::InvalidProblem("A storage size overflow"))?;
4681    if args.a.data.len() < need_a {
4682        return Err(Error::BufferTooSmall {
4683            needed: need_a,
4684            got: args.a.data.len(),
4685        });
4686    }
4687    let need_b = match desc.layout {
4688        LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
4689        LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
4690    }
4691    .ok_or(Error::InvalidProblem("B storage size overflow"))?;
4692    if args.b.data.len() < need_b {
4693        return Err(Error::BufferTooSmall {
4694            needed: need_b,
4695            got: args.b.data.len(),
4696        });
4697    }
4698    let need_d = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
4699        .ok_or(Error::InvalidProblem("D storage size overflow"))?;
4700    if args.d.data.len() < need_d {
4701        return Err(Error::BufferTooSmall {
4702            needed: need_d,
4703            got: args.d.data.len(),
4704        });
4705    }
4706    if let Some(c) = &args.c {
4707        let need_c = min_elements_row_major(c.rows, c.cols, c.ld)
4708            .ok_or(Error::InvalidProblem("C storage size overflow"))?;
4709        if c.data.len() < need_c {
4710            return Err(Error::BufferTooSmall {
4711                needed: need_c,
4712                got: c.data.len(),
4713            });
4714        }
4715    }
4716    Ok(())
4717}
4718
4719fn pick_int_arch(stream: &Stream, pref: PlanPreference) -> Result<ArchSku> {
4720    // Int kernels currently only ship for sm_80. The selection logic is a
4721    // simpler shape of `pick_arch` — no sm_90a path because there are no
4722    // sm_90a int kernels.
4723    let (major, _minor) = stream.context().device().compute_capability()?;
4724    if pref.allow_sm90a && cfg!(feature = "sm90a") && major >= 9 {
4725        // Allowed but not yet shipped — fall through to sm_80.
4726    }
4727    if cfg!(feature = "sm80") {
4728        if major >= 8 {
4729            return Ok(ArchSku::Sm80);
4730        }
4731        return Err(Error::Unsupported(
4732            "device compute capability < 8.0; sm_80 int8 kernels won't run here",
4733        ));
4734    }
4735    Err(Error::Unsupported(
4736        "no arch features enabled — build with --features sm80",
4737    ))
4738}
4739
4740#[cfg(test)]
4741mod buffer_size_tests {
4742    //! Regression tests for the per-Fuel-team-review buffer-size formulas.
4743    //!
4744    //! Pre-fix the helpers used `rows * ld` / `cols * ld`, which over-rejects
4745    //! valid padded slabs. The corrected formula is
4746    //! `(major - 1) * ld + minor`.
4747
4748    use super::{min_elements_rcr_a, min_elements_rcr_b, min_elements_rcr_cd};
4749
4750    #[test]
4751    fn rcr_a_tight_layout() {
4752        // [M=4, K=8] row-major, lda = K = 8.
4753        // Min elements = (4 - 1) * 8 + 8 = 32.
4754        assert_eq!(min_elements_rcr_a(4, 8, 8), Some(32));
4755    }
4756
4757    #[test]
4758    fn rcr_a_padded_layout_accepts_smaller_count() {
4759        // [M=4, K=8] row-major with lda = 16 (padded).
4760        // Min elements = (4 - 1) * 16 + 8 = 56.
4761        // Pre-fix formula was rows*ld = 4*16 = 64 — over-strict by 8 elements.
4762        assert_eq!(min_elements_rcr_a(4, 8, 16), Some(56));
4763    }
4764
4765    #[test]
4766    fn rcr_b_tight_layout() {
4767        // [K=8, N=4] column-major, ldb = K = 8.
4768        // Min elements = (4 - 1) * 8 + 8 = 32.
4769        assert_eq!(min_elements_rcr_b(8, 4, 8), Some(32));
4770    }
4771
4772    #[test]
4773    fn rcr_b_padded_layout_accepts_smaller_count() {
4774        // [K=8, N=4] column-major, ldb = 16 (padded).
4775        // Min elements = (4 - 1) * 16 + 8 = 56. Pre-fix was 4*16 = 64.
4776        assert_eq!(min_elements_rcr_b(8, 4, 16), Some(56));
4777    }
4778
4779    #[test]
4780    fn rcr_cd_tight_layout() {
4781        // [M=4, N=8] row-major, ld = N = 8.
4782        // Min elements = (4 - 1) * 8 + 8 = 32.
4783        assert_eq!(min_elements_rcr_cd(4, 8, 8), Some(32));
4784    }
4785
4786    #[test]
4787    fn rcr_cd_padded_layout_accepts_smaller_count() {
4788        // [M=4, N=8] row-major, ld = 16 (padded).
4789        // Min elements = (4 - 1) * 16 + 8 = 56. Pre-fix was 4*16 = 64.
4790        assert_eq!(min_elements_rcr_cd(4, 8, 16), Some(56));
4791    }
4792
4793    #[test]
4794    fn single_row_matrix_does_not_underflow() {
4795        // [M=1, K=8] should need exactly K elements regardless of ld.
4796        assert_eq!(min_elements_rcr_a(1, 8, 8), Some(8));
4797        assert_eq!(min_elements_rcr_a(1, 8, 256), Some(8));
4798    }
4799
4800    #[test]
4801    fn overflow_returns_none() {
4802        // Force i64::checked_mul overflow.
4803        assert_eq!(min_elements_rcr_a(i32::MAX, 1, i64::MAX), None);
4804    }
4805}