baracuda_cutlass/types.rs
1//! Value types for the CUTLASS plan-based API.
2//!
3//! All types here are pure data — no hidden device allocations, no
4//! handles, no driver state. Plans cache *selection metadata* on top of
5//! these descriptors but never own device memory.
6//!
7//! The shared type vocabulary (element / layout / epilogue / matrix-view
8//! / plan-preference / precision-guarantee / workspace) lives in
9//! [`baracuda_kernels_types`]. This module re-exports it for back-compat
10//! and additionally hosts the CUTLASS-specific descriptors (GEMM /
11//! batched GEMM / grouped GEMM / int-GEMM problem + args structs and
12//! the [`GemmSku`] tag) that aren't shared with the wider facade.
13//!
14//! The trait formerly known as `CutlassElement` is now
15//! [`baracuda_kernels_types::Element`]; the old name is preserved as a
16//! type alias below (see [`CutlassElement`]).
17
18// Re-export the shared vocabulary so downstream callers that import
19// from `baracuda_cutlass::types::*` (and any code inside this crate
20// referencing these types via `crate::types::Foo`) continues to work
21// unchanged.
22pub use baracuda_kernels_types::{
23 ActivationKind, ArchSku, BackendKind, BiasElement, BiasElementKind, Element, ElementKind,
24 EpilogueKind, F32Strict, IntElement, LayoutSku, MathPrecision, MatrixMut, MatrixRef,
25 PlanPreference, PrecisionGuarantee, S8, ScalarType, U8, VectorRef, Workspace,
26};
27
28/// Back-compat alias for [`Element`].
29///
30/// Originally the float-family element trait was named `CutlassElement`;
31/// it was renamed to `Element` in workspace alpha.16 when the shared
32/// type vocabulary moved into [`baracuda_kernels_types`]. The old name
33/// is preserved here so existing downstream imports keep working.
34///
35/// Prefer importing `Element` from `baracuda_kernels_types` (or, via
36/// re-export, from `baracuda_cutlass` / `baracuda_kernels`) in new code.
37pub use baracuda_kernels_types::Element as CutlassElement;
38
39/// Problem shape and configuration handed to [`GemmPlan::select`](crate::GemmPlan::select).
40#[derive(Copy, Clone, Debug)]
41pub struct GemmDescriptor {
42 /// Output row count.
43 pub m: i32,
44 /// Output column count.
45 pub n: i32,
46 /// Reduction depth.
47 pub k: i32,
48 /// Layout SKU.
49 pub layout: LayoutSku,
50 /// Epilogue kind.
51 pub epilogue: EpilogueKind,
52}
53
54/// Per-launch arguments for a [`GemmPlan::run`](crate::GemmPlan::run) call.
55///
56/// `c` is optional: when `None`, `β` is ignored at the safe layer (treated
57/// as `0`) and the kernel computes `D = α · A · B`. When `Some`, the
58/// kernel computes `D = α · A · B + β · C` — including the
59/// `c.data == d.data` case for in-place accumulation.
60///
61/// `bias` is required iff the descriptor's epilogue is one of the
62/// `Bias*` variants, in which case the kernel computes
63/// `D = activation(α · A · B + β · C + bias_broadcast(N))`.
64#[derive(Debug)]
65pub struct GemmArgs<'a, T: Element> {
66 /// Left input. Row-major `[M, K]`.
67 pub a: MatrixRef<'a, T>,
68 /// Right input. Layout depends on the descriptor's [`LayoutSku`]:
69 /// column-major `[K, N]` for [`LayoutSku::Rcr`], row-major `[K, N]`
70 /// for [`LayoutSku::Rrr`].
71 pub b: MatrixRef<'a, T>,
72 /// Optional accumulation source. Row-major `[M, N]`.
73 pub c: Option<MatrixRef<'a, T>>,
74 /// Output. Row-major `[M, N]`.
75 pub d: MatrixMut<'a, T>,
76 /// Optional bias vector. Required (`Some`) when the descriptor's
77 /// epilogue is any `Bias*` variant; must be `None` for
78 /// [`EpilogueKind::Identity`]. Length-`N`, contiguous (stride 1)
79 /// device memory; broadcast across rows of `D`.
80 pub bias: Option<VectorRef<'a, T>>,
81 /// Multiplier on the matrix-multiply accumulator. Scalar type
82 /// matches `T::Scalar` — `f32` for f16/bf16/f32/[`F32Strict`], `f64`
83 /// for [`prim@f64`].
84 pub alpha: T::Scalar,
85 /// Multiplier on `c`. Forced to `0` internally when `c` is `None`,
86 /// so callers don't need to pre-zero it for the no-accumulate case.
87 pub beta: T::Scalar,
88}
89
90/// Problem shape and configuration handed to
91/// [`BatchedGemmPlan::select`](crate::BatchedGemmPlan::select).
92///
93/// All batches share the same `(M, N, K)` and per-batch operands are
94/// addressed by adding `i * stride_*` (in elements) to the base
95/// pointer — see [`BatchedGemmArgs`]. For variable-shape grouped
96/// problems use [`GroupedGemmPlan`](crate::GroupedGemmPlan) instead.
97#[derive(Copy, Clone, Debug)]
98pub struct BatchedGemmDescriptor {
99 /// Output row count (per batch).
100 pub m: i32,
101 /// Output column count (per batch).
102 pub n: i32,
103 /// Reduction depth (per batch).
104 pub k: i32,
105 /// Number of batches launched in a single kernel invocation.
106 pub batch_count: i32,
107 /// Layout SKU. v1 supports only [`LayoutSku::Rcr`].
108 pub layout: LayoutSku,
109 /// Epilogue kind. v1 supports only [`EpilogueKind::Identity`].
110 pub epilogue: EpilogueKind,
111}
112
113/// Per-launch arguments for a
114/// [`BatchedGemmPlan::run`](crate::BatchedGemmPlan::run) call.
115///
116/// `stride_*` fields are in **elements**, not bytes — matching CUTLASS's
117/// `GemmBatched` API. Pass `0` for stride if the same matrix should be
118/// reused across all batches (broadcast).
119#[derive(Debug)]
120pub struct BatchedGemmArgs<'a, T: Element> {
121 /// Left input — base pointer for batch 0.
122 pub a: MatrixRef<'a, T>,
123 /// Element offset between consecutive A batches.
124 pub stride_a: i64,
125 /// Right input — base pointer for batch 0.
126 pub b: MatrixRef<'a, T>,
127 /// Element offset between consecutive B batches.
128 pub stride_b: i64,
129 /// Optional accumulation source.
130 pub c: Option<MatrixRef<'a, T>>,
131 /// Element offset between consecutive C batches. Ignored when `c` is `None`.
132 pub stride_c: i64,
133 /// Output — base pointer for batch 0.
134 pub d: MatrixMut<'a, T>,
135 /// Element offset between consecutive D batches.
136 pub stride_d: i64,
137 /// α multiplier (shared across batches). Scalar type matches
138 /// `T::Scalar` — `f32` for f16/bf16/f32/[`F32Strict`], `f64` for
139 /// [`prim@f64`].
140 pub alpha: T::Scalar,
141 /// β multiplier (shared across batches). Forced to `0` internally
142 /// when `c` is `None`.
143 pub beta: T::Scalar,
144}
145
146/// One per-group entry for a grouped GEMM launch.
147///
148/// Each group has its own shape and pointers; CUTLASS dispatches them in
149/// a single kernel invocation. Passed as a slice to
150/// [`GroupedGemmPlan::prepare`](crate::GroupedGemmPlan::prepare), which
151/// returns a [`PreparedGroupedGemm`](crate::PreparedGroupedGemm) whose
152/// [`run`](crate::PreparedGroupedGemm::run) method performs the launch.
153#[derive(Debug)]
154pub struct GroupedProblem<'a, T: Element> {
155 /// Group `M`.
156 pub m: i32,
157 /// Group `N`.
158 pub n: i32,
159 /// Group `K`.
160 pub k: i32,
161 /// Left input.
162 pub a: MatrixRef<'a, T>,
163 /// Right input.
164 pub b: MatrixRef<'a, T>,
165 /// Optional accumulation source.
166 pub c: Option<MatrixRef<'a, T>>,
167 /// Output.
168 pub d: MatrixMut<'a, T>,
169 /// α for this group. Scalar type matches `T::Scalar` — `f32` for
170 /// f16/bf16/f32/[`F32Strict`], `f64` for [`prim@f64`].
171 pub alpha: T::Scalar,
172 /// β for this group. Forced to `0` internally when `c` is `None`.
173 pub beta: T::Scalar,
174}
175
176/// Problem shape and configuration handed to
177/// [`IntGemmPlan::select`](crate::IntGemmPlan::select).
178///
179/// Parallel to [`GemmDescriptor`] for the integer GEMM family.
180/// `LayoutSku` and [`EpilogueKind`] are shared with the float family,
181/// but coverage on int8 is limited to [`LayoutSku::Rcr`] in this
182/// release — selecting [`LayoutSku::Rrr`] returns
183/// [`Error::Unsupported`](crate::Error::Unsupported). The
184/// `RowMajor × RowMajor` integer SKU lives in the bespoke
185/// `baracuda-kernels-sys` kernel family (lands in workspace alpha.16).
186#[derive(Copy, Clone, Debug)]
187pub struct IntGemmDescriptor {
188 /// Output row count.
189 pub m: i32,
190 /// Output column count.
191 pub n: i32,
192 /// Reduction depth.
193 pub k: i32,
194 /// Layout SKU. Today's int8 CUTLASS SKUs require [`LayoutSku::Rcr`].
195 pub layout: LayoutSku,
196 /// Epilogue kind. All five variants are supported on int8 RCR.
197 pub epilogue: EpilogueKind,
198}
199
200/// Per-launch arguments for an
201/// [`IntGemmPlan::run`](crate::IntGemmPlan::run) call.
202///
203/// Parallel to [`GemmArgs`] for the integer GEMM family. The matrix
204/// operands carry the kernel element type `T: IntElement`
205/// (today: [`S8`] or [`U8`]); the optional `bias` carries the
206/// independent bias element type `BT: BiasElement` (today: `f32` or
207/// `i32`). Scalar `alpha` / `beta` are always `f32` regardless of `T`
208/// or `BT` — CUTLASS's `LinearCombinationClamp` /
209/// `LinearCombinationBiasElementwise` epilogues do the entire
210/// alpha/beta/bias/activation chain in float (after int32→float
211/// dequant of the accumulator) and saturating-cast back to the int
212/// output range on store.
213#[derive(Debug)]
214pub struct IntGemmArgs<'a, T: IntElement, BT: BiasElement = f32> {
215 /// Left input. Row-major `[M, K]`.
216 pub a: MatrixRef<'a, T>,
217 /// Right input. Column-major `[K, N]` (RCR).
218 pub b: MatrixRef<'a, T>,
219 /// Optional accumulation source. Row-major `[M, N]`.
220 pub c: Option<MatrixRef<'a, T>>,
221 /// Output. Row-major `[M, N]`.
222 pub d: MatrixMut<'a, T>,
223 /// Optional bias vector. Required when the descriptor's epilogue
224 /// is any `Bias*` variant; must be `None` for
225 /// [`EpilogueKind::Identity`]. Length-`N`, contiguous (stride 1)
226 /// device memory; broadcast across rows of `D`.
227 pub bias: Option<VectorRef<'a, BT>>,
228 /// Multiplier on the matrix-multiply accumulator. Always `f32`
229 /// for int GEMM — CUTLASS does the entire epilogue compute in
230 /// float space.
231 pub alpha: f32,
232 /// Multiplier on `c`. Forced to `0` internally when `c` is `None`.
233 pub beta: f32,
234}
235
236/// How CUTLASS schedules tiles across the grouped problem set.
237///
238/// v0 ships only [`GroupedScheduleMode::DeviceOnly`]. CUTLASS also offers
239/// a `HostPrecompute` mode that pre-walks the schedule on the host and
240/// uploads it; we'll add it later if profiling justifies the API surface.
241#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Hash)]
242pub enum GroupedScheduleMode {
243 /// All schedule decisions made on-device by the kernel itself.
244 #[default]
245 DeviceOnly,
246}
247
248/// Hints for [`GroupedGemmPlan::select`](crate::GroupedGemmPlan::select).
249///
250/// Wraps a [`PlanPreference`] for the underlying GEMM tile selection,
251/// plus grouped-specific knobs.
252#[derive(Copy, Clone, Debug, Default)]
253pub struct GroupedPlanPreference {
254 /// Tile-selection preferences (forwarded to the underlying GEMM picker).
255 pub base: PlanPreference,
256 /// CUTLASS schedule mode (v0: only [`GroupedScheduleMode::DeviceOnly`]).
257 pub schedule: GroupedScheduleMode,
258}
259
260/// Identity of the kernel a plan picked.
261///
262/// Useful for caching plan selections in higher layers and for telemetry
263/// (e.g., logging which SKU the autotuner picked).
264///
265/// `bias_element` distinguishes int-GEMM bias kernels at the SKU level:
266/// the same `(arch, layout, epilogue=Bias, element=S8)` tuple maps to
267/// two distinct kernels depending on whether the bias broadcast is `f32`
268/// or `i32`. Float-GEMM bias kernels and Identity kernels leave this
269/// field `None` because the bias element (when present) is implied by
270/// `element`.
271#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
272pub struct GemmSku {
273 /// Architecture the kernel was compiled for.
274 pub arch: ArchSku,
275 /// Layout the kernel implements.
276 pub layout: LayoutSku,
277 /// Epilogue the kernel implements.
278 pub epilogue: EpilogueKind,
279 /// Element type the kernel operates on.
280 pub element: ElementKind,
281 /// Bias broadcast element type. `Some` only for int-GEMM bias
282 /// kernels (which can have either `f32` or `i32` bias);
283 /// `None` for Identity kernels and for float-GEMM bias kernels
284 /// (where the bias element is implied to match `element`).
285 pub bias_element: Option<BiasElementKind>,
286}
287
288impl GemmSku {
289 /// Numerical guarantees for the kernel identified by this SKU.
290 ///
291 /// Pure host-side lookup; returns the same value for the same SKU
292 /// across calls. The mapping is part of the public contract: a
293 /// stable SKU implies a stable precision guarantee.
294 pub fn precision_guarantee(self) -> PrecisionGuarantee {
295 // All shipped kernels accumulate into F32 (floats) or int32
296 // (int8). None use cross-block atomics, so all are deterministic.
297 // Tensor-core warp-reduction order isn't pinned by the spec for
298 // float MMA, so float tensor-core kernels are not bit-stable
299 // cross-driver. Integer MMA reductions are deterministic — the
300 // int32 accumulator has no rounding nondeterminism — so int8
301 // kernels ARE bit-stable on the same hardware.
302 let math_precision = match self.element {
303 ElementKind::F16 => MathPrecision::F16,
304 ElementKind::Bf16 => MathPrecision::Bf16,
305 ElementKind::F32 => MathPrecision::Tf32,
306 ElementKind::F32Strict => MathPrecision::F32,
307 ElementKind::F64 => MathPrecision::F64,
308 ElementKind::S8 | ElementKind::U8 => MathPrecision::Int8,
309 // `I32` is an accumulator-only kind, never a kernel input
310 // element. A `GemmSku` constructed with `element = I32` is
311 // a programming error; report Int8 math precision (the
312 // only int kernel family that produces an int32 accum) as
313 // a defensive fallback.
314 ElementKind::I32 => MathPrecision::Int8,
315 // `I64` and `Bool` are elementwise-only input element types
316 // (added in baracuda-kernels Phase 3.3). No CUTLASS GEMM
317 // SKU consumes them; defensive arm reports Int8 math
318 // precision as a placeholder — a `GemmSku` constructed
319 // with these is a programming error.
320 ElementKind::I64 | ElementKind::Bool => MathPrecision::Int8,
321 // FP8 kernels live in baracuda-kernels-sys, not baracuda-cutlass.
322 // No CUTLASS SKU produces these element kinds; defensive arm.
323 ElementKind::Fp8E4M3 => MathPrecision::Fp8E4M3,
324 ElementKind::Fp8E5M2 => MathPrecision::Fp8E5M2,
325 // Int4 kernels (S4 / U4) live in baracuda-kernels-sys, not
326 // baracuda-cutlass. Defensive arm.
327 ElementKind::S4 | ElementKind::U4 => MathPrecision::Int4,
328 // Binary (Bin) GEMM lives in baracuda-kernels-sys. Defensive arm.
329 ElementKind::Bin => MathPrecision::Binary,
330 // Complex32 / Complex64 are FFT-family element types — no
331 // CUTLASS GEMM SKU consumes them. Report the matching float
332 // math precision as a defensive fallback.
333 ElementKind::Complex32 => MathPrecision::F32,
334 ElementKind::Complex64 => MathPrecision::F64,
335 };
336 // F32Strict (SIMT CUDA cores) and int8 (integer tensor cores)
337 // are bit-stable on the same hardware. Float tensor-core
338 // kernels (F16 / Bf16 / Tf32 / F64) don't pin the warp-level
339 // reduction order so they can differ in the last bit.
340 let bit_stable_on_same_hardware = matches!(
341 self.element,
342 ElementKind::F32Strict
343 | ElementKind::S8
344 | ElementKind::U8
345 | ElementKind::S4
346 | ElementKind::U4
347 | ElementKind::Bin,
348 );
349 let accumulator = match self.element {
350 ElementKind::F64 => ElementKind::F64,
351 ElementKind::S8
352 | ElementKind::U8
353 | ElementKind::S4
354 | ElementKind::U4
355 | ElementKind::Bin => ElementKind::I32,
356 _ => ElementKind::F32,
357 };
358 PrecisionGuarantee {
359 math_precision,
360 accumulator,
361 bit_stable_on_same_hardware,
362 deterministic: true,
363 }
364 }
365}
366