kaio-candle 0.2.0

Candle bridge for KAIO — CustomOp bindings for 12 GPU ops (matmul_tc, matmul_tc_bf16, matmul_tc_async, matmul_tc_bf16_async, matmul_int4, matmul_int8, attention_tc, attention_tc_causal, attention_flash, attention_flash_causal, qkv_project_int8, qkv_project_int4). All four matmul TC variants (f16+bf16, sync+async) support backward (autograd); FlashAttention (plain+causal) supports backward via dedicated PTX kernels. Build with `cargo build --features cuda`.
//! # kaio-candle — Candle bridge for KAIO
//!
//! `CustomOp2` / `CustomOp3` bindings between
//! [candle](https://github.com/huggingface/candle) and the
//! [KAIO](https://github.com/dmriding/kaio) GPU kernel library.
//!
//! ## Status — v0.2.0
//!
//! Bridges 12 ops across two patterns:
//!
//! **CustomOp-based** (single-output, return `f32`):
//! - `matmul_tc` — f16 × f16 → f32 matmul via KAIO tensor-core kernel. **Backward supported.**
//! - `matmul_tc_bf16` — bf16 × bf16 → f32 matmul via KAIO tensor-core kernel. **Backward supported.** _(Sprint 9.1.3 fwd, 9.1.4 bwd)_
//! - `matmul_tc_async` — same as `matmul_tc`, `cp.async` variant (92.5% cuBLAS sgemm at 4096² on sm_89). **Backward supported.**
//! - `matmul_tc_bf16_async` — bf16 × bf16 → f32 matmul, `cp.async` variant. **Backward supported.** _(Sprint 9.1.3 fwd, 9.1.4 bwd)_
//! - `matmul_int4` — GPTQ-style INT4 dequantize-matmul with f16 group scales. Forward-only.
//! - `matmul_int8` — W8A8 symmetric-quant matmul with scalar f32 scale (80–94 TOPS at 4096³ on sm_89). Forward-only.
//! - `attention_tc` — fused tensor-core scaled-dot-product attention. Forward-only.
//! - `attention_tc_causal` — same, with decoder causal mask. Forward-only.
//! - `attention_flash` — FlashAttention, f32 single-head self-attention, no O(seq²) memory and no seq cap. **Backward supported.** _(Sprint 9.2, dedicated backward PTX kernels)_
//! - `attention_flash_causal` — same, with decoder causal mask. **Backward supported.** _(Sprint 9.2)_
//!
//! **Direct-call** (multi-output, return `f16`, forward-only):
//! - `qkv_project_int8` — W8A16 fused tri-output QKV projection (4 inputs → 3 f16 outputs).
//! - `qkv_project_int4` — W4A16 fused tri-output QKV projection (7 inputs → 3 f16 outputs).
//!
//! CustomOp-based ops return `f32` matching the kaio-ops accumulator.
//! Direct-call ops return `f16` because the fused kernel performs the
//! `f32→f16` conversion internally as part of the projection fusion.
//!
//! ## Backward support
//!
//! Two backward families ship:
//!
//! **Forward-reuse (matmul TC):** `matmul_tc`, `matmul_tc_bf16`,
//! `matmul_tc_async`, and `matmul_tc_bf16_async` all implement
//! `CustomOp2::bwd()`. The backward pass computes `dA = grad @ B^T`
//! and `dB = A^T @ grad` by reusing the same forward kernel — no new
//! PTX in either precision. The f16 backward shipped in Sprint 7.4d;
//! the bf16 backward shipped in Sprint 9.1.4.
//!
//! **Dedicated backward kernels (FlashAttention):** `attention_flash`
//! and `attention_flash_causal` implement `CustomOp3::bwd()` backed by
//! three purpose-built PTX kernels in kaio-ops (a D-term preprocess,
//! a dK/dV kernel, and a dQ kernel) — attention backward has no
//! forward-reuse identity. The kernels rebuild the softmax from a
//! per-row logsumexp rather than materializing the O(seq²) probability
//! matrix, preserving FlashAttention's memory profile through the
//! backward pass. Because candle's `CustomOp3` has no fwd→bwd
//! saved-intermediate channel, each backward call re-runs the
//! stats-saving forward once to recover the logsumexp (the forward is
//! deterministic, so the recomputed stats are bit-identical); direct
//! kaio-ops callers can keep the stats buffer and skip that cost.
//! Gradients are f32 end-to-end — no dtype casts.
//!
//! **Numerically approximate (f16):** the f32 upstream gradient is
//! downcast to f16 before the tensor-core matmul, and output gradients
//! are cast back to f16 to satisfy candle's dtype-matching constraint.
//! Initial autograd integration; not a final mixed-precision training
//! stack.
//!
//! **Numerically approximate (bf16):** same shape of cast (f32 → bf16
//! → matmul → f32 → bf16). bf16's 8-bit exponent gives values
//! representable at scales where f16 would overflow or underflow;
//! bf16's 7-bit mantissa is lower precision than f16's 10-bit, so
//! per-element quantization noise from the round-trip is higher in
//! absolute terms. The dual-tolerance gradient check
//! (`rel < 1e-2 || abs < 1e-3`, identical to f16) covers both
//! precisions in the shapes tested in
//! `kaio-candle/tests/candle_gpu_roundtrip.rs`; larger shapes or
//! different magnitude regimes may require recalibration.
//!
//! Remaining ops are forward-only: `attention_tc` / `attention_tc_causal`
//! are short-sequence inference ops (training users route to
//! `attention_flash`, which has backward and no `seq_k` cap); quantized
//! ops are inference-only by design.
//!
//! ## Build requirements
//!
//! This crate MUST be built with the `cuda` feature enabled:
//!
//! ```toml
//! [dependencies]
//! kaio-candle = { version = "0.2", features = ["cuda"] }
//! ```
//!
//! The default (feature-less) build produces an empty shell — attempting
//! to call a bridge function like `kaio_candle::matmul_tc(...)` surfaces
//! a "function not found" compile error pointing at the missing feature.
//! This matches candle-core's own opt-in model and keeps `cargo doc` /
//! no-CUDA CI legs working without the toolkit.
//!
//! The `cuda` feature requires the CUDA toolkit at build time (candle-core's
//! cudarc feature uses `dynamic-linking`). Downstream consumers who already
//! build candle with the `cuda` feature see no new system requirement.
//!
//! ## Standalone-crate rationale
//!
//! `kaio-candle` is NOT a member of the main KAIO workspace. cudarc rejects
//! `dynamic-loading` + `dynamic-linking` as simultaneously active; the main
//! KAIO workspace defaults to `dynamic-loading` (no CUDA toolkit required to
//! build), candle requires `dynamic-linking`. Cargo unions features across a
//! workspace build, so including `kaio-candle` in the main workspace would
//! break every no-CUDA CI runner. Standalone keeps the two worlds apart.
//!
//! See `kaio-candle/README.md` for the full rationale.
//!
//! ## Device lifetime
//!
//! The [`std::sync::Arc<kaio::prelude::KaioDevice>`] you construct and pass
//! to `kaio-candle` wrapper functions is independent of the
//! `candle_core::Device` you use for your tensors. Both retain the same CUDA
//! primary context via `cuDevicePrimaryCtxRetain`; neither owns the other.
//! Drop order between them is unconstrained.
//!
//! ## Known limitations (v0.2)
//!
//! - **Non-contiguous tensors rejected.** Call `.contiguous()?` upstream.
//! - **Non-zero storage offset rejected** (e.g. from `.narrow(...)` / `.slice(...)`).
//!   Call `.contiguous()?` to compact.
//! - **Rank-2 only.** Multi-head attention reshape to rank-2 before calling;
//!   wrappers error with a concrete reshape hint for higher-rank inputs.
//! - **CUDA Graph capture partially unblocked.** Event-based sync (Sprint
//!   7.4c) removes the prior `cuCtxSynchronize` blocker. However, full CUDA
//!   Graph capture requires non-default streams on both the candle and KAIO
//!   sides, which is not yet verified. Default-stream users should not
//!   attempt graph capture.
//! - **f32 output contract (CustomOp ops).** `matmul_tc`, `matmul_int4`,
//!   `matmul_int8`, `attention_tc`, and `attention_flash` return
//!   `DType::F32` matching the kaio-ops accumulator. Direct-call ops
//!   (`qkv_project_int{4,8}`) return `DType::F16` because the fused
//!   kernel converts internally.
//! - **Bench numbers vs direct-call gap.** Each bridge call issues event-
//!   based stream sync (two `join()` calls — `cuEventRecord` +
//!   `cuStreamWaitEvent` per sync point). This replaced the heavier
//!   `cuCtxSynchronize` fencing used during early bridge development
//!   but still allocates a transient `CudaEvent` per call. KAIO's published %-of-cuBLAS numbers are
//!   measured via direct kaio-ops calls, not through the bridge.

