1use core::ffi::c_void;
4use core::marker::PhantomData;
5
6use baracuda_driver::{Context, PinnedBuffer, Stream};
7use baracuda_kernels_types::BackendKind;
8
9use crate::error::{status_to_result, Error, Result};
10use crate::types::{
11 ArchSku, BatchedGemmArgs, BatchedGemmDescriptor, BiasElement, CutlassElement, ElementKind,
12 EpilogueKind, GemmArgs, GemmDescriptor, GemmSku, GroupedPlanPreference, GroupedProblem,
13 GroupedScheduleMode, IntElement, IntGemmArgs, IntGemmDescriptor, LayoutSku, PlanPreference,
14 PrecisionGuarantee, ScalarType, Workspace,
15};
16
17mod dispatch {
22 use super::{ElementKind, LayoutSku};
23 use core::ffi::c_void;
24
25 use super::EpilogueKind;
26
27 #[cfg(feature = "sm80")]
37 #[allow(clippy::too_many_arguments)]
38 pub(super) unsafe fn gemm_bias_sm80_run(
39 layout: LayoutSku,
40 kind: ElementKind,
41 epilogue: EpilogueKind,
42 m: i32,
43 n: i32,
44 k: i32,
45 a: *const c_void,
46 lda: i64,
47 b: *const c_void,
48 ldb: i64,
49 c: *const c_void,
50 ldc: i64,
51 d: *mut c_void,
52 ldd: i64,
53 bias: *const c_void,
54 alpha: f32,
55 beta: f32,
56 workspace: *mut c_void,
57 workspace_bytes: usize,
58 stream: *mut c_void,
59 ) -> i32 {
60 use baracuda_cutlass_kernels_sys as k_sys;
61 match (layout, kind, epilogue) {
62 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
63 k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_run(
64 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
65 bias, alpha, beta, workspace, workspace_bytes, stream,
66 )
67 },
68 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
69 k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_run(
70 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
71 bias, alpha, beta, workspace, workspace_bytes, stream,
72 )
73 },
74 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
75 k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_run(
76 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
77 bias, alpha, beta, workspace, workspace_bytes, stream,
78 )
79 },
80 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
81 k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_run(
82 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
83 bias, alpha, beta, workspace, workspace_bytes, stream,
84 )
85 },
86 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
87 k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_run(
88 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
89 bias, alpha, beta, workspace, workspace_bytes, stream,
90 )
91 },
92 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
93 k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_run(
94 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
95 bias, alpha, beta, workspace, workspace_bytes, stream,
96 )
97 },
98 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
99 k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_run(
100 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
101 bias, alpha, beta, workspace, workspace_bytes, stream,
102 )
103 },
104 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
105 k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_run(
106 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
107 bias, alpha, beta, workspace, workspace_bytes, stream,
108 )
109 },
110 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
112 k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_run(
113 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
114 bias, alpha, beta, workspace, workspace_bytes, stream,
115 )
116 },
117 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
118 k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_run(
119 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
120 bias, alpha, beta, workspace, workspace_bytes, stream,
121 )
122 },
123 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
124 k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_run(
125 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
126 bias, alpha, beta, workspace, workspace_bytes, stream,
127 )
128 },
129 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
130 k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_run(
131 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
132 bias, alpha, beta, workspace, workspace_bytes, stream,
133 )
134 },
135 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
136 k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_run(
137 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
138 bias, alpha, beta, workspace, workspace_bytes, stream,
139 )
140 },
141 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
142 k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_run(
143 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
144 bias, alpha, beta, workspace, workspace_bytes, stream,
145 )
146 },
147 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
148 k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_run(
149 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
150 bias, alpha, beta, workspace, workspace_bytes, stream,
151 )
152 },
153 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
154 k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_run(
155 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
156 bias, alpha, beta, workspace, workspace_bytes, stream,
157 )
158 },
159 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
161 k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_run(
162 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
163 bias, alpha, beta, workspace, workspace_bytes, stream,
164 )
165 },
166 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
167 k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_run(
168 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
169 bias, alpha, beta, workspace, workspace_bytes, stream,
170 )
171 },
172 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
173 k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_run(
174 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
175 bias, alpha, beta, workspace, workspace_bytes, stream,
176 )
177 },
178 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
179 k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_run(
180 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
181 bias, alpha, beta, workspace, workspace_bytes, stream,
182 )
183 },
184 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
186 k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_run(
187 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
188 bias, alpha, beta, workspace, workspace_bytes, stream,
189 )
190 },
191 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
192 k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_run(
193 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
194 bias, alpha, beta, workspace, workspace_bytes, stream,
195 )
196 },
197 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
198 k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_run(
199 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
200 bias, alpha, beta, workspace, workspace_bytes, stream,
201 )
202 },
203 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
204 k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_run(
205 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
206 bias, alpha, beta, workspace, workspace_bytes, stream,
207 )
208 },
209 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
211 k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_run(
212 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
213 bias, alpha, beta, workspace, workspace_bytes, stream,
214 )
215 },
216 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
217 k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_run(
218 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
219 bias, alpha, beta, workspace, workspace_bytes, stream,
220 )
221 },
222 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
223 k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_run(
224 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
225 bias, alpha, beta, workspace, workspace_bytes, stream,
226 )
227 },
228 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
229 k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_run(
230 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
231 bias, alpha, beta, workspace, workspace_bytes, stream,
232 )
233 },
234 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
236 k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_run(
237 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
238 bias, alpha, beta, workspace, workspace_bytes, stream,
239 )
240 },
241 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
242 k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_run(
243 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
244 bias, alpha, beta, workspace, workspace_bytes, stream,
245 )
246 },
247 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
248 k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_run(
249 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
250 bias, alpha, beta, workspace, workspace_bytes, stream,
251 )
252 },
253 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
254 k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_run(
255 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
256 bias, alpha, beta, workspace, workspace_bytes, stream,
257 )
258 },
259 _ => 3,
260 }
261 }
262
263 #[cfg(feature = "sm80")]
264 pub(super) fn gemm_bias_sm80_workspace_size(
265 layout: LayoutSku,
266 kind: ElementKind,
267 epilogue: EpilogueKind,
268 m: i32,
269 n: i32,
270 k: i32,
271 ) -> usize {
272 use baracuda_cutlass_kernels_sys as k_sys;
273 match (layout, kind, epilogue) {
274 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
275 k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_workspace_size(m, n, k)
276 },
277 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
278 k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_workspace_size(m, n, k)
279 },
280 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
281 k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_workspace_size(m, n, k)
282 },
283 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
284 k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_workspace_size(m, n, k)
285 },
286 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
287 k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_workspace_size(m, n, k)
288 },
289 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
290 k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_workspace_size(m, n, k)
291 },
292 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
293 k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_workspace_size(m, n, k)
294 },
295 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
296 k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_workspace_size(m, n, k)
297 },
298 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
299 k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_workspace_size(m, n, k)
300 },
301 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
302 k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_workspace_size(m, n, k)
303 },
304 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
305 k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_workspace_size(m, n, k)
306 },
307 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
308 k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_workspace_size(m, n, k)
309 },
310 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
311 k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_workspace_size(m, n, k)
312 },
313 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
314 k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_workspace_size(m, n, k)
315 },
316 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
317 k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_workspace_size(m, n, k)
318 },
319 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
320 k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_workspace_size(m, n, k)
321 },
322 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
323 k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_workspace_size(m, n, k)
324 },
325 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
326 k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_workspace_size(m, n, k)
327 },
328 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
329 k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_workspace_size(m, n, k)
330 },
331 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
332 k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_workspace_size(m, n, k)
333 },
334 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
335 k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_workspace_size(m, n, k)
336 },
337 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
338 k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_workspace_size(m, n, k)
339 },
340 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
341 k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_workspace_size(m, n, k)
342 },
343 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
344 k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_workspace_size(m, n, k)
345 },
346 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
347 k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_workspace_size(m, n, k)
348 },
349 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
350 k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_workspace_size(m, n, k)
351 },
352 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
353 k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_workspace_size(m, n, k)
354 },
355 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
356 k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_workspace_size(m, n, k)
357 },
358 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
359 k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_workspace_size(m, n, k)
360 },
361 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
362 k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_workspace_size(m, n, k)
363 },
364 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
365 k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_workspace_size(m, n, k)
366 },
367 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
368 k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_workspace_size(m, n, k)
369 },
370 _ => 0,
371 }
372 }
373
374 #[cfg(feature = "sm80")]
375 #[allow(clippy::too_many_arguments)]
376 pub(super) unsafe fn gemm_bias_sm80_can_implement(
377 layout: LayoutSku,
378 kind: ElementKind,
379 epilogue: EpilogueKind,
380 m: i32,
381 n: i32,
382 k: i32,
383 a: *const c_void,
384 lda: i64,
385 b: *const c_void,
386 ldb: i64,
387 c: *const c_void,
388 ldc: i64,
389 d: *mut c_void,
390 ldd: i64,
391 bias: *const c_void,
392 ) -> i32 {
393 use baracuda_cutlass_kernels_sys as k_sys;
394 match (layout, kind, epilogue) {
395 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
396 k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_can_implement(
397 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
398 )
399 },
400 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
401 k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_can_implement(
402 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
403 )
404 },
405 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
406 k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_can_implement(
407 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
408 )
409 },
410 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
411 k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_can_implement(
412 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
413 )
414 },
415 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
416 k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_can_implement(
417 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
418 )
419 },
420 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
421 k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_can_implement(
422 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
423 )
424 },
425 (LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
426 k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_can_implement(
427 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
428 )
429 },
430 (LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
431 k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_can_implement(
432 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
433 )
434 },
435 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
436 k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_can_implement(
437 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
438 )
439 },
440 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
441 k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_can_implement(
442 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
443 )
444 },
445 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
446 k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_can_implement(
447 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
448 )
449 },
450 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
451 k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_can_implement(
452 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
453 )
454 },
455 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
456 k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_can_implement(
457 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
458 )
459 },
460 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
461 k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_can_implement(
462 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
463 )
464 },
465 (LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
466 k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_can_implement(
467 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
468 )
469 },
470 (LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
471 k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_can_implement(
472 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
473 )
474 },
475 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
476 k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_can_implement(
477 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
478 )
479 },
480 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
481 k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_can_implement(
482 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
483 )
484 },
485 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
486 k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_can_implement(
487 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
488 )
489 },
490 (LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
491 k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_can_implement(
492 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
493 )
494 },
495 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
496 k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_can_implement(
497 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
498 )
499 },
500 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
501 k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_can_implement(
502 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
503 )
504 },
505 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
506 k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_can_implement(
507 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
508 )
509 },
510 (LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
511 k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_can_implement(
512 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
513 )
514 },
515 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
516 k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_can_implement(
517 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
518 )
519 },
520 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
521 k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_can_implement(
522 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
523 )
524 },
525 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
526 k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_can_implement(
527 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
528 )
529 },
530 (LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
531 k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_can_implement(
532 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
533 )
534 },
535 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
536 k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_can_implement(
537 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
538 )
539 },
540 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
541 k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_can_implement(
542 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
543 )
544 },
545 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
546 k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_can_implement(
547 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
548 )
549 },
550 (LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
551 k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_can_implement(
552 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
553 )
554 },
555 _ => 3,
556 }
557 }
558
559 #[cfg(feature = "sm80")]
565 #[allow(clippy::too_many_arguments)]
566 pub(super) unsafe fn gemm_sm80_run(
567 layout: LayoutSku,
568 kind: ElementKind,
569 m: i32,
570 n: i32,
571 k: i32,
572 a: *const c_void,
573 lda: i64,
574 b: *const c_void,
575 ldb: i64,
576 c: *const c_void,
577 ldc: i64,
578 d: *mut c_void,
579 ldd: i64,
580 alpha: f32,
581 beta: f32,
582 workspace: *mut c_void,
583 workspace_bytes: usize,
584 stream: *mut c_void,
585 ) -> i32 {
586 use baracuda_cutlass_kernels_sys as k_sys;
587 match (layout, kind) {
588 (LayoutSku::Rcr, ElementKind::F16) => unsafe {
589 k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_run(
590 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
591 alpha, beta, workspace, workspace_bytes, stream,
592 )
593 },
594 (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
595 k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_run(
596 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
597 alpha, beta, workspace, workspace_bytes, stream,
598 )
599 },
600 (LayoutSku::Rcr, ElementKind::F32) => unsafe {
601 k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_run(
602 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
603 alpha, beta, workspace, workspace_bytes, stream,
604 )
605 },
606 (LayoutSku::Rrr, ElementKind::F16) => unsafe {
607 k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_run(
608 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
609 alpha, beta, workspace, workspace_bytes, stream,
610 )
611 },
612 (LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
613 k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_run(
614 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
615 alpha, beta, workspace, workspace_bytes, stream,
616 )
617 },
618 (LayoutSku::Rrr, ElementKind::F32) => unsafe {
619 k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_run(
620 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
621 alpha, beta, workspace, workspace_bytes, stream,
622 )
623 },
624 (LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
625 k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_run(
626 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
627 alpha, beta, workspace, workspace_bytes, stream,
628 )
629 },
630 (LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
631 k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_run(
632 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
633 alpha, beta, workspace, workspace_bytes, stream,
634 )
635 },
636 (LayoutSku::Rcr, ElementKind::F64)
641 | (LayoutSku::Rrr, ElementKind::F64) => 3,
642 (_, ElementKind::S8) | (_, ElementKind::U8) | (_, ElementKind::I32)
648 | (_, ElementKind::I64)
649 | (_, ElementKind::Bool)
650 | (_, ElementKind::Fp8E4M3)
651 | (_, ElementKind::Fp8E5M2)
652 | (_, ElementKind::S4)
653 | (_, ElementKind::U4)
654 | (_, ElementKind::Bin)
655 | (_, ElementKind::Complex32)
656 | (_, ElementKind::Complex64) => 3,
657 }
658 }
659
660 #[cfg(feature = "sm80")]
661 pub(super) fn gemm_sm80_workspace_size(
662 layout: LayoutSku,
663 kind: ElementKind,
664 m: i32,
665 n: i32,
666 k: i32,
667 ) -> usize {
668 use baracuda_cutlass_kernels_sys as k_sys;
669 match (layout, kind) {
670 (LayoutSku::Rcr, ElementKind::F16) => unsafe {
671 k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_workspace_size(m, n, k)
672 },
673 (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
674 k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_workspace_size(m, n, k)
675 },
676 (LayoutSku::Rcr, ElementKind::F32) => unsafe {
677 k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_workspace_size(m, n, k)
678 },
679 (LayoutSku::Rrr, ElementKind::F16) => unsafe {
680 k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_workspace_size(m, n, k)
681 },
682 (LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
683 k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_workspace_size(m, n, k)
684 },
685 (LayoutSku::Rrr, ElementKind::F32) => unsafe {
686 k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_workspace_size(m, n, k)
687 },
688 (LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
689 k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_workspace_size(m, n, k)
690 },
691 (LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
692 k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_workspace_size(m, n, k)
693 },
694 (LayoutSku::Rcr, ElementKind::F64)
696 | (LayoutSku::Rrr, ElementKind::F64) => 0,
697 (_, ElementKind::S8)
701 | (_, ElementKind::U8)
702 | (_, ElementKind::I32)
703 | (_, ElementKind::I64)
704 | (_, ElementKind::Bool)
705 | (_, ElementKind::Fp8E4M3)
706 | (_, ElementKind::Fp8E5M2)
707 | (_, ElementKind::S4)
708 | (_, ElementKind::U4)
709 | (_, ElementKind::Bin)
710 | (_, ElementKind::Complex32)
711 | (_, ElementKind::Complex64) => 0,
712 }
713 }
714
715 #[cfg(feature = "sm80")]
716 #[allow(clippy::too_many_arguments)]
717 pub(super) unsafe fn gemm_sm80_can_implement(
718 layout: LayoutSku,
719 kind: ElementKind,
720 m: i32,
721 n: i32,
722 k: i32,
723 a: *const c_void,
724 lda: i64,
725 b: *const c_void,
726 ldb: i64,
727 c: *const c_void,
728 ldc: i64,
729 d: *mut c_void,
730 ldd: i64,
731 ) -> i32 {
732 use baracuda_cutlass_kernels_sys as k_sys;
733 match (layout, kind) {
734 (LayoutSku::Rcr, ElementKind::F16) => unsafe {
735 k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_can_implement(
736 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
737 )
738 },
739 (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
740 k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_can_implement(
741 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
742 )
743 },
744 (LayoutSku::Rcr, ElementKind::F32) => unsafe {
745 k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_can_implement(
746 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
747 )
748 },
749 (LayoutSku::Rrr, ElementKind::F16) => unsafe {
750 k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_can_implement(
751 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
752 )
753 },
754 (LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
755 k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_can_implement(
756 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
757 )
758 },
759 (LayoutSku::Rrr, ElementKind::F32) => unsafe {
760 k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_can_implement(
761 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
762 )
763 },
764 (LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
765 k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_can_implement(
766 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
767 )
768 },
769 (LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
770 k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_can_implement(
771 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
772 )
773 },
774 (LayoutSku::Rcr, ElementKind::F64)
776 | (LayoutSku::Rrr, ElementKind::F64) => 3,
777 (_, ElementKind::S8) | (_, ElementKind::U8) | (_, ElementKind::I32)
779 | (_, ElementKind::I64)
780 | (_, ElementKind::Bool)
781 | (_, ElementKind::Fp8E4M3)
782 | (_, ElementKind::Fp8E5M2)
783 | (_, ElementKind::S4)
784 | (_, ElementKind::U4)
785 | (_, ElementKind::Bin)
786 | (_, ElementKind::Complex32)
787 | (_, ElementKind::Complex64) => 3,
788 }
789 }
790
791 #[cfg(feature = "sm80")]
798 #[allow(clippy::too_many_arguments)]
799 pub(super) unsafe fn gemm_sm80_run_f64(
800 layout: LayoutSku,
801 m: i32,
802 n: i32,
803 k: i32,
804 a: *const c_void,
805 lda: i64,
806 b: *const c_void,
807 ldb: i64,
808 c: *const c_void,
809 ldc: i64,
810 d: *mut c_void,
811 ldd: i64,
812 alpha: f64,
813 beta: f64,
814 workspace: *mut c_void,
815 workspace_bytes: usize,
816 stream: *mut c_void,
817 ) -> i32 {
818 use baracuda_cutlass_kernels_sys as k_sys;
819 match layout {
820 LayoutSku::Rcr => unsafe {
821 k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_run(
822 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
823 alpha, beta, workspace, workspace_bytes, stream,
824 )
825 },
826 LayoutSku::Rrr => unsafe {
827 k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_run(
828 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
829 alpha, beta, workspace, workspace_bytes, stream,
830 )
831 },
832 }
833 }
834
835 #[cfg(feature = "sm80")]
836 pub(super) fn gemm_sm80_workspace_size_f64(layout: LayoutSku, m: i32, n: i32, k: i32) -> usize {
837 use baracuda_cutlass_kernels_sys as k_sys;
838 match layout {
839 LayoutSku::Rcr => unsafe {
840 k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_workspace_size(m, n, k)
841 },
842 LayoutSku::Rrr => unsafe {
843 k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_workspace_size(m, n, k)
844 },
845 }
846 }
847
848 #[cfg(feature = "sm80")]
849 #[allow(clippy::too_many_arguments)]
850 pub(super) unsafe fn gemm_sm80_can_implement_f64(
851 layout: LayoutSku,
852 m: i32,
853 n: i32,
854 k: i32,
855 a: *const c_void,
856 lda: i64,
857 b: *const c_void,
858 ldb: i64,
859 c: *const c_void,
860 ldc: i64,
861 d: *mut c_void,
862 ldd: i64,
863 ) -> i32 {
864 use baracuda_cutlass_kernels_sys as k_sys;
865 match layout {
866 LayoutSku::Rcr => unsafe {
867 k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_can_implement(
868 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
869 )
870 },
871 LayoutSku::Rrr => unsafe {
872 k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_can_implement(
873 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
874 )
875 },
876 }
877 }
878
879 #[cfg(feature = "sm80")]
882 #[allow(clippy::too_many_arguments)]
883 pub(super) unsafe fn gemm_bias_sm80_run_f64(
884 layout: LayoutSku,
885 epilogue: EpilogueKind,
886 m: i32,
887 n: i32,
888 k: i32,
889 a: *const c_void,
890 lda: i64,
891 b: *const c_void,
892 ldb: i64,
893 c: *const c_void,
894 ldc: i64,
895 d: *mut c_void,
896 ldd: i64,
897 bias: *const c_void,
898 alpha: f64,
899 beta: f64,
900 workspace: *mut c_void,
901 workspace_bytes: usize,
902 stream: *mut c_void,
903 ) -> i32 {
904 use baracuda_cutlass_kernels_sys as k_sys;
905 match (layout, epilogue) {
906 (LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
907 k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_run(
908 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
909 bias, alpha, beta, workspace, workspace_bytes, stream,
910 )
911 },
912 (LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
913 k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_run(
914 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
915 bias, alpha, beta, workspace, workspace_bytes, stream,
916 )
917 },
918 (LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
919 k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_run(
920 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
921 bias, alpha, beta, workspace, workspace_bytes, stream,
922 )
923 },
924 (LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
925 k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_run(
926 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
927 bias, alpha, beta, workspace, workspace_bytes, stream,
928 )
929 },
930 (LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
931 k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_run(
932 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
933 bias, alpha, beta, workspace, workspace_bytes, stream,
934 )
935 },
936 (LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
937 k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_run(
938 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
939 bias, alpha, beta, workspace, workspace_bytes, stream,
940 )
941 },
942 (LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
943 k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_run(
944 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
945 bias, alpha, beta, workspace, workspace_bytes, stream,
946 )
947 },
948 (LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
949 k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_run(
950 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
951 bias, alpha, beta, workspace, workspace_bytes, stream,
952 )
953 },
954 (_, EpilogueKind::Identity) => 3,
957 }
958 }
959
960 #[cfg(feature = "sm80")]
961 pub(super) fn gemm_bias_sm80_workspace_size_f64(
962 layout: LayoutSku,
963 epilogue: EpilogueKind,
964 m: i32,
965 n: i32,
966 k: i32,
967 ) -> usize {
968 use baracuda_cutlass_kernels_sys as k_sys;
969 match (layout, epilogue) {
970 (LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
971 k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_workspace_size(m, n, k)
972 },
973 (LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
974 k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_workspace_size(m, n, k)
975 },
976 (LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
977 k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_workspace_size(m, n, k)
978 },
979 (LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
980 k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_workspace_size(m, n, k)
981 },
982 (LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
983 k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_workspace_size(m, n, k)
984 },
985 (LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
986 k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_workspace_size(m, n, k)
987 },
988 (LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
989 k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_workspace_size(m, n, k)
990 },
991 (LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
992 k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_workspace_size(m, n, k)
993 },
994 (_, EpilogueKind::Identity) => 0,
995 }
996 }
997
998 #[cfg(feature = "sm80")]
999 #[allow(clippy::too_many_arguments)]
1000 pub(super) unsafe fn gemm_bias_sm80_can_implement_f64(
1001 layout: LayoutSku,
1002 epilogue: EpilogueKind,
1003 m: i32,
1004 n: i32,
1005 k: i32,
1006 a: *const c_void,
1007 lda: i64,
1008 b: *const c_void,
1009 ldb: i64,
1010 c: *const c_void,
1011 ldc: i64,
1012 d: *mut c_void,
1013 ldd: i64,
1014 bias: *const c_void,
1015 ) -> i32 {
1016 use baracuda_cutlass_kernels_sys as k_sys;
1017 match (layout, epilogue) {
1018 (LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
1019 k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_can_implement(
1020 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1021 )
1022 },
1023 (LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
1024 k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_can_implement(
1025 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1026 )
1027 },
1028 (LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
1029 k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_can_implement(
1030 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1031 )
1032 },
1033 (LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
1034 k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_can_implement(
1035 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1036 )
1037 },
1038 (LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
1039 k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_can_implement(
1040 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1041 )
1042 },
1043 (LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
1044 k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_can_implement(
1045 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1046 )
1047 },
1048 (LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
1049 k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_can_implement(
1050 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1051 )
1052 },
1053 (LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
1054 k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_can_implement(
1055 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1056 )
1057 },
1058 (_, EpilogueKind::Identity) => 3,
1059 }
1060 }
1061
1062 #[cfg(feature = "sm80")]
1069 #[allow(clippy::too_many_arguments)]
1070 pub(super) unsafe fn batched_gemm_sm80_run(
1071 layout: LayoutSku,
1072 kind: ElementKind,
1073 m: i32,
1074 n: i32,
1075 k: i32,
1076 a: *const c_void,
1077 lda: i64,
1078 stride_a: i64,
1079 b: *const c_void,
1080 ldb: i64,
1081 stride_b: i64,
1082 c: *const c_void,
1083 ldc: i64,
1084 stride_c: i64,
1085 d: *mut c_void,
1086 ldd: i64,
1087 stride_d: i64,
1088 alpha: f32,
1089 beta: f32,
1090 batch_count: i32,
1091 workspace: *mut c_void,
1092 workspace_bytes: usize,
1093 stream: *mut c_void,
1094 ) -> i32 {
1095 use baracuda_cutlass_kernels_sys as k_sys;
1096 match (layout, kind) {
1097 (LayoutSku::Rcr, ElementKind::F16) => unsafe {
1098 k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_run(
1099 m, n, k,
1100 a, lda, stride_a,
1101 b, ldb, stride_b,
1102 c, ldc, stride_c,
1103 d, ldd, stride_d,
1104 alpha, beta,
1105 batch_count,
1106 workspace, workspace_bytes,
1107 stream,
1108 )
1109 },
1110 (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
1111 k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_run(
1112 m, n, k,
1113 a, lda, stride_a,
1114 b, ldb, stride_b,
1115 c, ldc, stride_c,
1116 d, ldd, stride_d,
1117 alpha, beta,
1118 batch_count,
1119 workspace, workspace_bytes,
1120 stream,
1121 )
1122 },
1123 _ => 3,
1124 }
1125 }
1126
1127 #[cfg(feature = "sm80")]
1128 pub(super) fn batched_gemm_sm80_workspace_size(
1129 layout: LayoutSku,
1130 kind: ElementKind,
1131 m: i32,
1132 n: i32,
1133 k: i32,
1134 batch_count: i32,
1135 ) -> usize {
1136 use baracuda_cutlass_kernels_sys as k_sys;
1137 match (layout, kind) {
1138 (LayoutSku::Rcr, ElementKind::F16) => unsafe {
1139 k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_workspace_size(
1140 m, n, k, batch_count,
1141 )
1142 },
1143 (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
1144 k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_workspace_size(
1145 m, n, k, batch_count,
1146 )
1147 },
1148 _ => 0,
1149 }
1150 }
1151
1152 #[cfg(feature = "sm80")]
1153 #[allow(clippy::too_many_arguments)]
1154 pub(super) unsafe fn batched_gemm_sm80_can_implement(
1155 layout: LayoutSku,
1156 kind: ElementKind,
1157 m: i32,
1158 n: i32,
1159 k: i32,
1160 a: *const c_void,
1161 lda: i64,
1162 stride_a: i64,
1163 b: *const c_void,
1164 ldb: i64,
1165 stride_b: i64,
1166 c: *const c_void,
1167 ldc: i64,
1168 stride_c: i64,
1169 d: *mut c_void,
1170 ldd: i64,
1171 stride_d: i64,
1172 batch_count: i32,
1173 ) -> i32 {
1174 use baracuda_cutlass_kernels_sys as k_sys;
1175 match (layout, kind) {
1176 (LayoutSku::Rcr, ElementKind::F16) => unsafe {
1177 k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_can_implement(
1178 m, n, k,
1179 a, lda, stride_a,
1180 b, ldb, stride_b,
1181 c, ldc, stride_c,
1182 d, ldd, stride_d,
1183 batch_count,
1184 )
1185 },
1186 (LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
1187 k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_can_implement(
1188 m, n, k,
1189 a, lda, stride_a,
1190 b, ldb, stride_b,
1191 c, ldc, stride_c,
1192 d, ldd, stride_d,
1193 batch_count,
1194 )
1195 },
1196 _ => 3,
1197 }
1198 }
1199
1200 #[cfg(feature = "sm80")]
1203 pub(super) unsafe fn grouped_gemm_rcr_sm80_sufficient(
1204 kind: ElementKind,
1205 h_m: *const i32,
1206 h_n: *const i32,
1207 h_k: *const i32,
1208 group_count: i32,
1209 ) -> i32 {
1210 use baracuda_cutlass_kernels_sys as k_sys;
1211 match kind {
1212 ElementKind::F16 => unsafe {
1213 k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient(h_m, h_n, h_k, group_count)
1214 },
1215 ElementKind::Bf16 => unsafe {
1216 k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_sufficient(h_m, h_n, h_k, group_count)
1217 },
1218 ElementKind::F32
1219 | ElementKind::F32Strict
1220 | ElementKind::F64
1221 | ElementKind::S8
1222 | ElementKind::U8
1223 | ElementKind::I32
1224 | ElementKind::I64
1225 | ElementKind::Bool
1226 | ElementKind::Fp8E4M3
1227 | ElementKind::Fp8E5M2
1228 | ElementKind::S4
1229 | ElementKind::U4
1230 | ElementKind::Bin
1231 | ElementKind::Complex32
1232 | ElementKind::Complex64 => 0,
1233 }
1234 }
1235
1236 #[cfg(feature = "sm80")]
1237 pub(super) unsafe fn grouped_gemm_rcr_sm80_scratch_bytes(
1238 kind: ElementKind,
1239 h_m: *const i32,
1240 h_n: *const i32,
1241 h_k: *const i32,
1242 group_count: i32,
1243 threadblock_count: i32,
1244 ) -> usize {
1245 use baracuda_cutlass_kernels_sys as k_sys;
1246 match kind {
1247 ElementKind::F16 => unsafe {
1248 k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_scratch_bytes(
1249 h_m, h_n, h_k, group_count, threadblock_count,
1250 )
1251 },
1252 ElementKind::Bf16 => unsafe {
1253 k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_scratch_bytes(
1254 h_m, h_n, h_k, group_count, threadblock_count,
1255 )
1256 },
1257 ElementKind::F32
1258 | ElementKind::F32Strict
1259 | ElementKind::F64
1260 | ElementKind::S8
1261 | ElementKind::U8
1262 | ElementKind::I32
1263 | ElementKind::I64
1264 | ElementKind::Bool
1265 | ElementKind::Fp8E4M3
1266 | ElementKind::Fp8E5M2
1267 | ElementKind::S4
1268 | ElementKind::U4
1269 | ElementKind::Bin
1270 | ElementKind::Complex32
1271 | ElementKind::Complex64 => 0,
1272 }
1273 }
1274
1275 #[cfg(feature = "sm80")]
1276 pub(super) unsafe fn grouped_gemm_rcr_sm80_can_implement(
1277 kind: ElementKind,
1278 h_m: *const i32,
1279 h_n: *const i32,
1280 h_k: *const i32,
1281 group_count: i32,
1282 ) -> i32 {
1283 use baracuda_cutlass_kernels_sys as k_sys;
1284 match kind {
1285 ElementKind::F16 => unsafe {
1286 k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_can_implement(h_m, h_n, h_k, group_count)
1287 },
1288 ElementKind::Bf16 => unsafe {
1289 k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_can_implement(h_m, h_n, h_k, group_count)
1290 },
1291 ElementKind::F32
1292 | ElementKind::F32Strict
1293 | ElementKind::F64
1294 | ElementKind::S8
1295 | ElementKind::U8
1296 | ElementKind::I32
1297 | ElementKind::I64
1298 | ElementKind::Bool
1299 | ElementKind::Fp8E4M3
1300 | ElementKind::Fp8E5M2
1301 | ElementKind::S4
1302 | ElementKind::U4
1303 | ElementKind::Bin
1304 | ElementKind::Complex32
1305 | ElementKind::Complex64 => 3,
1306 }
1307 }
1308
1309 #[cfg(feature = "sm80")]
1310 #[allow(clippy::too_many_arguments)]
1311 pub(super) unsafe fn grouped_gemm_rcr_sm80_run(
1312 kind: ElementKind,
1313 group_count: i32,
1314 threadblock_count: i32,
1315 d_problem_sizes: *const c_void,
1316 d_ptr_a: *const c_void,
1317 d_ptr_b: *const c_void,
1318 d_ptr_c: *const c_void,
1319 d_ptr_d: *mut c_void,
1320 d_lda: *const c_void,
1321 d_ldb: *const c_void,
1322 d_ldc: *const c_void,
1323 d_ldd: *const c_void,
1324 h_problem_sizes: *const c_void,
1325 alpha: f32,
1326 beta: f32,
1327 scratch: *mut c_void,
1328 scratch_bytes: usize,
1329 stream: *mut c_void,
1330 ) -> i32 {
1331 use baracuda_cutlass_kernels_sys as k_sys;
1332 match kind {
1333 ElementKind::F16 => unsafe {
1334 k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_run(
1335 group_count, threadblock_count,
1336 d_problem_sizes,
1337 d_ptr_a, d_ptr_b, d_ptr_c, d_ptr_d,
1338 d_lda, d_ldb, d_ldc, d_ldd,
1339 h_problem_sizes,
1340 alpha, beta,
1341 scratch, scratch_bytes,
1342 stream,
1343 )
1344 },
1345 ElementKind::Bf16 => unsafe {
1346 k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_run(
1347 group_count, threadblock_count,
1348 d_problem_sizes,
1349 d_ptr_a, d_ptr_b, d_ptr_c, d_ptr_d,
1350 d_lda, d_ldb, d_ldc, d_ldd,
1351 h_problem_sizes,
1352 alpha, beta,
1353 scratch, scratch_bytes,
1354 stream,
1355 )
1356 },
1357 ElementKind::F32
1358 | ElementKind::F32Strict
1359 | ElementKind::F64
1360 | ElementKind::S8
1361 | ElementKind::U8
1362 | ElementKind::I32
1363 | ElementKind::I64
1364 | ElementKind::Bool
1365 | ElementKind::Fp8E4M3
1366 | ElementKind::Fp8E5M2
1367 | ElementKind::S4
1368 | ElementKind::U4
1369 | ElementKind::Bin
1370 | ElementKind::Complex32
1371 | ElementKind::Complex64 => 3,
1372 }
1373 }
1374
1375 #[cfg(feature = "sm80")]
1385 #[allow(clippy::too_many_arguments)]
1386 pub(super) unsafe fn int_gemm_rcr_sm80_run(
1387 layout: LayoutSku,
1388 kind: ElementKind,
1389 m: i32,
1390 n: i32,
1391 k: i32,
1392 a: *const c_void,
1393 lda: i64,
1394 b: *const c_void,
1395 ldb: i64,
1396 c: *const c_void,
1397 ldc: i64,
1398 d: *mut c_void,
1399 ldd: i64,
1400 alpha: f32,
1401 beta: f32,
1402 workspace: *mut c_void,
1403 workspace_bytes: usize,
1404 stream: *mut c_void,
1405 ) -> i32 {
1406 use baracuda_cutlass_kernels_sys as k_sys;
1407 match (layout, kind) {
1408 (LayoutSku::Rcr, ElementKind::S8) => unsafe {
1409 k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_run(
1410 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1411 alpha, beta, workspace, workspace_bytes, stream,
1412 )
1413 },
1414 (LayoutSku::Rcr, ElementKind::U8) => unsafe {
1415 k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_run(
1416 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1417 alpha, beta, workspace, workspace_bytes, stream,
1418 )
1419 },
1420 (LayoutSku::Rrr, ElementKind::S8) | (LayoutSku::Rrr, ElementKind::U8) => 3,
1424 _ => 3,
1426 }
1427 }
1428
1429 #[cfg(feature = "sm80")]
1430 pub(super) fn int_gemm_rcr_sm80_workspace_size(
1431 layout: LayoutSku,
1432 kind: ElementKind,
1433 m: i32,
1434 n: i32,
1435 k: i32,
1436 ) -> usize {
1437 use baracuda_cutlass_kernels_sys as k_sys;
1438 match (layout, kind) {
1439 (LayoutSku::Rcr, ElementKind::S8) => unsafe {
1440 k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_workspace_size(m, n, k)
1441 },
1442 (LayoutSku::Rcr, ElementKind::U8) => unsafe {
1443 k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_workspace_size(m, n, k)
1444 },
1445 _ => 0,
1446 }
1447 }
1448
1449 #[cfg(feature = "sm80")]
1450 #[allow(clippy::too_many_arguments)]
1451 pub(super) unsafe fn int_gemm_rcr_sm80_can_implement(
1452 layout: LayoutSku,
1453 kind: ElementKind,
1454 m: i32,
1455 n: i32,
1456 k: i32,
1457 a: *const c_void,
1458 lda: i64,
1459 b: *const c_void,
1460 ldb: i64,
1461 c: *const c_void,
1462 ldc: i64,
1463 d: *mut c_void,
1464 ldd: i64,
1465 ) -> i32 {
1466 use baracuda_cutlass_kernels_sys as k_sys;
1467 match (layout, kind) {
1468 (LayoutSku::Rcr, ElementKind::S8) => unsafe {
1469 k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_can_implement(
1470 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1471 )
1472 },
1473 (LayoutSku::Rcr, ElementKind::U8) => unsafe {
1474 k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_can_implement(
1475 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1476 )
1477 },
1478 (LayoutSku::Rrr, ElementKind::S8) | (LayoutSku::Rrr, ElementKind::U8) => 3,
1479 _ => 3,
1480 }
1481 }
1482
1483 use crate::types::BiasElementKind;
1494
1495 #[cfg(feature = "sm80")]
1496 #[allow(clippy::too_many_arguments)]
1497 pub(super) unsafe fn int_gemm_bias_rcr_sm80_run(
1498 layout: LayoutSku,
1499 kind: ElementKind,
1500 epilogue: EpilogueKind,
1501 bias_kind: BiasElementKind,
1502 m: i32,
1503 n: i32,
1504 k: i32,
1505 a: *const c_void,
1506 lda: i64,
1507 b: *const c_void,
1508 ldb: i64,
1509 c: *const c_void,
1510 ldc: i64,
1511 d: *mut c_void,
1512 ldd: i64,
1513 bias: *const c_void,
1514 alpha: f32,
1515 beta: f32,
1516 workspace: *mut c_void,
1517 workspace_bytes: usize,
1518 stream: *mut c_void,
1519 ) -> i32 {
1520 use baracuda_cutlass_kernels_sys as k_sys;
1521 if !matches!(layout, LayoutSku::Rcr) {
1523 return 3;
1524 }
1525 match (kind, epilogue, bias_kind) {
1526 (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1528 k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_run(
1529 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1530 bias, alpha, beta, workspace, workspace_bytes, stream,
1531 )
1532 },
1533 (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1534 k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_run(
1535 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1536 bias, alpha, beta, workspace, workspace_bytes, stream,
1537 )
1538 },
1539 (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1540 k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_run(
1541 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1542 bias, alpha, beta, workspace, workspace_bytes, stream,
1543 )
1544 },
1545 (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1546 k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_run(
1547 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1548 bias, alpha, beta, workspace, workspace_bytes, stream,
1549 )
1550 },
1551 (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1553 k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_run(
1554 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1555 bias, alpha, beta, workspace, workspace_bytes, stream,
1556 )
1557 },
1558 (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1559 k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_run(
1560 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1561 bias, alpha, beta, workspace, workspace_bytes, stream,
1562 )
1563 },
1564 (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1565 k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_run(
1566 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1567 bias, alpha, beta, workspace, workspace_bytes, stream,
1568 )
1569 },
1570 (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1571 k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_run(
1572 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1573 bias, alpha, beta, workspace, workspace_bytes, stream,
1574 )
1575 },
1576 (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1578 k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_run(
1579 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1580 bias, alpha, beta, workspace, workspace_bytes, stream,
1581 )
1582 },
1583 (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1584 k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_run(
1585 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1586 bias, alpha, beta, workspace, workspace_bytes, stream,
1587 )
1588 },
1589 (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1590 k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_run(
1591 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1592 bias, alpha, beta, workspace, workspace_bytes, stream,
1593 )
1594 },
1595 (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1596 k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_run(
1597 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1598 bias, alpha, beta, workspace, workspace_bytes, stream,
1599 )
1600 },
1601 (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1603 k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_run(
1604 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1605 bias, alpha, beta, workspace, workspace_bytes, stream,
1606 )
1607 },
1608 (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1609 k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_run(
1610 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1611 bias, alpha, beta, workspace, workspace_bytes, stream,
1612 )
1613 },
1614 (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1615 k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_run(
1616 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1617 bias, alpha, beta, workspace, workspace_bytes, stream,
1618 )
1619 },
1620 (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1621 k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_run(
1622 m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
1623 bias, alpha, beta, workspace, workspace_bytes, stream,
1624 )
1625 },
1626 (_, EpilogueKind::Identity, _) => 3,
1630 _ => 3,
1632 }
1633 }
1634
1635 #[cfg(feature = "sm80")]
1636 pub(super) fn int_gemm_bias_rcr_sm80_workspace_size(
1637 layout: LayoutSku,
1638 kind: ElementKind,
1639 epilogue: EpilogueKind,
1640 bias_kind: BiasElementKind,
1641 m: i32,
1642 n: i32,
1643 k: i32,
1644 ) -> usize {
1645 use baracuda_cutlass_kernels_sys as k_sys;
1646 if !matches!(layout, LayoutSku::Rcr) {
1647 return 0;
1648 }
1649 match (kind, epilogue, bias_kind) {
1650 (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1651 k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1652 },
1653 (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1654 k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1655 },
1656 (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1657 k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1658 },
1659 (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1660 k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
1661 },
1662 (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1663 k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1664 },
1665 (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1666 k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1667 },
1668 (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1669 k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1670 },
1671 (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1672 k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
1673 },
1674 (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1675 k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1676 },
1677 (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1678 k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1679 },
1680 (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1681 k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1682 },
1683 (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1684 k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
1685 },
1686 (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1687 k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1688 },
1689 (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1690 k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1691 },
1692 (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1693 k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1694 },
1695 (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1696 k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
1697 },
1698 _ => 0,
1699 }
1700 }
1701
1702 #[cfg(feature = "sm80")]
1703 #[allow(clippy::too_many_arguments)]
1704 pub(super) unsafe fn int_gemm_bias_rcr_sm80_can_implement(
1705 layout: LayoutSku,
1706 kind: ElementKind,
1707 epilogue: EpilogueKind,
1708 bias_kind: BiasElementKind,
1709 m: i32,
1710 n: i32,
1711 k: i32,
1712 a: *const c_void,
1713 lda: i64,
1714 b: *const c_void,
1715 ldb: i64,
1716 c: *const c_void,
1717 ldc: i64,
1718 d: *mut c_void,
1719 ldd: i64,
1720 bias: *const c_void,
1721 ) -> i32 {
1722 use baracuda_cutlass_kernels_sys as k_sys;
1723 if !matches!(layout, LayoutSku::Rcr) {
1724 return 3;
1725 }
1726 match (kind, epilogue, bias_kind) {
1727 (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1728 k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_can_implement(
1729 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1730 )
1731 },
1732 (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1733 k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_can_implement(
1734 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1735 )
1736 },
1737 (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1738 k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_can_implement(
1739 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1740 )
1741 },
1742 (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1743 k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_can_implement(
1744 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1745 )
1746 },
1747 (ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1748 k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_can_implement(
1749 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1750 )
1751 },
1752 (ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1753 k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_can_implement(
1754 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1755 )
1756 },
1757 (ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1758 k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_can_implement(
1759 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1760 )
1761 },
1762 (ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1763 k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_can_implement(
1764 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1765 )
1766 },
1767 (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
1768 k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_can_implement(
1769 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1770 )
1771 },
1772 (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
1773 k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_can_implement(
1774 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1775 )
1776 },
1777 (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
1778 k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_can_implement(
1779 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1780 )
1781 },
1782 (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
1783 k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_can_implement(
1784 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1785 )
1786 },
1787 (ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
1788 k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_can_implement(
1789 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1790 )
1791 },
1792 (ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
1793 k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_can_implement(
1794 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1795 )
1796 },
1797 (ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
1798 k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_can_implement(
1799 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1800 )
1801 },
1802 (ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
1803 k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_can_implement(
1804 m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
1805 )
1806 },
1807 _ => 3,
1808 }
1809 }
1810}
1811
1812fn min_elements_row_major(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1822 let r = (rows - 1) as i64;
1827 let needed = r.checked_mul(ld)?.checked_add(cols as i64)?;
1828 usize::try_from(needed).ok()
1829}
1830
1831fn min_elements_col_major(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1832 let c = (cols - 1) as i64;
1835 let needed = c.checked_mul(ld)?.checked_add(rows as i64)?;
1836 usize::try_from(needed).ok()
1837}
1838
1839#[cfg(test)]
1843fn min_elements_rcr_a(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1844 min_elements_row_major(rows, cols, ld)
1845}
1846#[cfg(test)]
1847fn min_elements_rcr_b(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1848 min_elements_col_major(rows, cols, ld)
1849}
1850#[cfg(test)]
1851fn min_elements_rcr_cd(rows: i32, cols: i32, ld: i64) -> Option<usize> {
1852 min_elements_row_major(rows, cols, ld)
1853}
1854
1855fn check_descriptor(desc: &GemmDescriptor) -> Result<()> {
1856 if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
1857 return Err(Error::InvalidProblem("M, N, K must all be positive"));
1858 }
1859 Ok(())
1863}
1864
1865fn check_args<T: CutlassElement>(desc: &GemmDescriptor, args: &GemmArgs<'_, T>) -> Result<()> {
1866 match (desc.epilogue.requires_bias(), &args.bias) {
1869 (false, Some(_)) => {
1870 return Err(Error::InvalidProblem(
1871 "args.bias must be None when descriptor.epilogue is Identity",
1872 ));
1873 }
1874 (true, None) => {
1875 return Err(Error::InvalidProblem(
1876 "args.bias is required when descriptor.epilogue is in the Bias family \
1877 (Bias / BiasRelu / BiasGelu / BiasSilu)",
1878 ));
1879 }
1880 (false, None) | (true, Some(_)) => {}
1881 }
1882 if let Some(bias) = &args.bias {
1883 if bias.len != desc.n {
1884 return Err(Error::InvalidProblem(
1885 "bias vector length must equal N",
1886 ));
1887 }
1888 if bias.stride != 1 {
1889 return Err(Error::Unsupported(
1890 "bias vector must be contiguous (stride 1) — strided bias not supported",
1891 ));
1892 }
1893 if bias.data.len() < desc.n as usize {
1894 return Err(Error::BufferTooSmall {
1895 needed: desc.n as usize,
1896 got: bias.data.len(),
1897 });
1898 }
1899 }
1900 if args.a.rows != desc.m || args.a.cols != desc.k {
1901 return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
1902 }
1903 if args.b.rows != desc.k || args.b.cols != desc.n {
1904 return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
1905 }
1906 if args.d.rows != desc.m || args.d.cols != desc.n {
1907 return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
1908 }
1909 if let Some(c) = &args.c {
1910 if c.rows != desc.m || c.cols != desc.n {
1911 return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
1912 }
1913 }
1914 if args.a.ld < desc.k as i64 {
1916 return Err(Error::InvalidProblem("A leading dimension must be >= K"));
1917 }
1918 let b_min_ld = match desc.layout {
1922 LayoutSku::Rcr => desc.k as i64,
1923 LayoutSku::Rrr => desc.n as i64,
1924 };
1925 if args.b.ld < b_min_ld {
1926 return Err(Error::InvalidProblem(match desc.layout {
1927 LayoutSku::Rcr => "B leading dimension must be >= K (column-major Rcr layout)",
1928 LayoutSku::Rrr => "B leading dimension must be >= N (row-major Rrr layout)",
1929 }));
1930 }
1931 if args.d.ld < desc.n as i64 {
1932 return Err(Error::InvalidProblem("D leading dimension must be >= N"));
1933 }
1934 if let Some(c) = &args.c {
1935 if c.ld < desc.n as i64 {
1936 return Err(Error::InvalidProblem("C leading dimension must be >= N"));
1937 }
1938 }
1939
1940 let need_a = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
1941 .ok_or(Error::InvalidProblem("A storage size overflow"))?;
1942 if args.a.data.len() < need_a {
1943 return Err(Error::BufferTooSmall {
1944 needed: need_a,
1945 got: args.a.data.len(),
1946 });
1947 }
1948 let need_b = match desc.layout {
1949 LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
1950 LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
1951 }
1952 .ok_or(Error::InvalidProblem("B storage size overflow"))?;
1953 if args.b.data.len() < need_b {
1954 return Err(Error::BufferTooSmall {
1955 needed: need_b,
1956 got: args.b.data.len(),
1957 });
1958 }
1959 let need_d = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
1960 .ok_or(Error::InvalidProblem("D storage size overflow"))?;
1961 if args.d.data.len() < need_d {
1962 return Err(Error::BufferTooSmall {
1963 needed: need_d,
1964 got: args.d.data.len(),
1965 });
1966 }
1967 if let Some(c) = &args.c {
1968 let need_c = min_elements_row_major(c.rows, c.cols, c.ld)
1969 .ok_or(Error::InvalidProblem("C storage size overflow"))?;
1970 if c.data.len() < need_c {
1971 return Err(Error::BufferTooSmall {
1972 needed: need_c,
1973 got: c.data.len(),
1974 });
1975 }
1976 }
1977 Ok(())
1978}
1979
1980mod cublas_backend {
2013 use core::cell::RefCell;
2014
2015 use baracuda_cublas::Handle as CublasHandle;
2016 use baracuda_driver::Stream;
2017
2018 thread_local! {
2035 static HANDLE_CACHE: RefCell<Vec<(usize, CublasHandle)>> =
2036 const { RefCell::new(Vec::new()) };
2037 }
2038
2039 pub(super) fn handle_for(stream: &Stream) -> crate::Result<CublasHandle> {
2047 let ctx_key = stream.context().as_raw() as usize;
2052 let handle = HANDLE_CACHE.with(|cache| -> crate::Result<CublasHandle> {
2053 let mut cache = cache.borrow_mut();
2054 if let Some((_, h)) = cache.iter().find(|(k, _)| *k == ctx_key) {
2055 return Ok(h.clone());
2056 }
2057 stream
2063 .context()
2064 .set_current()
2065 .map_err(crate::Error::Driver)?;
2066 let h = {
2080 let mut last_err = None;
2081 let mut handle: Option<CublasHandle> = None;
2082 for attempt in 0..5 {
2083 match CublasHandle::new() {
2084 Ok(h) => { handle = Some(h); break }
2085 Err(e) => {
2086 last_err = Some(e);
2087 std::thread::sleep(std::time::Duration::from_millis(
2089 50 * (attempt as u64 + 1),
2090 ));
2091 }
2092 }
2093 }
2094 match handle {
2095 Some(h) => h,
2096 None => {
2097 let _ = last_err; return Err(crate::Error::Unsupported(
2103 "cuBLAS handle creation failed after 5 retries \
2104 (library missing, device unavailable, or \
2105 persistent driver-init contention)",
2106 ));
2107 }
2108 }
2109 };
2110 cache.push((ctx_key, h.clone()));
2111 Ok(h)
2112 })?;
2113 handle
2119 .set_stream(stream)
2120 .map_err(|_| crate::Error::Unsupported(
2121 "cuBLAS set_stream failed",
2122 ))?;
2123 Ok(handle)
2124 }
2125}
2126
2127const CUBLAS_GEMM_ALGO: i32 = -1;
2143
2144#[derive(Copy, Clone, Debug, Eq, PartialEq)]
2153enum BackendChoice {
2154 Cutlass { arch: ArchSku },
2158 Cublas,
2162 Ozaki { slices: u8 },
2168}
2169
2170impl BackendChoice {
2171 fn as_public(self) -> BackendKind {
2172 match self {
2173 BackendChoice::Cutlass { .. } => BackendKind::Cutlass,
2174 BackendChoice::Cublas => BackendKind::Cublas,
2175 BackendChoice::Ozaki { slices } => BackendKind::Ozaki { slices },
2176 }
2177 }
2178}
2179
2180fn should_use_cublas_for_fp(
2209 desc: &GemmDescriptor,
2210 element: ElementKind,
2211) -> bool {
2212 if desc.epilogue.requires_bias() {
2214 return false;
2215 }
2216 match element {
2217 ElementKind::F16 | ElementKind::Bf16 => desc.m >= 2 && desc.m < 128,
2222 ElementKind::F32 => false,
2226 ElementKind::F32Strict => false,
2228 ElementKind::F64 => false,
2231 _ => false,
2235 }
2236}
2237
2238#[cfg_attr(not(feature = "ozimmu"), allow(unused_variables))]
2251fn validate_ozaki_request(
2252 desc: &GemmDescriptor,
2253 element: ElementKind,
2254 slices: u8,
2255) -> Result<()> {
2256 #[cfg(not(feature = "ozimmu"))]
2257 {
2258 return Err(Error::Unsupported(
2259 "PlanPreference::prefer_backend = Some(Ozaki {..}) requires the \
2260 `ozimmu` cargo feature on baracuda-cutlass (off by default — \
2261 enable on baracuda-kernels too if going through the kernels facade)",
2262 ));
2263 }
2264 #[cfg(feature = "ozimmu")]
2265 {
2266 if element != ElementKind::F64 {
2267 return Err(Error::Unsupported(
2268 "BackendKind::Ozaki is FP64-only (Ozaki-scheme synthesizes \
2269 DGEMM from int8; f16/bf16/f32/F32Strict have no Ozaki path)",
2270 ));
2271 }
2272 if desc.epilogue != EpilogueKind::Identity {
2273 return Err(Error::Unsupported(
2274 "BackendKind::Ozaki only supports the Identity epilogue \
2275 (no fused bias / activation chain on the Ozaki path)",
2276 ));
2277 }
2278 let s = slices & 0x1F; let v = slices >> 5; if s != 0 && !(3..=18).contains(&s) {
2288 return Err(Error::Unsupported(
2289 "BackendKind::Ozaki slice count (low 5 bits) must be 0 \
2290 (auto) or 3..=18",
2291 ));
2292 }
2293 if v > 3 {
2294 return Err(Error::Unsupported(
2295 "BackendKind::Ozaki variant (high 3 bits) must be 0 (Base), \
2296 1 (EF), 2 (RN), or 3 (H)",
2297 ));
2298 }
2299 Ok(())
2300 }
2301}
2302
2303fn cublas_dtype_for(kind: ElementKind) -> Option<baracuda_cublas_sys::functions::cudaDataType_t> {
2308 use baracuda_cublas_sys::functions::cudaDataType_t;
2309 match kind {
2310 ElementKind::F16 => Some(cudaDataType_t::R_16F),
2311 ElementKind::Bf16 => Some(cudaDataType_t::R_16BF),
2312 ElementKind::F32 => Some(cudaDataType_t::R_32F),
2313 ElementKind::F64 => Some(cudaDataType_t::R_64F),
2314 _ => None,
2315 }
2316}
2317
2318#[derive(Debug)]
2341pub struct GemmPlan<T: CutlassElement> {
2342 desc: GemmDescriptor,
2343 sku: GemmSku,
2344 backend: BackendChoice,
2345 _element: PhantomData<T>,
2346}
2347
2348impl<T: CutlassElement> GemmPlan<T> {
2349 pub fn select(stream: &Stream, desc: &GemmDescriptor, pref: PlanPreference) -> Result<Self> {
2361 check_descriptor(desc)?;
2362 let element = T::KIND;
2383
2384 if let Some(BackendKind::Ozaki { slices }) = pref.prefer_backend {
2388 validate_ozaki_request(desc, element, slices)?;
2389 let arch_for_sku = pick_arch(stream, desc, pref)?;
2390 let backend = BackendChoice::Ozaki { slices };
2391 let sku = GemmSku {
2392 arch: arch_for_sku,
2393 layout: desc.layout,
2394 epilogue: desc.epilogue,
2395 element,
2396 bias_element: None,
2397 };
2398 return Ok(Self {
2399 desc: *desc,
2400 sku,
2401 backend,
2402 _element: PhantomData,
2403 });
2404 }
2405
2406 let use_cublas = match pref.prefer_backend {
2407 Some(BackendKind::Cublas) => {
2408 if desc.epilogue.requires_bias() {
2413 return Err(Error::Unsupported(
2414 "cuBLAS backend doesn't fuse bias activations \
2415 (use Cutlass backend for Bias* epilogues)",
2416 ));
2417 }
2418 if cublas_dtype_for(element).is_none() {
2419 return Err(Error::Unsupported(
2420 "cuBLAS backend has no GemmEx dtype for this element \
2421 (F32Strict / integer / FP8 stay on Cutlass)",
2422 ));
2423 }
2424 true
2425 }
2426 Some(BackendKind::Cutlass) => false,
2427 Some(BackendKind::Ozaki { .. }) => {
2428 false
2430 }
2431 Some(_) => {
2432 should_use_cublas_for_fp(desc, element)
2435 && cublas_dtype_for(element).is_some()
2436 }
2437 None => {
2438 should_use_cublas_for_fp(desc, element)
2439 && cublas_dtype_for(element).is_some()
2440 }
2441 };
2442
2443 let (backend, sku_arch) = if use_cublas {
2444 let arch_for_sku = pick_arch(stream, desc, pref)?;
2448 (BackendChoice::Cublas, arch_for_sku)
2449 } else {
2450 let arch = pick_arch(stream, desc, pref)?;
2451 (BackendChoice::Cutlass { arch }, arch)
2452 };
2453
2454 let sku = GemmSku {
2455 arch: sku_arch,
2456 layout: desc.layout,
2457 epilogue: desc.epilogue,
2458 element,
2459 bias_element: None,
2464 };
2465 Ok(Self {
2466 desc: *desc,
2467 sku,
2468 backend,
2469 _element: PhantomData,
2470 })
2471 }
2472
2473 pub fn backend(&self) -> BackendKind {
2479 self.backend.as_public()
2480 }
2481
2482 pub fn can_implement(&self, args: &GemmArgs<'_, T>) -> Result<()> {
2495 check_args(&self.desc, args)?;
2496
2497 let a_ptr = args.a.data.as_raw().0 as *const c_void;
2507 let b_ptr = args.b.data.as_raw().0 as *const c_void;
2508 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
2509 let (c_ptr, ldc) = match &args.c {
2510 Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
2511 None => (core::ptr::null(), 0i64),
2512 };
2513 let bias_ptr = args
2514 .bias
2515 .as_ref()
2516 .map(|b| b.data.as_raw().0 as *const c_void)
2517 .unwrap_or(core::ptr::null());
2518
2519 let bias_family = self.sku.epilogue.requires_bias();
2520 let status = match (self.sku.arch, bias_family) {
2521 #[cfg(feature = "sm80")]
2522 (ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2523 dispatch::gemm_sm80_can_implement_f64(
2524 self.sku.layout,
2525 self.desc.m, self.desc.n, self.desc.k,
2526 a_ptr, args.a.ld,
2527 b_ptr, args.b.ld,
2528 c_ptr, ldc,
2529 d_ptr, args.d.ld,
2530 )
2531 },
2532 #[cfg(feature = "sm80")]
2533 (ArchSku::Sm80, false) => unsafe {
2534 dispatch::gemm_sm80_can_implement(
2535 self.sku.layout,
2536 T::KIND,
2537 self.desc.m, self.desc.n, self.desc.k,
2538 a_ptr, args.a.ld,
2539 b_ptr, args.b.ld,
2540 c_ptr, ldc,
2541 d_ptr, args.d.ld,
2542 )
2543 },
2544 #[cfg(feature = "sm80")]
2545 (ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2546 dispatch::gemm_bias_sm80_can_implement_f64(
2547 self.sku.layout,
2548 self.sku.epilogue,
2549 self.desc.m, self.desc.n, self.desc.k,
2550 a_ptr, args.a.ld,
2551 b_ptr, args.b.ld,
2552 c_ptr, ldc,
2553 d_ptr, args.d.ld,
2554 bias_ptr,
2555 )
2556 },
2557 #[cfg(feature = "sm80")]
2558 (ArchSku::Sm80, true) => unsafe {
2559 dispatch::gemm_bias_sm80_can_implement(
2560 self.sku.layout,
2561 T::KIND,
2562 self.sku.epilogue,
2563 self.desc.m, self.desc.n, self.desc.k,
2564 a_ptr, args.a.ld,
2565 b_ptr, args.b.ld,
2566 c_ptr, ldc,
2567 d_ptr, args.d.ld,
2568 bias_ptr,
2569 )
2570 },
2571 #[cfg(not(feature = "sm80"))]
2572 (ArchSku::Sm80, _) => {
2573 return Err(Error::Unsupported(
2574 "sm80 selected but the `sm80` feature isn't enabled",
2575 ));
2576 }
2577 (ArchSku::Sm90a, _) => {
2578 return Err(Error::Unsupported(
2579 "sm90a kernels not yet shipped (deferred until Hopper hardware available for validation)",
2580 ));
2581 }
2582 (ArchSku::Sm89, _) => {
2583 return Err(Error::Unsupported(
2584 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
2585 ));
2586 }
2587 };
2588
2589 status_to_result(status)
2590 }
2591
2592 pub fn workspace_size(&self) -> usize {
2607 let bias_family = self.sku.epilogue.requires_bias();
2608 match (self.sku.arch, bias_family) {
2609 #[cfg(feature = "sm80")]
2610 (ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => {
2611 dispatch::gemm_sm80_workspace_size_f64(
2612 self.sku.layout,
2613 self.desc.m, self.desc.n, self.desc.k,
2614 )
2615 }
2616 #[cfg(feature = "sm80")]
2617 (ArchSku::Sm80, false) => dispatch::gemm_sm80_workspace_size(
2618 self.sku.layout,
2619 T::KIND,
2620 self.desc.m, self.desc.n, self.desc.k,
2621 ),
2622 #[cfg(feature = "sm80")]
2623 (ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => {
2624 dispatch::gemm_bias_sm80_workspace_size_f64(
2625 self.sku.layout,
2626 self.sku.epilogue,
2627 self.desc.m, self.desc.n, self.desc.k,
2628 )
2629 }
2630 #[cfg(feature = "sm80")]
2631 (ArchSku::Sm80, true) => dispatch::gemm_bias_sm80_workspace_size(
2632 self.sku.layout,
2633 T::KIND,
2634 self.sku.epilogue,
2635 self.desc.m, self.desc.n, self.desc.k,
2636 ),
2637 #[cfg(not(feature = "sm80"))]
2638 (ArchSku::Sm80, _) => 0,
2639 (ArchSku::Sm90a, _) => 0,
2640 (ArchSku::Sm89, _) => 0,
2641 }
2642 }
2643
2644 pub fn sku(&self) -> GemmSku {
2646 self.sku
2647 }
2648
2649 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
2657 self.sku.precision_guarantee()
2658 }
2659
2660 pub fn run(
2666 &self,
2667 stream: &Stream,
2668 workspace: Workspace<'_>,
2669 args: GemmArgs<'_, T>,
2670 ) -> Result<()> {
2671 self.can_implement(&args)?;
2672
2673 let needed = self.workspace_size();
2674 let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
2675 Workspace::None => {
2676 if needed != 0 {
2677 return Err(Error::WorkspaceTooSmall {
2678 needed,
2679 got: 0,
2680 });
2681 }
2682 (core::ptr::null_mut(), 0)
2683 }
2684 Workspace::Borrowed(slice) => {
2685 if slice.len() < needed {
2686 return Err(Error::WorkspaceTooSmall {
2687 needed,
2688 got: slice.len(),
2689 });
2690 }
2691 (slice.as_raw().0 as *mut c_void, slice.len())
2692 }
2693 };
2694
2695 let a_ptr = args.a.data.as_raw().0 as *const c_void;
2696 let b_ptr = args.b.data.as_raw().0 as *const c_void;
2697 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
2698 let (c_ptr, ldc) = match &args.c {
2699 Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
2700 None => (core::ptr::null(), 0i64),
2701 };
2702 let bias_ptr = args
2703 .bias
2704 .as_ref()
2705 .map(|b| b.data.as_raw().0 as *const c_void)
2706 .unwrap_or(core::ptr::null());
2707 let beta_eff = if args.c.is_some() { args.beta } else { <T::Scalar as Default>::default() };
2713 let stream_raw = stream.as_raw();
2714
2715 if matches!(self.backend, BackendChoice::Cublas) {
2739 let capturing = stream.is_capturing().unwrap_or(false);
2740 if !capturing {
2741 return self.run_cublas(stream, args, beta_eff);
2742 }
2743 }
2749
2750 #[cfg(feature = "ozimmu")]
2757 if let BackendChoice::Ozaki { slices } = self.backend {
2758 let capturing = stream.is_capturing().unwrap_or(false);
2759 if !capturing {
2760 return self.run_ozaki(stream, args, beta_eff, slices);
2761 }
2762 }
2763 #[cfg(not(feature = "ozimmu"))]
2764 if matches!(self.backend, BackendChoice::Ozaki { .. }) {
2765 return Err(Error::Unsupported(
2768 "BackendChoice::Ozaki selected without `ozimmu` cargo feature",
2769 ));
2770 }
2771
2772 let bias_family = self.sku.epilogue.requires_bias();
2773 let status = match (self.sku.arch, bias_family) {
2774 #[cfg(feature = "sm80")]
2780 (ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2781 dispatch::gemm_sm80_run_f64(
2782 self.sku.layout,
2783 self.desc.m, self.desc.n, self.desc.k,
2784 a_ptr, args.a.ld,
2785 b_ptr, args.b.ld,
2786 c_ptr, ldc,
2787 d_ptr, args.d.ld,
2788 args.alpha.to_f64(),
2789 beta_eff.to_f64(),
2790 ws_ptr, ws_bytes, stream_raw,
2791 )
2792 },
2793 #[cfg(feature = "sm80")]
2794 (ArchSku::Sm80, false) => unsafe {
2795 dispatch::gemm_sm80_run(
2796 self.sku.layout,
2797 T::KIND,
2798 self.desc.m, self.desc.n, self.desc.k,
2799 a_ptr, args.a.ld,
2800 b_ptr, args.b.ld,
2801 c_ptr, ldc,
2802 d_ptr, args.d.ld,
2803 args.alpha.to_f32(),
2804 beta_eff.to_f32(),
2805 ws_ptr, ws_bytes, stream_raw,
2806 )
2807 },
2808 #[cfg(feature = "sm80")]
2809 (ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
2810 dispatch::gemm_bias_sm80_run_f64(
2811 self.sku.layout,
2812 self.sku.epilogue,
2813 self.desc.m, self.desc.n, self.desc.k,
2814 a_ptr, args.a.ld,
2815 b_ptr, args.b.ld,
2816 c_ptr, ldc,
2817 d_ptr, args.d.ld,
2818 bias_ptr,
2819 args.alpha.to_f64(),
2820 beta_eff.to_f64(),
2821 ws_ptr, ws_bytes, stream_raw,
2822 )
2823 },
2824 #[cfg(feature = "sm80")]
2825 (ArchSku::Sm80, true) => unsafe {
2826 dispatch::gemm_bias_sm80_run(
2827 self.sku.layout,
2828 T::KIND,
2829 self.sku.epilogue,
2830 self.desc.m, self.desc.n, self.desc.k,
2831 a_ptr, args.a.ld,
2832 b_ptr, args.b.ld,
2833 c_ptr, ldc,
2834 d_ptr, args.d.ld,
2835 bias_ptr,
2836 args.alpha.to_f32(),
2837 beta_eff.to_f32(),
2838 ws_ptr, ws_bytes, stream_raw,
2839 )
2840 },
2841 #[cfg(not(feature = "sm80"))]
2842 (ArchSku::Sm80, _) => {
2843 return Err(Error::Unsupported(
2844 "sm80 selected but the `sm80` feature isn't enabled",
2845 ));
2846 }
2847 (ArchSku::Sm90a, _) => {
2848 return Err(Error::Unsupported(
2849 "sm90a kernels not yet implemented (Phase 4c)",
2850 ));
2851 }
2852 (ArchSku::Sm89, _) => {
2853 return Err(Error::Unsupported(
2854 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
2855 ));
2856 }
2857 };
2858
2859 status_to_result(status)
2860 }
2861
2862 fn run_cublas(
2871 &self,
2872 stream: &Stream,
2873 args: GemmArgs<'_, T>,
2874 beta_eff: T::Scalar,
2875 ) -> Result<()> {
2876 use baracuda_cublas::Op as CublasOp;
2877 use baracuda_cublas_sys::functions::cublasComputeType_t;
2878
2879 if self.sku.epilogue.requires_bias() {
2882 return Err(Error::Unsupported(
2883 "cuBLAS backend doesn't fuse bias activations \
2884 (caller forced a Bias* epilogue onto the cuBLAS path)",
2885 ));
2886 }
2887
2888 let handle = cublas_backend::handle_for(stream)?;
2889
2890 let m = self.desc.m;
2891 let n = self.desc.n;
2892 let k = self.desc.k;
2893 let a_ptr = args.a.data.as_raw().0 as *const c_void;
2894 let b_ptr = args.b.data.as_raw().0 as *const c_void;
2895 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
2896 let (c_ptr, ldc_arg) = match &args.c {
2902 Some(c) => (c.data.as_raw().0 as *mut c_void, c.ld as i32),
2904 None => (d_ptr, args.d.ld as i32),
2905 };
2906
2907 let (transa, transb) = match self.desc.layout {
2948 LayoutSku::Rcr => (CublasOp::T, CublasOp::N),
2949 LayoutSku::Rrr => (CublasOp::N, CublasOp::N),
2950 };
2951 let cublas_lda = args.b.ld as i32; let cublas_ldb = args.a.ld as i32; let ldd_arg = args.d.ld as i32;
2954
2955 if <T::Scalar as ScalarType>::IS_F64 {
2961 use baracuda_cublas_sys::cublasOperation_t;
2963
2964 let to_raw = |op: CublasOp| match op {
2966 CublasOp::N => cublasOperation_t::N,
2967 CublasOp::T => cublasOperation_t::T,
2968 CublasOp::C => cublasOperation_t::C,
2969 };
2970 let alpha_f64 = args.alpha.to_f64();
2974 let beta_f64 = beta_eff.to_f64();
2975 let c_api = baracuda_cublas_sys::cublas()
2976 .map_err(|_| Error::Unsupported("cuBLAS library unavailable"))?;
2977 let dgemm = c_api
2978 .cublas_dgemm()
2979 .map_err(|_| Error::Unsupported("cublasDgemm symbol unavailable"))?;
2980 let status = unsafe {
2984 dgemm(
2985 handle.as_raw(),
2986 to_raw(transa),
2987 to_raw(transb),
2988 n,
2989 m,
2990 k,
2991 &alpha_f64,
2992 b_ptr as *const f64,
2993 cublas_lda,
2994 a_ptr as *const f64,
2995 cublas_ldb,
2996 &beta_f64,
2997 if args.c.is_some() {
3000 return Err(Error::Unsupported(
3006 "cuBLAS f64 GEMM with explicit C operand is not yet wired \
3007 (D and C alias differently than cuBLAS expects); \
3008 use Cutlass backend or set c = None",
3009 ));
3010 } else {
3011 d_ptr as *mut f64
3012 },
3013 ldd_arg,
3014 )
3015 };
3016 return match status {
3017 baracuda_cublas_sys::cublasStatus_t::SUCCESS => Ok(()),
3018 _ => Err(Error::CutlassInternal(status.0)),
3019 };
3020 }
3021
3022 let dtype = cublas_dtype_for(self.sku.element).ok_or(Error::Unsupported(
3025 "cuBLAS backend selected for element kind without a cuBLAS dtype mapping",
3026 ))?;
3027 let a_type = dtype;
3031 let b_type = dtype;
3032 let c_type = dtype;
3033 let alpha_f32 = args.alpha.to_f32();
3036 let beta_f32 = beta_eff.to_f32();
3037
3038 if args.c.is_some() {
3046 return Err(Error::Unsupported(
3047 "cuBLAS GemmPlan path requires c = None \
3048 (cublasGemmEx writes the output in-place into the C operand; \
3049 explicit-C with D ≠ C requires an extra copy step — \
3050 force Cutlass backend if you need it)",
3051 ));
3052 }
3053 let _ = (c_ptr, ldc_arg); unsafe {
3059 baracuda_cublas::gemm_ex(
3060 &handle,
3061 transa,
3062 transb,
3063 n,
3064 m,
3065 k,
3066 &alpha_f32 as *const f32 as *const c_void,
3067 b_ptr,
3068 b_type,
3069 cublas_lda,
3070 a_ptr,
3071 a_type,
3072 cublas_ldb,
3073 &beta_f32 as *const f32 as *const c_void,
3074 d_ptr,
3075 c_type,
3076 ldd_arg,
3077 cublasComputeType_t::Compute32F,
3078 CUBLAS_GEMM_ALGO,
3079 )
3080 .map_err(|_| Error::CutlassInternal(-1))
3081 }
3082 }
3083
3084 #[cfg(feature = "ozimmu")]
3105 fn run_ozaki(
3106 &self,
3107 stream: &Stream,
3108 args: GemmArgs<'_, T>,
3109 beta_eff: T::Scalar,
3110 slices: u8,
3111 ) -> Result<()> {
3112 use baracuda_ozimmu::{Op as OzakiOp, OzakiSlices, OzakiVariant};
3113
3114 if !<T::Scalar as ScalarType>::IS_F64 {
3115 return Err(Error::Unsupported(
3116 "BackendChoice::Ozaki reached on non-f64 element \
3117 (select() guard should have rejected this)",
3118 ));
3119 }
3120 if args.c.is_some() {
3121 return Err(Error::Unsupported(
3122 "ozIMMU GemmPlan path requires c = None \
3123 (the Ozaki path writes its output in-place into the C \
3124 operand of the underlying cuBLAS GEMM — explicit-C with \
3125 D ≠ C requires an extra copy step that the Phase 44 \
3126 alpha does not yet wire; force Cutlass backend if needed)",
3127 ));
3128 }
3129
3130 let s = slices & 0x1F;
3136 let v = slices >> 5;
3137 let slice_choice = match s {
3138 0 => OzakiSlices::Auto,
3139 3 => OzakiSlices::S3,
3140 4 => OzakiSlices::S4,
3141 5 => OzakiSlices::S5,
3142 6 => OzakiSlices::S6,
3143 7 => OzakiSlices::S7,
3144 8 => OzakiSlices::S8,
3145 9 => OzakiSlices::S9,
3146 10 => OzakiSlices::S10,
3147 11 => OzakiSlices::S11,
3148 12 => OzakiSlices::S12,
3149 13 => OzakiSlices::S13,
3150 14 => OzakiSlices::S14,
3151 15 => OzakiSlices::S15,
3152 16 => OzakiSlices::S16,
3153 17 => OzakiSlices::S17,
3154 18 => OzakiSlices::S18,
3155 _ => {
3156 return Err(Error::Unsupported(
3157 "ozIMMU slice count out of range (validated at select; \
3158 this is unreachable)",
3159 ));
3160 }
3161 };
3162 let variant_choice = match v {
3163 0 => OzakiVariant::Base,
3164 1 => OzakiVariant::EF,
3165 2 => OzakiVariant::RN,
3166 3 => OzakiVariant::H,
3167 _ => {
3168 return Err(Error::Unsupported(
3169 "ozIMMU variant out of range (validated at select; \
3170 this is unreachable)",
3171 ));
3172 }
3173 };
3174
3175 let handle = ozimmu_backend::handle_for(stream)?;
3176
3177 let (transa, transb) = match self.desc.layout {
3186 LayoutSku::Rcr => (OzakiOp::T, OzakiOp::N),
3187 LayoutSku::Rrr => (OzakiOp::N, OzakiOp::N),
3188 };
3189 let m = self.desc.m as usize;
3190 let n = self.desc.n as usize;
3191 let k = self.desc.k as usize;
3192 let lda = args.b.ld as usize; let ldb = args.a.ld as usize; let ldc = args.d.ld as usize;
3195
3196 let a_ptr = args.a.data.as_raw().0 as *const f64;
3197 let b_ptr = args.b.data.as_raw().0 as *const f64;
3198 let d_ptr = args.d.data.as_raw().0 as *mut f64;
3199 let alpha = args.alpha.to_f64();
3200 let beta = beta_eff.to_f64();
3201
3202 unsafe {
3210 handle.dgemm_with_variant(
3211 transa, transb,
3212 n, m, k,
3214 alpha,
3215 b_ptr, lda,
3216 a_ptr, ldb,
3217 beta,
3218 d_ptr, ldc,
3219 slice_choice,
3220 variant_choice,
3221 )
3222 .map_err(|e| {
3223 use baracuda_ozimmu::Error as OzErr;
3224 match e {
3225 OzErr::DgemmFailed(s) => Error::CutlassInternal(s),
3226 _ => Error::Unsupported(
3227 "ozIMMU dgemm rejected the request (see logs)",
3228 ),
3229 }
3230 })
3231 }
3232 }
3233}
3234
3235#[cfg(feature = "ozimmu")]
3248mod ozimmu_backend {
3249 use core::cell::RefCell;
3250 use std::rc::Rc;
3251
3252 use baracuda_driver::Stream;
3253 use baracuda_ozimmu::Handle as OzimmuHandle;
3254
3255 thread_local! {
3256 static HANDLE_CACHE: RefCell<Vec<(usize, Rc<OzimmuHandle>)>> =
3257 const { RefCell::new(Vec::new()) };
3258 }
3259
3260 pub(super) fn handle_for(stream: &Stream) -> crate::Result<Rc<OzimmuHandle>> {
3263 let ctx_key = stream.context().as_raw() as usize;
3264 let handle = HANDLE_CACHE.with(|cache| -> crate::Result<Rc<OzimmuHandle>> {
3265 let mut cache = cache.borrow_mut();
3266 if let Some((_, h)) = cache.iter().find(|(k, _)| *k == ctx_key) {
3267 return Ok(h.clone());
3268 }
3269 stream
3273 .context()
3274 .set_current()
3275 .map_err(crate::Error::Driver)?;
3276 let mut last_status: Option<i32> = None;
3281 let mut handle: Option<OzimmuHandle> = None;
3282 for attempt in 0..5 {
3283 match OzimmuHandle::new() {
3284 Ok(h) => { handle = Some(h); break }
3285 Err(e) => {
3286 if let baracuda_ozimmu::Error::CreateFailed(s) = e {
3287 last_status = Some(s);
3288 }
3289 std::thread::sleep(std::time::Duration::from_millis(
3290 50 * (attempt as u64 + 1),
3291 ));
3292 }
3293 }
3294 }
3295 let h = match handle {
3296 Some(h) => h,
3297 None => {
3298 let _ = last_status;
3299 return Err(crate::Error::Unsupported(
3300 "ozIMMU handle creation failed after 5 retries \
3301 (library missing, device unavailable, or persistent \
3302 init contention)",
3303 ));
3304 }
3305 };
3306 let rc = Rc::new(h);
3307 cache.push((ctx_key, rc.clone()));
3308 Ok(rc)
3309 })?;
3310 handle.set_stream(stream);
3311 Ok(handle)
3312 }
3313}
3314
3315fn check_batched_descriptor(desc: &BatchedGemmDescriptor) -> Result<()> {
3326 if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
3327 return Err(Error::InvalidProblem("M, N, K must all be positive"));
3328 }
3329 if desc.batch_count <= 0 {
3330 return Err(Error::InvalidProblem("batch_count must be positive"));
3331 }
3332 if desc.epilogue != EpilogueKind::Identity {
3333 return Err(Error::Unsupported(
3334 "BatchedGemmPlan v1 supports only EpilogueKind::Identity",
3335 ));
3336 }
3337 Ok(())
3338}
3339
3340fn check_batched_args<T: CutlassElement>(
3341 desc: &BatchedGemmDescriptor,
3342 args: &BatchedGemmArgs<'_, T>,
3343) -> Result<()> {
3344 if args.a.rows != desc.m || args.a.cols != desc.k {
3348 return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
3349 }
3350 if args.b.rows != desc.k || args.b.cols != desc.n {
3351 return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
3352 }
3353 if args.d.rows != desc.m || args.d.cols != desc.n {
3354 return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
3355 }
3356 if let Some(c) = &args.c {
3357 if c.rows != desc.m || c.cols != desc.n {
3358 return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
3359 }
3360 }
3361 if args.a.ld < desc.k as i64 {
3362 return Err(Error::InvalidProblem("A leading dimension must be >= K"));
3363 }
3364 let b_min_ld = match desc.layout {
3365 LayoutSku::Rcr => desc.k as i64,
3366 LayoutSku::Rrr => desc.n as i64,
3367 };
3368 if args.b.ld < b_min_ld {
3369 return Err(Error::InvalidProblem("B leading dimension too small for layout"));
3370 }
3371 if args.d.ld < desc.n as i64 {
3372 return Err(Error::InvalidProblem("D leading dimension must be >= N"));
3373 }
3374 if let Some(c) = &args.c {
3375 if c.ld < desc.n as i64 {
3376 return Err(Error::InvalidProblem("C leading dimension must be >= N"));
3377 }
3378 }
3379
3380 fn need_for_batches(
3384 per_batch_min: usize,
3385 stride: i64,
3386 batch_count: i32,
3387 ) -> Option<usize> {
3388 if batch_count <= 1 || stride == 0 {
3389 return Some(per_batch_min);
3390 }
3391 let extra = stride.checked_mul((batch_count - 1) as i64)?;
3392 let extra = usize::try_from(extra).ok()?;
3393 per_batch_min.checked_add(extra)
3394 }
3395
3396 let a_per = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
3397 .ok_or(Error::InvalidProblem("A storage size overflow"))?;
3398 let need_a = need_for_batches(a_per, args.stride_a, desc.batch_count)
3399 .ok_or(Error::InvalidProblem("A batched storage size overflow"))?;
3400 if args.a.data.len() < need_a {
3401 return Err(Error::BufferTooSmall {
3402 needed: need_a,
3403 got: args.a.data.len(),
3404 });
3405 }
3406
3407 let b_per = match desc.layout {
3408 LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
3409 LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
3410 }
3411 .ok_or(Error::InvalidProblem("B storage size overflow"))?;
3412 let need_b = need_for_batches(b_per, args.stride_b, desc.batch_count)
3413 .ok_or(Error::InvalidProblem("B batched storage size overflow"))?;
3414 if args.b.data.len() < need_b {
3415 return Err(Error::BufferTooSmall {
3416 needed: need_b,
3417 got: args.b.data.len(),
3418 });
3419 }
3420
3421 let d_per = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
3422 .ok_or(Error::InvalidProblem("D storage size overflow"))?;
3423 let need_d = need_for_batches(d_per, args.stride_d, desc.batch_count)
3424 .ok_or(Error::InvalidProblem("D batched storage size overflow"))?;
3425 if args.d.data.len() < need_d {
3426 return Err(Error::BufferTooSmall {
3427 needed: need_d,
3428 got: args.d.data.len(),
3429 });
3430 }
3431
3432 if let Some(c) = &args.c {
3433 let c_per = min_elements_row_major(c.rows, c.cols, c.ld)
3434 .ok_or(Error::InvalidProblem("C storage size overflow"))?;
3435 let need_c = need_for_batches(c_per, args.stride_c, desc.batch_count)
3436 .ok_or(Error::InvalidProblem("C batched storage size overflow"))?;
3437 if c.data.len() < need_c {
3438 return Err(Error::BufferTooSmall {
3439 needed: need_c,
3440 got: c.data.len(),
3441 });
3442 }
3443 }
3444 Ok(())
3445}
3446
3447#[derive(Debug)]
3454pub struct BatchedGemmPlan<T: CutlassElement> {
3455 desc: BatchedGemmDescriptor,
3456 sku: GemmSku,
3457 _element: PhantomData<T>,
3458}
3459
3460impl<T: CutlassElement> BatchedGemmPlan<T> {
3461 pub fn select(
3467 stream: &Stream,
3468 desc: &BatchedGemmDescriptor,
3469 pref: PlanPreference,
3470 ) -> Result<Self> {
3471 check_batched_descriptor(desc)?;
3472 let one_off_desc = GemmDescriptor {
3473 m: desc.m,
3474 n: desc.n,
3475 k: desc.k,
3476 layout: desc.layout,
3477 epilogue: desc.epilogue,
3478 };
3479 let arch = pick_arch(stream, &one_off_desc, pref)?;
3480 match (desc.layout, T::KIND) {
3482 (LayoutSku::Rcr, ElementKind::F16) | (LayoutSku::Rcr, ElementKind::Bf16) => {}
3483 _ => {
3484 return Err(Error::Unsupported(
3485 "BatchedGemmPlan v1 only ships Rcr × {F16, Bf16} on sm_80",
3486 ));
3487 }
3488 }
3489 let sku = GemmSku {
3490 arch,
3491 layout: desc.layout,
3492 epilogue: desc.epilogue,
3493 element: T::KIND,
3494 bias_element: None,
3499 };
3500 Ok(Self {
3501 desc: *desc,
3502 sku,
3503 _element: PhantomData,
3504 })
3505 }
3506
3507 pub fn can_implement(&self, args: &BatchedGemmArgs<'_, T>) -> Result<()> {
3510 check_batched_args(&self.desc, args)?;
3511
3512 let a_ptr = args.a.data.as_raw().0 as *const c_void;
3513 let b_ptr = args.b.data.as_raw().0 as *const c_void;
3514 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
3515 let (c_ptr, ldc, stride_c) = match &args.c {
3516 Some(c) => (c.data.as_raw().0 as *const c_void, c.ld, args.stride_c),
3517 None => (core::ptr::null(), 0i64, 0i64),
3518 };
3519
3520 let status = match self.sku.arch {
3521 #[cfg(feature = "sm80")]
3522 ArchSku::Sm80 => unsafe {
3523 dispatch::batched_gemm_sm80_can_implement(
3524 self.sku.layout,
3525 T::KIND,
3526 self.desc.m,
3527 self.desc.n,
3528 self.desc.k,
3529 a_ptr,
3530 args.a.ld,
3531 args.stride_a,
3532 b_ptr,
3533 args.b.ld,
3534 args.stride_b,
3535 c_ptr,
3536 ldc,
3537 stride_c,
3538 d_ptr,
3539 args.d.ld,
3540 args.stride_d,
3541 self.desc.batch_count,
3542 )
3543 },
3544 #[cfg(not(feature = "sm80"))]
3545 ArchSku::Sm80 => {
3546 return Err(Error::Unsupported(
3547 "sm80 selected but the `sm80` feature isn't enabled",
3548 ));
3549 }
3550 ArchSku::Sm90a => {
3551 return Err(Error::Unsupported(
3552 "sm90a batched kernels not yet shipped",
3553 ));
3554 }
3555 ArchSku::Sm89 => {
3556 return Err(Error::Unsupported(
3557 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
3558 ));
3559 }
3560 };
3561
3562 status_to_result(status)
3563 }
3564
3565 pub fn workspace_size(&self) -> usize {
3567 match self.sku.arch {
3568 #[cfg(feature = "sm80")]
3569 ArchSku::Sm80 => dispatch::batched_gemm_sm80_workspace_size(
3570 self.sku.layout,
3571 T::KIND,
3572 self.desc.m,
3573 self.desc.n,
3574 self.desc.k,
3575 self.desc.batch_count,
3576 ),
3577 #[cfg(not(feature = "sm80"))]
3578 ArchSku::Sm80 => 0,
3579 ArchSku::Sm90a => 0,
3580 ArchSku::Sm89 => 0,
3581 }
3582 }
3583
3584 pub fn sku(&self) -> GemmSku {
3586 self.sku
3587 }
3588
3589 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
3591 self.sku.precision_guarantee()
3592 }
3593
3594 pub fn run(
3596 &self,
3597 stream: &Stream,
3598 workspace: Workspace<'_>,
3599 args: BatchedGemmArgs<'_, T>,
3600 ) -> Result<()> {
3601 self.can_implement(&args)?;
3602
3603 let needed = self.workspace_size();
3604 let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
3605 Workspace::None => {
3606 if needed != 0 {
3607 return Err(Error::WorkspaceTooSmall { needed, got: 0 });
3608 }
3609 (core::ptr::null_mut(), 0)
3610 }
3611 Workspace::Borrowed(slice) => {
3612 if slice.len() < needed {
3613 return Err(Error::WorkspaceTooSmall {
3614 needed,
3615 got: slice.len(),
3616 });
3617 }
3618 (slice.as_raw().0 as *mut c_void, slice.len())
3619 }
3620 };
3621
3622 let a_ptr = args.a.data.as_raw().0 as *const c_void;
3623 let b_ptr = args.b.data.as_raw().0 as *const c_void;
3624 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
3625 let (c_ptr, ldc, stride_c) = match &args.c {
3626 Some(c) => (c.data.as_raw().0 as *const c_void, c.ld, args.stride_c),
3627 None => (core::ptr::null(), 0i64, 0i64),
3628 };
3629 let beta_eff = if args.c.is_some() { args.beta } else { <T::Scalar as Default>::default() };
3630 let stream_raw = stream.as_raw();
3631
3632 let status = match self.sku.arch {
3633 #[cfg(feature = "sm80")]
3634 ArchSku::Sm80 => unsafe {
3635 dispatch::batched_gemm_sm80_run(
3639 self.sku.layout,
3640 T::KIND,
3641 self.desc.m,
3642 self.desc.n,
3643 self.desc.k,
3644 a_ptr,
3645 args.a.ld,
3646 args.stride_a,
3647 b_ptr,
3648 args.b.ld,
3649 args.stride_b,
3650 c_ptr,
3651 ldc,
3652 stride_c,
3653 d_ptr,
3654 args.d.ld,
3655 args.stride_d,
3656 args.alpha.to_f32(),
3657 beta_eff.to_f32(),
3658 self.desc.batch_count,
3659 ws_ptr,
3660 ws_bytes,
3661 stream_raw,
3662 )
3663 },
3664 #[cfg(not(feature = "sm80"))]
3665 ArchSku::Sm80 => {
3666 return Err(Error::Unsupported(
3667 "sm80 selected but the `sm80` feature isn't enabled",
3668 ));
3669 }
3670 ArchSku::Sm90a => {
3671 return Err(Error::Unsupported(
3672 "sm90a batched kernels not yet shipped",
3673 ));
3674 }
3675 ArchSku::Sm89 => {
3676 return Err(Error::Unsupported(
3677 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
3678 ));
3679 }
3680 };
3681
3682 status_to_result(status)
3683 }
3684}
3685
3686fn pick_arch(
3687 stream: &Stream,
3688 _desc: &GemmDescriptor,
3689 pref: PlanPreference,
3690) -> Result<ArchSku> {
3691 let (major, _minor) = stream.context().device().compute_capability()?;
3703
3704 if pref.allow_sm90a && cfg!(feature = "sm90a") && major >= 9 {
3705 return Ok(ArchSku::Sm90a);
3706 }
3707
3708 if cfg!(feature = "sm80") {
3709 if major >= 8 {
3711 return Ok(ArchSku::Sm80);
3712 }
3713 return Err(Error::Unsupported(
3714 "device compute capability < 8.0; sm_80 kernels won't run here",
3715 ));
3716 }
3717
3718 Err(Error::Unsupported(
3719 "no arch features enabled — build with --features sm80",
3720 ))
3721}
3722
3723const COORD_BYTES: usize = 12; const PTR_BYTES: usize = 8; const LD_BYTES: usize = 8; const SCRATCH_ALIGN: usize = 256; #[inline]
3757fn align_up(x: usize, align: usize) -> usize {
3758 (x + align - 1) & !(align - 1)
3759}
3760
3761#[derive(Copy, Clone, Debug)]
3763struct MetadataLayout {
3764 problem_sizes_offset: usize,
3765 ptr_a_offset: usize,
3766 ptr_b_offset: usize,
3767 ptr_c_offset: usize,
3768 ptr_d_offset: usize,
3769 lda_offset: usize,
3770 ldb_offset: usize,
3771 ldc_offset: usize,
3772 ldd_offset: usize,
3773 metadata_end: usize,
3775 scratch_offset: usize,
3777 total_workspace_bytes: usize,
3779}
3780
3781impl MetadataLayout {
3782 fn compute(group_count: usize, scratch_bytes: usize) -> Self {
3783 let mut off = 0usize;
3784 let problem_sizes_offset = off;
3785 off += COORD_BYTES * group_count;
3786 off = align_up(off, 8);
3787
3788 let ptr_a_offset = off;
3789 off += PTR_BYTES * group_count;
3790 let ptr_b_offset = off;
3791 off += PTR_BYTES * group_count;
3792 let ptr_c_offset = off;
3793 off += PTR_BYTES * group_count;
3794 let ptr_d_offset = off;
3795 off += PTR_BYTES * group_count;
3796 let lda_offset = off;
3797 off += LD_BYTES * group_count;
3798 let ldb_offset = off;
3799 off += LD_BYTES * group_count;
3800 let ldc_offset = off;
3801 off += LD_BYTES * group_count;
3802 let ldd_offset = off;
3803 off += LD_BYTES * group_count;
3804 let metadata_end = off;
3805
3806 let scratch_offset = align_up(metadata_end, SCRATCH_ALIGN);
3807 let total_workspace_bytes = scratch_offset + scratch_bytes;
3808
3809 Self {
3810 problem_sizes_offset,
3811 ptr_a_offset,
3812 ptr_b_offset,
3813 ptr_c_offset,
3814 ptr_d_offset,
3815 lda_offset,
3816 ldb_offset,
3817 ldc_offset,
3818 ldd_offset,
3819 metadata_end,
3820 scratch_offset,
3821 total_workspace_bytes,
3822 }
3823 }
3824}
3825
3826#[derive(Debug)]
3837pub struct GroupedGemmPlan<T: CutlassElement> {
3838 sku: GemmSku,
3839 schedule: GroupedScheduleMode,
3840 context: Context,
3841 _element: PhantomData<T>,
3842}
3843
3844impl<T: CutlassElement> GroupedGemmPlan<T> {
3845 pub fn select(
3850 stream: &Stream,
3851 epilogue: EpilogueKind,
3852 pref: GroupedPlanPreference,
3853 ) -> Result<Self> {
3854 if epilogue != EpilogueKind::Identity {
3855 return Err(Error::Unsupported(
3856 "v0 grouped GEMM supports only EpilogueKind::Identity",
3857 ));
3858 }
3859
3860 let dummy_desc = GemmDescriptor {
3861 m: 1,
3862 n: 1,
3863 k: 1,
3864 layout: LayoutSku::Rcr,
3865 epilogue,
3866 };
3867 let arch = pick_arch(stream, &dummy_desc, pref.base)?;
3868 let sku = GemmSku {
3869 arch,
3870 layout: LayoutSku::Rcr,
3871 epilogue,
3872 element: T::KIND,
3873 bias_element: None,
3874 };
3875 Ok(Self {
3876 sku,
3877 schedule: pref.schedule,
3878 context: stream.context().clone(),
3879 _element: PhantomData,
3880 })
3881 }
3882
3883 pub fn sku(&self) -> GemmSku {
3885 self.sku
3886 }
3887
3888 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
3892 self.sku.precision_guarantee()
3893 }
3894
3895 pub fn schedule(&self) -> GroupedScheduleMode {
3897 self.schedule
3898 }
3899
3900 pub fn prepare<'a, 'g>(
3912 &'a self,
3913 groups: &'g [GroupedProblem<'g, T>],
3914 ) -> Result<PreparedGroupedGemm<'a, T>> {
3915 if groups.is_empty() {
3916 return Err(Error::InvalidProblem("grouped GEMM requires at least one group"));
3917 }
3918
3919 let first_alpha = groups[0].alpha;
3924 let first_beta = groups[0].beta;
3925 let first_has_c = groups[0].c.is_some();
3926 for g in groups {
3927 if g.m <= 0 || g.n <= 0 || g.k <= 0 {
3928 return Err(Error::InvalidProblem("group M, N, K must all be positive"));
3929 }
3930 if g.a.rows != g.m || g.a.cols != g.k {
3931 return Err(Error::InvalidProblem("group A shape doesn't match (M, K)"));
3932 }
3933 if g.b.rows != g.k || g.b.cols != g.n {
3934 return Err(Error::InvalidProblem("group B shape doesn't match (K, N)"));
3935 }
3936 if g.d.rows != g.m || g.d.cols != g.n {
3937 return Err(Error::InvalidProblem("group D shape doesn't match (M, N)"));
3938 }
3939 if let Some(c) = &g.c {
3940 if c.rows != g.m || c.cols != g.n {
3941 return Err(Error::InvalidProblem("group C shape doesn't match (M, N)"));
3942 }
3943 }
3944 if g.a.ld < g.k as i64 || g.b.ld < g.k as i64 || g.d.ld < g.n as i64 {
3945 return Err(Error::InvalidProblem("group leading dimension too small"));
3946 }
3947 if g.alpha != first_alpha {
3948 return Err(Error::Unsupported(
3949 "v0 grouped GEMM requires all groups to share alpha",
3950 ));
3951 }
3952 if g.beta != first_beta {
3953 return Err(Error::Unsupported(
3954 "v0 grouped GEMM requires all groups to share beta",
3955 ));
3956 }
3957 if g.c.is_some() != first_has_c {
3958 return Err(Error::Unsupported(
3959 "v0 grouped GEMM requires all groups to consistently have c=None or c=Some",
3960 ));
3961 }
3962 }
3963
3964 let group_count = groups.len();
3967 let mut h_m: Vec<i32> = Vec::with_capacity(group_count);
3968 let mut h_n: Vec<i32> = Vec::with_capacity(group_count);
3969 let mut h_k: Vec<i32> = Vec::with_capacity(group_count);
3970 for g in groups {
3971 h_m.push(g.m);
3972 h_n.push(g.n);
3973 h_k.push(g.k);
3974 }
3975
3976 let kind = T::KIND;
3977 let group_count_i32 = group_count as i32;
3978
3979 let ci_status = match self.sku.arch {
3981 #[cfg(feature = "sm80")]
3982 ArchSku::Sm80 => unsafe {
3983 dispatch::grouped_gemm_rcr_sm80_can_implement(
3984 kind,
3985 h_m.as_ptr(),
3986 h_n.as_ptr(),
3987 h_k.as_ptr(),
3988 group_count_i32,
3989 )
3990 },
3991 #[cfg(not(feature = "sm80"))]
3992 ArchSku::Sm80 => {
3993 return Err(Error::Unsupported(
3994 "sm80 selected but the `sm80` feature isn't enabled",
3995 ));
3996 }
3997 ArchSku::Sm90a => {
3998 return Err(Error::Unsupported(
3999 "sm90a grouped kernels not yet shipped (deferred until Hopper hardware available)",
4000 ));
4001 }
4002 ArchSku::Sm89 => {
4003 return Err(Error::Unsupported(
4004 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4005 ));
4006 }
4007 };
4008 status_to_result(ci_status)?;
4009
4010 let threadblock_count = match self.sku.arch {
4012 #[cfg(feature = "sm80")]
4013 ArchSku::Sm80 => unsafe {
4014 dispatch::grouped_gemm_rcr_sm80_sufficient(
4015 kind,
4016 h_m.as_ptr(),
4017 h_n.as_ptr(),
4018 h_k.as_ptr(),
4019 group_count_i32,
4020 )
4021 },
4022 #[cfg(not(feature = "sm80"))]
4023 ArchSku::Sm80 => 0,
4024 ArchSku::Sm90a => 0,
4025 ArchSku::Sm89 => 0,
4026 };
4027 if threadblock_count <= 0 {
4028 return Err(Error::CutlassInternal(threadblock_count));
4029 }
4030
4031 let scratch_bytes = match self.sku.arch {
4032 #[cfg(feature = "sm80")]
4033 ArchSku::Sm80 => unsafe {
4034 dispatch::grouped_gemm_rcr_sm80_scratch_bytes(
4035 kind,
4036 h_m.as_ptr(),
4037 h_n.as_ptr(),
4038 h_k.as_ptr(),
4039 group_count_i32,
4040 threadblock_count,
4041 )
4042 },
4043 #[cfg(not(feature = "sm80"))]
4044 ArchSku::Sm80 => 0,
4045 ArchSku::Sm90a => 0,
4046 ArchSku::Sm89 => 0,
4047 };
4048
4049 let layout = MetadataLayout::compute(group_count, scratch_bytes);
4050
4051 let mut pinned: PinnedBuffer<u8> = PinnedBuffer::new(&self.context, layout.metadata_end)?;
4056
4057 let ptr_a: Vec<u64> = groups.iter().map(|g| g.a.data.as_raw().0).collect();
4061 let ptr_b: Vec<u64> = groups.iter().map(|g| g.b.data.as_raw().0).collect();
4062 let ptr_d: Vec<u64> = groups.iter().map(|g| g.d.data.as_raw().0).collect();
4063 let ptr_c: Vec<u64> = groups
4067 .iter()
4068 .map(|g| {
4069 g.c.as_ref()
4070 .map(|c| c.data.as_raw().0)
4071 .unwrap_or_else(|| g.d.data.as_raw().0)
4072 })
4073 .collect();
4074 let lda: Vec<i64> = groups.iter().map(|g| g.a.ld).collect();
4075 let ldb: Vec<i64> = groups.iter().map(|g| g.b.ld).collect();
4076 let ldd: Vec<i64> = groups.iter().map(|g| g.d.ld).collect();
4077 let ldc: Vec<i64> = groups
4078 .iter()
4079 .map(|g| g.c.as_ref().map(|c| c.ld).unwrap_or(g.d.ld))
4080 .collect();
4081
4082 {
4085 let host_packed: &mut [u8] = &mut pinned;
4086
4087 let mut p = layout.problem_sizes_offset;
4088 for g in groups {
4089 host_packed[p..p + 4].copy_from_slice(&g.m.to_ne_bytes());
4090 host_packed[p + 4..p + 8].copy_from_slice(&g.n.to_ne_bytes());
4091 host_packed[p + 8..p + 12].copy_from_slice(&g.k.to_ne_bytes());
4092 p += COORD_BYTES;
4093 }
4094
4095 let pack_ptrs = |dst: &mut [u8], offset: usize, ptrs: &[u64]| {
4096 let mut p = offset;
4097 for &val in ptrs {
4098 dst[p..p + 8].copy_from_slice(&val.to_ne_bytes());
4099 p += PTR_BYTES;
4100 }
4101 };
4102 pack_ptrs(host_packed, layout.ptr_a_offset, &ptr_a);
4103 pack_ptrs(host_packed, layout.ptr_b_offset, &ptr_b);
4104 pack_ptrs(host_packed, layout.ptr_c_offset, &ptr_c);
4105 pack_ptrs(host_packed, layout.ptr_d_offset, &ptr_d);
4106
4107 let pack_lds = |dst: &mut [u8], offset: usize, lds: &[i64]| {
4108 let mut p = offset;
4109 for &val in lds {
4110 dst[p..p + 8].copy_from_slice(&val.to_ne_bytes());
4111 p += LD_BYTES;
4112 }
4113 };
4114 pack_lds(host_packed, layout.lda_offset, &lda);
4115 pack_lds(host_packed, layout.ldb_offset, &ldb);
4116 pack_lds(host_packed, layout.ldc_offset, &ldc);
4117 pack_lds(host_packed, layout.ldd_offset, &ldd);
4118 }
4119
4120 let mut host_problem_sizes: Vec<i32> = Vec::with_capacity(group_count * 3);
4125 for g in groups {
4126 host_problem_sizes.push(g.m);
4127 host_problem_sizes.push(g.n);
4128 host_problem_sizes.push(g.k);
4129 }
4130
4131 let beta_eff = if first_has_c { first_beta } else { <T::Scalar as Default>::default() };
4132
4133 Ok(PreparedGroupedGemm {
4137 plan: self,
4138 pinned,
4139 host_problem_sizes,
4140 layout,
4141 threadblock_count,
4142 alpha: first_alpha.to_f32(),
4143 beta: beta_eff.to_f32(),
4144 _element: PhantomData,
4145 })
4146 }
4147}
4148
4149#[derive(Debug)]
4173pub struct PreparedGroupedGemm<'a, T: CutlassElement> {
4174 plan: &'a GroupedGemmPlan<T>,
4175 pinned: PinnedBuffer<u8>,
4179 host_problem_sizes: Vec<i32>,
4180 layout: MetadataLayout,
4181 threadblock_count: i32,
4182 alpha: f32,
4183 beta: f32,
4184 _element: PhantomData<T>,
4185}
4186
4187impl<'a, T: CutlassElement> PreparedGroupedGemm<'a, T> {
4188 pub fn workspace_size(&self) -> usize {
4193 self.layout.total_workspace_bytes
4194 }
4195
4196 pub fn sku(&self) -> GemmSku {
4199 self.plan.sku
4200 }
4201
4202 pub fn group_count(&self) -> usize {
4204 self.host_problem_sizes.len() / 3
4205 }
4206
4207 pub fn run(&self, stream: &Stream, workspace: Workspace<'_>) -> Result<()> {
4213 let needed = self.workspace_size();
4214 let workspace_slice = match workspace {
4215 Workspace::None => {
4216 return Err(Error::WorkspaceTooSmall { needed, got: 0 });
4217 }
4218 Workspace::Borrowed(slice) => {
4219 if slice.len() < needed {
4220 return Err(Error::WorkspaceTooSmall {
4221 needed,
4222 got: slice.len(),
4223 });
4224 }
4225 slice
4226 }
4227 };
4228
4229 let workspace_base = workspace_slice.as_raw().0;
4230
4231 {
4237 let mut workspace_for_meta = workspace_slice;
4238 let metadata_dst = workspace_for_meta.slice_mut(0..self.layout.metadata_end);
4239 metadata_dst.copy_from_host_async(&self.pinned, stream)?;
4240 }
4241
4242 let off = |o: usize| (workspace_base + o as u64) as *const c_void;
4245 let off_mut = |o: usize| (workspace_base + o as u64) as *mut c_void;
4246 let d_problem_sizes = off(self.layout.problem_sizes_offset);
4247 let d_ptr_a = off(self.layout.ptr_a_offset);
4248 let d_ptr_b = off(self.layout.ptr_b_offset);
4249 let d_ptr_c = off(self.layout.ptr_c_offset);
4250 let d_ptr_d = off_mut(self.layout.ptr_d_offset);
4251 let d_lda = off(self.layout.lda_offset);
4252 let d_ldb = off(self.layout.ldb_offset);
4253 let d_ldc = off(self.layout.ldc_offset);
4254 let d_ldd = off(self.layout.ldd_offset);
4255 let scratch_ptr = off_mut(self.layout.scratch_offset);
4256 let scratch_bytes = self.layout.total_workspace_bytes - self.layout.scratch_offset;
4257
4258 let h_problem_sizes = self.host_problem_sizes.as_ptr() as *const c_void;
4259 let stream_raw = stream.as_raw();
4260 let group_count = self.group_count() as i32;
4261
4262 let status = match self.plan.sku.arch {
4263 #[cfg(feature = "sm80")]
4264 ArchSku::Sm80 => unsafe {
4265 dispatch::grouped_gemm_rcr_sm80_run(
4266 T::KIND,
4267 group_count,
4268 self.threadblock_count,
4269 d_problem_sizes,
4270 d_ptr_a,
4271 d_ptr_b,
4272 d_ptr_c,
4273 d_ptr_d,
4274 d_lda,
4275 d_ldb,
4276 d_ldc,
4277 d_ldd,
4278 h_problem_sizes,
4279 self.alpha,
4280 self.beta,
4281 scratch_ptr,
4282 scratch_bytes,
4283 stream_raw,
4284 )
4285 },
4286 #[cfg(not(feature = "sm80"))]
4287 ArchSku::Sm80 => {
4288 return Err(Error::Unsupported(
4289 "sm80 selected but the `sm80` feature isn't enabled",
4290 ));
4291 }
4292 ArchSku::Sm90a => {
4293 return Err(Error::Unsupported(
4294 "sm90a grouped kernels not yet shipped",
4295 ));
4296 }
4297 ArchSku::Sm89 => {
4298 return Err(Error::Unsupported(
4299 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4300 ));
4301 }
4302 };
4303
4304 status_to_result(status)
4305 }
4306}
4307
4308#[derive(Debug)]
4340pub struct IntGemmPlan<T: IntElement, BT: BiasElement = f32> {
4341 desc: IntGemmDescriptor,
4342 sku: GemmSku,
4343 _element: PhantomData<T>,
4344 _bias_element: PhantomData<BT>,
4345}
4346
4347impl<T: IntElement, BT: BiasElement> IntGemmPlan<T, BT> {
4348 pub fn select(
4355 stream: &Stream,
4356 desc: &IntGemmDescriptor,
4357 pref: PlanPreference,
4358 ) -> Result<Self> {
4359 check_int_descriptor(desc)?;
4360 let arch = pick_int_arch(stream, pref)?;
4361 if !matches!(desc.layout, LayoutSku::Rcr) {
4365 return Err(Error::Unsupported(
4366 "int8 GEMM kernels are RCR-only in this release \
4367 (CUTLASS 4.2.0 lacks 8-bit `TensorOpMultiplicandCongruous` \
4368 warp iterators for RRR / row-major-B layout)",
4369 ));
4370 }
4371 let bias_element = if desc.epilogue.requires_bias() {
4375 Some(BT::KIND)
4376 } else {
4377 None
4378 };
4379 let sku = GemmSku {
4380 arch,
4381 layout: desc.layout,
4382 epilogue: desc.epilogue,
4383 element: T::KIND,
4384 bias_element,
4385 };
4386 Ok(Self {
4387 desc: *desc,
4388 sku,
4389 _element: PhantomData,
4390 _bias_element: PhantomData,
4391 })
4392 }
4393
4394 pub fn can_implement(&self, args: &IntGemmArgs<'_, T, BT>) -> Result<()> {
4398 check_int_args(&self.desc, args)?;
4399
4400 let a_ptr = args.a.data.as_raw().0 as *const c_void;
4401 let b_ptr = args.b.data.as_raw().0 as *const c_void;
4402 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
4403 let (c_ptr, ldc) = match &args.c {
4404 Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
4405 None => (core::ptr::null(), 0i64),
4406 };
4407 let bias_ptr = args
4408 .bias
4409 .as_ref()
4410 .map(|b| b.data.as_raw().0 as *const c_void)
4411 .unwrap_or(core::ptr::null());
4412
4413 let bias_family = self.sku.epilogue.requires_bias();
4414 let status = match (self.sku.arch, bias_family) {
4415 #[cfg(feature = "sm80")]
4416 (ArchSku::Sm80, false) => unsafe {
4417 dispatch::int_gemm_rcr_sm80_can_implement(
4418 self.sku.layout,
4419 T::KIND,
4420 self.desc.m, self.desc.n, self.desc.k,
4421 a_ptr, args.a.ld,
4422 b_ptr, args.b.ld,
4423 c_ptr, ldc,
4424 d_ptr, args.d.ld,
4425 )
4426 },
4427 #[cfg(feature = "sm80")]
4428 (ArchSku::Sm80, true) => unsafe {
4429 dispatch::int_gemm_bias_rcr_sm80_can_implement(
4430 self.sku.layout,
4431 T::KIND,
4432 self.sku.epilogue,
4433 BT::KIND,
4434 self.desc.m, self.desc.n, self.desc.k,
4435 a_ptr, args.a.ld,
4436 b_ptr, args.b.ld,
4437 c_ptr, ldc,
4438 d_ptr, args.d.ld,
4439 bias_ptr,
4440 )
4441 },
4442 #[cfg(not(feature = "sm80"))]
4443 (ArchSku::Sm80, _) => {
4444 return Err(Error::Unsupported(
4445 "sm80 selected but the `sm80` feature isn't enabled",
4446 ));
4447 }
4448 (ArchSku::Sm90a, _) => {
4449 return Err(Error::Unsupported(
4450 "sm90a int8 kernels not yet shipped",
4451 ));
4452 }
4453 (ArchSku::Sm89, _) => {
4454 return Err(Error::Unsupported(
4455 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4456 ));
4457 }
4458 };
4459
4460 status_to_result(status)
4461 }
4462
4463 pub fn workspace_size(&self) -> usize {
4465 let bias_family = self.sku.epilogue.requires_bias();
4466 match (self.sku.arch, bias_family) {
4467 #[cfg(feature = "sm80")]
4468 (ArchSku::Sm80, false) => dispatch::int_gemm_rcr_sm80_workspace_size(
4469 self.sku.layout,
4470 T::KIND,
4471 self.desc.m, self.desc.n, self.desc.k,
4472 ),
4473 #[cfg(feature = "sm80")]
4474 (ArchSku::Sm80, true) => dispatch::int_gemm_bias_rcr_sm80_workspace_size(
4475 self.sku.layout,
4476 T::KIND,
4477 self.sku.epilogue,
4478 BT::KIND,
4479 self.desc.m, self.desc.n, self.desc.k,
4480 ),
4481 #[cfg(not(feature = "sm80"))]
4482 (ArchSku::Sm80, _) => 0,
4483 (ArchSku::Sm90a, _) => 0,
4484 (ArchSku::Sm89, _) => 0,
4485 }
4486 }
4487
4488 pub fn sku(&self) -> GemmSku {
4490 self.sku
4491 }
4492
4493 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
4495 self.sku.precision_guarantee()
4496 }
4497
4498 pub fn run(
4500 &self,
4501 stream: &Stream,
4502 workspace: Workspace<'_>,
4503 args: IntGemmArgs<'_, T, BT>,
4504 ) -> Result<()> {
4505 self.can_implement(&args)?;
4506
4507 let needed = self.workspace_size();
4508 let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
4509 Workspace::None => {
4510 if needed != 0 {
4511 return Err(Error::WorkspaceTooSmall { needed, got: 0 });
4512 }
4513 (core::ptr::null_mut(), 0)
4514 }
4515 Workspace::Borrowed(slice) => {
4516 if slice.len() < needed {
4517 return Err(Error::WorkspaceTooSmall {
4518 needed,
4519 got: slice.len(),
4520 });
4521 }
4522 (slice.as_raw().0 as *mut c_void, slice.len())
4523 }
4524 };
4525
4526 let a_ptr = args.a.data.as_raw().0 as *const c_void;
4527 let b_ptr = args.b.data.as_raw().0 as *const c_void;
4528 let d_ptr = args.d.data.as_raw().0 as *mut c_void;
4529 let (c_ptr, ldc) = match &args.c {
4530 Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
4531 None => (core::ptr::null(), 0i64),
4532 };
4533 let bias_ptr = args
4534 .bias
4535 .as_ref()
4536 .map(|b| b.data.as_raw().0 as *const c_void)
4537 .unwrap_or(core::ptr::null());
4538 let beta_eff: f32 = if args.c.is_some() { args.beta } else { 0.0 };
4542 let stream_raw = stream.as_raw();
4543
4544 let bias_family = self.sku.epilogue.requires_bias();
4545 let status = match (self.sku.arch, bias_family) {
4546 #[cfg(feature = "sm80")]
4547 (ArchSku::Sm80, false) => unsafe {
4548 dispatch::int_gemm_rcr_sm80_run(
4549 self.sku.layout,
4550 T::KIND,
4551 self.desc.m, self.desc.n, self.desc.k,
4552 a_ptr, args.a.ld,
4553 b_ptr, args.b.ld,
4554 c_ptr, ldc,
4555 d_ptr, args.d.ld,
4556 args.alpha,
4557 beta_eff,
4558 ws_ptr, ws_bytes, stream_raw,
4559 )
4560 },
4561 #[cfg(feature = "sm80")]
4562 (ArchSku::Sm80, true) => unsafe {
4563 dispatch::int_gemm_bias_rcr_sm80_run(
4564 self.sku.layout,
4565 T::KIND,
4566 self.sku.epilogue,
4567 BT::KIND,
4568 self.desc.m, self.desc.n, self.desc.k,
4569 a_ptr, args.a.ld,
4570 b_ptr, args.b.ld,
4571 c_ptr, ldc,
4572 d_ptr, args.d.ld,
4573 bias_ptr,
4574 args.alpha,
4575 beta_eff,
4576 ws_ptr, ws_bytes, stream_raw,
4577 )
4578 },
4579 #[cfg(not(feature = "sm80"))]
4580 (ArchSku::Sm80, _) => {
4581 return Err(Error::Unsupported(
4582 "sm80 selected but the `sm80` feature isn't enabled",
4583 ));
4584 }
4585 (ArchSku::Sm90a, _) => {
4586 return Err(Error::Unsupported("sm90a int8 kernels not yet shipped"));
4587 }
4588 (ArchSku::Sm89, _) => {
4589 return Err(Error::Unsupported(
4590 "Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
4591 ));
4592 }
4593 };
4594
4595 status_to_result(status)
4596 }
4597}
4598
4599fn check_int_descriptor(desc: &IntGemmDescriptor) -> Result<()> {
4603 if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
4604 return Err(Error::InvalidProblem("M, N, K must all be positive"));
4605 }
4606 Ok(())
4607}
4608
4609fn check_int_args<T: IntElement, BT: BiasElement>(
4610 desc: &IntGemmDescriptor,
4611 args: &IntGemmArgs<'_, T, BT>,
4612) -> Result<()> {
4613 match (desc.epilogue.requires_bias(), &args.bias) {
4615 (false, Some(_)) => {
4616 return Err(Error::InvalidProblem(
4617 "args.bias must be None when descriptor.epilogue is Identity",
4618 ));
4619 }
4620 (true, None) => {
4621 return Err(Error::InvalidProblem(
4622 "args.bias is required when descriptor.epilogue is in the Bias family \
4623 (Bias / BiasRelu / BiasGelu / BiasSilu)",
4624 ));
4625 }
4626 (false, None) | (true, Some(_)) => {}
4627 }
4628 if let Some(bias) = &args.bias {
4629 if bias.len != desc.n {
4630 return Err(Error::InvalidProblem("bias vector length must equal N"));
4631 }
4632 if bias.stride != 1 {
4633 return Err(Error::Unsupported(
4634 "bias vector must be contiguous (stride 1) — strided bias not supported",
4635 ));
4636 }
4637 if bias.data.len() < desc.n as usize {
4638 return Err(Error::BufferTooSmall {
4639 needed: desc.n as usize,
4640 got: bias.data.len(),
4641 });
4642 }
4643 }
4644 if args.a.rows != desc.m || args.a.cols != desc.k {
4645 return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
4646 }
4647 if args.b.rows != desc.k || args.b.cols != desc.n {
4648 return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
4649 }
4650 if args.d.rows != desc.m || args.d.cols != desc.n {
4651 return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
4652 }
4653 if let Some(c) = &args.c {
4654 if c.rows != desc.m || c.cols != desc.n {
4655 return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
4656 }
4657 }
4658 if args.a.ld < desc.k as i64 {
4659 return Err(Error::InvalidProblem("A leading dimension must be >= K"));
4660 }
4661 let b_min_ld = match desc.layout {
4662 LayoutSku::Rcr => desc.k as i64,
4663 LayoutSku::Rrr => desc.n as i64,
4664 };
4665 if args.b.ld < b_min_ld {
4666 return Err(Error::InvalidProblem(match desc.layout {
4667 LayoutSku::Rcr => "B leading dimension must be >= K (column-major Rcr layout)",
4668 LayoutSku::Rrr => "B leading dimension must be >= N (row-major Rrr layout)",
4669 }));
4670 }
4671 if args.d.ld < desc.n as i64 {
4672 return Err(Error::InvalidProblem("D leading dimension must be >= N"));
4673 }
4674 if let Some(c) = &args.c {
4675 if c.ld < desc.n as i64 {
4676 return Err(Error::InvalidProblem("C leading dimension must be >= N"));
4677 }
4678 }
4679 let need_a = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
4680 .ok_or(Error::InvalidProblem("A storage size overflow"))?;
4681 if args.a.data.len() < need_a {
4682 return Err(Error::BufferTooSmall {
4683 needed: need_a,
4684 got: args.a.data.len(),
4685 });
4686 }
4687 let need_b = match desc.layout {
4688 LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
4689 LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
4690 }
4691 .ok_or(Error::InvalidProblem("B storage size overflow"))?;
4692 if args.b.data.len() < need_b {
4693 return Err(Error::BufferTooSmall {
4694 needed: need_b,
4695 got: args.b.data.len(),
4696 });
4697 }
4698 let need_d = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
4699 .ok_or(Error::InvalidProblem("D storage size overflow"))?;
4700 if args.d.data.len() < need_d {
4701 return Err(Error::BufferTooSmall {
4702 needed: need_d,
4703 got: args.d.data.len(),
4704 });
4705 }
4706 if let Some(c) = &args.c {
4707 let need_c = min_elements_row_major(c.rows, c.cols, c.ld)
4708 .ok_or(Error::InvalidProblem("C storage size overflow"))?;
4709 if c.data.len() < need_c {
4710 return Err(Error::BufferTooSmall {
4711 needed: need_c,
4712 got: c.data.len(),
4713 });
4714 }
4715 }
4716 Ok(())
4717}
4718
4719fn pick_int_arch(stream: &Stream, pref: PlanPreference) -> Result<ArchSku> {
4720 let (major, _minor) = stream.context().device().compute_capability()?;
4724 if pref.allow_sm90a && cfg!(feature = "sm90a") && major >= 9 {
4725 }
4727 if cfg!(feature = "sm80") {
4728 if major >= 8 {
4729 return Ok(ArchSku::Sm80);
4730 }
4731 return Err(Error::Unsupported(
4732 "device compute capability < 8.0; sm_80 int8 kernels won't run here",
4733 ));
4734 }
4735 Err(Error::Unsupported(
4736 "no arch features enabled — build with --features sm80",
4737 ))
4738}
4739
4740#[cfg(test)]
4741mod buffer_size_tests {
4742 use super::{min_elements_rcr_a, min_elements_rcr_b, min_elements_rcr_cd};
4749
4750 #[test]
4751 fn rcr_a_tight_layout() {
4752 assert_eq!(min_elements_rcr_a(4, 8, 8), Some(32));
4755 }
4756
4757 #[test]
4758 fn rcr_a_padded_layout_accepts_smaller_count() {
4759 assert_eq!(min_elements_rcr_a(4, 8, 16), Some(56));
4763 }
4764
4765 #[test]
4766 fn rcr_b_tight_layout() {
4767 assert_eq!(min_elements_rcr_b(8, 4, 8), Some(32));
4770 }
4771
4772 #[test]
4773 fn rcr_b_padded_layout_accepts_smaller_count() {
4774 assert_eq!(min_elements_rcr_b(8, 4, 16), Some(56));
4777 }
4778
4779 #[test]
4780 fn rcr_cd_tight_layout() {
4781 assert_eq!(min_elements_rcr_cd(4, 8, 8), Some(32));
4784 }
4785
4786 #[test]
4787 fn rcr_cd_padded_layout_accepts_smaller_count() {
4788 assert_eq!(min_elements_rcr_cd(4, 8, 16), Some(56));
4791 }
4792
4793 #[test]
4794 fn single_row_matrix_does_not_underflow() {
4795 assert_eq!(min_elements_rcr_a(1, 8, 8), Some(8));
4797 assert_eq!(min_elements_rcr_a(1, 8, 256), Some(8));
4798 }
4799
4800 #[test]
4801 fn overflow_returns_none() {
4802 assert_eq!(min_elements_rcr_a(i32::MAX, 1, i64::MAX), None);
4804 }
4805}