Skip to main content

baracuda_kernels/quantize/
quantized_linear.rs

1//! `quantized_linear` — fused W8A8 quantized matmul (Phase 8.3).
2//!
3//! The canonical inference-time LLM matmul recipe:
4//!
5//! 1. Quantize the FP activation per-token (dynamic-range, symmetric).
6//! 2. Accumulate the int8 × int8 GEMM into int32.
7//! 3. Dequantize the int32 acc by `scale_a[m] · scale_w[n]` and store as FP.
8//!
9//! Used by SmoothQuant, AWQ-runtime, and most production W8A8 LLM
10//! kernels. The Plan owns the orchestration; the underlying bespoke
11//! kernel fuses the int8 mma + dequant + FP store as one launch.
12//!
13//! ## Layout
14//!
15//! - `activation`   : `[M, K]` FP (row-major).
16//! - `weight_q`     : `[C_out, K]` int8 (row-major — one row per output channel).
17//! - `weight_scale` : `[C_out]` FP (per-output-channel, saved when the
18//!   weight was quantized offline).
19//! - `output`       : `[M, C_out]` FP.
20//!
21//! `weight_q` is `[C_out, K]` rather than `[K, C_out]` so the inner-K
22//! reduction reads contiguous K spans from both the activation row and
23//! the weight row — the natural layout for the linear-layer convention
24//! `y = x · W^T` where `W` is the weight matrix in `[C_out, C_in]` form
25//! (PyTorch `nn.Linear.weight` layout).
26//!
27//! ## Trailblazer scope
28//!
29//! - Symmetric + per-token activation quantization (composes
30//!   [`super::DynamicRangeQuantizePlan`]).
31//! - Per-output-channel weight scale (caller supplies, computed offline).
32//! - `TIn ∈ {f32, f64}` activation + output; weight = `S8`.
33//! - **Naive kernel** (one thread per output cell, register-only int32
34//!   accumulator) — correctness scaffold, not throughput-optimized.
35//!   Tiled-smem / mma.sync optimizations land in a perf milestone.
36//! - **Inference-only** — no backward. The W8A8 path is forward-only
37//!   by convention; if a downstream needs gradients, they should use
38//!   [`super::FakeQuantizePlan`] for QAT (quant-aware training) and run
39//!   a normal FP matmul.
40
41use core::ffi::c_void;
42use core::marker::PhantomData;
43
44use baracuda_cutlass::{Error, Result};
45use baracuda_driver::Stream;
46use baracuda_kernels_types::{
47    ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
48    PlanPreference, PrecisionGuarantee, QuantizeKind, S8, TensorMut, TensorRef, Workspace,
49};
50
51use super::map_status;
52
53/// Descriptor for a `quantized_linear` op.
54#[derive(Copy, Clone, Debug)]
55pub struct QuantizedLinearDescriptor {
56    /// Number of token rows in the activation (and rows of the output).
57    pub m: i32,
58    /// Number of output channels (rows of `weight_q`, cols of output).
59    pub c_out: i32,
60    /// Inner reduction dim (cols of `activation` and `weight_q`).
61    pub k: i32,
62    /// Activation quantization range lower bound (symmetric: `-127`).
63    pub q_min: i32,
64    /// Activation quantization range upper bound (symmetric: `127`).
65    pub q_max: i32,
66    /// Activation FP element kind. Must match `TIn::KIND`.
67    pub activation_element: ElementKind,
68    /// Weight int element kind. Today wired only for `S8`.
69    pub weight_element: ElementKind,
70}
71
72/// Args bundle for a `quantized_linear` launch.
73///
74/// The caller supplies the already-quantized weight + its per-channel
75/// scale (offline-computed). The activation is FP; per-token
76/// activation quantization happens inside [`QuantizedLinearPlan::run`]
77/// via an internally orchestrated [`super::DynamicRangeQuantizePlan`]
78/// pass.
79///
80/// `act_q_scratch` and `act_scale_scratch` are caller-owned scratch
81/// buffers for the quantized activation + computed per-row activation
82/// scale. They are part of the args bundle (not workspace) so callers
83/// can reuse them across launches without re-allocation — the Plan's
84/// `workspace_size()` returns 0.
85pub struct QuantizedLinearArgs<'a, TIn: Element, TWQ: IntElement> {
86    /// FP activation `[M, K]`.
87    pub activation: TensorRef<'a, TIn, 2>,
88    /// Already-quantized int8 weight `[C_out, K]`.
89    pub weight_q: TensorRef<'a, TWQ, 2>,
90    /// Per-output-channel weight scale `[C_out]` in FP.
91    pub weight_scale: TensorRef<'a, TIn, 1>,
92    /// FP output `[M, C_out]`.
93    pub output: TensorMut<'a, TIn, 2>,
94    /// Scratch for the per-token quantized activation `[M, K]` in int8.
95    /// Caller-owned; reused across launches.
96    pub act_q_scratch: TensorMut<'a, S8, 2>,
97    /// Scratch for the per-token activation scale `[M]` in FP.
98    /// Caller-owned; reused across launches. Populated by the
99    /// internally orchestrated dynamic-range pass.
100    pub act_scale_scratch: TensorMut<'a, TIn, 1>,
101}
102
103/// `quantized_linear` plan (W8A8 fused).
104///
105/// Composes two passes internally:
106///
107/// 1. **Activation quantize** — per-token symmetric dynamic-range
108///    quantization, fused max-abs reduce + scale compute + quantize.
109///    Implemented by the same `dynamic_range_quantize_per_token_sym`
110///    kernel that backs [`super::DynamicRangeQuantizePlan`].
111/// 2. **Quantized matmul** — fused int8 GEMM + per-row/per-col
112///    dequantize + FP store. Implemented by the bespoke
113///    `quantized_linear_w8a8` kernel.
114///
115/// Both passes share the same stream and execute back-to-back; the Plan
116/// does NOT own an internal `DynamicRangeQuantizePlan` instance — it
117/// invokes the FFI directly to keep the launch ordering explicit.
118///
119/// **When to use**: W8A8 inference matmul (SmoothQuant / AWQ-runtime
120/// style). Inference-only — no BW; for QAT use
121/// [`FakeQuantizePlan`](crate::FakeQuantizePlan) + normal FP matmul.
122///
123/// **Dtypes (trailblazer)**: `TIn (act/out) ∈ {f32, f64}`; `TWQ = S8`.
124/// f16 / bf16 activations and u8 weight not yet wired.
125///
126/// **Shape limits**: `activation` `[M, K]`; `weight_q` `[C_out, K]`;
127/// `weight_scale` `[C_out]`; `output` `[M, C_out]`. The W4 layout
128/// `[C_out, K]` matches `y = x · W^T` (PyTorch `nn.Linear.weight`).
129///
130/// **Workspace**: zero in [`Workspace`]. Caller supplies
131/// `act_q_scratch` `[M, K]` (int8) and `act_scale_scratch` `[M]`
132/// (FP) in [`QuantizedLinearArgs`] for the fused activation-quant
133/// pass.
134///
135/// **Precision guarantee**: deterministic, bit-stable. Naive kernel
136/// (one thread per output cell, register-only int32 acc) for
137/// correctness; tiled-smem / mma.sync optimizations land in a perf
138/// milestone — current variant is **correctness-scaffold, not
139/// throughput-optimized**.
140pub struct QuantizedLinearPlan<TIn: Element, TWQ: IntElement> {
141    desc: QuantizedLinearDescriptor,
142    sku: KernelSku,
143    _marker: PhantomData<(TIn, TWQ)>,
144}
145
146impl<TIn: Element, TWQ: IntElement> QuantizedLinearPlan<TIn, TWQ> {
147    /// Pick a kernel for `desc`.
148    pub fn select(
149        _stream: &Stream,
150        desc: &QuantizedLinearDescriptor,
151        _pref: PlanPreference,
152    ) -> Result<Self> {
153        if desc.activation_element != TIn::KIND {
154            return Err(Error::Unsupported(
155                "QuantizedLinearPlan: descriptor activation_element != TIn",
156            ));
157        }
158        if desc.weight_element != TWQ::KIND {
159            return Err(Error::Unsupported(
160                "QuantizedLinearPlan: descriptor weight_element != TWQ",
161            ));
162        }
163        // Trailblazer dtype matrix: TIn ∈ {f32, f64}, TWQ = S8.
164        if !matches!(TIn::KIND, ElementKind::F32 | ElementKind::F64) {
165            return Err(Error::Unsupported(
166                "QuantizedLinearPlan: 8.3 trailblazer only wires f32 / f64 \
167                 activation (f16 / bf16 deferred)",
168            ));
169        }
170        if TWQ::KIND != ElementKind::S8 {
171            return Err(Error::Unsupported(
172                "QuantizedLinearPlan: 8.3 trailblazer only wires S8 weight \
173                 (U8 deferred)",
174            ));
175        }
176        if desc.m < 0 || desc.c_out < 0 || desc.k < 0 {
177            return Err(Error::InvalidProblem(
178                "QuantizedLinearPlan: m, c_out, k must be non-negative",
179            ));
180        }
181        if desc.q_max <= 0 {
182            return Err(Error::InvalidProblem(
183                "QuantizedLinearPlan: q_max must be > 0",
184            ));
185        }
186        if desc.q_max < desc.q_min {
187            return Err(Error::InvalidProblem(
188                "QuantizedLinearPlan: q_max < q_min",
189            ));
190        }
191        if desc.m > 65535 {
192            return Err(Error::Unsupported(
193                "QuantizedLinearPlan: M > 65535 — the internal dynamic-range pass \
194                 uses one block per row and would exceed the legacy grid limit \
195                 (lift when row tiling lands)",
196            ));
197        }
198        let sku = build_sku::<TIn, TWQ>(QuantizeKind::QuantizedLinear);
199        Ok(Self {
200            desc: *desc,
201            sku,
202            _marker: PhantomData,
203        })
204    }
205
206    /// Validate args at run time.
207    pub fn can_implement(&self, args: &QuantizedLinearArgs<'_, TIn, TWQ>) -> Result<()> {
208        if args.activation.shape != [self.desc.m, self.desc.k] {
209            return Err(Error::InvalidProblem(
210                "QuantizedLinearPlan: activation shape != [M, K]",
211            ));
212        }
213        if args.weight_q.shape != [self.desc.c_out, self.desc.k] {
214            return Err(Error::InvalidProblem(
215                "QuantizedLinearPlan: weight_q shape != [C_out, K]",
216            ));
217        }
218        if args.weight_scale.shape != [self.desc.c_out] {
219            return Err(Error::InvalidProblem(
220                "QuantizedLinearPlan: weight_scale shape != [C_out]",
221            ));
222        }
223        if args.output.shape != [self.desc.m, self.desc.c_out] {
224            return Err(Error::InvalidProblem(
225                "QuantizedLinearPlan: output shape != [M, C_out]",
226            ));
227        }
228        if args.act_q_scratch.shape != [self.desc.m, self.desc.k] {
229            return Err(Error::InvalidProblem(
230                "QuantizedLinearPlan: act_q_scratch shape != [M, K]",
231            ));
232        }
233        if args.act_scale_scratch.shape != [self.desc.m] {
234            return Err(Error::InvalidProblem(
235                "QuantizedLinearPlan: act_scale_scratch shape != [M]",
236            ));
237        }
238        Ok(())
239    }
240
241    /// Workspace bytes — none. Scratch buffers are caller-owned via the
242    /// args bundle (`act_q_scratch` + `act_scale_scratch`), allowing
243    /// reuse across launches.
244    #[inline]
245    pub fn workspace_size(&self) -> usize {
246        0
247    }
248
249    /// Identity of the selected kernel.
250    #[inline]
251    pub fn sku(&self) -> KernelSku {
252        self.sku
253    }
254
255    /// Numerical guarantees.
256    #[inline]
257    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
258        self.sku.precision_guarantee
259    }
260
261    /// Launch.
262    pub fn run(
263        &self,
264        stream: &Stream,
265        _workspace: Workspace<'_>,
266        args: QuantizedLinearArgs<'_, TIn, TWQ>,
267    ) -> Result<()> {
268        self.can_implement(&args)?;
269        if (self.desc.m as i64) * (self.desc.c_out as i64) == 0
270            || self.desc.k == 0
271        {
272            return Ok(());
273        }
274
275        let stream_ptr = stream.as_raw() as *mut c_void;
276
277        // ---- Pass 1: dynamic-range per-token symmetric quantize the
278        //              FP activation into the int8 scratch. -----------
279        let act_ptr = args.activation.data.as_raw().0 as *const c_void;
280        let act_scale_ptr = args.act_scale_scratch.data.as_raw().0 as *mut c_void;
281        let act_q_ptr = args.act_q_scratch.data.as_raw().0 as *mut c_void;
282        let drq_status = match TIn::KIND {
283            ElementKind::F32 => unsafe {
284                baracuda_kernels_sys::baracuda_kernels_dynamic_range_quantize_per_token_sym_f32_s8_run(
285                    self.desc.m,
286                    self.desc.k,
287                    self.desc.q_min,
288                    self.desc.q_max,
289                    act_ptr, act_scale_ptr, act_q_ptr,
290                    core::ptr::null_mut(), 0, stream_ptr,
291                )
292            },
293            ElementKind::F64 => unsafe {
294                baracuda_kernels_sys::baracuda_kernels_dynamic_range_quantize_per_token_sym_f64_s8_run(
295                    self.desc.m,
296                    self.desc.k,
297                    self.desc.q_min,
298                    self.desc.q_max,
299                    act_ptr, act_scale_ptr, act_q_ptr,
300                    core::ptr::null_mut(), 0, stream_ptr,
301                )
302            },
303            _ => {
304                return Err(Error::Unsupported(
305                    "QuantizedLinearPlan::run reached unsupported TIn at \
306                     activation-quantize pass (select should have caught)",
307                ))
308            }
309        };
310        map_status(drq_status)?;
311
312        // ---- Pass 2: fused quantized-linear (int8 GEMM + dequant + FP store). ----
313        let weight_ptr = args.weight_q.data.as_raw().0 as *const c_void;
314        let act_q_const = args.act_q_scratch.data.as_raw().0 as *const c_void;
315        let act_scale_const = args.act_scale_scratch.data.as_raw().0 as *const c_void;
316        let w_scale_ptr = args.weight_scale.data.as_raw().0 as *const c_void;
317        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
318        let ql_status = match TIn::KIND {
319            ElementKind::F32 => unsafe {
320                baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f32_run(
321                    self.desc.m,
322                    self.desc.c_out,
323                    self.desc.k,
324                    weight_ptr, act_q_const,
325                    act_scale_const, w_scale_ptr,
326                    out_ptr,
327                    core::ptr::null_mut(), 0, stream_ptr,
328                )
329            },
330            ElementKind::F64 => unsafe {
331                baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f64_run(
332                    self.desc.m,
333                    self.desc.c_out,
334                    self.desc.k,
335                    weight_ptr, act_q_const,
336                    act_scale_const, w_scale_ptr,
337                    out_ptr,
338                    core::ptr::null_mut(), 0, stream_ptr,
339                )
340            },
341            _ => {
342                return Err(Error::Unsupported(
343                    "QuantizedLinearPlan::run reached unsupported TIn at \
344                     quantized-linear pass (select should have caught)",
345                ))
346            }
347        };
348        map_status(ql_status)
349    }
350}
351
352/// Build the [`KernelSku`] for a quantized-linear plan.
353fn build_sku<TIn: Element, TWQ: IntElement>(op: QuantizeKind) -> KernelSku {
354    let precision_guarantee = PrecisionGuarantee {
355        math_precision: if TIn::KIND == ElementKind::F64 {
356            MathPrecision::F64
357        } else {
358            MathPrecision::F32
359        },
360        accumulator: ElementKind::F32,
361        // Deterministic: int32 reduction + per-row block scan + FP
362        // dequant; no atomics. Different (M, K) tile orderings would
363        // change the float sum-cancellation pattern in a later tiled
364        // kernel, but the naive trailblazer is fully serial within a
365        // thread and bit-stable on the same hardware.
366        bit_stable_on_same_hardware: true,
367        deterministic: true,
368    };
369    KernelSku {
370        category: OpCategory::Quantization,
371        op: op as u16,
372        element: TIn::KIND,
373        aux_element: Some(TWQ::KIND),
374        layout: None,
375        epilogue: None,
376        arch: ArchSku::Sm80,
377        backend: BackendKind::Bespoke,
378        precision_guarantee,
379    }
380}