baracuda-kernels 0.0.1-alpha.68

Unified ML op facade for the baracuda CUDA ecosystem. Exposes every primitive an ML framework would expect (union of PyTorch torch.* + nn.functional and JAX lax.* / numpy ops) through a single Plan-based Rust surface, internally dispatching to baracuda-cutlass, the baracuda-* NVIDIA-library wrappers, or bespoke baracuda-kernels-sys kernels.
Documentation

baracuda-kernels

Unified Rust ML-op facade for the baracuda CUDA ecosystem.

One crate, one API style. Internally dispatches to whichever backend (NVIDIA library wrapper or bespoke .cu kernel) is best for the selected op × dtype × layout × architecture. The dispatch choice is observable via Plan::sku() for telemetry but doesn't leak into the call site.

For the high-level project pitch, the layered design, and the full roadmap see the workspace README.md and ARCHITECTURE.md.

What this crate exposes

The full PyTorch (torch.* + nn.functional) and JAX (jax.lax.* + jax.numpy.*) op union as Plan-based Rust types. As of alpha.25 the following families are wired with forward + backward kernels across the expected dtype matrices:

  • GEMMGemmPlan, BatchedGemmPlan, GroupedGemmPlan, IntGemmPlan, Fp8GemmPlan, Int4GemmPlan, BinGemmPlan.
  • ElementwiseUnaryPlan, BinaryPlan, TernaryPlan, BinaryCmpPlan, WherePlan, AffinePlan, CastPlan, GatedActivationPlan, UnaryParamPlan / BinaryParamPlan (PReLU, Lerp, Threshold, …), all with paired *BackwardPlan.
  • Shape / layoutConcatPlan, PadPlan, RepeatPlan, RollPlan, FlipPlan, PermutePlan, FillPlan, all with BW.
  • ReductionsReducePlan, ArgReducePlan, BoolReducePlan, CountReducePlan, TracePlan.
  • ScansScanPlan (cumsum, cumprod, cummax, cummin, logcumsumexp).
  • Softmax familySoftmaxPlan, GumbelSoftmaxPlan, SparsemaxPlan, all with BW.
  • NormalizationRMSNormPlan, LayerNormPlan, BatchNormPlan, GroupNormPlan, InstanceNormPlan, all with BW.
  • Loss — MSE / L1 / Huber / SmoothL1 / NLL / CrossEntropy / BCE / BCEWithLogits / KLDiv / GaussianNLL / PoissonNLL / Cosine / HingeEmbedding / MarginRanking / MultiMargin / MultilabelMargin / MultilabelSoftMargin / TripletMargin / CTCLoss.
  • RandomRandomPlan, DropoutPlan.
  • AttentionSdpaPlan, FlashSdpaPlan, FlashSdpaSm89Plan (sm_89 sibling), RopePlan, AlibiPlan, KvCacheAppendPlan.
  • LinalgCholeskyPlan, LuPlan, QrPlan, BatchedQrPlan, SvdPlan, BatchedSvdPlan, BatchedSvdaPlan, EigPlan, EighPlan, InversePlan, SolvePlan, LstSqPlan, BatchedOrmqrPlan / BatchedOrmqrWyPlan, BatchedQrMaterializePlan.
  • Convolution + Pooling (cuDNN-backed; cudnn feature) — Conv2dPlan, MaxPool2dPlan, AvgPool2dPlan.
  • FFTFftPlan, RfftPlan, IrfftPlan, FftNdPlan, RfftNdPlan, IrfftNdPlan, FftShiftPlan, FftShiftNdPlan.
  • IndexingGatherPlan, ScatterAddPlan, IndexSelectPlan, MaskedFillPlan, OneHotPlan, NonzeroPlan.
  • EmbeddingEmbeddingPlan, EmbeddingBagPlan.
  • Segment ops — sorted + unsorted variants of segment sum / mean / max / min / prod.
  • Quantization — per-tensor / per-channel / per-token / per-group quantize + dequantize, FakeQuantizePlan, DynamicRangeQuantizePlan, QuantizedLinearPlan, plus GGUF block-format dequant + MMVQ for Q4_0..Q8_K + k-quants.
  • MoE — fused per-token-dispatch + expert-matmul + accumulate (MoePlan with MoeVariant::Wmma, ScalarGguf, WmmaGguf).
  • ImageInterpolatePlan, GridSamplePlan, AffineGridPlan, PixelShufflePlan, RoiAlignPlan, RoiPoolPlan, NmsPlan.
  • Sort / topkSortPlan, ArgsortPlan, TopkPlan, KthvaluePlan, MsortPlan, SearchsortedPlan, BincountPlan, HistogramPlan, UniquePlan, UniqueConsecutivePlan.

