1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
//! # kaio-candle — Candle bridge for KAIO
//!
//! `CustomOp2` / `CustomOp3` bindings between
//! [candle](https://github.com/huggingface/candle) and the
//! [KAIO](https://github.com/dmriding/kaio) GPU kernel library.
//!
//! ## Status — v0.2.0
//!
//! Bridges 12 ops across two patterns:
//!
//! **CustomOp-based** (single-output, return `f32`):
//! - `matmul_tc` — f16 × f16 → f32 matmul via KAIO tensor-core kernel. **Backward supported.**
//! - `matmul_tc_bf16` — bf16 × bf16 → f32 matmul via KAIO tensor-core kernel. **Backward supported.** _(Sprint 9.1.3 fwd, 9.1.4 bwd)_
//! - `matmul_tc_async` — same as `matmul_tc`, `cp.async` variant (92.5% cuBLAS sgemm at 4096² on sm_89). **Backward supported.**
//! - `matmul_tc_bf16_async` — bf16 × bf16 → f32 matmul, `cp.async` variant. **Backward supported.** _(Sprint 9.1.3 fwd, 9.1.4 bwd)_
//! - `matmul_int4` — GPTQ-style INT4 dequantize-matmul with f16 group scales. Forward-only.
//! - `matmul_int8` — W8A8 symmetric-quant matmul with scalar f32 scale (80–94 TOPS at 4096³ on sm_89). Forward-only.
//! - `attention_tc` — fused tensor-core scaled-dot-product attention. Forward-only.
//! - `attention_tc_causal` — same, with decoder causal mask. Forward-only.
//! - `attention_flash` — FlashAttention, f32 single-head self-attention, no O(seq²) memory and no seq cap. **Backward supported.** _(Sprint 9.2, dedicated backward PTX kernels)_
//! - `attention_flash_causal` — same, with decoder causal mask. **Backward supported.** _(Sprint 9.2)_
//!
//! **Direct-call** (multi-output, return `f16`, forward-only):
//! - `qkv_project_int8` — W8A16 fused tri-output QKV projection (4 inputs → 3 f16 outputs).
//! - `qkv_project_int4` — W4A16 fused tri-output QKV projection (7 inputs → 3 f16 outputs).
//!
//! CustomOp-based ops return `f32` matching the kaio-ops accumulator.
//! Direct-call ops return `f16` because the fused kernel performs the
//! `f32→f16` conversion internally as part of the projection fusion.
//!
//! ## Backward support
//!
//! Two backward families ship:
//!
//! **Forward-reuse (matmul TC):** `matmul_tc`, `matmul_tc_bf16`,
//! `matmul_tc_async`, and `matmul_tc_bf16_async` all implement
//! `CustomOp2::bwd()`. The backward pass computes `dA = grad @ B^T`
//! and `dB = A^T @ grad` by reusing the same forward kernel — no new
//! PTX in either precision. The f16 backward shipped in Sprint 7.4d;
//! the bf16 backward shipped in Sprint 9.1.4.
//!
//! **Dedicated backward kernels (FlashAttention):** `attention_flash`
//! and `attention_flash_causal` implement `CustomOp3::bwd()` backed by
//! three purpose-built PTX kernels in kaio-ops (a D-term preprocess,
//! a dK/dV kernel, and a dQ kernel) — attention backward has no
//! forward-reuse identity. The kernels rebuild the softmax from a
//! per-row logsumexp rather than materializing the O(seq²) probability
//! matrix, preserving FlashAttention's memory profile through the
//! backward pass. Because candle's `CustomOp3` has no fwd→bwd
//! saved-intermediate channel, each backward call re-runs the
//! stats-saving forward once to recover the logsumexp (the forward is
//! deterministic, so the recomputed stats are bit-identical); direct
//! kaio-ops callers can keep the stats buffer and skip that cost.
//! Gradients are f32 end-to-end — no dtype casts.
//!
//! **Numerically approximate (f16):** the f32 upstream gradient is
//! downcast to f16 before the tensor-core matmul, and output gradients
//! are cast back to f16 to satisfy candle's dtype-matching constraint.
//! Initial autograd integration; not a final mixed-precision training
//! stack.
//!
//! **Numerically approximate (bf16):** same shape of cast (f32 → bf16
//! → matmul → f32 → bf16). bf16's 8-bit exponent gives values
//! representable at scales where f16 would overflow or underflow;
//! bf16's 7-bit mantissa is lower precision than f16's 10-bit, so
//! per-element quantization noise from the round-trip is higher in
//! absolute terms. The dual-tolerance gradient check
//! (`rel < 1e-2 || abs < 1e-3`, identical to f16) covers both
//! precisions in the shapes tested in
//! `kaio-candle/tests/candle_gpu_roundtrip.rs`; larger shapes or
//! different magnitude regimes may require recalibration.
//!
//! Remaining ops are forward-only: `attention_tc` / `attention_tc_causal`
//! are short-sequence inference ops (training users route to
//! `attention_flash`, which has backward and no `seq_k` cap); quantized
//! ops are inference-only by design.
//!
//! ## Build requirements
//!
//! This crate MUST be built with the `cuda` feature enabled:
//!
//! ```toml
//! [dependencies]
//! kaio-candle = { version = "0.2", features = ["cuda"] }
//! ```
//!
//! The default (feature-less) build produces an empty shell — attempting
//! to call a bridge function like `kaio_candle::matmul_tc(...)` surfaces
//! a "function not found" compile error pointing at the missing feature.
//! This matches candle-core's own opt-in model and keeps `cargo doc` /
//! no-CUDA CI legs working without the toolkit.
//!
//! The `cuda` feature requires the CUDA toolkit at build time (candle-core's
//! cudarc feature uses `dynamic-linking`). Downstream consumers who already
//! build candle with the `cuda` feature see no new system requirement.
//!
//! ## Standalone-crate rationale
//!
//! `kaio-candle` is NOT a member of the main KAIO workspace. cudarc rejects
//! `dynamic-loading` + `dynamic-linking` as simultaneously active; the main
//! KAIO workspace defaults to `dynamic-loading` (no CUDA toolkit required to
//! build), candle requires `dynamic-linking`. Cargo unions features across a
//! workspace build, so including `kaio-candle` in the main workspace would
//! break every no-CUDA CI runner. Standalone keeps the two worlds apart.
//!
//! See `kaio-candle/README.md` for the full rationale.
//!
//! ## Device lifetime
//!
//! The [`std::sync::Arc<kaio::prelude::KaioDevice>`] you construct and pass
//! to `kaio-candle` wrapper functions is independent of the
//! `candle_core::Device` you use for your tensors. Both retain the same CUDA
//! primary context via `cuDevicePrimaryCtxRetain`; neither owns the other.
//! Drop order between them is unconstrained.
//!
//! ## Known limitations (v0.2)
//!
//! - **Non-contiguous tensors rejected.** Call `.contiguous()?` upstream.
//! - **Non-zero storage offset rejected** (e.g. from `.narrow(...)` / `.slice(...)`).
//! Call `.contiguous()?` to compact.
//! - **Rank-2 only.** Multi-head attention reshape to rank-2 before calling;
//! wrappers error with a concrete reshape hint for higher-rank inputs.
//! - **CUDA Graph capture partially unblocked.** Event-based sync (Sprint
//! 7.4c) removes the prior `cuCtxSynchronize` blocker. However, full CUDA
//! Graph capture requires non-default streams on both the candle and KAIO
//! sides, which is not yet verified. Default-stream users should not
//! attempt graph capture.
//! - **f32 output contract (CustomOp ops).** `matmul_tc`, `matmul_int4`,
//! `matmul_int8`, `attention_tc`, and `attention_flash` return
//! `DType::F32` matching the kaio-ops accumulator. Direct-call ops
//! (`qkv_project_int{4,8}`) return `DType::F16` because the fused
//! kernel converts internally.
//! - **Bench numbers vs direct-call gap.** Each bridge call issues event-
//! based stream sync (two `join()` calls — `cuEventRecord` +
//! `cuStreamWaitEvent` per sync point). This replaced the heavier
//! `cuCtxSynchronize` fencing used during early bridge development
//! but still allocates a transient `CudaEvent` per call. KAIO's published %-of-cuBLAS numbers are
//! measured via direct kaio-ops calls, not through the bridge.
// Without the `cuda` feature, kaio-candle is an empty shell. This matches
// candle-core's own opt-in `cuda` model: consumers who forget the feature
// get a clear "function not found" when they try to call into the bridge
// (e.g. `kaio_candle::matmul_tc(...)`), rather than a lib-level
// `compile_error!` that breaks `cargo check` / `cargo doc` on no-CUDA CI
// legs. The supported no-CUDA commands are `cargo check
// --no-default-features` and `cargo doc --no-deps --no-default-features`,
// exercised by the CI `candle-no-cuda` leg — they must succeed on a
// no-CUDA-toolkit host. `cargo test` is NOT in that set: Cargo cannot
// feature-gate dev-dependencies, so the GPU tests' unconditional cudarc
// dev-dependency probes the CUDA toolkit at build time even with default
// features off.
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use ;
pub use qkv_project_int8;
pub use qkv_project_int4;