baracuda_kernels/quantize/quantized_linear.rs
1//! `quantized_linear` — fused W8A8 quantized matmul (Phase 8.3).
2//!
3//! The canonical inference-time LLM matmul recipe:
4//!
5//! 1. Quantize the FP activation per-token (dynamic-range, symmetric).
6//! 2. Accumulate the int8 × int8 GEMM into int32.
7//! 3. Dequantize the int32 acc by `scale_a[m] · scale_w[n]` and store as FP.
8//!
9//! Used by SmoothQuant, AWQ-runtime, and most production W8A8 LLM
10//! kernels. The Plan owns the orchestration; the underlying bespoke
11//! kernel fuses the int8 mma + dequant + FP store as one launch.
12//!
13//! ## Layout
14//!
15//! - `activation` : `[M, K]` FP (row-major).
16//! - `weight_q` : `[C_out, K]` int8 (row-major — one row per output channel).
17//! - `weight_scale` : `[C_out]` FP (per-output-channel, saved when the
18//! weight was quantized offline).
19//! - `output` : `[M, C_out]` FP.
20//!
21//! `weight_q` is `[C_out, K]` rather than `[K, C_out]` so the inner-K
22//! reduction reads contiguous K spans from both the activation row and
23//! the weight row — the natural layout for the linear-layer convention
24//! `y = x · W^T` where `W` is the weight matrix in `[C_out, C_in]` form
25//! (PyTorch `nn.Linear.weight` layout).
26//!
27//! ## Trailblazer scope
28//!
29//! - Symmetric + per-token activation quantization (composes
30//! [`super::DynamicRangeQuantizePlan`]).
31//! - Per-output-channel weight scale (caller supplies, computed offline).
32//! - `TIn ∈ {f32, f64}` activation + output; weight = `S8`.
33//! - **Naive kernel** (one thread per output cell, register-only int32
34//! accumulator) — correctness scaffold, not throughput-optimized.
35//! Tiled-smem / mma.sync optimizations land in a perf milestone.
36//! - **Inference-only** — no backward. The W8A8 path is forward-only
37//! by convention; if a downstream needs gradients, they should use
38//! [`super::FakeQuantizePlan`] for QAT (quant-aware training) and run
39//! a normal FP matmul.
40
41use core::ffi::c_void;
42use core::marker::PhantomData;
43
44use baracuda_cutlass::{Error, Result};
45use baracuda_driver::Stream;
46use baracuda_kernels_types::{
47 ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
48 PlanPreference, PrecisionGuarantee, QuantizeKind, S8, TensorMut, TensorRef, Workspace,
49};
50
51use super::map_status;
52
53/// Descriptor for a `quantized_linear` op.
54#[derive(Copy, Clone, Debug)]
55pub struct QuantizedLinearDescriptor {
56 /// Number of token rows in the activation (and rows of the output).
57 pub m: i32,
58 /// Number of output channels (rows of `weight_q`, cols of output).
59 pub c_out: i32,
60 /// Inner reduction dim (cols of `activation` and `weight_q`).
61 pub k: i32,
62 /// Activation quantization range lower bound (symmetric: `-127`).
63 pub q_min: i32,
64 /// Activation quantization range upper bound (symmetric: `127`).
65 pub q_max: i32,
66 /// Activation FP element kind. Must match `TIn::KIND`.
67 pub activation_element: ElementKind,
68 /// Weight int element kind. Today wired only for `S8`.
69 pub weight_element: ElementKind,
70}
71
72/// Args bundle for a `quantized_linear` launch.
73///
74/// The caller supplies the already-quantized weight + its per-channel
75/// scale (offline-computed). The activation is FP; per-token
76/// activation quantization happens inside [`QuantizedLinearPlan::run`]
77/// via an internally orchestrated [`super::DynamicRangeQuantizePlan`]
78/// pass.
79///
80/// `act_q_scratch` and `act_scale_scratch` are caller-owned scratch
81/// buffers for the quantized activation + computed per-row activation
82/// scale. They are part of the args bundle (not workspace) so callers
83/// can reuse them across launches without re-allocation — the Plan's
84/// `workspace_size()` returns 0.
85pub struct QuantizedLinearArgs<'a, TIn: Element, TWQ: IntElement> {
86 /// FP activation `[M, K]`.
87 pub activation: TensorRef<'a, TIn, 2>,
88 /// Already-quantized int8 weight `[C_out, K]`.
89 pub weight_q: TensorRef<'a, TWQ, 2>,
90 /// Per-output-channel weight scale `[C_out]` in FP.
91 pub weight_scale: TensorRef<'a, TIn, 1>,
92 /// FP output `[M, C_out]`.
93 pub output: TensorMut<'a, TIn, 2>,
94 /// Scratch for the per-token quantized activation `[M, K]` in int8.
95 /// Caller-owned; reused across launches.
96 pub act_q_scratch: TensorMut<'a, S8, 2>,
97 /// Scratch for the per-token activation scale `[M]` in FP.
98 /// Caller-owned; reused across launches. Populated by the
99 /// internally orchestrated dynamic-range pass.
100 pub act_scale_scratch: TensorMut<'a, TIn, 1>,
101}
102
103/// `quantized_linear` plan (W8A8 fused).
104///
105/// Composes two passes internally:
106///
107/// 1. **Activation quantize** — per-token symmetric dynamic-range
108/// quantization, fused max-abs reduce + scale compute + quantize.
109/// Implemented by the same `dynamic_range_quantize_per_token_sym`
110/// kernel that backs [`super::DynamicRangeQuantizePlan`].
111/// 2. **Quantized matmul** — fused int8 GEMM + per-row/per-col
112/// dequantize + FP store. Implemented by the bespoke
113/// `quantized_linear_w8a8` kernel.
114///
115/// Both passes share the same stream and execute back-to-back; the Plan
116/// does NOT own an internal `DynamicRangeQuantizePlan` instance — it
117/// invokes the FFI directly to keep the launch ordering explicit.
118///
119/// **When to use**: W8A8 inference matmul (SmoothQuant / AWQ-runtime
120/// style). Inference-only — no BW; for QAT use
121/// [`FakeQuantizePlan`](crate::FakeQuantizePlan) + normal FP matmul.
122///
123/// **Dtypes (trailblazer)**: `TIn (act/out) ∈ {f32, f64}`; `TWQ = S8`.
124/// f16 / bf16 activations and u8 weight not yet wired.
125///
126/// **Shape limits**: `activation` `[M, K]`; `weight_q` `[C_out, K]`;
127/// `weight_scale` `[C_out]`; `output` `[M, C_out]`. The W4 layout
128/// `[C_out, K]` matches `y = x · W^T` (PyTorch `nn.Linear.weight`).
129///
130/// **Workspace**: zero in [`Workspace`]. Caller supplies
131/// `act_q_scratch` `[M, K]` (int8) and `act_scale_scratch` `[M]`
132/// (FP) in [`QuantizedLinearArgs`] for the fused activation-quant
133/// pass.
134///
135/// **Precision guarantee**: deterministic, bit-stable. Naive kernel
136/// (one thread per output cell, register-only int32 acc) for
137/// correctness; tiled-smem / mma.sync optimizations land in a perf
138/// milestone — current variant is **correctness-scaffold, not
139/// throughput-optimized**.
140pub struct QuantizedLinearPlan<TIn: Element, TWQ: IntElement> {
141 desc: QuantizedLinearDescriptor,
142 sku: KernelSku,
143 _marker: PhantomData<(TIn, TWQ)>,
144}
145
146impl<TIn: Element, TWQ: IntElement> QuantizedLinearPlan<TIn, TWQ> {
147 /// Pick a kernel for `desc`.
148 pub fn select(
149 _stream: &Stream,
150 desc: &QuantizedLinearDescriptor,
151 _pref: PlanPreference,
152 ) -> Result<Self> {
153 if desc.activation_element != TIn::KIND {
154 return Err(Error::Unsupported(
155 "QuantizedLinearPlan: descriptor activation_element != TIn",
156 ));
157 }
158 if desc.weight_element != TWQ::KIND {
159 return Err(Error::Unsupported(
160 "QuantizedLinearPlan: descriptor weight_element != TWQ",
161 ));
162 }
163 // Trailblazer dtype matrix: TIn ∈ {f32, f64}, TWQ = S8.
164 if !matches!(TIn::KIND, ElementKind::F32 | ElementKind::F64) {
165 return Err(Error::Unsupported(
166 "QuantizedLinearPlan: 8.3 trailblazer only wires f32 / f64 \
167 activation (f16 / bf16 deferred)",
168 ));
169 }
170 if TWQ::KIND != ElementKind::S8 {
171 return Err(Error::Unsupported(
172 "QuantizedLinearPlan: 8.3 trailblazer only wires S8 weight \
173 (U8 deferred)",
174 ));
175 }
176 if desc.m < 0 || desc.c_out < 0 || desc.k < 0 {
177 return Err(Error::InvalidProblem(
178 "QuantizedLinearPlan: m, c_out, k must be non-negative",
179 ));
180 }
181 if desc.q_max <= 0 {
182 return Err(Error::InvalidProblem(
183 "QuantizedLinearPlan: q_max must be > 0",
184 ));
185 }
186 if desc.q_max < desc.q_min {
187 return Err(Error::InvalidProblem(
188 "QuantizedLinearPlan: q_max < q_min",
189 ));
190 }
191 if desc.m > 65535 {
192 return Err(Error::Unsupported(
193 "QuantizedLinearPlan: M > 65535 — the internal dynamic-range pass \
194 uses one block per row and would exceed the legacy grid limit \
195 (lift when row tiling lands)",
196 ));
197 }
198 let sku = build_sku::<TIn, TWQ>(QuantizeKind::QuantizedLinear);
199 Ok(Self {
200 desc: *desc,
201 sku,
202 _marker: PhantomData,
203 })
204 }
205
206 /// Validate args at run time.
207 pub fn can_implement(&self, args: &QuantizedLinearArgs<'_, TIn, TWQ>) -> Result<()> {
208 if args.activation.shape != [self.desc.m, self.desc.k] {
209 return Err(Error::InvalidProblem(
210 "QuantizedLinearPlan: activation shape != [M, K]",
211 ));
212 }
213 if args.weight_q.shape != [self.desc.c_out, self.desc.k] {
214 return Err(Error::InvalidProblem(
215 "QuantizedLinearPlan: weight_q shape != [C_out, K]",
216 ));
217 }
218 if args.weight_scale.shape != [self.desc.c_out] {
219 return Err(Error::InvalidProblem(
220 "QuantizedLinearPlan: weight_scale shape != [C_out]",
221 ));
222 }
223 if args.output.shape != [self.desc.m, self.desc.c_out] {
224 return Err(Error::InvalidProblem(
225 "QuantizedLinearPlan: output shape != [M, C_out]",
226 ));
227 }
228 if args.act_q_scratch.shape != [self.desc.m, self.desc.k] {
229 return Err(Error::InvalidProblem(
230 "QuantizedLinearPlan: act_q_scratch shape != [M, K]",
231 ));
232 }
233 if args.act_scale_scratch.shape != [self.desc.m] {
234 return Err(Error::InvalidProblem(
235 "QuantizedLinearPlan: act_scale_scratch shape != [M]",
236 ));
237 }
238 Ok(())
239 }
240
241 /// Workspace bytes — none. Scratch buffers are caller-owned via the
242 /// args bundle (`act_q_scratch` + `act_scale_scratch`), allowing
243 /// reuse across launches.
244 #[inline]
245 pub fn workspace_size(&self) -> usize {
246 0
247 }
248
249 /// Identity of the selected kernel.
250 #[inline]
251 pub fn sku(&self) -> KernelSku {
252 self.sku
253 }
254
255 /// Numerical guarantees.
256 #[inline]
257 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
258 self.sku.precision_guarantee
259 }
260
261 /// Launch.
262 pub fn run(
263 &self,
264 stream: &Stream,
265 _workspace: Workspace<'_>,
266 args: QuantizedLinearArgs<'_, TIn, TWQ>,
267 ) -> Result<()> {
268 self.can_implement(&args)?;
269 if (self.desc.m as i64) * (self.desc.c_out as i64) == 0
270 || self.desc.k == 0
271 {
272 return Ok(());
273 }
274
275 let stream_ptr = stream.as_raw() as *mut c_void;
276
277 // ---- Pass 1: dynamic-range per-token symmetric quantize the
278 // FP activation into the int8 scratch. -----------
279 let act_ptr = args.activation.data.as_raw().0 as *const c_void;
280 let act_scale_ptr = args.act_scale_scratch.data.as_raw().0 as *mut c_void;
281 let act_q_ptr = args.act_q_scratch.data.as_raw().0 as *mut c_void;
282 let drq_status = match TIn::KIND {
283 ElementKind::F32 => unsafe {
284 baracuda_kernels_sys::baracuda_kernels_dynamic_range_quantize_per_token_sym_f32_s8_run(
285 self.desc.m,
286 self.desc.k,
287 self.desc.q_min,
288 self.desc.q_max,
289 act_ptr, act_scale_ptr, act_q_ptr,
290 core::ptr::null_mut(), 0, stream_ptr,
291 )
292 },
293 ElementKind::F64 => unsafe {
294 baracuda_kernels_sys::baracuda_kernels_dynamic_range_quantize_per_token_sym_f64_s8_run(
295 self.desc.m,
296 self.desc.k,
297 self.desc.q_min,
298 self.desc.q_max,
299 act_ptr, act_scale_ptr, act_q_ptr,
300 core::ptr::null_mut(), 0, stream_ptr,
301 )
302 },
303 _ => {
304 return Err(Error::Unsupported(
305 "QuantizedLinearPlan::run reached unsupported TIn at \
306 activation-quantize pass (select should have caught)",
307 ))
308 }
309 };
310 map_status(drq_status)?;
311
312 // ---- Pass 2: fused quantized-linear (int8 GEMM + dequant + FP store). ----
313 let weight_ptr = args.weight_q.data.as_raw().0 as *const c_void;
314 let act_q_const = args.act_q_scratch.data.as_raw().0 as *const c_void;
315 let act_scale_const = args.act_scale_scratch.data.as_raw().0 as *const c_void;
316 let w_scale_ptr = args.weight_scale.data.as_raw().0 as *const c_void;
317 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
318 let ql_status = match TIn::KIND {
319 ElementKind::F32 => unsafe {
320 baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f32_run(
321 self.desc.m,
322 self.desc.c_out,
323 self.desc.k,
324 weight_ptr, act_q_const,
325 act_scale_const, w_scale_ptr,
326 out_ptr,
327 core::ptr::null_mut(), 0, stream_ptr,
328 )
329 },
330 ElementKind::F64 => unsafe {
331 baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f64_run(
332 self.desc.m,
333 self.desc.c_out,
334 self.desc.k,
335 weight_ptr, act_q_const,
336 act_scale_const, w_scale_ptr,
337 out_ptr,
338 core::ptr::null_mut(), 0, stream_ptr,
339 )
340 },
341 _ => {
342 return Err(Error::Unsupported(
343 "QuantizedLinearPlan::run reached unsupported TIn at \
344 quantized-linear pass (select should have caught)",
345 ))
346 }
347 };
348 map_status(ql_status)
349 }
350}
351
352/// Build the [`KernelSku`] for a quantized-linear plan.
353fn build_sku<TIn: Element, TWQ: IntElement>(op: QuantizeKind) -> KernelSku {
354 let precision_guarantee = PrecisionGuarantee {
355 math_precision: if TIn::KIND == ElementKind::F64 {
356 MathPrecision::F64
357 } else {
358 MathPrecision::F32
359 },
360 accumulator: ElementKind::F32,
361 // Deterministic: int32 reduction + per-row block scan + FP
362 // dequant; no atomics. Different (M, K) tile orderings would
363 // change the float sum-cancellation pattern in a later tiled
364 // kernel, but the naive trailblazer is fully serial within a
365 // thread and bit-stable on the same hardware.
366 bit_stable_on_same_hardware: true,
367 deterministic: true,
368 };
369 KernelSku {
370 category: OpCategory::Quantization,
371 op: op as u16,
372 element: TIn::KIND,
373 aux_element: Some(TWQ::KIND),
374 layout: None,
375 epilogue: None,
376 arch: ArchSku::Sm80,
377 backend: BackendKind::Bespoke,
378 precision_guarantee,
379 }
380}