Skip to main content

baracuda_cutlass_kernels_sys/
lib.rs

1//! # baracuda-cutlass-kernels-sys
2//!
3//! Raw `extern "C"` entry points for compiled CUTLASS template
4//! instantiations. **You almost certainly want [`baracuda-cutlass`]
5//! instead** — that crate wraps these unsafe calls with typed plans,
6//! lifetime-checked device buffers, and a proper Rust API.
7//!
8//! Functions in this crate take raw `void*` pointers, integer dimensions,
9//! and a `cudaStream_t` cast as `*mut c_void`. They are unsafe because:
10//!
11//! - They dereference the pointer arguments without bounds-checking.
12//! - They assume the pointers are valid device addresses.
13//! - They assume the workspace pointer (when non-null) points to at least
14//!   `workspace_bytes` of writable device memory.
15//! - They assume the stream is a valid CUDA stream owned by the calling
16//!   thread's current context.
17//!
18//! ## Status codes
19//!
20//! All `*_run` and `*_can_implement` functions return an [`i32`] status:
21//! - `0`: success.
22//! - `1`: misaligned operand.
23//! - `2`: invalid problem (e.g. M, N, or K is non-positive).
24//! - `3`: not supported (this kernel doesn't implement the requested shape).
25//! - `4`: workspace too small or null when required.
26//! - `5`: internal CUTLASS error (typically a kernel launch failure).
27//!
28//! [`baracuda-cutlass`]: https://docs.rs/baracuda-cutlass
29
30#![no_std]
31
32use core::ffi::c_void;
33
34// ============================================================================
35// GEMM — RCR layout, sm_80 instantiation
36// ============================================================================
37//
38// Layout convention `RCR`:
39//   A: row-major    [M, K], leading dimension `lda`
40//   B: column-major [K, N], leading dimension `ldb`
41//   C: row-major    [M, N], leading dimension `ldc` (optional; pass null
42//                                                    + beta = 0 to skip)
43//   D: row-major    [M, N], leading dimension `ldd` (always written)
44//
45// Accumulator and alpha/beta scalars are FP32. Identity epilogue only
46// (`D = alpha * AB + beta * C`). The Bias epilogue lands in a follow-up
47// once a `LinearCombinationBias` template instantiation is added; until
48// then there is no `bias` argument and the safe layer's `EpilogueKind`
49// enum has no `Bias` variant.
50
51#[cfg(any(feature = "sm80", feature = "sm90a"))]
52unsafe extern "C" {
53    /// `f16` GEMM, RCR layout, sm_80.
54    ///
55    /// # Safety
56    /// All pointer args must be device-resident (or null where allowed) and
57    /// remain valid for the duration of the launch. `stream` must be a live
58    /// CUDA stream in the current context.
59    pub fn baracuda_cutlass_gemm_f16_rcr_sm80_run(
60        m: i32,
61        n: i32,
62        k: i32,
63        a: *const c_void,
64        lda: i64,
65        b: *const c_void,
66        ldb: i64,
67        c: *const c_void,
68        ldc: i64,
69        d: *mut c_void,
70        ldd: i64,
71        alpha: f32,
72        beta: f32,
73        workspace: *mut c_void,
74        workspace_bytes: usize,
75        stream: *mut c_void,
76    ) -> i32;
77
78    /// Workspace size in bytes for `f16` RCR sm_80 GEMM at the given problem size.
79    pub fn baracuda_cutlass_gemm_f16_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
80
81    /// Pre-launch implementability check for `f16` RCR sm_80.
82    ///
83    /// Returns `0` when the kernel can launch with the given shape, leading
84    /// dimensions, and pointer alignments; non-zero with the standard
85    /// status-code mapping otherwise. Does not launch a kernel and does
86    /// not require a stream.
87    ///
88    /// # Safety
89    /// Same pointer-validity contract as [`baracuda_cutlass_gemm_f16_rcr_sm80_run`],
90    /// but no device dereferences occur — only host-side checks of pointer
91    /// alignment and the leading-dimension fields.
92    pub fn baracuda_cutlass_gemm_f16_rcr_sm80_can_implement(
93        m: i32,
94        n: i32,
95        k: i32,
96        a: *const c_void,
97        lda: i64,
98        b: *const c_void,
99        ldb: i64,
100        c: *const c_void,
101        ldc: i64,
102        d: *mut c_void,
103        ldd: i64,
104    ) -> i32;
105
106    /// `bf16` GEMM, RCR layout, sm_80.
107    ///
108    /// # Safety
109    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
110    pub fn baracuda_cutlass_gemm_bf16_rcr_sm80_run(
111        m: i32,
112        n: i32,
113        k: i32,
114        a: *const c_void,
115        lda: i64,
116        b: *const c_void,
117        ldb: i64,
118        c: *const c_void,
119        ldc: i64,
120        d: *mut c_void,
121        ldd: i64,
122        alpha: f32,
123        beta: f32,
124        workspace: *mut c_void,
125        workspace_bytes: usize,
126        stream: *mut c_void,
127    ) -> i32;
128
129    /// Workspace size in bytes for `bf16` RCR sm_80 GEMM at the given problem size.
130    pub fn baracuda_cutlass_gemm_bf16_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
131
132    /// Pre-launch implementability check for `bf16` RCR sm_80.
133    ///
134    /// # Safety
135    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
136    pub fn baracuda_cutlass_gemm_bf16_rcr_sm80_can_implement(
137        m: i32,
138        n: i32,
139        k: i32,
140        a: *const c_void,
141        lda: i64,
142        b: *const c_void,
143        ldb: i64,
144        c: *const c_void,
145        ldc: i64,
146        d: *mut c_void,
147        ldd: i64,
148    ) -> i32;
149}
150
151// ============================================================================
152// GEMM — bias-fused (with optional activation), RCR layout, sm_80
153// ============================================================================
154//
155// Computes `D = activation(alpha*AB + beta*C + bias_broadcast(N))` in a
156// single kernel pass via `cutlass::gemm::device::GemmUniversalWithBroadcast`
157// + `LinearCombinationBiasElementwise`. The bias vector has length `N`
158// (one element per output column) and is broadcast across rows. Layout
159// matches the standard RCR variant (A row-major, B column-major, C/D
160// row-major).
161//
162// Symbol naming: `..._gemm_<flavor>_<dtype>_rcr_sm80_<op>` where
163//   flavor ∈ {bias, bias_relu, bias_gelu, bias_silu}
164//   dtype  ∈ {f16, bf16}
165//   op     ∈ {run, workspace_size, can_implement}
166// = 24 entry points. The `bias` flavor uses Identity activation; the
167// others fuse the named CUTLASS activation functor into the same
168// epilogue pass (no extra memory traffic vs plain bias).
169
170#[cfg(any(feature = "sm80", feature = "sm90a"))]
171unsafe extern "C" {
172    /// `f16` bias-fused GEMM, RCR layout, sm_80.
173    ///
174    /// # Safety
175    /// All pointer args must be device-resident. `bias` must be a
176    /// device-resident length-`n` vector. See
177    /// [`baracuda_cutlass_gemm_f16_rcr_sm80_run`] for the rest.
178    #[allow(clippy::too_many_arguments)]
179    pub fn baracuda_cutlass_gemm_bias_f16_rcr_sm80_run(
180        m: i32,
181        n: i32,
182        k: i32,
183        a: *const c_void,
184        lda: i64,
185        b: *const c_void,
186        ldb: i64,
187        c: *const c_void,
188        ldc: i64,
189        d: *mut c_void,
190        ldd: i64,
191        bias: *const c_void,
192        alpha: f32,
193        beta: f32,
194        workspace: *mut c_void,
195        workspace_bytes: usize,
196        stream: *mut c_void,
197    ) -> i32;
198
199    /// Workspace bytes needed by the `f16` bias-fused RCR sm_80 GEMM.
200    pub fn baracuda_cutlass_gemm_bias_f16_rcr_sm80_workspace_size(
201        m: i32,
202        n: i32,
203        k: i32,
204    ) -> usize;
205
206    /// Pre-launch implementability check for `f16` bias RCR sm_80.
207    ///
208    /// # Safety
209    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
210    pub fn baracuda_cutlass_gemm_bias_f16_rcr_sm80_can_implement(
211        m: i32,
212        n: i32,
213        k: i32,
214        a: *const c_void,
215        lda: i64,
216        b: *const c_void,
217        ldb: i64,
218        c: *const c_void,
219        ldc: i64,
220        d: *mut c_void,
221        ldd: i64,
222        bias: *const c_void,
223    ) -> i32;
224
225    /// `bf16` bias-fused GEMM, RCR layout, sm_80.
226    ///
227    /// # Safety
228    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
229    #[allow(clippy::too_many_arguments)]
230    pub fn baracuda_cutlass_gemm_bias_bf16_rcr_sm80_run(
231        m: i32,
232        n: i32,
233        k: i32,
234        a: *const c_void,
235        lda: i64,
236        b: *const c_void,
237        ldb: i64,
238        c: *const c_void,
239        ldc: i64,
240        d: *mut c_void,
241        ldd: i64,
242        bias: *const c_void,
243        alpha: f32,
244        beta: f32,
245        workspace: *mut c_void,
246        workspace_bytes: usize,
247        stream: *mut c_void,
248    ) -> i32;
249
250    /// Workspace bytes needed by the `bf16` bias-fused RCR sm_80 GEMM.
251    pub fn baracuda_cutlass_gemm_bias_bf16_rcr_sm80_workspace_size(
252        m: i32,
253        n: i32,
254        k: i32,
255    ) -> usize;
256
257    /// Pre-launch implementability check for `bf16` bias RCR sm_80.
258    ///
259    /// # Safety
260    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
261    pub fn baracuda_cutlass_gemm_bias_bf16_rcr_sm80_can_implement(
262        m: i32,
263        n: i32,
264        k: i32,
265        a: *const c_void,
266        lda: i64,
267        b: *const c_void,
268        ldb: i64,
269        c: *const c_void,
270        ldc: i64,
271        d: *mut c_void,
272        ldd: i64,
273        bias: *const c_void,
274    ) -> i32;
275
276    // ---- bias + ReLU activation ---------------------------------------
277
278    /// `f16` bias + ReLU activation GEMM, RCR layout, sm_80.
279    /// Computes `D = max(alpha*AB + beta*C + bias_broadcast(N), 0)`.
280    ///
281    /// # Safety
282    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
283    #[allow(clippy::too_many_arguments)]
284    pub fn baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_run(
285        m: i32, n: i32, k: i32,
286        a: *const c_void, lda: i64,
287        b: *const c_void, ldb: i64,
288        c: *const c_void, ldc: i64,
289        d: *mut c_void, ldd: i64,
290        bias: *const c_void,
291        alpha: f32, beta: f32,
292        workspace: *mut c_void, workspace_bytes: usize,
293        stream: *mut c_void,
294    ) -> i32;
295
296    /// Workspace bytes for `f16` bias+ReLU RCR sm_80 GEMM.
297    pub fn baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_workspace_size(
298        m: i32, n: i32, k: i32,
299    ) -> usize;
300
301    /// Pre-launch check for `f16` bias+ReLU RCR sm_80.
302    /// # Safety
303    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
304    pub fn baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_can_implement(
305        m: i32, n: i32, k: i32,
306        a: *const c_void, lda: i64,
307        b: *const c_void, ldb: i64,
308        c: *const c_void, ldc: i64,
309        d: *mut c_void, ldd: i64,
310        bias: *const c_void,
311    ) -> i32;
312
313    /// `bf16` bias + ReLU activation GEMM, RCR layout, sm_80.
314    /// # Safety
315    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
316    #[allow(clippy::too_many_arguments)]
317    pub fn baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_run(
318        m: i32, n: i32, k: i32,
319        a: *const c_void, lda: i64,
320        b: *const c_void, ldb: i64,
321        c: *const c_void, ldc: i64,
322        d: *mut c_void, ldd: i64,
323        bias: *const c_void,
324        alpha: f32, beta: f32,
325        workspace: *mut c_void, workspace_bytes: usize,
326        stream: *mut c_void,
327    ) -> i32;
328
329    /// Workspace bytes for `bf16` bias+ReLU RCR sm_80 GEMM.
330    pub fn baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_workspace_size(
331        m: i32, n: i32, k: i32,
332    ) -> usize;
333
334    /// Pre-launch check for `bf16` bias+ReLU RCR sm_80.
335    /// # Safety
336    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
337    pub fn baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_can_implement(
338        m: i32, n: i32, k: i32,
339        a: *const c_void, lda: i64,
340        b: *const c_void, ldb: i64,
341        c: *const c_void, ldc: i64,
342        d: *mut c_void, ldd: i64,
343        bias: *const c_void,
344    ) -> i32;
345
346    // ---- bias + GELU activation (exact, erf-based) ---------------------
347
348    /// `f16` bias + GELU activation GEMM, RCR layout, sm_80.
349    /// Computes `D = gelu(alpha*AB + beta*C + bias_broadcast(N))` using
350    /// the exact (erf-based) GELU formulation, matching PyTorch's
351    /// default `nn.GELU()`.
352    ///
353    /// # Safety
354    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
355    #[allow(clippy::too_many_arguments)]
356    pub fn baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_run(
357        m: i32, n: i32, k: i32,
358        a: *const c_void, lda: i64,
359        b: *const c_void, ldb: i64,
360        c: *const c_void, ldc: i64,
361        d: *mut c_void, ldd: i64,
362        bias: *const c_void,
363        alpha: f32, beta: f32,
364        workspace: *mut c_void, workspace_bytes: usize,
365        stream: *mut c_void,
366    ) -> i32;
367
368    /// Workspace bytes for `f16` bias+GELU RCR sm_80 GEMM.
369    pub fn baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_workspace_size(
370        m: i32, n: i32, k: i32,
371    ) -> usize;
372
373    /// Pre-launch check for `f16` bias+GELU RCR sm_80.
374    /// # Safety
375    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
376    pub fn baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_can_implement(
377        m: i32, n: i32, k: i32,
378        a: *const c_void, lda: i64,
379        b: *const c_void, ldb: i64,
380        c: *const c_void, ldc: i64,
381        d: *mut c_void, ldd: i64,
382        bias: *const c_void,
383    ) -> i32;
384
385    /// `bf16` bias + GELU activation GEMM, RCR layout, sm_80.
386    /// # Safety
387    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
388    #[allow(clippy::too_many_arguments)]
389    pub fn baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_run(
390        m: i32, n: i32, k: i32,
391        a: *const c_void, lda: i64,
392        b: *const c_void, ldb: i64,
393        c: *const c_void, ldc: i64,
394        d: *mut c_void, ldd: i64,
395        bias: *const c_void,
396        alpha: f32, beta: f32,
397        workspace: *mut c_void, workspace_bytes: usize,
398        stream: *mut c_void,
399    ) -> i32;
400
401    /// Workspace bytes for `bf16` bias+GELU RCR sm_80 GEMM.
402    pub fn baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_workspace_size(
403        m: i32, n: i32, k: i32,
404    ) -> usize;
405
406    /// Pre-launch check for `bf16` bias+GELU RCR sm_80.
407    /// # Safety
408    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
409    pub fn baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_can_implement(
410        m: i32, n: i32, k: i32,
411        a: *const c_void, lda: i64,
412        b: *const c_void, ldb: i64,
413        c: *const c_void, ldc: i64,
414        d: *mut c_void, ldd: i64,
415        bias: *const c_void,
416    ) -> i32;
417
418    // ---- bias + SiLU activation (x * sigmoid(x)) -----------------------
419
420    /// `f16` bias + SiLU activation GEMM, RCR layout, sm_80.
421    /// Computes `D = silu(alpha*AB + beta*C + bias_broadcast(N))` where
422    /// `silu(x) = x * sigmoid(x)`. Also known as Swish.
423    ///
424    /// # Safety
425    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
426    #[allow(clippy::too_many_arguments)]
427    pub fn baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_run(
428        m: i32, n: i32, k: i32,
429        a: *const c_void, lda: i64,
430        b: *const c_void, ldb: i64,
431        c: *const c_void, ldc: i64,
432        d: *mut c_void, ldd: i64,
433        bias: *const c_void,
434        alpha: f32, beta: f32,
435        workspace: *mut c_void, workspace_bytes: usize,
436        stream: *mut c_void,
437    ) -> i32;
438
439    /// Workspace bytes for `f16` bias+SiLU RCR sm_80 GEMM.
440    pub fn baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_workspace_size(
441        m: i32, n: i32, k: i32,
442    ) -> usize;
443
444    /// Pre-launch check for `f16` bias+SiLU RCR sm_80.
445    /// # Safety
446    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
447    pub fn baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_can_implement(
448        m: i32, n: i32, k: i32,
449        a: *const c_void, lda: i64,
450        b: *const c_void, ldb: i64,
451        c: *const c_void, ldc: i64,
452        d: *mut c_void, ldd: i64,
453        bias: *const c_void,
454    ) -> i32;
455
456    /// `bf16` bias + SiLU activation GEMM, RCR layout, sm_80.
457    /// # Safety
458    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
459    #[allow(clippy::too_many_arguments)]
460    pub fn baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_run(
461        m: i32, n: i32, k: i32,
462        a: *const c_void, lda: i64,
463        b: *const c_void, ldb: i64,
464        c: *const c_void, ldc: i64,
465        d: *mut c_void, ldd: i64,
466        bias: *const c_void,
467        alpha: f32, beta: f32,
468        workspace: *mut c_void, workspace_bytes: usize,
469        stream: *mut c_void,
470    ) -> i32;
471
472    /// Workspace bytes for `bf16` bias+SiLU RCR sm_80 GEMM.
473    pub fn baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_workspace_size(
474        m: i32, n: i32, k: i32,
475    ) -> usize;
476
477    /// Pre-launch check for `bf16` bias+SiLU RCR sm_80.
478    /// # Safety
479    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
480    pub fn baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_can_implement(
481        m: i32, n: i32, k: i32,
482        a: *const c_void, lda: i64,
483        b: *const c_void, ldb: i64,
484        c: *const c_void, ldc: i64,
485        d: *mut c_void, ldd: i64,
486        bias: *const c_void,
487    ) -> i32;
488}
489
490// ============================================================================
491// GEMM — RRR layout, sm_80 instantiation
492// ============================================================================
493//
494// Layout convention `RRR`:
495//   A: row-major [M, K], leading dimension `lda`
496//   B: row-major [K, N], leading dimension `ldb`
497//   C: row-major [M, N], leading dimension `ldc` (optional; null + beta = 0)
498//   D: row-major [M, N], leading dimension `ldd`
499//
500// Same accumulator (FP32), epilogue (Identity), and status-code mapping as
501// the RCR variant. This is the natural shape for activations stored
502// row-major and weights stored row-major (no transpose copy).
503
504#[cfg(any(feature = "sm80", feature = "sm90a"))]
505unsafe extern "C" {
506    /// `f16` GEMM, RRR layout, sm_80.
507    ///
508    /// # Safety
509    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
510    pub fn baracuda_cutlass_gemm_f16_rrr_sm80_run(
511        m: i32,
512        n: i32,
513        k: i32,
514        a: *const c_void,
515        lda: i64,
516        b: *const c_void,
517        ldb: i64,
518        c: *const c_void,
519        ldc: i64,
520        d: *mut c_void,
521        ldd: i64,
522        alpha: f32,
523        beta: f32,
524        workspace: *mut c_void,
525        workspace_bytes: usize,
526        stream: *mut c_void,
527    ) -> i32;
528
529    /// Workspace size in bytes for `f16` RRR sm_80 GEMM.
530    pub fn baracuda_cutlass_gemm_f16_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
531
532    /// Pre-launch implementability check for `f16` RRR sm_80.
533    ///
534    /// # Safety
535    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
536    pub fn baracuda_cutlass_gemm_f16_rrr_sm80_can_implement(
537        m: i32,
538        n: i32,
539        k: i32,
540        a: *const c_void,
541        lda: i64,
542        b: *const c_void,
543        ldb: i64,
544        c: *const c_void,
545        ldc: i64,
546        d: *mut c_void,
547        ldd: i64,
548    ) -> i32;
549
550    /// `bf16` GEMM, RRR layout, sm_80.
551    ///
552    /// # Safety
553    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
554    pub fn baracuda_cutlass_gemm_bf16_rrr_sm80_run(
555        m: i32,
556        n: i32,
557        k: i32,
558        a: *const c_void,
559        lda: i64,
560        b: *const c_void,
561        ldb: i64,
562        c: *const c_void,
563        ldc: i64,
564        d: *mut c_void,
565        ldd: i64,
566        alpha: f32,
567        beta: f32,
568        workspace: *mut c_void,
569        workspace_bytes: usize,
570        stream: *mut c_void,
571    ) -> i32;
572
573    /// Workspace size in bytes for `bf16` RRR sm_80 GEMM.
574    pub fn baracuda_cutlass_gemm_bf16_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
575
576    /// Pre-launch implementability check for `bf16` RRR sm_80.
577    ///
578    /// # Safety
579    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
580    pub fn baracuda_cutlass_gemm_bf16_rrr_sm80_can_implement(
581        m: i32,
582        n: i32,
583        k: i32,
584        a: *const c_void,
585        lda: i64,
586        b: *const c_void,
587        ldb: i64,
588        c: *const c_void,
589        ldc: i64,
590        d: *mut c_void,
591        ldd: i64,
592    ) -> i32;
593}
594
595// ============================================================================
596// GEMM — bias-fused (with optional activation), RRR layout, sm_80
597// ============================================================================
598//
599// Mirror of the RCR bias family but with `B` row-major rather than
600// column-major. Computes
601// `D = activation(alpha*AB + beta*C + bias_broadcast(N))` in a single
602// fused kernel pass. Symbol naming mirrors the RCR set, with `_rrr_`
603// in place of `_rcr_`. 24 entry points total (4 flavors × 2 dtypes ×
604// 3 ops).
605
606#[cfg(any(feature = "sm80", feature = "sm90a"))]
607unsafe extern "C" {
608    // ---- plain bias (Identity activation) -------------------------------
609
610    /// `f16` bias-fused GEMM, RRR layout, sm_80.
611    /// # Safety
612    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
613    #[allow(clippy::too_many_arguments)]
614    pub fn baracuda_cutlass_gemm_bias_f16_rrr_sm80_run(
615        m: i32, n: i32, k: i32,
616        a: *const c_void, lda: i64,
617        b: *const c_void, ldb: i64,
618        c: *const c_void, ldc: i64,
619        d: *mut c_void, ldd: i64,
620        bias: *const c_void,
621        alpha: f32, beta: f32,
622        workspace: *mut c_void, workspace_bytes: usize,
623        stream: *mut c_void,
624    ) -> i32;
625
626    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_f16_rrr_sm80).
627    pub fn baracuda_cutlass_gemm_bias_f16_rrr_sm80_workspace_size(
628        m: i32, n: i32, k: i32,
629    ) -> usize;
630
631    /// # Safety
632    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
633    pub fn baracuda_cutlass_gemm_bias_f16_rrr_sm80_can_implement(
634        m: i32, n: i32, k: i32,
635        a: *const c_void, lda: i64,
636        b: *const c_void, ldb: i64,
637        c: *const c_void, ldc: i64,
638        d: *mut c_void, ldd: i64,
639        bias: *const c_void,
640    ) -> i32;
641
642    /// `bf16` bias-fused GEMM, RRR layout, sm_80.
643    /// # Safety
644    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
645    #[allow(clippy::too_many_arguments)]
646    pub fn baracuda_cutlass_gemm_bias_bf16_rrr_sm80_run(
647        m: i32, n: i32, k: i32,
648        a: *const c_void, lda: i64,
649        b: *const c_void, ldb: i64,
650        c: *const c_void, ldc: i64,
651        d: *mut c_void, ldd: i64,
652        bias: *const c_void,
653        alpha: f32, beta: f32,
654        workspace: *mut c_void, workspace_bytes: usize,
655        stream: *mut c_void,
656    ) -> i32;
657
658    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_bf16_rrr_sm80).
659    pub fn baracuda_cutlass_gemm_bias_bf16_rrr_sm80_workspace_size(
660        m: i32, n: i32, k: i32,
661    ) -> usize;
662
663    /// # Safety
664    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
665    pub fn baracuda_cutlass_gemm_bias_bf16_rrr_sm80_can_implement(
666        m: i32, n: i32, k: i32,
667        a: *const c_void, lda: i64,
668        b: *const c_void, ldb: i64,
669        c: *const c_void, ldc: i64,
670        d: *mut c_void, ldd: i64,
671        bias: *const c_void,
672    ) -> i32;
673
674    // ---- bias + ReLU activation ---------------------------------------
675
676    /// `f16` bias+ReLU GEMM, RRR layout, sm_80.
677    /// # Safety
678    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
679    #[allow(clippy::too_many_arguments)]
680    pub fn baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_run(
681        m: i32, n: i32, k: i32,
682        a: *const c_void, lda: i64,
683        b: *const c_void, ldb: i64,
684        c: *const c_void, ldc: i64,
685        d: *mut c_void, ldd: i64,
686        bias: *const c_void,
687        alpha: f32, beta: f32,
688        workspace: *mut c_void, workspace_bytes: usize,
689        stream: *mut c_void,
690    ) -> i32;
691
692    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_f16_rrr_sm80).
693    pub fn baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_workspace_size(
694        m: i32, n: i32, k: i32,
695    ) -> usize;
696
697    /// # Safety
698    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
699    pub fn baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_can_implement(
700        m: i32, n: i32, k: i32,
701        a: *const c_void, lda: i64,
702        b: *const c_void, ldb: i64,
703        c: *const c_void, ldc: i64,
704        d: *mut c_void, ldd: i64,
705        bias: *const c_void,
706    ) -> i32;
707
708    /// `bf16` bias+ReLU GEMM, RRR layout, sm_80.
709    /// # Safety
710    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
711    #[allow(clippy::too_many_arguments)]
712    pub fn baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_run(
713        m: i32, n: i32, k: i32,
714        a: *const c_void, lda: i64,
715        b: *const c_void, ldb: i64,
716        c: *const c_void, ldc: i64,
717        d: *mut c_void, ldd: i64,
718        bias: *const c_void,
719        alpha: f32, beta: f32,
720        workspace: *mut c_void, workspace_bytes: usize,
721        stream: *mut c_void,
722    ) -> i32;
723
724    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_bf16_rrr_sm80).
725    pub fn baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_workspace_size(
726        m: i32, n: i32, k: i32,
727    ) -> usize;
728
729    /// # Safety
730    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
731    pub fn baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_can_implement(
732        m: i32, n: i32, k: i32,
733        a: *const c_void, lda: i64,
734        b: *const c_void, ldb: i64,
735        c: *const c_void, ldc: i64,
736        d: *mut c_void, ldd: i64,
737        bias: *const c_void,
738    ) -> i32;
739
740    // ---- bias + GELU activation (exact, erf-based) --------------------
741
742    /// `f16` bias+GELU GEMM, RRR layout, sm_80.
743    /// # Safety
744    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
745    #[allow(clippy::too_many_arguments)]
746    pub fn baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_run(
747        m: i32, n: i32, k: i32,
748        a: *const c_void, lda: i64,
749        b: *const c_void, ldb: i64,
750        c: *const c_void, ldc: i64,
751        d: *mut c_void, ldd: i64,
752        bias: *const c_void,
753        alpha: f32, beta: f32,
754        workspace: *mut c_void, workspace_bytes: usize,
755        stream: *mut c_void,
756    ) -> i32;
757
758    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_f16_rrr_sm80).
759    pub fn baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_workspace_size(
760        m: i32, n: i32, k: i32,
761    ) -> usize;
762
763    /// # Safety
764    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
765    pub fn baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_can_implement(
766        m: i32, n: i32, k: i32,
767        a: *const c_void, lda: i64,
768        b: *const c_void, ldb: i64,
769        c: *const c_void, ldc: i64,
770        d: *mut c_void, ldd: i64,
771        bias: *const c_void,
772    ) -> i32;
773
774    /// `bf16` bias+GELU GEMM, RRR layout, sm_80.
775    /// # Safety
776    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
777    #[allow(clippy::too_many_arguments)]
778    pub fn baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_run(
779        m: i32, n: i32, k: i32,
780        a: *const c_void, lda: i64,
781        b: *const c_void, ldb: i64,
782        c: *const c_void, ldc: i64,
783        d: *mut c_void, ldd: i64,
784        bias: *const c_void,
785        alpha: f32, beta: f32,
786        workspace: *mut c_void, workspace_bytes: usize,
787        stream: *mut c_void,
788    ) -> i32;
789
790    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_bf16_rrr_sm80).
791    pub fn baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_workspace_size(
792        m: i32, n: i32, k: i32,
793    ) -> usize;
794
795    /// # Safety
796    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
797    pub fn baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_can_implement(
798        m: i32, n: i32, k: i32,
799        a: *const c_void, lda: i64,
800        b: *const c_void, ldb: i64,
801        c: *const c_void, ldc: i64,
802        d: *mut c_void, ldd: i64,
803        bias: *const c_void,
804    ) -> i32;
805
806    // ---- bias + SiLU activation ---------------------------------------
807
808    /// `f16` bias+SiLU GEMM, RRR layout, sm_80.
809    /// # Safety
810    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
811    #[allow(clippy::too_many_arguments)]
812    pub fn baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_run(
813        m: i32, n: i32, k: i32,
814        a: *const c_void, lda: i64,
815        b: *const c_void, ldb: i64,
816        c: *const c_void, ldc: i64,
817        d: *mut c_void, ldd: i64,
818        bias: *const c_void,
819        alpha: f32, beta: f32,
820        workspace: *mut c_void, workspace_bytes: usize,
821        stream: *mut c_void,
822    ) -> i32;
823
824    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_f16_rrr_sm80).
825    pub fn baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_workspace_size(
826        m: i32, n: i32, k: i32,
827    ) -> usize;
828
829    /// # Safety
830    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
831    pub fn baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_can_implement(
832        m: i32, n: i32, k: i32,
833        a: *const c_void, lda: i64,
834        b: *const c_void, ldb: i64,
835        c: *const c_void, ldc: i64,
836        d: *mut c_void, ldd: i64,
837        bias: *const c_void,
838    ) -> i32;
839
840    /// `bf16` bias+SiLU GEMM, RRR layout, sm_80.
841    /// # Safety
842    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
843    #[allow(clippy::too_many_arguments)]
844    pub fn baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_run(
845        m: i32, n: i32, k: i32,
846        a: *const c_void, lda: i64,
847        b: *const c_void, ldb: i64,
848        c: *const c_void, ldc: i64,
849        d: *mut c_void, ldd: i64,
850        bias: *const c_void,
851        alpha: f32, beta: f32,
852        workspace: *mut c_void, workspace_bytes: usize,
853        stream: *mut c_void,
854    ) -> i32;
855
856    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_bf16_rrr_sm80).
857    pub fn baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_workspace_size(
858        m: i32, n: i32, k: i32,
859    ) -> usize;
860
861    /// # Safety
862    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
863    pub fn baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_can_implement(
864        m: i32, n: i32, k: i32,
865        a: *const c_void, lda: i64,
866        b: *const c_void, ldb: i64,
867        c: *const c_void, ldc: i64,
868        d: *mut c_void, ldd: i64,
869        bias: *const c_void,
870    ) -> i32;
871}
872
873// ============================================================================
874// GEMM — TF32 (f32 input via TF32 tensor cores), RCR layout, sm_80
875// ============================================================================
876//
877// Inputs are IEEE 754 binary32 stored in device memory. The math
878// instruction reduces inputs to TF32 (10-bit mantissa, 8-bit exponent)
879// and accumulates into FP32. Faster than full-F32 SIMT GEMM at the cost
880// of ~10-bit math precision — analogous to cuBLAS's
881// `CUBLAS_COMPUTE_32F_FAST_TF32`.
882
883#[cfg(any(feature = "sm80", feature = "sm90a"))]
884unsafe extern "C" {
885    /// `f32` GEMM via TF32 tensor cores, RCR layout, sm_80.
886    ///
887    /// # Safety
888    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
889    pub fn baracuda_cutlass_gemm_tf32_rcr_sm80_run(
890        m: i32,
891        n: i32,
892        k: i32,
893        a: *const c_void,
894        lda: i64,
895        b: *const c_void,
896        ldb: i64,
897        c: *const c_void,
898        ldc: i64,
899        d: *mut c_void,
900        ldd: i64,
901        alpha: f32,
902        beta: f32,
903        workspace: *mut c_void,
904        workspace_bytes: usize,
905        stream: *mut c_void,
906    ) -> i32;
907
908    /// Workspace size in bytes for the `tf32` RCR sm_80 GEMM.
909    pub fn baracuda_cutlass_gemm_tf32_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
910
911    /// Pre-launch implementability check for `tf32` RCR sm_80.
912    ///
913    /// # Safety
914    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
915    pub fn baracuda_cutlass_gemm_tf32_rcr_sm80_can_implement(
916        m: i32,
917        n: i32,
918        k: i32,
919        a: *const c_void,
920        lda: i64,
921        b: *const c_void,
922        ldb: i64,
923        c: *const c_void,
924        ldc: i64,
925        d: *mut c_void,
926        ldd: i64,
927    ) -> i32;
928
929    /// `f32` GEMM via TF32 tensor cores, RRR layout, sm_80.
930    ///
931    /// Same numerical behavior as the RCR TF32 kernel but with `B` row-major.
932    /// The natural shape for f32 activations × f32 weights when both tensors
933    /// are stored row-major (no transpose pass before launch).
934    ///
935    /// # Safety
936    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
937    pub fn baracuda_cutlass_gemm_tf32_rrr_sm80_run(
938        m: i32,
939        n: i32,
940        k: i32,
941        a: *const c_void,
942        lda: i64,
943        b: *const c_void,
944        ldb: i64,
945        c: *const c_void,
946        ldc: i64,
947        d: *mut c_void,
948        ldd: i64,
949        alpha: f32,
950        beta: f32,
951        workspace: *mut c_void,
952        workspace_bytes: usize,
953        stream: *mut c_void,
954    ) -> i32;
955
956    /// Workspace size in bytes for the `tf32` RRR sm_80 GEMM.
957    pub fn baracuda_cutlass_gemm_tf32_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
958
959    /// Pre-launch implementability check for `tf32` RRR sm_80.
960    ///
961    /// # Safety
962    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
963    pub fn baracuda_cutlass_gemm_tf32_rrr_sm80_can_implement(
964        m: i32,
965        n: i32,
966        k: i32,
967        a: *const c_void,
968        lda: i64,
969        b: *const c_void,
970        ldb: i64,
971        c: *const c_void,
972        ldc: i64,
973        d: *mut c_void,
974        ldd: i64,
975    ) -> i32;
976}
977
978// ============================================================================
979// GEMM — bias-fused (with optional activation), TF32 path, RCR layout, sm_80
980// ============================================================================
981//
982// f32 inputs reduced through Ampere TF32 tensor cores, with bias and
983// optional activation fused into the epilogue. Mirrors the f16/bf16
984// bias family but uses the TF32 tile shape (4 elements per access).
985// 12 entry points total (4 flavors × 3 ops; single element type since
986// TF32 implies f32 storage).
987
988#[cfg(any(feature = "sm80", feature = "sm90a"))]
989unsafe extern "C" {
990    // ---- plain bias (Identity activation) ----
991
992    /// `f32` (TF32) bias-fused GEMM, RCR layout, sm_80.
993    /// # Safety
994    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
995    #[allow(clippy::too_many_arguments)]
996    pub fn baracuda_cutlass_gemm_bias_tf32_rcr_sm80_run(
997        m: i32, n: i32, k: i32,
998        a: *const c_void, lda: i64,
999        b: *const c_void, ldb: i64,
1000        c: *const c_void, ldc: i64,
1001        d: *mut c_void, ldd: i64,
1002        bias: *const c_void,
1003        alpha: f32, beta: f32,
1004        workspace: *mut c_void, workspace_bytes: usize,
1005        stream: *mut c_void,
1006    ) -> i32;
1007
1008    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_tf32_rcr_sm80).
1009    pub fn baracuda_cutlass_gemm_bias_tf32_rcr_sm80_workspace_size(
1010        m: i32, n: i32, k: i32,
1011    ) -> usize;
1012
1013    /// # Safety
1014    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1015    pub fn baracuda_cutlass_gemm_bias_tf32_rcr_sm80_can_implement(
1016        m: i32, n: i32, k: i32,
1017        a: *const c_void, lda: i64,
1018        b: *const c_void, ldb: i64,
1019        c: *const c_void, ldc: i64,
1020        d: *mut c_void, ldd: i64,
1021        bias: *const c_void,
1022    ) -> i32;
1023
1024    // ---- bias + ReLU activation ----
1025
1026    /// `f32` (TF32) bias+ReLU GEMM, RCR layout, sm_80.
1027    /// # Safety
1028    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1029    #[allow(clippy::too_many_arguments)]
1030    pub fn baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_run(
1031        m: i32, n: i32, k: i32,
1032        a: *const c_void, lda: i64,
1033        b: *const c_void, ldb: i64,
1034        c: *const c_void, ldc: i64,
1035        d: *mut c_void, ldd: i64,
1036        bias: *const c_void,
1037        alpha: f32, beta: f32,
1038        workspace: *mut c_void, workspace_bytes: usize,
1039        stream: *mut c_void,
1040    ) -> i32;
1041
1042    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_tf32_rcr_sm80).
1043    pub fn baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_workspace_size(
1044        m: i32, n: i32, k: i32,
1045    ) -> usize;
1046
1047    /// # Safety
1048    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1049    pub fn baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_can_implement(
1050        m: i32, n: i32, k: i32,
1051        a: *const c_void, lda: i64,
1052        b: *const c_void, ldb: i64,
1053        c: *const c_void, ldc: i64,
1054        d: *mut c_void, ldd: i64,
1055        bias: *const c_void,
1056    ) -> i32;
1057
1058    // ---- bias + GELU activation (exact, erf-based) ----
1059
1060    /// `f32` (TF32) bias+GELU GEMM, RCR layout, sm_80.
1061    /// # Safety
1062    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1063    #[allow(clippy::too_many_arguments)]
1064    pub fn baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_run(
1065        m: i32, n: i32, k: i32,
1066        a: *const c_void, lda: i64,
1067        b: *const c_void, ldb: i64,
1068        c: *const c_void, ldc: i64,
1069        d: *mut c_void, ldd: i64,
1070        bias: *const c_void,
1071        alpha: f32, beta: f32,
1072        workspace: *mut c_void, workspace_bytes: usize,
1073        stream: *mut c_void,
1074    ) -> i32;
1075
1076    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_tf32_rcr_sm80).
1077    pub fn baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_workspace_size(
1078        m: i32, n: i32, k: i32,
1079    ) -> usize;
1080
1081    /// # Safety
1082    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1083    pub fn baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_can_implement(
1084        m: i32, n: i32, k: i32,
1085        a: *const c_void, lda: i64,
1086        b: *const c_void, ldb: i64,
1087        c: *const c_void, ldc: i64,
1088        d: *mut c_void, ldd: i64,
1089        bias: *const c_void,
1090    ) -> i32;
1091
1092    // ---- bias + SiLU activation ----
1093
1094    /// `f32` (TF32) bias+SiLU GEMM, RCR layout, sm_80.
1095    /// # Safety
1096    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1097    #[allow(clippy::too_many_arguments)]
1098    pub fn baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_run(
1099        m: i32, n: i32, k: i32,
1100        a: *const c_void, lda: i64,
1101        b: *const c_void, ldb: i64,
1102        c: *const c_void, ldc: i64,
1103        d: *mut c_void, ldd: i64,
1104        bias: *const c_void,
1105        alpha: f32, beta: f32,
1106        workspace: *mut c_void, workspace_bytes: usize,
1107        stream: *mut c_void,
1108    ) -> i32;
1109
1110    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_tf32_rcr_sm80).
1111    pub fn baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_workspace_size(
1112        m: i32, n: i32, k: i32,
1113    ) -> usize;
1114
1115    /// # Safety
1116    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1117    pub fn baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_can_implement(
1118        m: i32, n: i32, k: i32,
1119        a: *const c_void, lda: i64,
1120        b: *const c_void, ldb: i64,
1121        c: *const c_void, ldc: i64,
1122        d: *mut c_void, ldd: i64,
1123        bias: *const c_void,
1124    ) -> i32;
1125}
1126
1127// ============================================================================
1128// GEMM — bias-fused (with optional activation), TF32 path, RRR layout, sm_80
1129// ============================================================================
1130//
1131// Mirror of the TF32 RCR bias family with `B` row-major rather than
1132// column-major. Same TF32 tile shape, same epilogue family. 12 entry
1133// points total (4 flavors × 3 ops; single element type since TF32
1134// implies f32).
1135
1136#[cfg(any(feature = "sm80", feature = "sm90a"))]
1137unsafe extern "C" {
1138    // ---- plain bias (Identity activation) ----
1139
1140    /// `f32` (TF32) bias-fused GEMM, RRR layout, sm_80.
1141    /// # Safety
1142    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1143    #[allow(clippy::too_many_arguments)]
1144    pub fn baracuda_cutlass_gemm_bias_tf32_rrr_sm80_run(
1145        m: i32, n: i32, k: i32,
1146        a: *const c_void, lda: i64,
1147        b: *const c_void, ldb: i64,
1148        c: *const c_void, ldc: i64,
1149        d: *mut c_void, ldd: i64,
1150        bias: *const c_void,
1151        alpha: f32, beta: f32,
1152        workspace: *mut c_void, workspace_bytes: usize,
1153        stream: *mut c_void,
1154    ) -> i32;
1155
1156    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_tf32_rrr_sm80).
1157    pub fn baracuda_cutlass_gemm_bias_tf32_rrr_sm80_workspace_size(
1158        m: i32, n: i32, k: i32,
1159    ) -> usize;
1160
1161    /// # Safety
1162    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1163    pub fn baracuda_cutlass_gemm_bias_tf32_rrr_sm80_can_implement(
1164        m: i32, n: i32, k: i32,
1165        a: *const c_void, lda: i64,
1166        b: *const c_void, ldb: i64,
1167        c: *const c_void, ldc: i64,
1168        d: *mut c_void, ldd: i64,
1169        bias: *const c_void,
1170    ) -> i32;
1171
1172    // ---- bias + ReLU activation ----
1173
1174    /// `f32` (TF32) bias+ReLU GEMM, RRR layout, sm_80.
1175    /// # Safety
1176    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1177    #[allow(clippy::too_many_arguments)]
1178    pub fn baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_run(
1179        m: i32, n: i32, k: i32,
1180        a: *const c_void, lda: i64,
1181        b: *const c_void, ldb: i64,
1182        c: *const c_void, ldc: i64,
1183        d: *mut c_void, ldd: i64,
1184        bias: *const c_void,
1185        alpha: f32, beta: f32,
1186        workspace: *mut c_void, workspace_bytes: usize,
1187        stream: *mut c_void,
1188    ) -> i32;
1189
1190    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_tf32_rrr_sm80).
1191    pub fn baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_workspace_size(
1192        m: i32, n: i32, k: i32,
1193    ) -> usize;
1194
1195    /// # Safety
1196    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1197    pub fn baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_can_implement(
1198        m: i32, n: i32, k: i32,
1199        a: *const c_void, lda: i64,
1200        b: *const c_void, ldb: i64,
1201        c: *const c_void, ldc: i64,
1202        d: *mut c_void, ldd: i64,
1203        bias: *const c_void,
1204    ) -> i32;
1205
1206    // ---- bias + GELU activation (exact, erf-based) ----
1207
1208    /// `f32` (TF32) bias+GELU GEMM, RRR layout, sm_80.
1209    /// # Safety
1210    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1211    #[allow(clippy::too_many_arguments)]
1212    pub fn baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_run(
1213        m: i32, n: i32, k: i32,
1214        a: *const c_void, lda: i64,
1215        b: *const c_void, ldb: i64,
1216        c: *const c_void, ldc: i64,
1217        d: *mut c_void, ldd: i64,
1218        bias: *const c_void,
1219        alpha: f32, beta: f32,
1220        workspace: *mut c_void, workspace_bytes: usize,
1221        stream: *mut c_void,
1222    ) -> i32;
1223
1224    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_tf32_rrr_sm80).
1225    pub fn baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_workspace_size(
1226        m: i32, n: i32, k: i32,
1227    ) -> usize;
1228
1229    /// # Safety
1230    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1231    pub fn baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_can_implement(
1232        m: i32, n: i32, k: i32,
1233        a: *const c_void, lda: i64,
1234        b: *const c_void, ldb: i64,
1235        c: *const c_void, ldc: i64,
1236        d: *mut c_void, ldd: i64,
1237        bias: *const c_void,
1238    ) -> i32;
1239
1240    // ---- bias + SiLU activation ----
1241
1242    /// `f32` (TF32) bias+SiLU GEMM, RRR layout, sm_80.
1243    /// # Safety
1244    /// See [`baracuda_cutlass_gemm_bias_f16_rcr_sm80_run`].
1245    #[allow(clippy::too_many_arguments)]
1246    pub fn baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_run(
1247        m: i32, n: i32, k: i32,
1248        a: *const c_void, lda: i64,
1249        b: *const c_void, ldb: i64,
1250        c: *const c_void, ldc: i64,
1251        d: *mut c_void, ldd: i64,
1252        bias: *const c_void,
1253        alpha: f32, beta: f32,
1254        workspace: *mut c_void, workspace_bytes: usize,
1255        stream: *mut c_void,
1256    ) -> i32;
1257
1258    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_tf32_rrr_sm80).
1259    pub fn baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_workspace_size(
1260        m: i32, n: i32, k: i32,
1261    ) -> usize;
1262
1263    /// # Safety
1264    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1265    pub fn baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_can_implement(
1266        m: i32, n: i32, k: i32,
1267        a: *const c_void, lda: i64,
1268        b: *const c_void, ldb: i64,
1269        c: *const c_void, ldc: i64,
1270        d: *mut c_void, ldd: i64,
1271        bias: *const c_void,
1272    ) -> i32;
1273}
1274
1275// ============================================================================
1276// GEMM — f32-SIMT path (CUDA cores, no tensor cores), RCR + RRR, sm_80
1277// ============================================================================
1278//
1279// Strict-precision counterpart to the TF32 family. f32 inputs are
1280// multiplied through the SIMT mainloop (full IEEE 754 binary32 FMA, no
1281// tensor-core warp-reduction nondeterminism) and accumulated into f32.
1282// Identical layout conventions to the f16/bf16 kernels.
1283//
1284// Bias variants use a vendored partial specialization of
1285// `cutlass::gemm::kernel::DefaultGemmWithBroadcast` for `OpClassSimt`
1286// (see `kernels/include/baracuda_simt_broadcast_epilogue.h`) so that
1287// `GemmUniversalWithBroadcast` can route through the SIMT broadcast
1288// epilogue path that CUTLASS ships but doesn't wire by default.
1289
1290#[cfg(any(feature = "sm80", feature = "sm90a"))]
1291unsafe extern "C" {
1292    /// `f32` GEMM via SIMT (CUDA cores), RCR layout, sm_80.
1293    /// Full-precision counterpart to the TF32 RCR kernel.
1294    ///
1295    /// # Safety
1296    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
1297    pub fn baracuda_cutlass_gemm_f32_simt_rcr_sm80_run(
1298        m: i32, n: i32, k: i32,
1299        a: *const c_void, lda: i64,
1300        b: *const c_void, ldb: i64,
1301        c: *const c_void, ldc: i64,
1302        d: *mut c_void, ldd: i64,
1303        alpha: f32, beta: f32,
1304        workspace: *mut c_void, workspace_bytes: usize,
1305        stream: *mut c_void,
1306    ) -> i32;
1307
1308    /// Workspace size in bytes for `f32_simt` RCR sm_80 GEMM.
1309    pub fn baracuda_cutlass_gemm_f32_simt_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1310
1311    /// # Safety
1312    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1313    pub fn baracuda_cutlass_gemm_f32_simt_rcr_sm80_can_implement(
1314        m: i32, n: i32, k: i32,
1315        a: *const c_void, lda: i64,
1316        b: *const c_void, ldb: i64,
1317        c: *const c_void, ldc: i64,
1318        d: *mut c_void, ldd: i64,
1319    ) -> i32;
1320
1321    /// `f32` GEMM via SIMT (CUDA cores), RRR layout, sm_80.
1322    ///
1323    /// # Safety
1324    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
1325    pub fn baracuda_cutlass_gemm_f32_simt_rrr_sm80_run(
1326        m: i32, n: i32, k: i32,
1327        a: *const c_void, lda: i64,
1328        b: *const c_void, ldb: i64,
1329        c: *const c_void, ldc: i64,
1330        d: *mut c_void, ldd: i64,
1331        alpha: f32, beta: f32,
1332        workspace: *mut c_void, workspace_bytes: usize,
1333        stream: *mut c_void,
1334    ) -> i32;
1335
1336    /// Workspace size in bytes for `f32_simt` RRR sm_80 GEMM.
1337    pub fn baracuda_cutlass_gemm_f32_simt_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1338
1339    /// # Safety
1340    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1341    pub fn baracuda_cutlass_gemm_f32_simt_rrr_sm80_can_implement(
1342        m: i32, n: i32, k: i32,
1343        a: *const c_void, lda: i64,
1344        b: *const c_void, ldb: i64,
1345        c: *const c_void, ldc: i64,
1346        d: *mut c_void, ldd: i64,
1347    ) -> i32;
1348}
1349
1350// ============================================================================
1351// GEMM — bias-fused (with optional activation), f32-SIMT path, sm_80
1352// ============================================================================
1353//
1354// Routes through the vendored `OpClassSimt` partial specialization of
1355// `DefaultGemmWithBroadcast`. 24 entry points total (4 flavors × 2 layouts ×
1356// 3 ops). All scalars and the bias vector are `float`.
1357
1358#[cfg(any(feature = "sm80", feature = "sm90a"))]
1359unsafe extern "C" {
1360    // ---- RCR layout ----
1361
1362    /// CUTLASS GEMM trampoline (launch gemm_bias_f32_simt_rcr_sm80).
1363    #[allow(clippy::too_many_arguments)]
1364    pub fn baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_run(
1365        m: i32, n: i32, k: i32,
1366        a: *const c_void, lda: i64,
1367        b: *const c_void, ldb: i64,
1368        c: *const c_void, ldc: i64,
1369        d: *mut c_void, ldd: i64,
1370        bias: *const c_void,
1371        alpha: f32, beta: f32,
1372        workspace: *mut c_void, workspace_bytes: usize,
1373        stream: *mut c_void,
1374    ) -> i32;
1375    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_f32_simt_rcr_sm80).
1376    pub fn baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1377    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_f32_simt_rcr_sm80).
1378    pub fn baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_can_implement(
1379        m: i32, n: i32, k: i32,
1380        a: *const c_void, lda: i64,
1381        b: *const c_void, ldb: i64,
1382        c: *const c_void, ldc: i64,
1383        d: *mut c_void, ldd: i64,
1384        bias: *const c_void,
1385    ) -> i32;
1386
1387    /// CUTLASS GEMM trampoline (launch gemm_bias_relu_f32_simt_rcr_sm80).
1388    #[allow(clippy::too_many_arguments)]
1389    pub fn baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_run(
1390        m: i32, n: i32, k: i32,
1391        a: *const c_void, lda: i64,
1392        b: *const c_void, ldb: i64,
1393        c: *const c_void, ldc: i64,
1394        d: *mut c_void, ldd: i64,
1395        bias: *const c_void,
1396        alpha: f32, beta: f32,
1397        workspace: *mut c_void, workspace_bytes: usize,
1398        stream: *mut c_void,
1399    ) -> i32;
1400    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_f32_simt_rcr_sm80).
1401    pub fn baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1402    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_relu_f32_simt_rcr_sm80).
1403    pub fn baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_can_implement(
1404        m: i32, n: i32, k: i32,
1405        a: *const c_void, lda: i64,
1406        b: *const c_void, ldb: i64,
1407        c: *const c_void, ldc: i64,
1408        d: *mut c_void, ldd: i64,
1409        bias: *const c_void,
1410    ) -> i32;
1411
1412    /// CUTLASS GEMM trampoline (launch gemm_bias_gelu_f32_simt_rcr_sm80).
1413    #[allow(clippy::too_many_arguments)]
1414    pub fn baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_run(
1415        m: i32, n: i32, k: i32,
1416        a: *const c_void, lda: i64,
1417        b: *const c_void, ldb: i64,
1418        c: *const c_void, ldc: i64,
1419        d: *mut c_void, ldd: i64,
1420        bias: *const c_void,
1421        alpha: f32, beta: f32,
1422        workspace: *mut c_void, workspace_bytes: usize,
1423        stream: *mut c_void,
1424    ) -> i32;
1425    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_f32_simt_rcr_sm80).
1426    pub fn baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1427    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_gelu_f32_simt_rcr_sm80).
1428    pub fn baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_can_implement(
1429        m: i32, n: i32, k: i32,
1430        a: *const c_void, lda: i64,
1431        b: *const c_void, ldb: i64,
1432        c: *const c_void, ldc: i64,
1433        d: *mut c_void, ldd: i64,
1434        bias: *const c_void,
1435    ) -> i32;
1436
1437    /// CUTLASS GEMM trampoline (launch gemm_bias_silu_f32_simt_rcr_sm80).
1438    #[allow(clippy::too_many_arguments)]
1439    pub fn baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_run(
1440        m: i32, n: i32, k: i32,
1441        a: *const c_void, lda: i64,
1442        b: *const c_void, ldb: i64,
1443        c: *const c_void, ldc: i64,
1444        d: *mut c_void, ldd: i64,
1445        bias: *const c_void,
1446        alpha: f32, beta: f32,
1447        workspace: *mut c_void, workspace_bytes: usize,
1448        stream: *mut c_void,
1449    ) -> i32;
1450    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_f32_simt_rcr_sm80).
1451    pub fn baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1452    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_silu_f32_simt_rcr_sm80).
1453    pub fn baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_can_implement(
1454        m: i32, n: i32, k: i32,
1455        a: *const c_void, lda: i64,
1456        b: *const c_void, ldb: i64,
1457        c: *const c_void, ldc: i64,
1458        d: *mut c_void, ldd: i64,
1459        bias: *const c_void,
1460    ) -> i32;
1461
1462    // ---- RRR layout ----
1463
1464    /// CUTLASS GEMM trampoline (launch gemm_bias_f32_simt_rrr_sm80).
1465    #[allow(clippy::too_many_arguments)]
1466    pub fn baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_run(
1467        m: i32, n: i32, k: i32,
1468        a: *const c_void, lda: i64,
1469        b: *const c_void, ldb: i64,
1470        c: *const c_void, ldc: i64,
1471        d: *mut c_void, ldd: i64,
1472        bias: *const c_void,
1473        alpha: f32, beta: f32,
1474        workspace: *mut c_void, workspace_bytes: usize,
1475        stream: *mut c_void,
1476    ) -> i32;
1477    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_f32_simt_rrr_sm80).
1478    pub fn baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1479    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_f32_simt_rrr_sm80).
1480    pub fn baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_can_implement(
1481        m: i32, n: i32, k: i32,
1482        a: *const c_void, lda: i64,
1483        b: *const c_void, ldb: i64,
1484        c: *const c_void, ldc: i64,
1485        d: *mut c_void, ldd: i64,
1486        bias: *const c_void,
1487    ) -> i32;
1488
1489    /// CUTLASS GEMM trampoline (launch gemm_bias_relu_f32_simt_rrr_sm80).
1490    #[allow(clippy::too_many_arguments)]
1491    pub fn baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_run(
1492        m: i32, n: i32, k: i32,
1493        a: *const c_void, lda: i64,
1494        b: *const c_void, ldb: i64,
1495        c: *const c_void, ldc: i64,
1496        d: *mut c_void, ldd: i64,
1497        bias: *const c_void,
1498        alpha: f32, beta: f32,
1499        workspace: *mut c_void, workspace_bytes: usize,
1500        stream: *mut c_void,
1501    ) -> i32;
1502    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_f32_simt_rrr_sm80).
1503    pub fn baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1504    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_relu_f32_simt_rrr_sm80).
1505    pub fn baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_can_implement(
1506        m: i32, n: i32, k: i32,
1507        a: *const c_void, lda: i64,
1508        b: *const c_void, ldb: i64,
1509        c: *const c_void, ldc: i64,
1510        d: *mut c_void, ldd: i64,
1511        bias: *const c_void,
1512    ) -> i32;
1513
1514    /// CUTLASS GEMM trampoline (launch gemm_bias_gelu_f32_simt_rrr_sm80).
1515    #[allow(clippy::too_many_arguments)]
1516    pub fn baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_run(
1517        m: i32, n: i32, k: i32,
1518        a: *const c_void, lda: i64,
1519        b: *const c_void, ldb: i64,
1520        c: *const c_void, ldc: i64,
1521        d: *mut c_void, ldd: i64,
1522        bias: *const c_void,
1523        alpha: f32, beta: f32,
1524        workspace: *mut c_void, workspace_bytes: usize,
1525        stream: *mut c_void,
1526    ) -> i32;
1527    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_f32_simt_rrr_sm80).
1528    pub fn baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1529    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_gelu_f32_simt_rrr_sm80).
1530    pub fn baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_can_implement(
1531        m: i32, n: i32, k: i32,
1532        a: *const c_void, lda: i64,
1533        b: *const c_void, ldb: i64,
1534        c: *const c_void, ldc: i64,
1535        d: *mut c_void, ldd: i64,
1536        bias: *const c_void,
1537    ) -> i32;
1538
1539    /// CUTLASS GEMM trampoline (launch gemm_bias_silu_f32_simt_rrr_sm80).
1540    #[allow(clippy::too_many_arguments)]
1541    pub fn baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_run(
1542        m: i32, n: i32, k: i32,
1543        a: *const c_void, lda: i64,
1544        b: *const c_void, ldb: i64,
1545        c: *const c_void, ldc: i64,
1546        d: *mut c_void, ldd: i64,
1547        bias: *const c_void,
1548        alpha: f32, beta: f32,
1549        workspace: *mut c_void, workspace_bytes: usize,
1550        stream: *mut c_void,
1551    ) -> i32;
1552    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_f32_simt_rrr_sm80).
1553    pub fn baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1554    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_silu_f32_simt_rrr_sm80).
1555    pub fn baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_can_implement(
1556        m: i32, n: i32, k: i32,
1557        a: *const c_void, lda: i64,
1558        b: *const c_void, ldb: i64,
1559        c: *const c_void, ldc: i64,
1560        d: *mut c_void, ldd: i64,
1561        bias: *const c_void,
1562    ) -> i32;
1563}
1564
1565// ============================================================================
1566// GEMM — f64 (DGEMM via Ampere FP64 tensor cores), RCR + RRR, sm_80
1567// ============================================================================
1568//
1569// Full IEEE 754 binary64 throughout: inputs, accumulator, alpha/beta, and
1570// output. Routes through the Ampere DGEMM mma instruction (`m8n8k4` in
1571// double). Analogous to cuBLAS's `CUBLAS_COMPUTE_64F`.
1572//
1573// Note the FFI signature difference: `alpha` and `beta` are `f64` (vs
1574// `f32` for every other shipped element type). Bias-family kernels
1575// follow the same f64-scalar convention. The plan layer routes through
1576// these symbols when `T::Scalar::IS_F64` is true.
1577
1578#[cfg(any(feature = "sm80", feature = "sm90a"))]
1579unsafe extern "C" {
1580    /// `f64` GEMM via Ampere FP64 tensor cores, RCR layout, sm_80.
1581    ///
1582    /// # Safety
1583    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
1584    pub fn baracuda_cutlass_gemm_f64_rcr_sm80_run(
1585        m: i32, n: i32, k: i32,
1586        a: *const c_void, lda: i64,
1587        b: *const c_void, ldb: i64,
1588        c: *const c_void, ldc: i64,
1589        d: *mut c_void, ldd: i64,
1590        alpha: f64, beta: f64,
1591        workspace: *mut c_void, workspace_bytes: usize,
1592        stream: *mut c_void,
1593    ) -> i32;
1594
1595    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_f64_rcr_sm80).
1596    pub fn baracuda_cutlass_gemm_f64_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1597
1598    /// # Safety
1599    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1600    pub fn baracuda_cutlass_gemm_f64_rcr_sm80_can_implement(
1601        m: i32, n: i32, k: i32,
1602        a: *const c_void, lda: i64,
1603        b: *const c_void, ldb: i64,
1604        c: *const c_void, ldc: i64,
1605        d: *mut c_void, ldd: i64,
1606    ) -> i32;
1607
1608    /// `f64` GEMM via Ampere FP64 tensor cores, RRR layout, sm_80.
1609    ///
1610    /// # Safety
1611    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`].
1612    pub fn baracuda_cutlass_gemm_f64_rrr_sm80_run(
1613        m: i32, n: i32, k: i32,
1614        a: *const c_void, lda: i64,
1615        b: *const c_void, ldb: i64,
1616        c: *const c_void, ldc: i64,
1617        d: *mut c_void, ldd: i64,
1618        alpha: f64, beta: f64,
1619        workspace: *mut c_void, workspace_bytes: usize,
1620        stream: *mut c_void,
1621    ) -> i32;
1622
1623    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_f64_rrr_sm80).
1624    pub fn baracuda_cutlass_gemm_f64_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1625
1626    /// # Safety
1627    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1628    pub fn baracuda_cutlass_gemm_f64_rrr_sm80_can_implement(
1629        m: i32, n: i32, k: i32,
1630        a: *const c_void, lda: i64,
1631        b: *const c_void, ldb: i64,
1632        c: *const c_void, ldc: i64,
1633        d: *mut c_void, ldd: i64,
1634    ) -> i32;
1635}
1636
1637// ============================================================================
1638// GEMM — bias-fused (with optional activation), f64 (DGEMM), sm_80
1639// ============================================================================
1640//
1641// 24 entry points (4 flavors × 2 layouts × 3 ops). All scalars and the
1642// bias vector are `double` / `f64`.
1643
1644#[cfg(any(feature = "sm80", feature = "sm90a"))]
1645unsafe extern "C" {
1646    // ---- RCR layout ----
1647
1648    /// CUTLASS GEMM trampoline (launch gemm_bias_f64_rcr_sm80).
1649    #[allow(clippy::too_many_arguments)]
1650    pub fn baracuda_cutlass_gemm_bias_f64_rcr_sm80_run(
1651        m: i32, n: i32, k: i32,
1652        a: *const c_void, lda: i64,
1653        b: *const c_void, ldb: i64,
1654        c: *const c_void, ldc: i64,
1655        d: *mut c_void, ldd: i64,
1656        bias: *const c_void,
1657        alpha: f64, beta: f64,
1658        workspace: *mut c_void, workspace_bytes: usize,
1659        stream: *mut c_void,
1660    ) -> i32;
1661    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_f64_rcr_sm80).
1662    pub fn baracuda_cutlass_gemm_bias_f64_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1663    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_f64_rcr_sm80).
1664    pub fn baracuda_cutlass_gemm_bias_f64_rcr_sm80_can_implement(
1665        m: i32, n: i32, k: i32,
1666        a: *const c_void, lda: i64,
1667        b: *const c_void, ldb: i64,
1668        c: *const c_void, ldc: i64,
1669        d: *mut c_void, ldd: i64,
1670        bias: *const c_void,
1671    ) -> i32;
1672
1673    /// CUTLASS GEMM trampoline (launch gemm_bias_relu_f64_rcr_sm80).
1674    #[allow(clippy::too_many_arguments)]
1675    pub fn baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_run(
1676        m: i32, n: i32, k: i32,
1677        a: *const c_void, lda: i64,
1678        b: *const c_void, ldb: i64,
1679        c: *const c_void, ldc: i64,
1680        d: *mut c_void, ldd: i64,
1681        bias: *const c_void,
1682        alpha: f64, beta: f64,
1683        workspace: *mut c_void, workspace_bytes: usize,
1684        stream: *mut c_void,
1685    ) -> i32;
1686    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_f64_rcr_sm80).
1687    pub fn baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1688    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_relu_f64_rcr_sm80).
1689    pub fn baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_can_implement(
1690        m: i32, n: i32, k: i32,
1691        a: *const c_void, lda: i64,
1692        b: *const c_void, ldb: i64,
1693        c: *const c_void, ldc: i64,
1694        d: *mut c_void, ldd: i64,
1695        bias: *const c_void,
1696    ) -> i32;
1697
1698    /// CUTLASS GEMM trampoline (launch gemm_bias_gelu_f64_rcr_sm80).
1699    #[allow(clippy::too_many_arguments)]
1700    pub fn baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_run(
1701        m: i32, n: i32, k: i32,
1702        a: *const c_void, lda: i64,
1703        b: *const c_void, ldb: i64,
1704        c: *const c_void, ldc: i64,
1705        d: *mut c_void, ldd: i64,
1706        bias: *const c_void,
1707        alpha: f64, beta: f64,
1708        workspace: *mut c_void, workspace_bytes: usize,
1709        stream: *mut c_void,
1710    ) -> i32;
1711    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_f64_rcr_sm80).
1712    pub fn baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1713    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_gelu_f64_rcr_sm80).
1714    pub fn baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_can_implement(
1715        m: i32, n: i32, k: i32,
1716        a: *const c_void, lda: i64,
1717        b: *const c_void, ldb: i64,
1718        c: *const c_void, ldc: i64,
1719        d: *mut c_void, ldd: i64,
1720        bias: *const c_void,
1721    ) -> i32;
1722
1723    /// CUTLASS GEMM trampoline (launch gemm_bias_silu_f64_rcr_sm80).
1724    #[allow(clippy::too_many_arguments)]
1725    pub fn baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_run(
1726        m: i32, n: i32, k: i32,
1727        a: *const c_void, lda: i64,
1728        b: *const c_void, ldb: i64,
1729        c: *const c_void, ldc: i64,
1730        d: *mut c_void, ldd: i64,
1731        bias: *const c_void,
1732        alpha: f64, beta: f64,
1733        workspace: *mut c_void, workspace_bytes: usize,
1734        stream: *mut c_void,
1735    ) -> i32;
1736    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_f64_rcr_sm80).
1737    pub fn baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1738    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_silu_f64_rcr_sm80).
1739    pub fn baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_can_implement(
1740        m: i32, n: i32, k: i32,
1741        a: *const c_void, lda: i64,
1742        b: *const c_void, ldb: i64,
1743        c: *const c_void, ldc: i64,
1744        d: *mut c_void, ldd: i64,
1745        bias: *const c_void,
1746    ) -> i32;
1747
1748    // ---- RRR layout ----
1749
1750    /// CUTLASS GEMM trampoline (launch gemm_bias_f64_rrr_sm80).
1751    #[allow(clippy::too_many_arguments)]
1752    pub fn baracuda_cutlass_gemm_bias_f64_rrr_sm80_run(
1753        m: i32, n: i32, k: i32,
1754        a: *const c_void, lda: i64,
1755        b: *const c_void, ldb: i64,
1756        c: *const c_void, ldc: i64,
1757        d: *mut c_void, ldd: i64,
1758        bias: *const c_void,
1759        alpha: f64, beta: f64,
1760        workspace: *mut c_void, workspace_bytes: usize,
1761        stream: *mut c_void,
1762    ) -> i32;
1763    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_f64_rrr_sm80).
1764    pub fn baracuda_cutlass_gemm_bias_f64_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1765    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_f64_rrr_sm80).
1766    pub fn baracuda_cutlass_gemm_bias_f64_rrr_sm80_can_implement(
1767        m: i32, n: i32, k: i32,
1768        a: *const c_void, lda: i64,
1769        b: *const c_void, ldb: i64,
1770        c: *const c_void, ldc: i64,
1771        d: *mut c_void, ldd: i64,
1772        bias: *const c_void,
1773    ) -> i32;
1774
1775    /// CUTLASS GEMM trampoline (launch gemm_bias_relu_f64_rrr_sm80).
1776    #[allow(clippy::too_many_arguments)]
1777    pub fn baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_run(
1778        m: i32, n: i32, k: i32,
1779        a: *const c_void, lda: i64,
1780        b: *const c_void, ldb: i64,
1781        c: *const c_void, ldc: i64,
1782        d: *mut c_void, ldd: i64,
1783        bias: *const c_void,
1784        alpha: f64, beta: f64,
1785        workspace: *mut c_void, workspace_bytes: usize,
1786        stream: *mut c_void,
1787    ) -> i32;
1788    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_relu_f64_rrr_sm80).
1789    pub fn baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1790    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_relu_f64_rrr_sm80).
1791    pub fn baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_can_implement(
1792        m: i32, n: i32, k: i32,
1793        a: *const c_void, lda: i64,
1794        b: *const c_void, ldb: i64,
1795        c: *const c_void, ldc: i64,
1796        d: *mut c_void, ldd: i64,
1797        bias: *const c_void,
1798    ) -> i32;
1799
1800    /// CUTLASS GEMM trampoline (launch gemm_bias_gelu_f64_rrr_sm80).
1801    #[allow(clippy::too_many_arguments)]
1802    pub fn baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_run(
1803        m: i32, n: i32, k: i32,
1804        a: *const c_void, lda: i64,
1805        b: *const c_void, ldb: i64,
1806        c: *const c_void, ldc: i64,
1807        d: *mut c_void, ldd: i64,
1808        bias: *const c_void,
1809        alpha: f64, beta: f64,
1810        workspace: *mut c_void, workspace_bytes: usize,
1811        stream: *mut c_void,
1812    ) -> i32;
1813    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_gelu_f64_rrr_sm80).
1814    pub fn baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1815    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_gelu_f64_rrr_sm80).
1816    pub fn baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_can_implement(
1817        m: i32, n: i32, k: i32,
1818        a: *const c_void, lda: i64,
1819        b: *const c_void, ldb: i64,
1820        c: *const c_void, ldc: i64,
1821        d: *mut c_void, ldd: i64,
1822        bias: *const c_void,
1823    ) -> i32;
1824
1825    /// CUTLASS GEMM trampoline (launch gemm_bias_silu_f64_rrr_sm80).
1826    #[allow(clippy::too_many_arguments)]
1827    pub fn baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_run(
1828        m: i32, n: i32, k: i32,
1829        a: *const c_void, lda: i64,
1830        b: *const c_void, ldb: i64,
1831        c: *const c_void, ldc: i64,
1832        d: *mut c_void, ldd: i64,
1833        bias: *const c_void,
1834        alpha: f64, beta: f64,
1835        workspace: *mut c_void, workspace_bytes: usize,
1836        stream: *mut c_void,
1837    ) -> i32;
1838    /// CUTLASS GEMM trampoline (workspace-bytes query for gemm_bias_silu_f64_rrr_sm80).
1839    pub fn baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
1840    /// CUTLASS GEMM trampoline (implementability check for gemm_bias_silu_f64_rrr_sm80).
1841    pub fn baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_can_implement(
1842        m: i32, n: i32, k: i32,
1843        a: *const c_void, lda: i64,
1844        b: *const c_void, ldb: i64,
1845        c: *const c_void, ldc: i64,
1846        d: *mut c_void, ldd: i64,
1847        bias: *const c_void,
1848    ) -> i32;
1849}
1850
1851// ============================================================================
1852// Batched GEMM — uniform-shape, RCR layout, sm_80 instantiation
1853// ============================================================================
1854//
1855// All batches share the same (M, N, K). Per-tensor `stride_*` (in
1856// elements, not bytes) gives the offset between batch slabs. Layout
1857// matches the single-GEMM RCR case. This is the natural fit for
1858// equal-batch attention / repeated linear layers; for variable-shape
1859// per-group problems use the grouped-GEMM API.
1860
1861#[cfg(any(feature = "sm80", feature = "sm90a"))]
1862unsafe extern "C" {
1863    /// `f16` batched GEMM, RCR layout, sm_80.
1864    ///
1865    /// # Safety
1866    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_run`]. Each batch's
1867    /// operand pointers are derived from base + `i * stride_*`; all
1868    /// derived addresses must be device-resident in the current context.
1869    #[allow(clippy::too_many_arguments)]
1870    pub fn baracuda_cutlass_gemm_batched_f16_rcr_sm80_run(
1871        m: i32,
1872        n: i32,
1873        k: i32,
1874        a: *const c_void,
1875        lda: i64,
1876        stride_a: i64,
1877        b: *const c_void,
1878        ldb: i64,
1879        stride_b: i64,
1880        c: *const c_void,
1881        ldc: i64,
1882        stride_c: i64,
1883        d: *mut c_void,
1884        ldd: i64,
1885        stride_d: i64,
1886        alpha: f32,
1887        beta: f32,
1888        batch_count: i32,
1889        workspace: *mut c_void,
1890        workspace_bytes: usize,
1891        stream: *mut c_void,
1892    ) -> i32;
1893
1894    /// Workspace bytes needed by the `f16` batched RCR sm_80 GEMM.
1895    pub fn baracuda_cutlass_gemm_batched_f16_rcr_sm80_workspace_size(
1896        m: i32,
1897        n: i32,
1898        k: i32,
1899        batch_count: i32,
1900    ) -> usize;
1901
1902    /// Pre-launch implementability check for `f16` batched RCR sm_80.
1903    ///
1904    /// # Safety
1905    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1906    #[allow(clippy::too_many_arguments)]
1907    pub fn baracuda_cutlass_gemm_batched_f16_rcr_sm80_can_implement(
1908        m: i32,
1909        n: i32,
1910        k: i32,
1911        a: *const c_void,
1912        lda: i64,
1913        stride_a: i64,
1914        b: *const c_void,
1915        ldb: i64,
1916        stride_b: i64,
1917        c: *const c_void,
1918        ldc: i64,
1919        stride_c: i64,
1920        d: *mut c_void,
1921        ldd: i64,
1922        stride_d: i64,
1923        batch_count: i32,
1924    ) -> i32;
1925
1926    /// `bf16` batched GEMM, RCR layout, sm_80.
1927    ///
1928    /// # Safety
1929    /// See [`baracuda_cutlass_gemm_batched_f16_rcr_sm80_run`].
1930    #[allow(clippy::too_many_arguments)]
1931    pub fn baracuda_cutlass_gemm_batched_bf16_rcr_sm80_run(
1932        m: i32,
1933        n: i32,
1934        k: i32,
1935        a: *const c_void,
1936        lda: i64,
1937        stride_a: i64,
1938        b: *const c_void,
1939        ldb: i64,
1940        stride_b: i64,
1941        c: *const c_void,
1942        ldc: i64,
1943        stride_c: i64,
1944        d: *mut c_void,
1945        ldd: i64,
1946        stride_d: i64,
1947        alpha: f32,
1948        beta: f32,
1949        batch_count: i32,
1950        workspace: *mut c_void,
1951        workspace_bytes: usize,
1952        stream: *mut c_void,
1953    ) -> i32;
1954
1955    /// Workspace bytes needed by the `bf16` batched RCR sm_80 GEMM.
1956    pub fn baracuda_cutlass_gemm_batched_bf16_rcr_sm80_workspace_size(
1957        m: i32,
1958        n: i32,
1959        k: i32,
1960        batch_count: i32,
1961    ) -> usize;
1962
1963    /// Pre-launch implementability check for `bf16` batched RCR sm_80.
1964    ///
1965    /// # Safety
1966    /// See [`baracuda_cutlass_gemm_f16_rcr_sm80_can_implement`].
1967    #[allow(clippy::too_many_arguments)]
1968    pub fn baracuda_cutlass_gemm_batched_bf16_rcr_sm80_can_implement(
1969        m: i32,
1970        n: i32,
1971        k: i32,
1972        a: *const c_void,
1973        lda: i64,
1974        stride_a: i64,
1975        b: *const c_void,
1976        ldb: i64,
1977        stride_b: i64,
1978        c: *const c_void,
1979        ldc: i64,
1980        stride_c: i64,
1981        d: *mut c_void,
1982        ldd: i64,
1983        stride_d: i64,
1984        batch_count: i32,
1985    ) -> i32;
1986}
1987
1988// ============================================================================
1989// Grouped GEMM — RCR layout, sm_80 instantiation
1990// ============================================================================
1991//
1992// Per-group layout matches the single-GEMM `RCR` case. The safe Rust layer
1993// (`baracuda-cutlass`) packs per-group `problem_sizes`, pointer arrays,
1994// and leading-dimension arrays into a caller-supplied workspace, then
1995// hands us pointers to those packed regions. The CUTLASS internal scratch
1996// (size from `*_scratch_bytes`) lives at the tail of the same workspace.
1997//
1998// `h_problem_sizes` is a HOST pointer to the same `[GemmCoord; G]` data
1999// that's also packed into device memory at `d_problem_sizes` — CUTLASS
2000// uses the host copy for `sufficient` / tile-count math.
2001
2002#[cfg(any(feature = "sm80", feature = "sm90a"))]
2003unsafe extern "C" {
2004    /// Compute the number of threadblocks to launch for an `f16` grouped
2005    /// GEMM with the given per-group `(M, N, K)` shapes. CUTLASS chooses
2006    /// based on device SM count vs total tile count.
2007    ///
2008    /// # Safety
2009    /// `h_m`, `h_n`, `h_k` must each be valid pointers to at least
2010    /// `group_count` `i32`s of host memory.
2011    pub fn baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient(
2012        h_m: *const i32,
2013        h_n: *const i32,
2014        h_k: *const i32,
2015        group_count: i32,
2016    ) -> i32;
2017
2018    /// CUTLASS-internal scratch bytes needed for the launch.
2019    ///
2020    /// # Safety
2021    /// Same as [`baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient`].
2022    pub fn baracuda_cutlass_grouped_gemm_f16_rcr_sm80_scratch_bytes(
2023        h_m: *const i32,
2024        h_n: *const i32,
2025        h_k: *const i32,
2026        group_count: i32,
2027        threadblock_count: i32,
2028    ) -> usize;
2029
2030    /// Pre-launch implementability check (host-only, no CUDA traffic).
2031    ///
2032    /// # Safety
2033    /// Same as [`baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient`].
2034    pub fn baracuda_cutlass_grouped_gemm_f16_rcr_sm80_can_implement(
2035        h_m: *const i32,
2036        h_n: *const i32,
2037        h_k: *const i32,
2038        group_count: i32,
2039    ) -> i32;
2040
2041    /// Launch the grouped GEMM.
2042    ///
2043    /// # Safety
2044    /// All `d_*` pointers must be device-resident, in the current context,
2045    /// and remain valid for the duration of the launch. `h_problem_sizes`
2046    /// must be a host pointer to a `[GemmCoord; group_count]` array (same
2047    /// data as `d_problem_sizes`). `scratch` must be at least
2048    /// `scratch_bytes` bytes of writable device memory. `stream` must be a
2049    /// live CUDA stream.
2050    #[allow(clippy::too_many_arguments)]
2051    pub fn baracuda_cutlass_grouped_gemm_f16_rcr_sm80_run(
2052        group_count: i32,
2053        threadblock_count: i32,
2054        d_problem_sizes: *const c_void,
2055        d_ptr_a: *const c_void,
2056        d_ptr_b: *const c_void,
2057        d_ptr_c: *const c_void,
2058        d_ptr_d: *mut c_void,
2059        d_lda: *const c_void,
2060        d_ldb: *const c_void,
2061        d_ldc: *const c_void,
2062        d_ldd: *const c_void,
2063        h_problem_sizes: *const c_void,
2064        alpha: f32,
2065        beta: f32,
2066        scratch: *mut c_void,
2067        scratch_bytes: usize,
2068        stream: *mut c_void,
2069    ) -> i32;
2070
2071    /// `bf16` grouped GEMM — see f16 counterpart for documentation.
2072    ///
2073    /// # Safety
2074    /// Same contract as [`baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient`].
2075    pub fn baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_sufficient(
2076        h_m: *const i32,
2077        h_n: *const i32,
2078        h_k: *const i32,
2079        group_count: i32,
2080    ) -> i32;
2081
2082    /// # Safety
2083    /// Same as [`baracuda_cutlass_grouped_gemm_f16_rcr_sm80_scratch_bytes`].
2084    pub fn baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_scratch_bytes(
2085        h_m: *const i32,
2086        h_n: *const i32,
2087        h_k: *const i32,
2088        group_count: i32,
2089        threadblock_count: i32,
2090    ) -> usize;
2091
2092    /// # Safety
2093    /// Same as [`baracuda_cutlass_grouped_gemm_f16_rcr_sm80_can_implement`].
2094    pub fn baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_can_implement(
2095        h_m: *const i32,
2096        h_n: *const i32,
2097        h_k: *const i32,
2098        group_count: i32,
2099    ) -> i32;
2100
2101    /// # Safety
2102    /// Same as [`baracuda_cutlass_grouped_gemm_f16_rcr_sm80_run`].
2103    #[allow(clippy::too_many_arguments)]
2104    pub fn baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_run(
2105        group_count: i32,
2106        threadblock_count: i32,
2107        d_problem_sizes: *const c_void,
2108        d_ptr_a: *const c_void,
2109        d_ptr_b: *const c_void,
2110        d_ptr_c: *const c_void,
2111        d_ptr_d: *mut c_void,
2112        d_lda: *const c_void,
2113        d_ldb: *const c_void,
2114        d_ldc: *const c_void,
2115        d_ldd: *const c_void,
2116        h_problem_sizes: *const c_void,
2117        alpha: f32,
2118        beta: f32,
2119        scratch: *mut c_void,
2120        scratch_bytes: usize,
2121        stream: *mut c_void,
2122    ) -> i32;
2123}
2124
2125// ============================================================================
2126// int8 GEMM — RCR layout, sm_80 instantiations (Phase 2)
2127// ============================================================================
2128//
2129// Layout convention `RCR`:
2130//   A: row-major    [M, K], leading dimension `lda`  (int8 or uint8)
2131//   B: column-major [K, N], leading dimension `ldb`  (matches A signedness)
2132//   C: row-major    [M, N], leading dimension `ldc`  (matches A signedness; optional)
2133//   D: row-major    [M, N], leading dimension `ldd`  (matches A signedness)
2134//
2135// Accumulator is int32; alpha/beta are f32 (the standard CUTLASS dequant-
2136// in-epilogue convention for integer GEMM). The final saturating cast
2137// from float compute back to int8/uint8 uses the `cvt.rni.sat.{s8,u8}.f32`
2138// PTX instruction. Operator = `OpMultiplyAddSaturate` — the accumulator
2139// clamps on overflow rather than wrapping.
2140//
2141// Bias-family symbols carry both an activation suffix (`bias`, `bias_relu`,
2142// `bias_gelu`, `bias_silu`) and a bias-element suffix (`f32bias` or
2143// `i32bias`) that picks the broadcast-vector dtype. The activation runs
2144// in float (after int32→float dequant), so GELU/SiLU compose without a
2145// custom epilogue.
2146//
2147// `RRR` (row-major × row-major) is **not** present at int8 — CUTLASS 4.2.0
2148// has no warp-level `MmaTensorOpMultiplicandTileIterator` specialization
2149// for `TensorOpMultiplicandCongruous<8, ...>`, so `RowMajor × RowMajor ×
2150// OpClassTensorOp` cannot be instantiated for 8-bit operands. The safe
2151// layer reports `Error::Unsupported` for an int8 RRR descriptor.
2152
2153#[cfg(any(feature = "sm80", feature = "sm90a"))]
2154unsafe extern "C" {
2155    // ---------- s8 Identity, RCR, sm_80 ----------
2156
2157    /// Signed-int8 GEMM, RCR layout, sm_80.
2158    ///
2159    /// Computes `D = saturating_cast<i8>(alpha * (A * B) + beta * C)`
2160    /// with int8 inputs, int32 accumulator, and f32 alpha/beta. `c` may
2161    /// be null with `beta = 0.0` to skip the source-tensor read.
2162    ///
2163    /// # Safety
2164    /// All pointer args must be device-resident (or null where allowed)
2165    /// and remain valid for the duration of the launch. `stream` must be
2166    /// a live CUDA stream in the current context.
2167    pub fn baracuda_cutlass_gemm_s8_rcr_sm80_run(
2168        m: i32,
2169        n: i32,
2170        k: i32,
2171        a: *const c_void,
2172        lda: i64,
2173        b: *const c_void,
2174        ldb: i64,
2175        c: *const c_void,
2176        ldc: i64,
2177        d: *mut c_void,
2178        ldd: i64,
2179        alpha: f32,
2180        beta: f32,
2181        workspace: *mut c_void,
2182        workspace_bytes: usize,
2183        stream: *mut c_void,
2184    ) -> i32;
2185
2186    /// Workspace size in bytes for the `s8` RCR sm_80 GEMM.
2187    pub fn baracuda_cutlass_gemm_s8_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
2188
2189    /// Pre-launch implementability check for the `s8` RCR sm_80 GEMM.
2190    ///
2191    /// # Safety
2192    /// Same pointer-validity contract as the matching `_run`, but only
2193    /// host-side alignment and leading-dimension checks occur.
2194    pub fn baracuda_cutlass_gemm_s8_rcr_sm80_can_implement(
2195        m: i32,
2196        n: i32,
2197        k: i32,
2198        a: *const c_void,
2199        lda: i64,
2200        b: *const c_void,
2201        ldb: i64,
2202        c: *const c_void,
2203        ldc: i64,
2204        d: *mut c_void,
2205        ldd: i64,
2206    ) -> i32;
2207
2208    // ---------- u8 Identity, RCR, sm_80 ----------
2209
2210    /// Unsigned-uint8 GEMM, RCR layout, sm_80.
2211    ///
2212    /// Same template family as the signed variant; output clamps to
2213    /// `[0, 255]` via `cvt.rni.sat.u8.f32`.
2214    ///
2215    /// # Safety
2216    /// See [`baracuda_cutlass_gemm_s8_rcr_sm80_run`].
2217    pub fn baracuda_cutlass_gemm_u8_rcr_sm80_run(
2218        m: i32,
2219        n: i32,
2220        k: i32,
2221        a: *const c_void,
2222        lda: i64,
2223        b: *const c_void,
2224        ldb: i64,
2225        c: *const c_void,
2226        ldc: i64,
2227        d: *mut c_void,
2228        ldd: i64,
2229        alpha: f32,
2230        beta: f32,
2231        workspace: *mut c_void,
2232        workspace_bytes: usize,
2233        stream: *mut c_void,
2234    ) -> i32;
2235
2236    /// Workspace size for `u8` RCR sm_80 GEMM.
2237    pub fn baracuda_cutlass_gemm_u8_rcr_sm80_workspace_size(m: i32, n: i32, k: i32) -> usize;
2238
2239    /// Pre-launch check for `u8` RCR sm_80 GEMM.
2240    ///
2241    /// # Safety
2242    /// See [`baracuda_cutlass_gemm_s8_rcr_sm80_can_implement`].
2243    pub fn baracuda_cutlass_gemm_u8_rcr_sm80_can_implement(
2244        m: i32,
2245        n: i32,
2246        k: i32,
2247        a: *const c_void,
2248        lda: i64,
2249        b: *const c_void,
2250        ldb: i64,
2251        c: *const c_void,
2252        ldc: i64,
2253        d: *mut c_void,
2254        ldd: i64,
2255    ) -> i32;
2256}
2257
2258// ---- int8 bias-family FFI decls (Bias + 3 activations × 2 sgn × 2 bias-types) ----
2259//
2260// 16 kernel families × 3 ops each = 48 extern decls. All share the same
2261// three signatures (run / workspace_size / can_implement); only the
2262// function name differs across families. A local macro_rules!
2263// generates each triplet to keep this section readable.
2264//
2265// Naming: `baracuda_cutlass_gemm_<epi>_<bias-elem>_<sgn>_rcr_sm80_<op>`
2266//   epi       ∈ {bias, bias_relu, bias_gelu, bias_silu}
2267//   bias-elem ∈ {f32bias, i32bias}
2268//   sgn       ∈ {s8, u8}
2269//   op        ∈ {run, workspace_size, can_implement}
2270
2271/// Internal: stamps out the (run, workspace_size, can_implement) extern
2272/// decl triple for one int8 bias-kernel family. Each `$run` / `$ws` /
2273/// `$ck` is the fully-qualified C symbol name as a Rust identifier.
2274macro_rules! int8_bias_ffi {
2275    ($run:ident, $ws:ident, $ck:ident) => {
2276        unsafe extern "C" {
2277            #[doc = concat!(
2278                "int8 bias-fused GEMM with optional fused activation.\n\n",
2279                "Computes `D = saturating_cast(activation(alpha * (A * B) ",
2280                "+ beta * C + bias_broadcast(N)))`. See the section header for ",
2281                "the layout / accumulator / clamp contract.\n\n",
2282                "# Safety\nSame contract as ",
2283                "[`baracuda_cutlass_gemm_s8_rcr_sm80_run`]."
2284            )]
2285            pub fn $run(
2286                m: i32,
2287                n: i32,
2288                k: i32,
2289                a: *const c_void,
2290                lda: i64,
2291                b: *const c_void,
2292                ldb: i64,
2293                c: *const c_void,
2294                ldc: i64,
2295                d: *mut c_void,
2296                ldd: i64,
2297                bias: *const c_void,
2298                alpha: f32,
2299                beta: f32,
2300                workspace: *mut c_void,
2301                workspace_bytes: usize,
2302                stream: *mut c_void,
2303            ) -> i32;
2304
2305            #[doc = "Workspace size in bytes for the corresponding `_run` entry point."]
2306            pub fn $ws(m: i32, n: i32, k: i32) -> usize;
2307
2308            #[doc = concat!(
2309                "Pre-launch implementability check for the corresponding ",
2310                "`_run` entry point.\n\n# Safety\nSame pointer-validity ",
2311                "contract as the matching `_run`, but only host-side ",
2312                "alignment and leading-dimension checks occur."
2313            )]
2314            pub fn $ck(
2315                m: i32,
2316                n: i32,
2317                k: i32,
2318                a: *const c_void,
2319                lda: i64,
2320                b: *const c_void,
2321                ldb: i64,
2322                c: *const c_void,
2323                ldc: i64,
2324                d: *mut c_void,
2325                ldd: i64,
2326                bias: *const c_void,
2327            ) -> i32;
2328        }
2329    };
2330}
2331
2332#[cfg(any(feature = "sm80", feature = "sm90a"))]
2333mod int8_bias_decls {
2334    use super::c_void;
2335
2336    // ===== s8 × f32 bias =====
2337    int8_bias_ffi!(
2338        baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_run,
2339        baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_workspace_size,
2340        baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_can_implement
2341    );
2342    int8_bias_ffi!(
2343        baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_run,
2344        baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_workspace_size,
2345        baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_can_implement
2346    );
2347    int8_bias_ffi!(
2348        baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_run,
2349        baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_workspace_size,
2350        baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_can_implement
2351    );
2352    int8_bias_ffi!(
2353        baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_run,
2354        baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_workspace_size,
2355        baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_can_implement
2356    );
2357
2358    // ===== s8 × i32 bias =====
2359    int8_bias_ffi!(
2360        baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_run,
2361        baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_workspace_size,
2362        baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_can_implement
2363    );
2364    int8_bias_ffi!(
2365        baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_run,
2366        baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_workspace_size,
2367        baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_can_implement
2368    );
2369    int8_bias_ffi!(
2370        baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_run,
2371        baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_workspace_size,
2372        baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_can_implement
2373    );
2374    int8_bias_ffi!(
2375        baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_run,
2376        baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_workspace_size,
2377        baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_can_implement
2378    );
2379
2380    // ===== u8 × f32 bias =====
2381    int8_bias_ffi!(
2382        baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_run,
2383        baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_workspace_size,
2384        baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_can_implement
2385    );
2386    int8_bias_ffi!(
2387        baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_run,
2388        baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_workspace_size,
2389        baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_can_implement
2390    );
2391    int8_bias_ffi!(
2392        baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_run,
2393        baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_workspace_size,
2394        baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_can_implement
2395    );
2396    int8_bias_ffi!(
2397        baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_run,
2398        baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_workspace_size,
2399        baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_can_implement
2400    );
2401
2402    // ===== u8 × i32 bias =====
2403    int8_bias_ffi!(
2404        baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_run,
2405        baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_workspace_size,
2406        baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_can_implement
2407    );
2408    int8_bias_ffi!(
2409        baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_run,
2410        baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_workspace_size,
2411        baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_can_implement
2412    );
2413    int8_bias_ffi!(
2414        baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_run,
2415        baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_workspace_size,
2416        baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_can_implement
2417    );
2418    int8_bias_ffi!(
2419        baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_run,
2420        baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_workspace_size,
2421        baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_can_implement
2422    );
2423}
2424
2425#[cfg(any(feature = "sm80", feature = "sm90a"))]
2426pub use int8_bias_decls::*;
2427