Skip to main content

baracuda_cutlass/
types.rs

1//! Value types for the CUTLASS plan-based API.
2//!
3//! All types here are pure data — no hidden device allocations, no
4//! handles, no driver state. Plans cache *selection metadata* on top of
5//! these descriptors but never own device memory.
6//!
7//! The shared type vocabulary (element / layout / epilogue / matrix-view
8//! / plan-preference / precision-guarantee / workspace) lives in
9//! [`baracuda_kernels_types`]. This module re-exports it for back-compat
10//! and additionally hosts the CUTLASS-specific descriptors (GEMM /
11//! batched GEMM / grouped GEMM / int-GEMM problem + args structs and
12//! the [`GemmSku`] tag) that aren't shared with the wider facade.
13//!
14//! The trait formerly known as `CutlassElement` is now
15//! [`baracuda_kernels_types::Element`]; the old name is preserved as a
16//! type alias below (see [`CutlassElement`]).
17
18// Re-export the shared vocabulary so downstream callers that import
19// from `baracuda_cutlass::types::*` (and any code inside this crate
20// referencing these types via `crate::types::Foo`) continues to work
21// unchanged.
22pub use baracuda_kernels_types::{
23    ActivationKind, ArchSku, BackendKind, BiasElement, BiasElementKind, Element, ElementKind,
24    EpilogueKind, F32Strict, IntElement, LayoutSku, MathPrecision, MatrixMut, MatrixRef,
25    PlanPreference, PrecisionGuarantee, S8, ScalarType, U8, VectorRef, Workspace,
26};
27
28/// Back-compat alias for [`Element`].
29///
30/// Originally the float-family element trait was named `CutlassElement`;
31/// it was renamed to `Element` in workspace alpha.16 when the shared
32/// type vocabulary moved into [`baracuda_kernels_types`]. The old name
33/// is preserved here so existing downstream imports keep working.
34///
35/// Prefer importing `Element` from `baracuda_kernels_types` (or, via
36/// re-export, from `baracuda_cutlass` / `baracuda_kernels`) in new code.
37pub use baracuda_kernels_types::Element as CutlassElement;
38
39/// Problem shape and configuration handed to [`GemmPlan::select`](crate::GemmPlan::select).
40#[derive(Copy, Clone, Debug)]
41pub struct GemmDescriptor {
42    /// Output row count.
43    pub m: i32,
44    /// Output column count.
45    pub n: i32,
46    /// Reduction depth.
47    pub k: i32,
48    /// Layout SKU.
49    pub layout: LayoutSku,
50    /// Epilogue kind.
51    pub epilogue: EpilogueKind,
52}
53
54/// Per-launch arguments for a [`GemmPlan::run`](crate::GemmPlan::run) call.
55///
56/// `c` is optional: when `None`, `β` is ignored at the safe layer (treated
57/// as `0`) and the kernel computes `D = α · A · B`. When `Some`, the
58/// kernel computes `D = α · A · B + β · C` — including the
59/// `c.data == d.data` case for in-place accumulation.
60///
61/// `bias` is required iff the descriptor's epilogue is one of the
62/// `Bias*` variants, in which case the kernel computes
63/// `D = activation(α · A · B + β · C + bias_broadcast(N))`.
64#[derive(Debug)]
65pub struct GemmArgs<'a, T: Element> {
66    /// Left input. Row-major `[M, K]`.
67    pub a: MatrixRef<'a, T>,
68    /// Right input. Layout depends on the descriptor's [`LayoutSku`]:
69    /// column-major `[K, N]` for [`LayoutSku::Rcr`], row-major `[K, N]`
70    /// for [`LayoutSku::Rrr`].
71    pub b: MatrixRef<'a, T>,
72    /// Optional accumulation source. Row-major `[M, N]`.
73    pub c: Option<MatrixRef<'a, T>>,
74    /// Output. Row-major `[M, N]`.
75    pub d: MatrixMut<'a, T>,
76    /// Optional bias vector. Required (`Some`) when the descriptor's
77    /// epilogue is any `Bias*` variant; must be `None` for
78    /// [`EpilogueKind::Identity`]. Length-`N`, contiguous (stride 1)
79    /// device memory; broadcast across rows of `D`.
80    pub bias: Option<VectorRef<'a, T>>,
81    /// Multiplier on the matrix-multiply accumulator. Scalar type
82    /// matches `T::Scalar` — `f32` for f16/bf16/f32/[`F32Strict`], `f64`
83    /// for [`prim@f64`].
84    pub alpha: T::Scalar,
85    /// Multiplier on `c`. Forced to `0` internally when `c` is `None`,
86    /// so callers don't need to pre-zero it for the no-accumulate case.
87    pub beta: T::Scalar,
88}
89
90/// Problem shape and configuration handed to
91/// [`BatchedGemmPlan::select`](crate::BatchedGemmPlan::select).
92///
93/// All batches share the same `(M, N, K)` and per-batch operands are
94/// addressed by adding `i * stride_*` (in elements) to the base
95/// pointer — see [`BatchedGemmArgs`]. For variable-shape grouped
96/// problems use [`GroupedGemmPlan`](crate::GroupedGemmPlan) instead.
97#[derive(Copy, Clone, Debug)]
98pub struct BatchedGemmDescriptor {
99    /// Output row count (per batch).
100    pub m: i32,
101    /// Output column count (per batch).
102    pub n: i32,
103    /// Reduction depth (per batch).
104    pub k: i32,
105    /// Number of batches launched in a single kernel invocation.
106    pub batch_count: i32,
107    /// Layout SKU. v1 supports only [`LayoutSku::Rcr`].
108    pub layout: LayoutSku,
109    /// Epilogue kind. v1 supports only [`EpilogueKind::Identity`].
110    pub epilogue: EpilogueKind,
111}
112
113/// Per-launch arguments for a
114/// [`BatchedGemmPlan::run`](crate::BatchedGemmPlan::run) call.
115///
116/// `stride_*` fields are in **elements**, not bytes — matching CUTLASS's
117/// `GemmBatched` API. Pass `0` for stride if the same matrix should be
118/// reused across all batches (broadcast).
119#[derive(Debug)]
120pub struct BatchedGemmArgs<'a, T: Element> {
121    /// Left input — base pointer for batch 0.
122    pub a: MatrixRef<'a, T>,
123    /// Element offset between consecutive A batches.
124    pub stride_a: i64,
125    /// Right input — base pointer for batch 0.
126    pub b: MatrixRef<'a, T>,
127    /// Element offset between consecutive B batches.
128    pub stride_b: i64,
129    /// Optional accumulation source.
130    pub c: Option<MatrixRef<'a, T>>,
131    /// Element offset between consecutive C batches. Ignored when `c` is `None`.
132    pub stride_c: i64,
133    /// Output — base pointer for batch 0.
134    pub d: MatrixMut<'a, T>,
135    /// Element offset between consecutive D batches.
136    pub stride_d: i64,
137    /// α multiplier (shared across batches). Scalar type matches
138    /// `T::Scalar` — `f32` for f16/bf16/f32/[`F32Strict`], `f64` for
139    /// [`prim@f64`].
140    pub alpha: T::Scalar,
141    /// β multiplier (shared across batches). Forced to `0` internally
142    /// when `c` is `None`.
143    pub beta: T::Scalar,
144}
145
146/// One per-group entry for a grouped GEMM launch.
147///
148/// Each group has its own shape and pointers; CUTLASS dispatches them in
149/// a single kernel invocation. Passed as a slice to
150/// [`GroupedGemmPlan::prepare`](crate::GroupedGemmPlan::prepare), which
151/// returns a [`PreparedGroupedGemm`](crate::PreparedGroupedGemm) whose
152/// [`run`](crate::PreparedGroupedGemm::run) method performs the launch.
153#[derive(Debug)]
154pub struct GroupedProblem<'a, T: Element> {
155    /// Group `M`.
156    pub m: i32,
157    /// Group `N`.
158    pub n: i32,
159    /// Group `K`.
160    pub k: i32,
161    /// Left input.
162    pub a: MatrixRef<'a, T>,
163    /// Right input.
164    pub b: MatrixRef<'a, T>,
165    /// Optional accumulation source.
166    pub c: Option<MatrixRef<'a, T>>,
167    /// Output.
168    pub d: MatrixMut<'a, T>,
169    /// α for this group. Scalar type matches `T::Scalar` — `f32` for
170    /// f16/bf16/f32/[`F32Strict`], `f64` for [`prim@f64`].
171    pub alpha: T::Scalar,
172    /// β for this group. Forced to `0` internally when `c` is `None`.
173    pub beta: T::Scalar,
174}
175
176/// Problem shape and configuration handed to
177/// [`IntGemmPlan::select`](crate::IntGemmPlan::select).
178///
179/// Parallel to [`GemmDescriptor`] for the integer GEMM family.
180/// `LayoutSku` and [`EpilogueKind`] are shared with the float family,
181/// but coverage on int8 is limited to [`LayoutSku::Rcr`] in this
182/// release — selecting [`LayoutSku::Rrr`] returns
183/// [`Error::Unsupported`](crate::Error::Unsupported). The
184/// `RowMajor × RowMajor` integer SKU lives in the bespoke
185/// `baracuda-kernels-sys` kernel family (lands in workspace alpha.16).
186#[derive(Copy, Clone, Debug)]
187pub struct IntGemmDescriptor {
188    /// Output row count.
189    pub m: i32,
190    /// Output column count.
191    pub n: i32,
192    /// Reduction depth.
193    pub k: i32,
194    /// Layout SKU. Today's int8 CUTLASS SKUs require [`LayoutSku::Rcr`].
195    pub layout: LayoutSku,
196    /// Epilogue kind. All five variants are supported on int8 RCR.
197    pub epilogue: EpilogueKind,
198}
199
200/// Per-launch arguments for an
201/// [`IntGemmPlan::run`](crate::IntGemmPlan::run) call.
202///
203/// Parallel to [`GemmArgs`] for the integer GEMM family. The matrix
204/// operands carry the kernel element type `T: IntElement`
205/// (today: [`S8`] or [`U8`]); the optional `bias` carries the
206/// independent bias element type `BT: BiasElement` (today: `f32` or
207/// `i32`). Scalar `alpha` / `beta` are always `f32` regardless of `T`
208/// or `BT` — CUTLASS's `LinearCombinationClamp` /
209/// `LinearCombinationBiasElementwise` epilogues do the entire
210/// alpha/beta/bias/activation chain in float (after int32→float
211/// dequant of the accumulator) and saturating-cast back to the int
212/// output range on store.
213#[derive(Debug)]
214pub struct IntGemmArgs<'a, T: IntElement, BT: BiasElement = f32> {
215    /// Left input. Row-major `[M, K]`.
216    pub a: MatrixRef<'a, T>,
217    /// Right input. Column-major `[K, N]` (RCR).
218    pub b: MatrixRef<'a, T>,
219    /// Optional accumulation source. Row-major `[M, N]`.
220    pub c: Option<MatrixRef<'a, T>>,
221    /// Output. Row-major `[M, N]`.
222    pub d: MatrixMut<'a, T>,
223    /// Optional bias vector. Required when the descriptor's epilogue
224    /// is any `Bias*` variant; must be `None` for
225    /// [`EpilogueKind::Identity`]. Length-`N`, contiguous (stride 1)
226    /// device memory; broadcast across rows of `D`.
227    pub bias: Option<VectorRef<'a, BT>>,
228    /// Multiplier on the matrix-multiply accumulator. Always `f32`
229    /// for int GEMM — CUTLASS does the entire epilogue compute in
230    /// float space.
231    pub alpha: f32,
232    /// Multiplier on `c`. Forced to `0` internally when `c` is `None`.
233    pub beta: f32,
234}
235
236/// How CUTLASS schedules tiles across the grouped problem set.
237///
238/// v0 ships only [`GroupedScheduleMode::DeviceOnly`]. CUTLASS also offers
239/// a `HostPrecompute` mode that pre-walks the schedule on the host and
240/// uploads it; we'll add it later if profiling justifies the API surface.
241#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
242pub enum GroupedScheduleMode {
243    /// All schedule decisions made on-device by the kernel itself.
244    #[default]
245    DeviceOnly,
246}
247
248/// Hints for [`GroupedGemmPlan::select`](crate::GroupedGemmPlan::select).
249///
250/// Wraps a [`PlanPreference`] for the underlying GEMM tile selection,
251/// plus grouped-specific knobs.
252#[derive(Copy, Clone, Debug, Default)]
253pub struct GroupedPlanPreference {
254    /// Tile-selection preferences (forwarded to the underlying GEMM picker).
255    pub base: PlanPreference,
256    /// CUTLASS schedule mode (v0: only [`GroupedScheduleMode::DeviceOnly`]).
257    pub schedule: GroupedScheduleMode,
258}
259
260/// Identity of the kernel a plan picked.
261///
262/// Useful for caching plan selections in higher layers and for telemetry
263/// (e.g., logging which SKU the autotuner picked).
264///
265/// `bias_element` distinguishes int-GEMM bias kernels at the SKU level:
266/// the same `(arch, layout, epilogue=Bias, element=S8)` tuple maps to
267/// two distinct kernels depending on whether the bias broadcast is `f32`
268/// or `i32`. Float-GEMM bias kernels and Identity kernels leave this
269/// field `None` because the bias element (when present) is implied by
270/// `element`.
271#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
272pub struct GemmSku {
273    /// Architecture the kernel was compiled for.
274    pub arch: ArchSku,
275    /// Layout the kernel implements.
276    pub layout: LayoutSku,
277    /// Epilogue the kernel implements.
278    pub epilogue: EpilogueKind,
279    /// Element type the kernel operates on.
280    pub element: ElementKind,
281    /// Bias broadcast element type. `Some` only for int-GEMM bias
282    /// kernels (which can have either `f32` or `i32` bias);
283    /// `None` for Identity kernels and for float-GEMM bias kernels
284    /// (where the bias element is implied to match `element`).
285    pub bias_element: Option<BiasElementKind>,
286}
287
288impl GemmSku {
289    /// Numerical guarantees for the kernel identified by this SKU.
290    ///
291    /// Pure host-side lookup; returns the same value for the same SKU
292    /// across calls. The mapping is part of the public contract: a
293    /// stable SKU implies a stable precision guarantee.
294    pub fn precision_guarantee(self) -> PrecisionGuarantee {
295        // All shipped kernels accumulate into F32 (floats) or int32
296        // (int8). None use cross-block atomics, so all are deterministic.
297        // Tensor-core warp-reduction order isn't pinned by the spec for
298        // float MMA, so float tensor-core kernels are not bit-stable
299        // cross-driver. Integer MMA reductions are deterministic — the
300        // int32 accumulator has no rounding nondeterminism — so int8
301        // kernels ARE bit-stable on the same hardware.
302        let math_precision = match self.element {
303            ElementKind::F16 => MathPrecision::F16,
304            ElementKind::Bf16 => MathPrecision::Bf16,
305            ElementKind::F32 => MathPrecision::Tf32,
306            ElementKind::F32Strict => MathPrecision::F32,
307            ElementKind::F64 => MathPrecision::F64,
308            ElementKind::S8 | ElementKind::U8 => MathPrecision::Int8,
309            // `I32` is an accumulator-only kind, never a kernel input
310            // element. A `GemmSku` constructed with `element = I32` is
311            // a programming error; report Int8 math precision (the
312            // only int kernel family that produces an int32 accum) as
313            // a defensive fallback.
314            ElementKind::I32 => MathPrecision::Int8,
315            // `I64` and `Bool` are elementwise-only input element types
316            // (added in baracuda-kernels Phase 3.3). No CUTLASS GEMM
317            // SKU consumes them; defensive arm reports Int8 math
318            // precision as a placeholder — a `GemmSku` constructed
319            // with these is a programming error.
320            ElementKind::I64 | ElementKind::Bool => MathPrecision::Int8,
321            // FP8 kernels live in baracuda-kernels-sys, not baracuda-cutlass.
322            // No CUTLASS SKU produces these element kinds; defensive arm.
323            ElementKind::Fp8E4M3 => MathPrecision::Fp8E4M3,
324            ElementKind::Fp8E5M2 => MathPrecision::Fp8E5M2,
325            // Int4 kernels (S4 / U4) live in baracuda-kernels-sys, not
326            // baracuda-cutlass. Defensive arm.
327            ElementKind::S4 | ElementKind::U4 => MathPrecision::Int4,
328            // Binary (Bin) GEMM lives in baracuda-kernels-sys. Defensive arm.
329            ElementKind::Bin => MathPrecision::Binary,
330            // Complex32 / Complex64 are FFT-family element types — no
331            // CUTLASS GEMM SKU consumes them. Report the matching float
332            // math precision as a defensive fallback.
333            ElementKind::Complex32 => MathPrecision::F32,
334            ElementKind::Complex64 => MathPrecision::F64,
335        };
336        // F32Strict (SIMT CUDA cores) and int8 (integer tensor cores)
337        // are bit-stable on the same hardware. Float tensor-core
338        // kernels (F16 / Bf16 / Tf32 / F64) don't pin the warp-level
339        // reduction order so they can differ in the last bit.
340        let bit_stable_on_same_hardware = matches!(
341            self.element,
342            ElementKind::F32Strict
343                | ElementKind::S8
344                | ElementKind::U8
345                | ElementKind::S4
346                | ElementKind::U4
347                | ElementKind::Bin,
348        );
349        let accumulator = match self.element {
350            ElementKind::F64 => ElementKind::F64,
351            ElementKind::S8
352            | ElementKind::U8
353            | ElementKind::S4
354            | ElementKind::U4
355            | ElementKind::Bin => ElementKind::I32,
356            _ => ElementKind::F32,
357        };
358        PrecisionGuarantee {
359            math_precision,
360            accumulator,
361            bit_stable_on_same_hardware,
362            deterministic: true,
363        }
364    }
365}
366