baracuda-cutlass
Safe Rust wrapper for compiled CUTLASS kernels in the baracuda ecosystem.
baracuda-cutlass provides a plan-based GEMM and grouped-GEMM API with
caller-supplied workspace, typed device-buffer arguments, and capture-safe
launches. It sits above baracuda-cutlass-kernels-sys (the compiled
kernels) and below framework integration crates like Fuel's fuel-cublaslt.
Scope
- Op families:
GemmPlan(single GEMM),BatchedGemmPlan(uniform-shape batched GEMM),GroupedGemmPlan+PreparedGroupedGemm(variable-M-per-group, MoE-friendly). - Element types:
half::f16,half::bf16,f32(routed through TF32 tensor cores at ~10-bit mantissa precision),F32Strict(full IEEE 754 binary32 via SIMT CUDA cores — bit-stable, no tensor-core warp-reduction nondeterminism), andf64(DGEMM via Ampere FP64 tensor cores). See [PrecisionGuarantee::math_precision] andScalarTypefor the per-element math precision and alpha/beta scalar mapping. - Layouts:
RCR(A row-major, B column-major, C/D row-major) andRRR(all three operands row-major — natural for activation@weight matmul without a transpose pass). All shipped element types ship both layouts; layout is a per-launch choice onGemmDescriptor. - Epilogues:
Identity,Bias,BiasRelu,BiasGelu,BiasSilu. TheBias*family computesD = activation(α·A·B + β·C + bias_broadcast(N))in a single fused kernel pass viacutlass::gemm::device::GemmUniversalWithBroadcast+LinearCombinationBiasElementwise— the bias add and activation both happen inside the epilogue, no extra memory traffic over plainBias. The bias vector has lengthNand must be contiguous (stride 1).GemmArgs::biasis required iffdescriptor.epilogue.requires_bias()istrue. GELU is the exact (erf-based) form, matching PyTorch's defaultnn.GELU(). - Architectures:
sm_80shipped today (runs on Ampere, Ada, and forward-compatibly on Hopper).sm_90aselection wiring is in place; the Hopper-specialized kernels themselves land when Hopper hardware is available for validation. - Workspace: caller-supplied —
Workspace::NoneorWorkspace::Borrowed(DeviceSliceMut<u8>). Plans never own device memory. Grouped GEMM additionally packs its per-group metadata into the front of the workspace via async H2D, with CUTLASS's internal scratch at the tail.
Kernel SKU coverage
| API | Layout × Element |
|---|---|
GemmPlan (Identity) |
{Rcr, Rrr} × {F16, Bf16, F32 (TF32), F32Strict (SIMT), F64 (DGEMM)} |
GemmPlan (Bias / BiasRelu / BiasGelu / BiasSilu) |
{Rcr, Rrr} × {F16, Bf16, F32 (TF32), F32Strict (SIMT), F64 (DGEMM)} |
IntGemmPlan (Identity) |
Rcr × {S8, U8} (this crate) · Rrr × {S8, U8} via baracuda-kernels |
IntGemmPlan (Bias / BiasRelu / BiasGelu / BiasSilu) |
Rcr × {S8, U8} × {bias = f32, bias = i32} (this crate) · Rrr via baracuda-kernels |
BatchedGemmPlan |
Rcr × {F16, Bf16} |
GroupedGemmPlan |
Rcr × {F16, Bf16} |
Per-element scalar (alpha / beta) types:
| Element | T::Scalar |
Notes |
|---|---|---|
f16 |
f32 |
Tensor-core math, F32 accumulator |
bf16 |
f32 |
Tensor-core math, F32 accumulator |
f32 |
f32 |
TF32 tensor-core math (10-bit mantissa), F32 accumulator |
F32Strict |
f32 |
SIMT full-precision math, F32 accumulator, bit-stable |
f64 |
f64 |
DGEMM tensor-core math, F64 accumulator |
S8 / U8 |
f32 |
Int8 tensor-core math, int32 accumulator, bit-stable. Float alpha/beta let the epilogue act as a dequantize. |
Int family notes:
IntGemmPlan<T: IntElement, BT: BiasElement = f32> is a sibling type to
GemmPlan. The matrix element T picks the kernel family
(S8 / U8 today; s4 / u4 / 1-bit deferred to follow-ups). For
Bias* epilogues, BT picks the bias broadcast element type — f32
(default; matches the float-bias convention used elsewhere) or i32
(matches TensorRT's int8 inference convention). Both routes use
LinearCombinationBiasElementwise with ElementCompute = float, so
the fused activation runs in float space after int32→float dequant and
the final saturating cast back to int8 happens via the
cvt.rni.sat.{s8,u8}.f32 PTX instruction.
Rrr for the int family is not in this crate — CUTLASS 4.2.0
lacks the warp-level iterator specializations for the 8-bit
TensorOpMultiplicandCongruous shared-memory layout that
RowMajor × RowMajor × OpClassTensorOp would select for int8.
Selecting LayoutSku::Rrr on this crate's IntGemmPlan returns
Error::Unsupported at plan selection time. The bespoke RRR kernels
live in baracuda-kernels — selecting Rrr on
baracuda_kernels::IntGemmPlan dispatches to a hand-rolled
mma.sync.m16n8k32 kernel set covering all 18 SKUs
({S8, U8} × {Identity, Bias, BiasRelu, BiasGelu, BiasSilu} × {f32, i32}
bias). Callers building new code should import from baracuda-kernels
— it's a strict superset of this crate's int-GEMM surface.
Remaining int / quantized dtypes (s4/u4/b1) are planned
follow-ups in baracuda-kernels and not yet shipped.
All on sm_80 (Ampere); sm_90a deferred until Hopper validation.
Why plan-based, not handle-based?
CUTLASS isn't cuBLAS. There is no persistent driver-side state that lives
across kernel launches. Every kernel is a self-contained instantiation of
a template. A Plan holds the selected kernel ID and its host-side
metadata — not a handle, not a workspace. This makes plans cheap to clone,
trivially Send + Sync, and capture-safe by construction (no host
allocations during run).
Quick start
use ;
use ;
use f16;
#
Grouped GEMM quick start (MoE-friendly)
use ;
use ;
use f16;
#
prepare validates per-group shapes and v0 invariants (all groups share
α/β; all groups consistently use c = None or c = Some(_)), packs host
arrays for problem_sizes, pointers, and leading dimensions, and queries
CUTLASS for the threadblock count + scratch size. run uploads the
metadata to the start of the workspace via async H2D and launches the
grouped kernel using the remainder as CUTLASS internal scratch.
Integration notes
Calling from a byte-storage substrate
Frameworks that store all device tensors as DeviceBuffer<u8> (e.g.
Fuel's unified-binding-table dispatch path) can construct typed
MatrixRef / MatrixMut views without copying or transmuting:
use ;
use DeviceBuffer;
use bf16;
#
DeviceBuffer<u8>::view_as
asserts byte-count divisibility and reuses the buffer's existing
allocation — no copy, no unsafe at the consumer site. For non-baracuda
allocations, the lower-level
DeviceSlice::from_raw_parts
escape hatch is available.
Sharing a stream across launchers
A consumer that holds an Arc<Stream> (e.g. one stream per device,
shared across many kernel launches) can pass it to plan.run directly
via Arc::as_ref — the &Stream borrow shape is the same as for an
owned Stream:
# use Arc;
# use Stream;
#
Mapping kernels to precision guarantees
For consumers maintaining a per-decision-point alternatives table
(picking between cuBLAS and CUTLASS at a given precision contract),
[GemmPlan::precision_guarantee] (and the grouped equivalent) returns
a PrecisionGuarantee value — math-instruction precision, accumulator
type, bit-stability and determinism flags — without re-derivation from
per-kernel docs.
Acknowledgments
API specification by the Fuel ML library team. Underlying CUTLASS by NVIDIA.
See NOTICE for full attribution.