Skip to main content

baracuda_kernels/quantize/
mod.rs

1//! Quantization op family — Category P.
2//!
3//! Phase 8 splits across two parallel milestones:
4//!
5//! - **Milestone 8.1** (sibling): per-tensor + per-channel quantize /
6//!   dequantize plus `fake_quantize`. Owns
7//!   `crates/baracuda-kernels-sys/kernels/quantize/per_tensor.cu` /
8//!   `per_channel.cu` / `fake_quantize.cu` and the Rust plans for those
9//!   ops in this `quantize/` module.
10//!
11//! - **Milestone 8.2** (this work): per-token + per-group quantize /
12//!   dequantize plus their STE backwards. Used by LLM activation (W8A8
13//!   per-row) and weight (GPTQ-style INT4 per-group, `g=128`) quant.
14//!   Owns
15//!   `crates/baracuda-kernels-sys/kernels/quantize/per_token.cu` /
16//!   `per_group.cu` and the plans in this module.
17//!
18//! The two milestones share **append-only** edits to this file, to
19//! `crate::lib`'s re-exports, and to `baracuda-kernels-sys/src/lib.rs`.
20//! No existing entry is rewritten.
21//!
22//! Trailblazer dtype coverage: input FP ∈ {`f32`, `f64`, `f16`, `bf16`};
23//! output int ∈ {`s8`, `u8`}. Sub-byte packed types (`s4` / `u4`) are
24//! deferred to a later milestone.
25//!
26//! Backward convention is the Straight-Through Estimator (STE):
27//! `dx = (dy / scale) * 1[qmin < q < qmax]`. The in-range mask is
28//! recomputed inside the BW kernel from the saved input — callers must
29//! retain the input tensor for autograd (which they would do anyway).
30
31// --- Milestone 8.1 modules (per-tensor + per-channel + fake_quantize). ---
32pub mod dequantize_per_channel;
33pub mod dequantize_per_channel_backward;
34pub mod dequantize_per_tensor;
35pub mod dequantize_per_tensor_backward;
36pub mod fake_quantize;
37pub mod fake_quantize_backward;
38pub mod per_channel;
39pub mod per_channel_backward;
40pub mod per_tensor;
41pub mod per_tensor_backward;
42
43// --- Milestone 8.2 modules (per-token + per-group). Full coverage of
44//     FW + STE BW + dequant FW + straight-through dequant BW for both
45//     per-token and per-group quant. ---
46pub mod dequantize_per_group;
47pub mod dequantize_per_group_backward;
48pub mod dequantize_per_token;
49pub mod dequantize_per_token_backward;
50pub mod per_group;
51pub mod per_group_backward;
52pub mod per_token;
53pub mod per_token_backward;
54
55// --- Milestone 8.3 modules — composing ops on top of 8.1 / 8.2. ----
56pub mod dynamic_range;
57pub mod quantized_linear;
58
59// --- Phase 45 module — SmoothQuant linear (pure Rust composition over
60//     the existing `quantized_linear_w8a8` kernel; zero new CUDA). ----
61pub mod smoothquant;
62
63// --- Milestone 8.4 module — GGUF block-format quant family (vendored
64//     from llama.cpp via fuel-cuda-kernels). Full block-format coverage
65//     for both dequant and MMVQ. Phase 11.4 added a bespoke Q8_K MMVQ
66//     (upstream llama.cpp / Fuel ship only Q8_K dequant). ----
67pub mod gguf;
68
69// --- Phase 53 — bitsandbytes NF4 (NormalFloat 4-bit) dequant + GEMV.
70//     Vendored kernel sources at
71//     `crates/baracuda-kernels-sys/vendor/bitsandbytes/` (MIT,
72//     Dettmers et al. arXiv:2305.14314). Gated behind the `bnb_nf4`
73//     cargo feature; the Rust plan types compile unconditionally so
74//     the public API surface is stable. ----
75pub mod nf4;
76
77pub use dequantize_per_channel::{
78    DequantizePerChannelArgs, DequantizePerChannelDescriptor, DequantizePerChannelPlan,
79};
80pub use dequantize_per_channel_backward::{
81    DequantizePerChannelBackwardArgs, DequantizePerChannelBackwardDescriptor,
82    DequantizePerChannelBackwardPlan,
83};
84pub use dequantize_per_tensor::{
85    DequantizePerTensorArgs, DequantizePerTensorDescriptor, DequantizePerTensorPlan,
86};
87pub use dequantize_per_tensor_backward::{
88    DequantizePerTensorBackwardArgs, DequantizePerTensorBackwardDescriptor,
89    DequantizePerTensorBackwardPlan,
90};
91pub use fake_quantize::{FakeQuantizeArgs, FakeQuantizeDescriptor, FakeQuantizePlan};
92pub use fake_quantize_backward::{
93    FakeQuantizeBackwardArgs, FakeQuantizeBackwardDescriptor, FakeQuantizeBackwardPlan,
94};
95pub use per_channel::{QuantizePerChannelArgs, QuantizePerChannelDescriptor, QuantizePerChannelPlan};
96pub use per_channel_backward::{
97    QuantizePerChannelBackwardArgs, QuantizePerChannelBackwardDescriptor,
98    QuantizePerChannelBackwardPlan,
99};
100pub use per_tensor::{QuantizePerTensorArgs, QuantizePerTensorDescriptor, QuantizePerTensorPlan};
101pub use per_tensor_backward::{
102    QuantizePerTensorBackwardArgs, QuantizePerTensorBackwardDescriptor,
103    QuantizePerTensorBackwardPlan,
104};
105
106pub use per_token::{QuantizePerTokenArgs, QuantizePerTokenDescriptor, QuantizePerTokenPlan};
107pub use per_token_backward::{
108    QuantizePerTokenBackwardArgs, QuantizePerTokenBackwardDescriptor, QuantizePerTokenBackwardPlan,
109};
110pub use dequantize_per_token::{
111    DequantizePerTokenArgs, DequantizePerTokenDescriptor, DequantizePerTokenPlan,
112};
113pub use dequantize_per_token_backward::{
114    DequantizePerTokenBackwardArgs, DequantizePerTokenBackwardDescriptor,
115    DequantizePerTokenBackwardPlan,
116};
117pub use per_group::{QuantizePerGroupArgs, QuantizePerGroupDescriptor, QuantizePerGroupPlan};
118pub use per_group_backward::{
119    QuantizePerGroupBackwardArgs, QuantizePerGroupBackwardDescriptor, QuantizePerGroupBackwardPlan,
120};
121pub use dequantize_per_group::{
122    DequantizePerGroupArgs, DequantizePerGroupDescriptor, DequantizePerGroupPlan,
123};
124pub use dequantize_per_group_backward::{
125    DequantizePerGroupBackwardArgs, DequantizePerGroupBackwardDescriptor,
126    DequantizePerGroupBackwardPlan,
127};
128
129// --- Milestone 8.3 exports ---
130pub use dynamic_range::{
131    DynamicRangeMode, DynamicRangeQuantizeArgs, DynamicRangeQuantizeDescriptor,
132    DynamicRangeQuantizePlan, DynamicRangeScope,
133};
134pub use quantized_linear::{
135    QuantizedLinearArgs, QuantizedLinearDescriptor, QuantizedLinearPlan,
136};
137
138// --- Phase 45 export — SmoothQuant linear (pure Rust composition). ---
139pub use smoothquant::{
140    SmoothQuantLinearArgs, SmoothQuantLinearDescriptor, SmoothQuantLinearPlan,
141};
142
143// --- Milestone 8.4 exports — GGUF block-format dequant + MMVQ ---
144pub use gguf::{
145    BlockQ2K, BlockQ3K, BlockQ4_0, BlockQ4_1, BlockQ4K, BlockQ5_0, BlockQ5_1, BlockQ5K, BlockQ6K,
146    BlockQ8_0, BlockQ8K, GgufDequantizeArgs, GgufDequantizeDescriptor, GgufDequantizePlan,
147    GgufMmvqArgs, GgufMmvqDescriptor, GgufMmvqPlan,
148};
149
150// --- Phase 20.1 export — GGUF batched MMVQ × N-experts (general-purpose
151//     routing primitive). 33 quant FFI symbols + 3 pure-FP FFI symbols. ---
152pub use gguf::{
153    GgufMmvqBatchedActivation, GgufMmvqBatchedArgs, GgufMmvqBatchedDescriptor,
154    GgufMmvqBatchedFormat, GgufMmvqBatchedPlan,
155};
156
157// --- Phase 33 export — multi-M MMVQ via Q8_1 activation staging (prefill
158//     speedup). Q8_0 weights only; 4 compile-time M sizes (1/2/4/8). ---
159pub use gguf::{GgufMmvqMultiMArgs, GgufMmvqMultiMDescriptor, GgufMmvqMultiMPlan};
160
161// --- Phase 53 exports — bitsandbytes NF4 dequant + GEMV (QLoRA
162//     inference path; behind `bnb_nf4` feature). Plan types are
163//     always exported; the FFI dispatch is feature-gated inside the
164//     plan's `run()` method. ----
165pub use nf4::{
166    Nf4Activation, Nf4DequantizeArgs, Nf4DequantizePlan, Nf4Descriptor, Nf4MmvqArgs,
167    Nf4MmvqMultiMArgs, Nf4MmvqMultiMDescriptor, Nf4MmvqMultiMPlan, Nf4MmvqPlan, NF4_CODEBOOK,
168};
169
170use baracuda_cutlass::{Error, Result};
171
172/// Shared status-code mapper (mirrors `indexing::gather::map_status` and
173/// `segment::map_status`).
174pub(crate) fn map_status(code: i32) -> Result<()> {
175    match code {
176        0 => Ok(()),
177        1 => Err(Error::MisalignedOperand),
178        2 => Err(Error::InvalidProblem(
179            "baracuda-kernels-sys reported invalid problem",
180        )),
181        3 => Err(Error::Unsupported(
182            "baracuda-kernels-sys reported unsupported configuration",
183        )),
184        4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
185        n => Err(Error::CutlassInternal(n)),
186    }
187}
188
189/// Element-kind check shared across the per-token / per-group plans.
190/// Returns Ok if `tin_kind` is one of the four supported FP dtypes.
191pub(crate) fn validate_input_element(
192    tin_kind: baracuda_kernels_types::ElementKind,
193    plan_name: &'static str,
194) -> Result<()> {
195    use baracuda_kernels_types::ElementKind;
196    if !matches!(
197        tin_kind,
198        ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
199    ) {
200        return Err(Error::Unsupported(plan_name));
201    }
202    Ok(())
203}
204
205/// Output-kind check shared across the per-token / per-group plans.
206/// Returns Ok if `tout_kind` is `S8` or `U8`.
207pub(crate) fn validate_output_element(
208    tout_kind: baracuda_kernels_types::ElementKind,
209    plan_name: &'static str,
210) -> Result<()> {
211    use baracuda_kernels_types::ElementKind;
212    if !matches!(tout_kind, ElementKind::S8 | ElementKind::U8) {
213        return Err(Error::Unsupported(plan_name));
214    }
215    Ok(())
216}
217
218/// Default `qmin` / `qmax` for an output integer dtype. Today wired for
219/// the two trailblazer output kinds — [`baracuda_kernels_types::S8`]
220/// (`[-128, 127]`) and [`baracuda_kernels_types::U8`] (`[0, 255]`).
221///
222/// Returns `None` for unsupported kinds; the plan's `select()` returns
223/// `Error::Unsupported` in that case.
224#[inline]
225pub fn default_q_range(out_kind: baracuda_kernels_types::ElementKind) -> Option<(i32, i32)> {
226    use baracuda_kernels_types::ElementKind;
227    match out_kind {
228        ElementKind::S8 => Some((-128, 127)),
229        ElementKind::U8 => Some((0, 255)),
230        _ => None,
231    }
232}