rlx-cuda
NVIDIA CUDA backend for RLX. cuBLAS / cuBLASLt for matmul, NVRTC-
compiled kernels for everything else, via the pure-Rust
cudarc crate — no nvcc at workspace
build time, no CUDA SDK install on developer machines. CUDA C++ kernel
sources live as &'static str and are JIT-compiled to PTX via NVRTC on
first dispatch (same pattern as rlx-wgpu's WGSL kernels).
Stack
- Matmul — cuBLAS (FP32), cuBLASLt (mixed precision via
GemmEx). - Convolution / pooling — cuDNN.
- Tensor cores — WMMA path for FP16 / BF16 GEMM on Volta+.
- Custom kernels — NVRTC-compiled
.cusources, cached on disk by graph fingerprint. - CUDA Graphs — capture + replay for inference-shaped workloads.
- Multi-stream — async copy + compute overlap.
- NVTX — span markers wired through Perfetto export.
Install
A working CUDA toolkit (libcudart / libcublas / libcudnn) must be on
the loader path. The crate is feature-gated in rlx-runtime:
[]
= { = "0.1", = ["cuda"] }
Mac-side iteration
cudarc's dynamic-loading feature loads libcuda via dlopen at
first FFI call. On Mac there's no libcuda, so:
-
cargo build -p rlx-cuda --release— compiles cleanly. The crate links against cudarc's stub bindings; libcuda is only resolved at runtime. -
cargo test -p rlx-cuda --release— runs the basic tests. Each test checksis_available()first; on Mac that returns false (the libcuda load fails inside cudarc and we catch the panic), so tests no-op cleanly. -
./rlx-cuda/check-compile.sh— builds the crate inside annvidia/cuda:12.6.0-devel-ubuntu22.04Docker image. Validates that our CUDA C++ sources compile against a real NVRTC + that cudarc links against the real libcuda. Apple Silicon runs the amd64 image under qemu emulation; takes a few minutes on first build, much faster on cache hits.
There's no path to actually run CUDA kernels on Mac — Apple Silicon has no NVIDIA GPU, and Docker Desktop's VM has no GPU passthrough even when running on a hypothetical Intel Mac with NVIDIA hardware. For benchmarks: use a cloud GPU (vast.ai, Lambda Labs, RunPod) or a self-hosted Linux box.
What's here
device.rs—CudaContextsingleton with panic-catching init so a missing libcuda returnsNoneinstead of crashing.arena.rs— single device buffer + per-node offsets, mirroring the rlx-wgpu f32-uniform arena. Reshape and Cast alias the input slot.kernels/*.cu— CUDA C++ sources (binary, unary, copy, matmul, attention, conv, etc.). Compiled via NVRTC at first dispatch and cached behindOnceLocks.kernels/mod.rs— NVRTC compile + module/function loader.backend.rs—CudaExecutable. Full IR coverage via the dispatch tier ladder below.
Matmul dispatch tier decision tree
Step::Matmul walks down a tier ladder; each tier checks its
preconditions and either dispatches or falls through. With
RLX_CUDA_LOG_FALLBACK=1 you'll see exactly which tier ran.
Step::Matmul(m, k, n, …)
│
┌───────────────────┴────────────────────┐
│ Is weight (B) in half-arena? │
│ (set_param_half was called for B) │
└───────────────────┬────────────────────┘
│
┌──── yes ──────┴───── no ──────┐
▼ ▼
┌─────────────────────────┐ ┌────────────────────────────┐
│ Tier 0: mixed-precision │ │ Tier 1: cublasLt fused │
│ cast f32 act → f16/bf16 │ │ matmul + bias + relu/gelu │
│ scratch; cublasGemmEx; │ │ in one launch │
│ epilogue kernel for │ │ — only when act ∈ {Relu, │
│ bias/act (any kind) │ │ Gelu, none} │
│ │ │ │
│ ✓ 2× weight memory │ │ ✓ Saves epilogue launch │
│ ✓ Tensor Core compute │ │ ✓ Bias broadcast inline │
└─────────────────────────┘ └────────────┬───────────────┘
│ act not relu/gelu
▼
┌────────────────────────────┐
│ Tier 2: cublasSgemm │
│ + matmul_epilogue.cu │
│ if has_bias || act ≠ id │
│ │
│ ✓ TF32 Tensor Core (auto) │
│ ✓ Handles all 12 acts │
└────────────┬───────────────┘
│ blas unavailable
▼
┌────────────────────────────┐
│ Tier 3: WMMA Tensor Core │
│ kernel (matmul_wmma.cu) │
│ — only if RLX_CUDA_WMMA=1 │
│ + SM 70+ NVRTC compile OK │
└────────────┬───────────────┘
│ env not set / SM<70
▼
┌────────────────────────────┐
│ Tier 4: scalar SGEMM │
│ 64×64 block + 4×4 reg tile │
│ float4 vec loads when │
│ K%4==0 && N%4==0 │
└────────────────────────────┘
Concrete examples
| Shape | Bias | Act | Half-arena? | Tier picked |
|---|---|---|---|---|
| 1024×4096×4096 | yes | gelu | yes (f16) | 0 mixed-precision GemmEx + epilogue |
| 1024×4096×4096 | yes | gelu | no | 1 cublasLt fused |
| 1024×4096×4096 | yes | silu | no | 2 sgemm + epilogue (silu not in cublasLt) |
| 1×3×2 (test) | no | — | no | 2 sgemm (cuBLAS handles tiny shapes fine) |
| any | any | any | no, no driver | 4 scalar fallback |
Conv dispatch
Step::Conv1d / Conv2d / Conv3d are simpler: cuDNN if libcudnn
loaded → custom direct-conv otherwise. Conv1d uses the conv2d helper
with H=kh=sh=1, ph=0, dh=1 (degenerate 2-D); Conv3d uses cuDNN's
nd-descriptor APIs.
Compile + execution modes
CudaExecutable::compile_with(graph, compile_mode, exec_mode) selects:
-
CompileMode::Jit(default) — kernels NVRTC-compile on first dispatch, then live in the cuModule cache for the rest of the process. Firstrun()pays the JIT cost (~10-100ms × 32 kernels). -
CompileMode::Aot— pre-compile every kernel (32 of them) at executable construction. Moves JIT cost out of the critical path at the cost of ~1-3s upfront. Good for inference servers that build the executable once and run forever. -
Persistent PTX disk cache. All NVRTC compiles cache their PTX to
$RLX_CUDA_PTX_CACHE(or$XDG_CACHE_HOME/rlx-cuda/~/.cache/rlx-cuda), namespaced by the cuda toolkit version. Cache key is<entry>-<fnv1a64(source)>.ptx; FNV-1a is just for filename uniqueness — a stale cache hit is impossible because mismatched source recompiles. Atomic via tmp + rename. Across-process cold-start drops from ~1-3s to ~50ms after first run. -
TF32 fast math in cublasLt. Compute type is
CUBLAS_COMPUTE_32F_FAST_TF32for f32 matmul — uses Tensor Cores on Ampere+ for ~2× speedup with a 10-bit-mantissa intermediate. Matches whatcublasSgemmdoes by default (since CUDA 11) and is well within transformer-inference precision tolerance. -
NVTX profiling ranges. Each
Stepdispatch is wrapped in annvtx::scoped_rangenamedrlx::<StepKind>. Negligible overhead when no profiler is attached; nsight-systems / nvprof traces show step boundaries cleanly so devs can see where time goes. -
Backend-level element-wise fusion.
fuse_elementwise_chainsruns after the schedule is built and merges adjacentBinary → Unarypairs into a singleFusedBinaryUnarystep when the intermediate offset has exactly one consumer in the schedule. -
Half-precision params side-buffer + mixed-precision matmul.
Arena.half_bufferis an optionalCudaSlice<u16>(raw bits —f16orbf16per-node tag viaHalfDtype) for storing weights. Activations stay f32 in the mainbuffer. UseCudaExecutable::set_param_half(name, dtype, &[u16])to upload weights in half-precision instead ofset_param. The matmul dispatch detects half-stored weights viaArena.half_by_f32_offand:- Casts the f32 activations to f16/bf16 into a scratch buffer
(
cast_f32_to_half.cukernel). - Calls
cublasGemmExwith both inputs f16/bf16, compute typeCUBLAS_COMPUTE_32F_FAST_16F/CUBLAS_COMPUTE_32F_FAST_16BF, and a f32 accumulator that writes back to the main arena. - Optional bias / activation epilogue runs as a separate
matmul_epilogue.cupass after.
- Casts the f32 activations to f16/bf16 into a scratch buffer
(
-
ExecMode::Stream(default) — everyrun()dispatches each step on the default stream. -
ExecMode::Graph— firstrun()captures the schedule into a CUDA Graph; subsequent runs replay the captured graph. Saves per-launch dispatch overhead (~10-20% on small-batch decode). -
ExecMode::Eager—CudaExecutable::eager(graph, inputs)one-shot helper that compiles + runs + drops in one call. -
ExecMode::MultiStream(n)— allocate a pool ofnstreams and assign eachStepbased on producer-consumer relations on arena offsets (computed bystep_offsets). Independent ops run in parallel; cross-stream sync is via CUDA events at fork/join points. Incompatible withExecMode::Graph.
Build / test
Status
Functional; less battle-tested than the Apple Silicon path. The kernel
sources are shared with rlx-rocm (sister crate) so coverage moves in
lock-step.
Dev: HIP-CPU validation path
--features hip-cpu-validate is an opt-in dev feature that lets us
run the same .cu kernel sources on CPU threads via HIP-CPU.
Useful for catching kernel-logic and IR-lowering bugs on Mac (or any
host without an NVIDIA driver) before paying for cloud-GPU time.
Off by default. Never enabled in production builds.
Workflow
# One-time: pull HIP-CPU as a submodule.
# Compile + test the CPU-execution path.
# In Docker (any architecture, no GPU needed):
Architecture
┌─── shared sources: src/kernels/*.cu ───┐
│ │
cudarc + libcuda HIP-CPU + cc::Build
│ │
NVIDIA GPU dispatch CPU thread dispatch
(production: rlx-cuda) (dev: hip-cpu-validate)
build.rs compiles cpp/cpu_dispatch.cpp against HIP-CPU headers when
the feature is on. The TU #includes each .cu file directly and
exposes one extern "C" launch_<kernel> wrapper per kernel using
hipLaunchKernelGGL. Rust calls those via FFI in src/cpu_dispatch.rs.
Coverage
All 32 kernel entry points are wired end-to-end (= 30 .cu files +
matmul/scatter_add contributing extras). Each one has:
#include "<kernel>.cu"incpp/cpu_dispatch.cppplus aextern "C" launch_<kernel>(...)wrapper that callshipLaunchKernelGGLwith the kernel's argument tuple.- The matching
extern "C"declaration + safe Rust wrapper insrc/cpu_dispatch.rs(onerun_<kernel>(...)fn per family). - A unit test under
tests/hip_cpu_validate.rsthat exercises the FFI dispatch on a tiny representative shape.
Caveats
- HIP-CPU is CPU emulation of CUDA semantics, not a full
reimplementation.
__shared__works,__syncthreads()is a barrier, atomics usestd::atomic. We avoid__shfl_*warp-level primitives because HIP-CPU's wavefront size differs from CUDA's 32-thread warp. - Translation differences between NVCC and clang (sign extension, FMA fusion ordering, intrinsic lowering) won't surface here. Real CUDA validation requires a real CUDA box.
- HIP-CPU's perf is wildly slower than a real GPU (~1000×). Don't bench against it; only use it for correctness.
Gotchas
-
dynamic-loadingpanics on missing libcuda. Even callingcudarc::driver::CudaContext::new(0)panics rather than returning anErrwhen libcuda can't bedlopen'd. We wrap the first call inpanic::catch_unwindsois_available()returns false cleanly. -
FlashAttention-1 KV blocking.
attention.cuis a one-block-per -(batch, head, q-tile) kernel. BR=16 query rows × BC=32 KV-tile, 128 threads/block. K and V tiles are loaded into shared memory once per tile and reused for both QK and PV passes. Online softmax across KV tiles maintains row_max/row_sum and rescales the running V accumulator on every tile. Static head_dim cap of 128 (covers Llama 70B); larger head_dim early-returns. -
cuDNN conv dispatch.
Conv1d/Conv2d/Conv3dall route through cuDNN's v7 heuristic-picked forward conv when libcudnn is available. Workspace is a 32 MiB scratch buffer per executable. -
Grouped matmul (MoE) sorted-batch path.
Step::GroupedMatmuldownloads the expert-id buffer to host, detects runs of identical consecutive ids, and issues onecublasSgemmper run when the run count is ≤ m/4. Falls back to the per-token expert-lookup kernel for random idx, where the cuBLAS launch overhead would dominate. -
Kernels JIT-compile on first dispatch. First
run()per kernel pays an NVRTC compile (~10-100ms each); subsequent calls reuse the cachedcuModule. Pre-warming all kernels at compile time would amortize this, but it'd hit the cold path during compile rather than first-run. -
Native ElementwiseRegion (PLAN L2).
Op::ElementwiseRegionis lowered by an NVRTC interpreted-chain kernel (kernels/elementwise_region.cu). One thread per output element walks a runtime chain encoding (4 u32s per step:op_kind/op_sub/lhs_enc/rhs_enc) into a privatefloat scratch[16]register array and writes the last step's result toarena[dst_off + i]. Operand bit 31 picks the source (0=Input →arena[input_offs[idx]+i], 1=Step →scratch[idx]). Caps: 16 chain steps, 8 inputs — same as the Metal MSL / wgpu WGSL kernels so the encoder inrlx-optproduces one byte stream all three backends interpret identically.
License
GPL-3.0-only.