# baracuda-cutlass
Safe Rust wrapper for compiled CUTLASS kernels in the baracuda ecosystem.
`baracuda-cutlass` provides a **plan-based GEMM and grouped-GEMM API** with
caller-supplied workspace, typed device-buffer arguments, and capture-safe
launches. It sits above [`baracuda-cutlass-kernels-sys`] (the compiled
kernels) and below framework integration crates like Fuel's `fuel-cublaslt`.
## Scope
- **Op families**: `GemmPlan` (single GEMM), `BatchedGemmPlan`
(uniform-shape batched GEMM), `GroupedGemmPlan` +
`PreparedGroupedGemm` (variable-M-per-group, MoE-friendly).
- **Element types**: `half::f16`, `half::bf16`, `f32` (routed through
TF32 tensor cores at ~10-bit mantissa precision),
[`F32Strict`](crate::F32Strict) (full IEEE 754 binary32 via SIMT CUDA
cores — bit-stable, no tensor-core warp-reduction nondeterminism),
and `f64` (DGEMM via Ampere FP64 tensor cores). See
[`PrecisionGuarantee::math_precision`] and
[`ScalarType`](crate::ScalarType) for the per-element math precision
and alpha/beta scalar mapping.
- **Layouts**: `RCR` (A row-major, B column-major, C/D row-major) and
`RRR` (all three operands row-major — natural for activation@weight
matmul without a transpose pass). All shipped element types ship both
layouts; layout is a per-launch choice on `GemmDescriptor`.
- **Epilogues**: `Identity`, `Bias`, `BiasRelu`, `BiasGelu`, `BiasSilu`.
The `Bias*` family computes
`D = activation(α·A·B + β·C + bias_broadcast(N))` in a single fused
kernel pass via `cutlass::gemm::device::GemmUniversalWithBroadcast` +
`LinearCombinationBiasElementwise` — the bias add and activation
both happen inside the epilogue, no extra memory traffic over plain
`Bias`. The bias vector has length `N` and must be contiguous
(stride 1). `GemmArgs::bias` is required iff
`descriptor.epilogue.requires_bias()` is `true`. GELU is the exact
(erf-based) form, matching PyTorch's default `nn.GELU()`.
- **Architectures**: `sm_80` shipped today (runs on Ampere, Ada, and
forward-compatibly on Hopper). `sm_90a` selection wiring is in place;
the Hopper-specialized kernels themselves land when Hopper hardware is
available for validation.
- **Workspace**: caller-supplied — `Workspace::None` or
`Workspace::Borrowed(DeviceSliceMut<u8>)`. Plans never own device memory.
Grouped GEMM additionally packs its per-group metadata into the front of
the workspace via async H2D, with CUTLASS's internal scratch at the tail.
### Kernel SKU coverage
| `GemmPlan` (Identity) | `{Rcr, Rrr} × {F16, Bf16, F32 (TF32), F32Strict (SIMT), F64 (DGEMM)}` |
| `GemmPlan` (Bias / BiasRelu / BiasGelu / BiasSilu) | `{Rcr, Rrr} × {F16, Bf16, F32 (TF32), F32Strict (SIMT), F64 (DGEMM)}` |
| `IntGemmPlan` (Identity) | `Rcr × {S8, U8}` (this crate) · `Rrr × {S8, U8}` via [`baracuda-kernels`] |
| `IntGemmPlan` (Bias / BiasRelu / BiasGelu / BiasSilu) | `Rcr × {S8, U8} × {bias = f32, bias = i32}` (this crate) · `Rrr` via [`baracuda-kernels`] |
| `BatchedGemmPlan` | `Rcr × {F16, Bf16}` |
| `GroupedGemmPlan` | `Rcr × {F16, Bf16}` |
**Per-element scalar (alpha / beta) types:**
| `f16` | `f32` | Tensor-core math, F32 accumulator |
| `bf16` | `f32` | Tensor-core math, F32 accumulator |
| `f32` | `f32` | TF32 tensor-core math (10-bit mantissa), F32 accumulator |
| `F32Strict` | `f32` | SIMT full-precision math, F32 accumulator, bit-stable |
| `f64` | `f64` | DGEMM tensor-core math, F64 accumulator |
| `S8` / `U8` | `f32` | Int8 tensor-core math, int32 accumulator, bit-stable. Float alpha/beta let the epilogue act as a dequantize. |
**Int family notes:**
`IntGemmPlan<T: IntElement, BT: BiasElement = f32>` is a sibling type to
`GemmPlan`. The matrix element `T` picks the kernel family
(`S8` / `U8` today; `s4` / `u4` / 1-bit deferred to follow-ups). For
`Bias*` epilogues, `BT` picks the bias broadcast element type — `f32`
(default; matches the float-bias convention used elsewhere) or `i32`
(matches TensorRT's int8 inference convention). Both routes use
`LinearCombinationBiasElementwise` with `ElementCompute = float`, so
the fused activation runs in float space after int32→float dequant and
the final saturating cast back to int8 happens via the
`cvt.rni.sat.{s8,u8}.f32` PTX instruction.
`Rrr` for the int family is **not in this crate** — CUTLASS 4.2.0
lacks the warp-level iterator specializations for the 8-bit
`TensorOpMultiplicandCongruous` shared-memory layout that
`RowMajor × RowMajor × OpClassTensorOp` would select for int8.
Selecting `LayoutSku::Rrr` on **this crate's** `IntGemmPlan` returns
`Error::Unsupported` at plan selection time. The bespoke RRR kernels
live in [`baracuda-kernels`] — selecting `Rrr` on
`baracuda_kernels::IntGemmPlan` dispatches to a hand-rolled
`mma.sync.m16n8k32` kernel set covering all 18 SKUs
(`{S8, U8} × {Identity, Bias, BiasRelu, BiasGelu, BiasSilu} × {f32, i32}`
bias). Callers building new code should import from `baracuda-kernels`
— it's a strict superset of this crate's int-GEMM surface.
Remaining int / quantized dtypes (`s4`/`u4`/`b1`) are planned
follow-ups in `baracuda-kernels` and not yet shipped.
[`baracuda-kernels`]: https://docs.rs/baracuda-kernels
All on `sm_80` (Ampere); `sm_90a` deferred until Hopper validation.
## Why plan-based, not handle-based?
CUTLASS isn't cuBLAS. There is no persistent driver-side state that lives
across kernel launches. Every kernel is a self-contained instantiation of
a template. A `Plan` holds the *selected kernel ID and its host-side
metadata* — not a handle, not a workspace. This makes plans cheap to clone,
trivially `Send + Sync`, and capture-safe by construction (no host
allocations during `run`).
## Quick start
```rust,no_run
use baracuda_cutlass::{
EpilogueKind, GemmArgs, GemmDescriptor, GemmPlan, LayoutSku,
MatrixMut, MatrixRef, PlanPreference, Workspace,
};
use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
use half::f16;
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let ctx = Context::new(&Device::get(0)?)?;
let stream = Stream::new(&ctx)?;
let m = 128i32; let n = 128i32; let k = 128i32;
let dev_a: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * k) as usize)?;
let dev_b: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (k * n) as usize)?;
let mut dev_d: DeviceBuffer<f16> = DeviceBuffer::zeros(&ctx, (m * n) as usize)?;
let desc = GemmDescriptor {
m, n, k,
layout: LayoutSku::Rcr,
epilogue: EpilogueKind::Identity,
};
let plan = GemmPlan::<f16>::select(&stream, &desc, PlanPreference::default())?;
let args = GemmArgs::<f16> {
a: MatrixRef { data: dev_a.as_slice(), rows: m, cols: k, ld: k as i64 },
b: MatrixRef { data: dev_b.as_slice(), rows: k, cols: n, ld: k as i64 },
c: None,
d: MatrixMut { data: dev_d.as_slice_mut(), rows: m, cols: n, ld: n as i64 },
bias: None,
alpha: 1.0,
beta: 0.0,
};
plan.can_implement(&args)?;
plan.run(&stream, Workspace::None, args)?;
# Ok(()) }
```
## Grouped GEMM quick start (MoE-friendly)
```rust,no_run
use baracuda_cutlass::{
EpilogueKind, GroupedGemmPlan, GroupedPlanPreference, GroupedProblem,
MatrixMut, MatrixRef, Workspace,
};
use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
use half::f16;
# fn run() -> Result<(), Box<dyn std::error::Error>> {
let ctx = Context::new(&Device::get(0)?)?;
let stream = Stream::new(&ctx)?;
// One "group" per expert. Variable M (token count), shared K/N.
// Build your GroupedProblem<f16> slice from per-expert device buffers
// (omitted here for brevity — see tests/grouped_gemm_smoke.rs).
let groups: Vec<GroupedProblem<'_, f16>> = todo!("build per-expert problems");
let plan = GroupedGemmPlan::<f16>::select(
&stream,
EpilogueKind::Identity,
GroupedPlanPreference::default(),
)?;
let prepared = plan.prepare(&groups)?;
// Allocate one workspace big enough for the packed metadata + CUTLASS scratch.
let mut workspace: DeviceBuffer<u8> = DeviceBuffer::zeros(&ctx, prepared.workspace_size())?;
prepared.run(&stream, Workspace::Borrowed(workspace.as_slice_mut()))?;
# Ok(()) }
```
`prepare` validates per-group shapes and v0 invariants (all groups share
α/β; all groups consistently use `c = None` or `c = Some(_)`), packs host
arrays for `problem_sizes`, pointers, and leading dimensions, and queries
CUTLASS for the threadblock count + scratch size. `run` uploads the
metadata to the start of the workspace via async H2D and launches the
grouped kernel using the remainder as CUTLASS internal scratch.
## Integration notes
### Calling from a byte-storage substrate
Frameworks that store all device tensors as `DeviceBuffer<u8>` (e.g.
Fuel's unified-binding-table dispatch path) can construct typed
`MatrixRef` / `MatrixMut` views without copying or transmuting:
```rust,no_run
use baracuda_cutlass::{MatrixMut, MatrixRef};
use baracuda_driver::DeviceBuffer;
use half::bf16;
# fn demo(byte_a: &DeviceBuffer<u8>, byte_d: &mut DeviceBuffer<u8>) {
let m = 128i32; let n = 128i32; let k = 128i32;
let a_view: MatrixRef<bf16> = MatrixRef {
data: byte_a.view_as::<bf16>(),
rows: m, cols: k, ld: k as i64,
};
let d_view: MatrixMut<bf16> = MatrixMut {
data: byte_d.view_as_mut::<bf16>(),
rows: m, cols: n, ld: n as i64,
};
# let _ = (a_view, d_view); }
```
[`DeviceBuffer<u8>::view_as`](baracuda_driver::DeviceBuffer::view_as)
asserts byte-count divisibility and reuses the buffer's existing
allocation — no copy, no `unsafe` at the consumer site. For non-baracuda
allocations, the lower-level
[`DeviceSlice::from_raw_parts`](baracuda_driver::DeviceSlice::from_raw_parts)
escape hatch is available.
### Sharing a stream across launchers
A consumer that holds an `Arc<Stream>` (e.g. one stream per device,
shared across many kernel launches) can pass it to `plan.run` directly
via `Arc::as_ref` — the `&Stream` borrow shape is the same as for an
owned `Stream`:
```rust,no_run
# use std::sync::Arc;
# use baracuda_driver::Stream;
# fn demo(shared: Arc<Stream>) {
// `shared.run(...)` — Arc<Stream> auto-derefs to &Stream at the
// call site; no extra Stream::new per launcher needed.
# let _ = shared.as_ref(); }
```
### Mapping kernels to precision guarantees
For consumers maintaining a per-decision-point alternatives table
(picking between cuBLAS and CUTLASS at a given precision contract),
[`GemmPlan::precision_guarantee`] (and the grouped equivalent) returns
a `PrecisionGuarantee` value — math-instruction precision, accumulator
type, bit-stability and determinism flags — without re-derivation from
per-kernel docs.
## Acknowledgments
API specification by the Fuel ML library team. Underlying CUTLASS by NVIDIA.
See `NOTICE` for full attribution.
[`baracuda-cutlass-kernels-sys`]: ../baracuda-cutlass-kernels-sys