The shared vocabulary (KernelDtype, Element, TensorRef, KernelSku, PlanPreference, Workspace, every op-kind enum) is re-exported from baracuda-kernels-types so callers import one crate.

Element vs KernelDtype — which to bound on

[KernelDtype] is the umbrella marker for every kernel-usable dtype, including the sub-byte / FP8 / packed-bit newtypes (S4, U4, S8, U8, Fp8E4M3, Fp8E5M2, Bin) that have their own kernel families. The op-shaped sub-traits (Element, IntElement, FpElement, BinElement) all extend KernelDtype, so a function bounded by <T: KernelDtype> accepts any kernel-usable type.

In practice plans bound on Element, IntElement, FpElement, or BinElement — whichever family the plan's kernel set fits — because they parameterize the plan shape. Reach for the umbrella KernelDtype bound only when the receiver needs to handle the union of every dtype (a generic dtype-size helper, a telemetry function, a wrapper crate downstream).

See baracuda-kernels-types for the full trait map and the per-trait dtype list.

#[non_exhaustive] and forward-compat

Phase 28 marked the op-family discriminant enums and several auxiliary tag enums #[non_exhaustive] in preparation for the 1.0 freeze. Downstream match arms must include a _ => catch-all — adding new op variants in future phases then no longer breaks the build. ElementKind, LayoutSku, ArchSku, EpilogueKind, ActivationKind, and Workspace<'a> are intentionally left exhaustive because they're hot-path-matched by the kernel dispatchers; new variants there are a deliberate breaking-change event. See the baracuda-kernels-types README for the full classification.

Quick start

use baracuda_driver::{Context, Device, DeviceBuffer, Stream};
use baracuda_kernels::{
    EpilogueKind, IntGemmArgs, IntGemmDescriptor, IntGemmPlan,
    LayoutSku, MatrixMut, MatrixRef, PlanPreference, S8, Workspace,
};

fn run() -> Result<(), Box<dyn std::error::Error>> {
    let ctx = Context::new(&Device::get(0)?)?;
    let stream = Stream::new(&ctx)?;

    let m = 128i32; let n = 128i32; let k = 128i32;
    let dev_a: DeviceBuffer<S8> = DeviceBuffer::zeros(&ctx, (m * k) as usize)?;
    let dev_b: DeviceBuffer<S8> = DeviceBuffer::zeros(&ctx, (k * n) as usize)?;
    let mut dev_d: DeviceBuffer<S8> = DeviceBuffer::zeros(&ctx, (m * n) as usize)?;

    // Rrr dispatches to bespoke int8 kernels in baracuda-kernels-sys.
    // Switching to Rcr would dispatch the same call through CUTLASS.
    let desc = IntGemmDescriptor {
        m, n, k,
        layout: LayoutSku::Rrr,
        epilogue: EpilogueKind::Identity,
    };
    let plan = IntGemmPlan::<S8>::select(&stream, &desc, PlanPreference::default())?;

    let args = IntGemmArgs::<S8, f32> {
        a: MatrixRef { data: dev_a.as_slice(), rows: m, cols: k, ld: k as i64 },
        b: MatrixRef { data: dev_b.as_slice(), rows: k, cols: n, ld: n as i64 },
        c: None,
        d: MatrixMut { data: dev_d.as_slice_mut(), rows: m, cols: n, ld: n as i64 },
        bias: None,
        alpha: 0.125,
        beta: 0.0,
    };
    plan.run(&stream, Workspace::None, args)?;
    stream.synchronize()?;
    Ok(())
}

The same lifecycle — Descriptor → Plan::select → query_workspace_size → Args → Plan::run — applies to every op family in the crate. See ARCHITECTURE.md for the design rationale behind the triple.

Cargo features

Feature Default Effect
sm80 yes Build the Ampere-baseline kernel set. Runs forward-compatibly on Ada and Hopper.
sm89 no Build the Ada Lovelace specializations: FP8 GEMM, FlashSdpaSm89Plan.
sm90a no Build the Hopper-specialized kernels (stubs today).
cudnn no Link cuDNN and enable Conv2dPlan, MaxPool2dPlan, AvgPool2dPlan, CtcLossCudnnPlan.

cudnn is off by default because cuDNN is a separate NVIDIA download not bundled with the stock CUDA toolkit installer. See the workspace README.md for the auto-discovery paths the build script probes.

Phase 32 builder migration — descriptor structs

Phase 32 marked the descriptor structs whose field set has grown per-phase #[non_exhaustive] and added pub fn new(...) constructors plus chainable with_* setters. Downstream callers that previously constructed descriptors via struct literal must migrate to the builder — the struct literal no longer compiles outside this crate.