#![warn(missing_docs)]

// Without the `cuda` feature, kaio-candle is an empty shell. This matches
// candle-core's own opt-in `cuda` model: consumers who forget the feature
// get a clear "function not found" when they try to call into the bridge
// (e.g. `kaio_candle::matmul_tc(...)`), rather than a lib-level
// `compile_error!` that breaks `cargo check` / `cargo doc` on no-CUDA CI
// legs. The supported no-CUDA commands are `cargo check
// --no-default-features` and `cargo doc --no-deps --no-default-features`,
// exercised by the CI `candle-no-cuda` leg — they must succeed on a
// no-CUDA-toolkit host. `cargo test` is NOT in that set: Cargo cannot
// feature-gate dev-dependencies, so the GPU tests' unconditional cudarc
// dev-dependency probes the CUDA toolkit at build time even with default
// features off.
#[cfg(feature = "cuda")]
mod bridge;

#[cfg(feature = "cuda")]
mod matmul_tc;
#[cfg(feature = "cuda")]
pub use matmul_tc::{MatmulTcOp, matmul_tc};

#[cfg(feature = "cuda")]
mod matmul_tc_bf16;
#[cfg(feature = "cuda")]
pub use matmul_tc_bf16::{MatmulTcBf16Op, matmul_tc_bf16};

