Please check the build logs for more information.
See Builds for ideas on how to fix a failed build, or Metadata for how to configure docs.rs builds.
If you believe this is docs.rs' fault, open an issue.
kaio-candle
Candle bridge for KAIO — CustomOp bindings that let you call KAIO's tensor-core GPU kernels directly on candle_core::Tensor.
Ships twelve ops: matmul_tc, matmul_tc_bf16, matmul_tc_async, matmul_tc_bf16_async, matmul_int4, matmul_int8, attention_tc, attention_tc_causal, attention_flash, attention_flash_causal, qkv_project_int8, qkv_project_int4. All four matmul TC variants (f16 + bf16, sync + async) support backward (autograd) via the forward-reuse pattern — no new PTX in either precision. FlashAttention (attention_flash + attention_flash_causal) supports backward through dedicated backward PTX kernels, preserving the no-O(seq²)-memory profile through the backward pass. attention_tc and quantized ops are forward-only.
Status — v0.2.0
v0.2.0 adds the bf16 matmul family (matmul_tc_bf16, matmul_tc_bf16_async, forward + backward) and FlashAttention backward (plain + causal) to the v0.1 surface.
All ops are bit-exact verified against direct kaio-ops calls with the same input bits.
Why a separate crate?
kaio-candle is not a member of the main KAIO workspace. cudarc rejects dynamic-loading + dynamic-linking as simultaneously active features:
- Main KAIO defaults to
dynamic-loading— no CUDA toolkit required to build. Host tests pass on bare GitHub runners. - candle-core with its
cudafeature activatesdynamic-linking— it links againstlibcudaat compile time.
Cargo unions features across a workspace build, so including kaio-candle in the main workspace would force every main-workspace build to also carry candle's dynamic-linking, breaking no-CUDA CI. The standalone crate keeps the two worlds apart.
Consumers who already build candle with the cuda feature see no new system requirement beyond what candle itself needs.
Build
The cuda feature is required for any actual bridge functionality. Without it, kaio-candle is an empty shell (matches candle-core's own opt-in cuda pattern) — attempting to call kaio_candle::matmul_tc(...) surfaces a "function not found" compile error pointing at the missing feature.
Build requirements with cuda:
- CUDA toolkit (candle-core's cudarc feature uses
dynamic-linking). - NVIDIA GPU with SM 8.0 or newer (Ampere, Ada, Hopper).
Quickstart
# Cargo.toml
[]
= { = "0.2", = ["cuda"] }
= "0.5"
= { = "0.10", = ["cuda"] }
= "2"
use Arc;
use ;
use f16;
use KaioDevice;
Three runnable examples ship in examples/:
Op surface
| Op | Trait | Shapes | Dtype |
|---|---|---|---|
matmul_tc(kd, a, b) |
CustomOp2 |
a: [M, K], b: [K, N] → [M, N] |
f16 × f16 → f32 |
matmul_tc_async(kd, a, b) |
CustomOp2 |
same | f16 × f16 → f32 |
matmul_tc_bf16(kd, a, b) |
CustomOp2 |
same | bf16 × bf16 → f32 |
matmul_tc_bf16_async(kd, a, b) |
CustomOp2 |
same | bf16 × bf16 → f32 |
matmul_int4(kd, a, b_packed, scales) |
CustomOp3 |
a: [M, K], b_packed: [K/8, N], scales: [K/128, N] → [M, N] |
f16 × u32 × f16 → f32 |
matmul_int8(kd, a, b, scale) |
CustomOp2 |
a: [M, K], b: [K, N] → [M, N] |
u8-as-i8 × u8-as-i8 → f32 (× f32 scale) |
attention_tc(kd, q, k, v) |
CustomOp3 |
q: [seq_q, d_k], k: [seq_k, d_k], v: [seq_k, d_v] → [seq_q, d_v] |
f16 × f16 × f16 → f32 |
attention_tc_causal(kd, q, k, v) |
CustomOp3 |
same | f16 × f16 × f16 → f32 |
attention_flash(kd, q, k, v) |
CustomOp3 |
q, k, v all [seq_len, d_k] → [seq_len, d_k] |
f32 × f32 × f32 → f32 |
attention_flash_causal(kd, q, k, v) |
CustomOp3 |
same | f32 × f32 × f32 → f32 |
qkv_project_int8(kd, x, wq, wk, wv, sq, sk, sv) |
Direct-call | x: [M, K], wq/wk/wv: [K, N] → (Q, K, V) each [M, N] |
f16 × u8-as-i8 → f16 |
qkv_project_int4(kd, x, wq, wk, wv, sq, sk, sv) |
Direct-call | x: [M, K], wq/wk/wv: [K/8, N], sq/sk/sv: [K/128, N] → (Q, K, V) each [M, N] |
f16 × u32 × f16 → f16 |
matmul_int4 is GPTQ-style: group_size=128 is locked in by the kaio-ops kernel contract. K must be a multiple of 128, weights are packed 8 INT4 values per u32, one f16 scale per group of 128 elements.
matmul_int8 is W8A8 symmetric quant. Candle has no DType::I8, so the convention is DType::U8 tensors whose bytes are interpreted as signed INT8 (-128..=127) by the kernel. The bridge reinterprets the storage via a same-layout transmute. scale is a scalar f32 applied in the accumulator; a typical realistic value is max_abs / 127.
attention_tc uses a shared-memory scores buffer capped at seq_k ≤ 384. attention_flash has no seq cap (online softmax — no materialized score matrix) but is strictly single-head self-attention: Q, K, V must all be [seq_len, d_k] with d_k ≤ 256; cross-attention shapes are rejected with a pointer back to attention_tc.
qkv_project_int8 and qkv_project_int4 are direct-call functions (not CustomOpN — candle's trait maxes at 3 inputs and single output). They return (Tensor, Tensor, Tensor) with DType::F16 output because the fused kernel performs the f32→f16 conversion internally as part of the projection fusion. Gradient-tracked inputs are rejected with a loud error requiring .detach() — these ops are forward-only.
Backward support
| Op | Backward | Notes |
|---|---|---|
matmul_tc |
Supported | dA = grad @ B^T, dB = A^T @ grad via forward kernel |
matmul_tc_async |
Supported | Same, uses cp.async variant in both directions |
matmul_tc_bf16 |
Supported | Same forward-reuse pattern in bf16 |
matmul_tc_bf16_async |
Supported | Same, cp.async variant in bf16 |
attention_flash / attention_flash_causal |
Supported | Dedicated backward PTX kernels (D-preprocess + dK/dV + dQ) rebuilding the softmax from a per-row logsumexp; f32 end-to-end, no dtype casts |
attention_tc / attention_tc_causal |
No | Short-sequence inference op — training users route to attention_flash, which has backward and no seq_k cap |
matmul_int4 / matmul_int8 |
No | Quantized inference ops — frozen weights, no backprop in practice |
qkv_project_int8 / qkv_project_int4 |
No | Direct-call ops, inference-only by design |
Numerically approximate (matmul TC backward): The backward implementation downcasts the f32 upstream gradient to f16 (or bf16) to reuse the existing tensor-core forward kernels, and casts the output gradients back to the input dtype to satisfy candle's dtype-matching constraint. This is an initial autograd integration proving the bwd() bridge pattern, not a final mixed-precision training stack. The FlashAttention backward has no such round-trip — it is f32 end-to-end.
Memory (matmul TC backward): The backward pass materializes transposed tensors in VRAM (.t()?.contiguous() = allocation + copy). Peak backward memory is approximately 2–3x the forward input size. Designed for integration testing and light training, not high-throughput training loops where allocator overhead matters.
Recompute (FlashAttention backward): candle's CustomOp3 has no fwd→bwd saved-intermediate channel, so each backward call re-runs the stats-saving forward once to recover the per-row logsumexp before launching the backward kernels (the forward is deterministic; recomputed stats are bit-identical to saved ones). Direct kaio_ops::attention_flash_with_stats + attention_flash_bwd callers can hold the stats buffer themselves and skip the recompute. Measured cost of both tiers is in docs/performance.md.
Device lifetime
The Arc<kaio::prelude::KaioDevice> you construct and pass to kaio-candle wrapper functions is independent of the candle_core::Device you use for your tensors. Both retain the same CUDA primary context via cuDevicePrimaryCtxRetain; neither owns the other. Drop order between them is unconstrained.
Every wrapper call checks that the KAIO device and candle device share the same CUDA ordinal; a mismatch is a loud error.
Candle version policy
kaio-candle = 0.2 pins candle-core = "=0.10.2" exactly (unchanged from the 0.1.x line — candle 0.10.2 remains the current release). This is deliberate:
- The
CustomOp2/CustomOp3surface has changed between candle minor versions in the past. - cudarc feature conventions change with candle releases.
We re-pin kaio-candle against each new candle minor release. Use kaio-candle 0.1.x–0.2.x with candle-core 0.10.x; the next kaio-candle minor after a candle minor bump will target whichever candle release is current when it publishes.
A weekly GitHub Actions workflow (.github/workflows/candle-head.yml) builds kaio-candle against candle-core's git main branch once per Monday. If this badge goes red for more than two consecutive weeks, either the pin moves to the new candle minor or this section documents the divergence.
Known limitations (v0.2)
- Non-contiguous tensors rejected. Call
.contiguous()?upstream. - Non-zero storage offset rejected (e.g. from
.narrow(...)/.slice(...)). Call.contiguous()?to compact. - Rank-2 only. Multi-head attention callers must reshape
[heads, seq, d]to[heads * seq, d]or call per-head with rank-2 slices. Wrappers error with a concrete reshape hint for higher-rank inputs. - CUDA Graph capture partially unblocked. Event-based sync (Sprint 7.4c) removes the prior
cuCtxSynchronizeblocker. However, full CUDA Graph capture requires non-default streams on both the candle and KAIO sides, which is not yet verified. - f32 output (CustomOp ops) / f16 output (direct-call ops).
matmul_tc,matmul_tc_bf16(and the async siblings),matmul_int4,matmul_int8,attention_tc, andattention_flashreturnf32matching the kaio-ops accumulator.qkv_project_int{4,8}returnf16because the fused kernel converts internally. - No CPU fallback.
cpu_fwdreturns a loud error rather than silently routing tocandle.matmul(). KAIO's value is GPU-specific PTX; a silent CPU fallback would mask every perf claim. - Bench numbers vs direct-call gap. Each bridge call issues event-based stream sync (two
join()calls per op). This replaced the heaviercuCtxSynchronizefencing used during early bridge development but still allocates a transientCudaEventper call. KAIO's published %-of-cuBLAS numbers are measured via direct kaio-ops calls, not through the bridge.
License
Dual-licensed under MIT or Apache-2.0, at your option.