Descriptor ::new required args Optional fields (defaults)
Conv1dDescriptor batch, c_in, l_in, c_out, l_filt, element pad_l=0, stride_l=1, dilation_l=1, groups=1
Conv2dDescriptor batch, c_in, h_in, w_in, c_out, h_filt, w_filt, element pad=(0,0), stride=(1,1), dilation=(1,1), groups=1
Conv3dDescriptor batch, c_in, d_in, h_in, w_in, c_out, d_filt, h_filt, w_filt, element pad=(0,0,0), stride=(1,1,1), dilation=(1,1,1), groups=1
ConvTranspose1dDescriptor batch, c_in, l_in, c_out, l_filt, element pad=0, stride=1, dilation=1, output_pad=0, groups=1
ConvTranspose2dDescriptor batch, c_in, h_in, w_in, c_out, h_filt, w_filt, element pad=(0,0), stride=(1,1), dilation=(1,1), output_pad=(0,0), groups=1
ConvTranspose3dDescriptor batch, c_in, d_in, h_in, w_in, c_out, d_filt, h_filt, w_filt, element pad=(0,0,0), stride=(1,1,1), dilation=(1,1,1), output_pad=(0,0,0), groups=1
Pool1dDescriptor batch, channels, l_in, window, mode, element pad=0, stride=window (PyTorch default)
Pool2dDescriptor batch, channels, h_in, w_in, window_h, window_w, mode, element pad=(0,0), stride=(window_h, window_w)
Pool3dDescriptor batch, channels, d_in, h_in, w_in, window_d, window_h, window_w, mode, element pad=(0,0,0), stride=(window_d, window_h, window_w)
AdaptivePool1dDescriptor batch, channels, l_in, l_out, element (no optional fields)
AdaptivePool2dDescriptor batch, channels, h_in, w_in, h_out, w_out, element (no optional fields)
AdaptivePool3dDescriptor batch, channels, d_in, h_in, w_in, d_out, h_out, w_out, element (no optional fields)
LpPool1dDescriptor batch, channels, l_in, window, p, element stride=window, ceil_mode=false
LpPool2dDescriptor batch, channels, h_in, w_in, window_h, window_w, p, element stride=(window_h, window_w), ceil_mode=false
FractionalMaxPool2dDescriptor batch, channels, h_in, w_in, window_h, window_w, h_out, w_out, element seed=0 (currently unused — caller supplies samples)
FractionalMaxPool3dDescriptor batch, channels, d_in, h_in, w_in, window_d, window_h, window_w, d_out, h_out, w_out, element seed=0
InterpolateDescriptor n, c, ih, iw, oh, ow, mode, element align_corners=false, scale_h=None, scale_w=None
InterpolateBackwardDescriptor n, c, ih, iw, oh, ow, mode, element align_corners=false, scale_h=None, scale_w=None

Setter naming convention:

  • with_padding(...), with_stride(...), with_dilation(...), with_output_padding(...), with_groups(...) — Conv / ConvTranspose family. Each takes the per-spatial-axis values as positional args (e.g. Conv2d::with_padding(pad_h, pad_w)).
  • with_stride(...), with_padding(...) — Pool family (Pool, LpPool). Pool's stride defaults to the per-axis window extent (matching PyTorch's pooling-default convention), so most call sites can leave it implicit and just set with_padding(...) if needed.
  • with_ceil_mode(bool) — LpPool.
  • with_seed(u64) — FractionalMaxPool.
  • with_align_corners(bool) / with_scale_h(Option<f64>) / with_scale_w(Option<f64>) — Interpolate / InterpolateBackward.

Example — Phase 11 conv with explicit groups + padding:

use baracuda_kernels::{Conv2dDescriptor, ElementKind};

let desc = Conv2dDescriptor::new(
    /* batch */ 16, /* c_in */ 64, /* h_in */ 28, /* w_in */ 28,
    /* c_out */ 128, /* h_filt */ 3, /* w_filt */ 3,
    ElementKind::F32,
)
.with_padding(1, 1)
.with_stride(1, 1)
.with_groups(2);  // grouped conv

The adaptive-pool family (AdaptivePool{1,2,3}dDescriptor) has no optional fields and a single positional constructor — adaptive pooling takes only its in / out extents.

Verifying the API surface compiles

cargo check -p baracuda-kernels --features sm89,cudnn

The GPU integration tests are gated behind #[ignore]; run them with cargo test -p baracuda-kernels --release -- --ignored on a host with a working NVIDIA driver. The full regression covers ~1630 tests on an RTX 4070.

See also