baracuda-kernels
Unified Rust ML-op facade for the baracuda CUDA ecosystem.
One crate, one API style. Internally dispatches to whichever backend
(NVIDIA library wrapper or bespoke .cu kernel) is best for the
selected op × dtype × layout × architecture. The dispatch choice is
observable via Plan::sku() for telemetry but doesn't leak into the
call site.
For the high-level project pitch, the layered design, and the full
roadmap see the workspace
README.md and
ARCHITECTURE.md.
What this crate exposes
The full PyTorch (torch.* + nn.functional) and JAX (jax.lax.* +
jax.numpy.*) op union as Plan-based Rust types. As of alpha.25 the
following families are wired with forward + backward kernels across the
expected dtype matrices:
- GEMM —
GemmPlan,BatchedGemmPlan,GroupedGemmPlan,IntGemmPlan,Fp8GemmPlan,Int4GemmPlan,BinGemmPlan. - Elementwise —
UnaryPlan,BinaryPlan,TernaryPlan,BinaryCmpPlan,WherePlan,AffinePlan,CastPlan,GatedActivationPlan,UnaryParamPlan/BinaryParamPlan(PReLU, Lerp, Threshold, …), all with paired*BackwardPlan. - Shape / layout —
ConcatPlan,PadPlan,RepeatPlan,RollPlan,FlipPlan,PermutePlan,FillPlan, all with BW. - Reductions —
ReducePlan,ArgReducePlan,BoolReducePlan,CountReducePlan,TracePlan. - Scans —
ScanPlan(cumsum, cumprod, cummax, cummin, logcumsumexp). - Softmax family —
SoftmaxPlan,GumbelSoftmaxPlan,SparsemaxPlan, all with BW. - Normalization —
RMSNormPlan,LayerNormPlan,BatchNormPlan,GroupNormPlan,InstanceNormPlan, all with BW. - Loss — MSE / L1 / Huber / SmoothL1 / NLL / CrossEntropy / BCE / BCEWithLogits / KLDiv / GaussianNLL / PoissonNLL / Cosine / HingeEmbedding / MarginRanking / MultiMargin / MultilabelMargin / MultilabelSoftMargin / TripletMargin / CTCLoss.
- Random —
RandomPlan,DropoutPlan. - Attention —
SdpaPlan,FlashSdpaPlan,FlashSdpaSm89Plan(sm_89 sibling),RopePlan,AlibiPlan,KvCacheAppendPlan. - Linalg —
CholeskyPlan,LuPlan,QrPlan,BatchedQrPlan,SvdPlan,BatchedSvdPlan,BatchedSvdaPlan,EigPlan,EighPlan,InversePlan,SolvePlan,LstSqPlan,BatchedOrmqrPlan/BatchedOrmqrWyPlan,BatchedQrMaterializePlan. - Convolution + Pooling (cuDNN-backed;
cudnnfeature) —Conv2dPlan,MaxPool2dPlan,AvgPool2dPlan. - FFT —
FftPlan,RfftPlan,IrfftPlan,FftNdPlan,RfftNdPlan,IrfftNdPlan,FftShiftPlan,FftShiftNdPlan. - Indexing —
GatherPlan,ScatterAddPlan,IndexSelectPlan,MaskedFillPlan,OneHotPlan,NonzeroPlan. - Embedding —
EmbeddingPlan,EmbeddingBagPlan. - Segment ops — sorted + unsorted variants of segment sum / mean / max / min / prod.
- Quantization — per-tensor / per-channel / per-token / per-group
quantize + dequantize,
FakeQuantizePlan,DynamicRangeQuantizePlan,QuantizedLinearPlan, plus GGUF block-format dequant + MMVQ for Q4_0..Q8_K + k-quants. - MoE — fused per-token-dispatch + expert-matmul + accumulate
(
MoePlanwithMoeVariant::Wmma,ScalarGguf,WmmaGguf). - Image —
InterpolatePlan,GridSamplePlan,AffineGridPlan,PixelShufflePlan,RoiAlignPlan,RoiPoolPlan,NmsPlan. - Sort / topk —
SortPlan,ArgsortPlan,TopkPlan,KthvaluePlan,MsortPlan,SearchsortedPlan,BincountPlan,HistogramPlan,UniquePlan,UniqueConsecutivePlan.
The shared vocabulary (KernelDtype, Element, TensorRef,
KernelSku, PlanPreference, Workspace, every op-kind enum) is
re-exported from baracuda-kernels-types so callers import one crate.
Element vs KernelDtype — which to bound on
[KernelDtype] is the umbrella marker for every kernel-usable
dtype, including the sub-byte / FP8 / packed-bit newtypes (S4,
U4, S8, U8, Fp8E4M3, Fp8E5M2, Bin) that have their own
kernel families. The op-shaped sub-traits (Element, IntElement,
FpElement, BinElement) all extend KernelDtype, so a function
bounded by <T: KernelDtype> accepts any kernel-usable type.
In practice plans bound on Element, IntElement, FpElement, or
BinElement — whichever family the plan's kernel set fits —
because they parameterize the plan shape. Reach for the umbrella
KernelDtype bound only when the receiver needs to handle the
union of every dtype (a generic dtype-size helper, a telemetry
function, a wrapper crate downstream).
See baracuda-kernels-types
for the full trait map and the per-trait dtype list.
#[non_exhaustive] and forward-compat
Phase 28 marked the op-family discriminant enums and several
auxiliary tag enums #[non_exhaustive] in preparation for the 1.0
freeze. Downstream match arms must include a _ => catch-all —
adding new op variants in future phases then no longer breaks the
build. ElementKind, LayoutSku, ArchSku, EpilogueKind,
ActivationKind, and Workspace<'a> are intentionally left
exhaustive because they're hot-path-matched by the kernel
dispatchers; new variants there are a deliberate breaking-change
event. See the
baracuda-kernels-types README
for the full classification.
Quick start
use ;
use ;
The same lifecycle — Descriptor → Plan::select → query_workspace_size →
Args → Plan::run — applies to every op family in the crate. See
ARCHITECTURE.md for the design rationale
behind the triple.
Cargo features
| Feature | Default | Effect |
|---|---|---|
sm80 |
yes | Build the Ampere-baseline kernel set. Runs forward-compatibly on Ada and Hopper. |
sm89 |
no | Build the Ada Lovelace specializations: FP8 GEMM, FlashSdpaSm89Plan. |
sm90a |
no | Build the Hopper-specialized kernels (stubs today). |
cudnn |
no | Link cuDNN and enable Conv2dPlan, MaxPool2dPlan, AvgPool2dPlan, CtcLossCudnnPlan. |
cudnn is off by default because cuDNN is a separate NVIDIA download
not bundled with the stock CUDA toolkit installer. See the workspace
README.md for the auto-discovery paths the build
script probes.
Phase 32 builder migration — descriptor structs
Phase 32 marked the descriptor structs whose field set has grown
per-phase #[non_exhaustive] and added pub fn new(...) constructors
plus chainable with_* setters. Downstream callers that previously
constructed descriptors via struct literal must migrate to the
builder — the struct literal no longer compiles outside this crate.
| Descriptor | ::new required args |
Optional fields (defaults) |
|---|---|---|
Conv1dDescriptor |
batch, c_in, l_in, c_out, l_filt, element |
pad_l=0, stride_l=1, dilation_l=1, groups=1 |
Conv2dDescriptor |
batch, c_in, h_in, w_in, c_out, h_filt, w_filt, element |
pad=(0,0), stride=(1,1), dilation=(1,1), groups=1 |
Conv3dDescriptor |
batch, c_in, d_in, h_in, w_in, c_out, d_filt, h_filt, w_filt, element |
pad=(0,0,0), stride=(1,1,1), dilation=(1,1,1), groups=1 |
ConvTranspose1dDescriptor |
batch, c_in, l_in, c_out, l_filt, element |
pad=0, stride=1, dilation=1, output_pad=0, groups=1 |
ConvTranspose2dDescriptor |
batch, c_in, h_in, w_in, c_out, h_filt, w_filt, element |
pad=(0,0), stride=(1,1), dilation=(1,1), output_pad=(0,0), groups=1 |
ConvTranspose3dDescriptor |
batch, c_in, d_in, h_in, w_in, c_out, d_filt, h_filt, w_filt, element |
pad=(0,0,0), stride=(1,1,1), dilation=(1,1,1), output_pad=(0,0,0), groups=1 |
Pool1dDescriptor |
batch, channels, l_in, window, mode, element |
pad=0, stride=window (PyTorch default) |
Pool2dDescriptor |
batch, channels, h_in, w_in, window_h, window_w, mode, element |
pad=(0,0), stride=(window_h, window_w) |
Pool3dDescriptor |
batch, channels, d_in, h_in, w_in, window_d, window_h, window_w, mode, element |
pad=(0,0,0), stride=(window_d, window_h, window_w) |
AdaptivePool1dDescriptor |
batch, channels, l_in, l_out, element |
(no optional fields) |
AdaptivePool2dDescriptor |
batch, channels, h_in, w_in, h_out, w_out, element |
(no optional fields) |
AdaptivePool3dDescriptor |
batch, channels, d_in, h_in, w_in, d_out, h_out, w_out, element |
(no optional fields) |
LpPool1dDescriptor |
batch, channels, l_in, window, p, element |
stride=window, ceil_mode=false |
LpPool2dDescriptor |
batch, channels, h_in, w_in, window_h, window_w, p, element |
stride=(window_h, window_w), ceil_mode=false |
FractionalMaxPool2dDescriptor |
batch, channels, h_in, w_in, window_h, window_w, h_out, w_out, element |
seed=0 (currently unused — caller supplies samples) |
FractionalMaxPool3dDescriptor |
batch, channels, d_in, h_in, w_in, window_d, window_h, window_w, d_out, h_out, w_out, element |
seed=0 |
InterpolateDescriptor |
n, c, ih, iw, oh, ow, mode, element |
align_corners=false, scale_h=None, scale_w=None |
InterpolateBackwardDescriptor |
n, c, ih, iw, oh, ow, mode, element |
align_corners=false, scale_h=None, scale_w=None |
Setter naming convention:
with_padding(...),with_stride(...),with_dilation(...),with_output_padding(...),with_groups(...)— Conv / ConvTranspose family. Each takes the per-spatial-axis values as positional args (e.g.Conv2d::with_padding(pad_h, pad_w)).with_stride(...),with_padding(...)— Pool family (Pool, LpPool). Pool'sstridedefaults to the per-axis window extent (matching PyTorch's pooling-default convention), so most call sites can leave it implicit and just setwith_padding(...)if needed.with_ceil_mode(bool)— LpPool.with_seed(u64)— FractionalMaxPool.with_align_corners(bool)/with_scale_h(Option<f64>)/with_scale_w(Option<f64>)— Interpolate / InterpolateBackward.
Example — Phase 11 conv with explicit groups + padding:
use ;
let desc = new
.with_padding
.with_stride
.with_groups; // grouped conv
The adaptive-pool family (AdaptivePool{1,2,3}dDescriptor) has no
optional fields and a single positional constructor — adaptive pooling
takes only its in / out extents.
Verifying the API surface compiles
The GPU integration tests are gated behind #[ignore]; run them with
cargo test -p baracuda-kernels --release -- --ignored on a host with
a working NVIDIA driver. The full regression covers ~1630 tests on an
RTX 4070.
See also
baracuda-kernels-types— the shared type vocabulary.baracuda-kernels-sys— raw FFI to the bespoke.cukernels behind this facade.baracuda-cutlass— CUTLASS plan types re-exported here unchanged.baracuda-kernels-bench— criterion bench harness for sm_89 perf sweeps.