#[cfg(feature = "cuda")]
mod matmul_tc_async;
#[cfg(feature = "cuda")]
pub use matmul_tc_async::{MatmulTcAsyncOp, matmul_tc_async};

#[cfg(feature = "cuda")]
mod matmul_tc_bf16_async;
#[cfg(feature = "cuda")]
pub use matmul_tc_bf16_async::{MatmulTcBf16AsyncOp, matmul_tc_bf16_async};

#[cfg(feature = "cuda")]
mod matmul_int4;
#[cfg(feature = "cuda")]
pub use matmul_int4::{MatmulInt4Op, matmul_int4};

#[cfg(feature = "cuda")]
mod attention_tc;
#[cfg(feature = "cuda")]
pub use attention_tc::{AttentionTcOp, attention_tc, attention_tc_causal};

#[cfg(feature = "cuda")]
mod attention_flash;
#[cfg(feature = "cuda")]
pub use attention_flash::{AttentionFlashOp, attention_flash, attention_flash_causal};

#[cfg(feature = "cuda")]
mod matmul_int8;
#[cfg(feature = "cuda")]
pub use matmul_int8::{MatmulInt8Op, matmul_int8};

#[cfg(feature = "cuda")]
mod qkv_project_int8;
#[cfg(feature = "cuda")]
pub use qkv_project_int8::qkv_project_int8;

#[cfg(feature = "cuda")]
mod qkv_project_int4;
#[cfg(feature = "cuda")]
pub use qkv_project_int4::qkv_project_int4;