baracuda_kernels/quantize/smoothquant.rs
1//! `SmoothQuantLinearPlan` — Phase 45 zero-new-CUDA composition.
2//!
3//! Implements the inference-time linear pass of **SmoothQuant**
4//! (Xiao et al. ICML 2023, MIT;
5//! [mit-han-lab/smoothquant](https://github.com/mit-han-lab/smoothquant)).
6//!
7//! SmoothQuant is an **offline algorithmic recipe** — a Python
8//! preprocessing pass that migrates outlier difficulty from
9//! activations to weights via a per-channel divisor `s[K]` so that
10//! both the smoothed activation `A_smooth = A / diag(s)` and the
11//! smoothed weight `W_smooth = diag(s) · W` quantize cleanly under a
12//! single per-tensor activation scale + a per-output-channel weight
13//! scale. The smoothing itself is **not** a CUDA kernel; it lives in
14//! the Python flow at training-prep time. baracuda only needs to
15//! consume the already-smoothed-and-quantized tensors.
16//!
17//! The inference-time math is the standard W8A8 dequant:
18//!
19//! ```text
20//! y[m, n] = act_scale · weight_scale[n] · Σ_k a_q[m, k] · w_q[n, k]
21//! ```
22//!
23//! Differences from [`super::QuantizedLinearPlan`]:
24//!
25//! - **Per-tensor activation scale** (single `f32`) vs per-token
26//! `[M]` dynamic-range scale. The whole point of SmoothQuant is
27//! that one static scalar suffices once outliers are migrated.
28//! - **Caller-pre-quantized int8 activation** — no internal
29//! `dynamic_range_quantize_per_token_sym` pass.
30//!
31//! Composition strategy: the bespoke `quantized_linear_w8a8_*`
32//! kernel (vendored Milestone 8.3) consumes `scale_a: [M]` — we
33//! reuse it verbatim by having the caller supply an `[M]` scratch
34//! buffer that we fill with the constant `act_scale` via
35//! [`super::super::FillPlan`] before the matmul. Zero new CUDA.
36//!
37//! ## Layout
38//!
39//! - `act_q` : `[M, K]` int8 (row-major).
40//! - `weight_q` : `[N, K]` int8 (row-major — one row per output channel,
41//! matching PyTorch `nn.Linear.weight` layout, same as
42//! [`super::QuantizedLinearPlan`]).
43//! - `weight_scale` : `[N]` FP (per-output-channel; saved alongside
44//! the smoothed-then-quantized weights from the offline flow).
45//! - `act_scale` : single `f32` scalar in the descriptor (per-tensor).
46//! - `output` : `[M, N]` FP.
47//!
48//! ## Trailblazer scope
49//!
50//! - `TIn ∈ {f32, f64}` activation-scale + weight-scale + output;
51//! `TWQ = S8` weight; activation = `S8`.
52//! (f16 / bf16 / u8-weight follow the same matrix as
53//! [`QuantizedLinearPlan`]; deferred until the underlying bespoke
54//! `quantized_linear_w8a8` kernel grows the dtypes.)
55//! - **Inference-only** — no backward. The W8A8 path is forward-only
56//! by convention; if a downstream needs gradients, they should use
57//! [`super::FakeQuantizePlan`] for QAT.
58//! - **No bias fusion in trailblazer** — bias addition is a separate
59//! downstream op (Affine / Binary Add). The underlying
60//! `quantized_linear_w8a8` kernel doesn't take a bias today and
61//! we don't synthesize one here.
62
63use core::ffi::c_void;
64use core::marker::PhantomData;
65
66use baracuda_cutlass::{Error, Result};
67use baracuda_driver::Stream;
68use baracuda_kernels_types::{
69 ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
70 PlanPreference, PrecisionGuarantee, QuantizeKind, S8, TensorMut, TensorRef, Workspace,
71};
72
73use super::map_status;
74
75/// Descriptor for a `SmoothQuant` linear op.
76///
77/// The per-tensor activation scale lives in the descriptor (not the
78/// args) because in the SmoothQuant flow it's part of the model's
79/// frozen quantization metadata — it doesn't change between launches
80/// for the same layer.
81#[derive(Copy, Clone, Debug)]
82#[non_exhaustive]
83pub struct SmoothQuantLinearDescriptor {
84 /// Number of token rows in the activation (and rows of the output).
85 pub m: i32,
86 /// Number of output channels (rows of `weight_q`, cols of output).
87 pub n: i32,
88 /// Inner reduction dim (cols of `act_q` and `weight_q`).
89 pub k: i32,
90 /// Per-tensor activation scale produced by the offline SmoothQuant
91 /// Python flow. Always `f32` regardless of `TIn` — the underlying
92 /// `quantized_linear_w8a8` kernel does the scale multiply in float
93 /// space irrespective of output dtype.
94 pub act_scale: f32,
95 /// Activation int element kind. Today wired only for `S8`.
96 pub activation_element: ElementKind,
97 /// Weight int element kind. Today wired only for `S8`.
98 pub weight_element: ElementKind,
99 /// Output FP element kind. Must match `TIn::KIND`.
100 pub output_element: ElementKind,
101}
102
103impl SmoothQuantLinearDescriptor {
104 /// Construct a `SmoothQuantLinearDescriptor` for the given problem
105 /// shape and per-tensor activation scale. Defaults `S8` for both
106 /// activation and weight; output element matches `TIn::KIND`.
107 pub fn new<TIn: Element>(m: i32, n: i32, k: i32, act_scale: f32) -> Self {
108 Self {
109 m,
110 n,
111 k,
112 act_scale,
113 activation_element: ElementKind::S8,
114 weight_element: ElementKind::S8,
115 output_element: TIn::KIND,
116 }
117 }
118}
119
120/// Args bundle for a `SmoothQuant` linear launch.
121///
122/// `act_scale_scratch` is a caller-owned `[M]` FP scratch buffer used
123/// to broadcast the descriptor's per-tensor `act_scale` into the
124/// per-row form the underlying `quantized_linear_w8a8` kernel
125/// consumes. Caller-owned so it can be reused across launches without
126/// re-allocation — the Plan's `workspace_size()` returns 0.
127pub struct SmoothQuantLinearArgs<'a, TIn: Element, TWQ: IntElement> {
128 /// Pre-quantized int8 activation `[M, K]`.
129 pub act_q: TensorRef<'a, S8, 2>,
130 /// Pre-smoothed-then-quantized int8 weight `[N, K]`.
131 pub weight_q: TensorRef<'a, TWQ, 2>,
132 /// Per-output-channel weight scale `[N]` in FP.
133 pub weight_scale: TensorRef<'a, TIn, 1>,
134 /// FP output `[M, N]`.
135 pub output: TensorMut<'a, TIn, 2>,
136 /// Scratch for the per-row broadcast of `act_scale`. `[M]` FP.
137 /// Caller-owned; reused across launches. Populated by the plan
138 /// before the matmul launch.
139 pub act_scale_scratch: TensorMut<'a, TIn, 1>,
140}
141
142/// `SmoothQuant` linear plan — pure Rust composition over the
143/// bespoke `quantized_linear_w8a8` kernel.
144///
145/// **When to use**: SmoothQuant inference matmul. Activation has
146/// already been smoothed (divided by per-channel `s[K]`) and
147/// quantized per-tensor to int8; weight has already been smoothed
148/// (multiplied by `s[K]`) and quantized per-output-channel to int8;
149/// caller passes both, plus the static per-tensor act-scale + per-N
150/// weight-scale, to this plan.
151///
152/// **Dtypes (trailblazer)**: `TIn (scales/out) ∈ {f32, f64}`;
153/// `TWQ = S8` weight; activation is fixed at `S8`. f16 / bf16 / u8
154/// weight follow once the underlying `quantized_linear_w8a8` kernel
155/// grows those dtypes (same matrix as
156/// [`super::QuantizedLinearPlan`]).
157///
158/// **Shape limits**: `act_q` `[M, K]`; `weight_q` `[N, K]`;
159/// `weight_scale` `[N]`; `output` `[M, N]`. `[N, K]` weight layout
160/// matches `y = x · W^T` (PyTorch `nn.Linear.weight`).
161///
162/// **Workspace**: zero in [`Workspace`]. Caller supplies
163/// `act_scale_scratch` `[M]` (FP) in [`SmoothQuantLinearArgs`] for
164/// the act-scale broadcast.
165///
166/// **Precision guarantee**: deterministic, bit-stable on the same
167/// hardware (inherits from the underlying `quantized_linear_w8a8`
168/// kernel — register-only int32 accumulator + serial FP scale
169/// multiply, no atomics).
170pub struct SmoothQuantLinearPlan<TIn: Element, TWQ: IntElement> {
171 desc: SmoothQuantLinearDescriptor,
172 sku: KernelSku,
173 _marker: PhantomData<(TIn, TWQ)>,
174}
175
176impl<TIn: Element, TWQ: IntElement> SmoothQuantLinearPlan<TIn, TWQ> {
177 /// Pick a kernel for `desc`.
178 pub fn select(
179 _stream: &Stream,
180 desc: &SmoothQuantLinearDescriptor,
181 _pref: PlanPreference,
182 ) -> Result<Self> {
183 if desc.output_element != TIn::KIND {
184 return Err(Error::Unsupported(
185 "SmoothQuantLinearPlan: descriptor output_element != TIn",
186 ));
187 }
188 if desc.weight_element != TWQ::KIND {
189 return Err(Error::Unsupported(
190 "SmoothQuantLinearPlan: descriptor weight_element != TWQ",
191 ));
192 }
193 if desc.activation_element != ElementKind::S8 {
194 return Err(Error::Unsupported(
195 "SmoothQuantLinearPlan: trailblazer only wires S8 activation \
196 (matches underlying quantized_linear_w8a8 kernel)",
197 ));
198 }
199 // Trailblazer dtype matrix mirrors QuantizedLinearPlan exactly
200 // (same underlying kernel): TIn ∈ {f32, f64}, TWQ = S8.
201 if !matches!(TIn::KIND, ElementKind::F32 | ElementKind::F64) {
202 return Err(Error::Unsupported(
203 "SmoothQuantLinearPlan: trailblazer only wires f32 / f64 \
204 output (f16 / bf16 follow when quantized_linear_w8a8 grows them)",
205 ));
206 }
207 if TWQ::KIND != ElementKind::S8 {
208 return Err(Error::Unsupported(
209 "SmoothQuantLinearPlan: trailblazer only wires S8 weight (U8 deferred)",
210 ));
211 }
212 if desc.m < 0 || desc.n < 0 || desc.k < 0 {
213 return Err(Error::InvalidProblem(
214 "SmoothQuantLinearPlan: m, n, k must be non-negative",
215 ));
216 }
217 if !desc.act_scale.is_finite() {
218 return Err(Error::InvalidProblem(
219 "SmoothQuantLinearPlan: act_scale must be finite",
220 ));
221 }
222 // We don't require act_scale > 0 strictly — a zero scale produces
223 // a zero output (degenerate but well-defined). Negative scales
224 // are unusual but mathematically valid (SmoothQuant's offline
225 // flow always yields positive scales; we don't enforce here).
226 let precision_guarantee = PrecisionGuarantee {
227 math_precision: if TIn::KIND == ElementKind::F64 {
228 MathPrecision::F64
229 } else {
230 MathPrecision::F32
231 },
232 accumulator: ElementKind::F32,
233 bit_stable_on_same_hardware: true,
234 deterministic: true,
235 };
236 let sku = KernelSku {
237 category: OpCategory::Quantization,
238 // SmoothQuant rides on the same op-kind discriminant as the
239 // existing W8A8 path — they're variants of the same logical
240 // op (W8A8 fused matmul), just with different
241 // activation-scale provenance.
242 op: QuantizeKind::QuantizedLinear as u16,
243 element: TIn::KIND,
244 aux_element: Some(TWQ::KIND),
245 layout: None,
246 epilogue: None,
247 arch: ArchSku::Sm80,
248 backend: BackendKind::Bespoke,
249 precision_guarantee,
250 };
251 Ok(Self {
252 desc: *desc,
253 sku,
254 _marker: PhantomData,
255 })
256 }
257
258 /// Validate args at run time.
259 pub fn can_implement(&self, args: &SmoothQuantLinearArgs<'_, TIn, TWQ>) -> Result<()> {
260 if args.act_q.shape != [self.desc.m, self.desc.k] {
261 return Err(Error::InvalidProblem(
262 "SmoothQuantLinearPlan: act_q shape != [M, K]",
263 ));
264 }
265 if args.weight_q.shape != [self.desc.n, self.desc.k] {
266 return Err(Error::InvalidProblem(
267 "SmoothQuantLinearPlan: weight_q shape != [N, K]",
268 ));
269 }
270 if args.weight_scale.shape != [self.desc.n] {
271 return Err(Error::InvalidProblem(
272 "SmoothQuantLinearPlan: weight_scale shape != [N]",
273 ));
274 }
275 if args.output.shape != [self.desc.m, self.desc.n] {
276 return Err(Error::InvalidProblem(
277 "SmoothQuantLinearPlan: output shape != [M, N]",
278 ));
279 }
280 if args.act_scale_scratch.shape != [self.desc.m] {
281 return Err(Error::InvalidProblem(
282 "SmoothQuantLinearPlan: act_scale_scratch shape != [M]",
283 ));
284 }
285 Ok(())
286 }
287
288 /// Workspace bytes — none. The act-scale `[M]` broadcast buffer is
289 /// caller-owned via the args bundle (`act_scale_scratch`),
290 /// allowing reuse across launches.
291 #[inline]
292 pub fn workspace_size(&self) -> usize {
293 0
294 }
295
296 /// Identity of the selected kernel.
297 #[inline]
298 pub fn sku(&self) -> KernelSku {
299 self.sku
300 }
301
302 /// Numerical guarantees.
303 #[inline]
304 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
305 self.sku.precision_guarantee
306 }
307
308 /// Launch.
309 ///
310 /// Two-pass: (1) fill the `[M]` scratch with the descriptor's
311 /// `act_scale`; (2) launch the `quantized_linear_w8a8` kernel
312 /// directly via the FFI (skips the dynamic-range-quantize pass
313 /// that [`super::QuantizedLinearPlan`] does).
314 pub fn run(
315 &self,
316 stream: &Stream,
317 _workspace: Workspace<'_>,
318 args: SmoothQuantLinearArgs<'_, TIn, TWQ>,
319 ) -> Result<()> {
320 self.can_implement(&args)?;
321 if (self.desc.m as i64) * (self.desc.n as i64) == 0 || self.desc.k == 0 {
322 return Ok(());
323 }
324
325 let stream_ptr = stream.as_raw() as *mut c_void;
326
327 // ---- Pass 1: broadcast `act_scale` across [M]. ---------------
328 //
329 // We invoke the fill FFI directly rather than constructing a
330 // FillPlan — both paths land on the same underlying kernel,
331 // and the direct call sidesteps the FillPlan's Element trait
332 // bound which would force us through transmute_copy gymnastics
333 // (the descriptor knows the scale as f32; the scratch is TIn).
334 let fill_ptr = args.act_scale_scratch.data.as_raw().0 as *mut c_void;
335 let fill_status = match TIn::KIND {
336 ElementKind::F32 => unsafe {
337 baracuda_kernels_sys::baracuda_kernels_fill_f32_run(
338 self.desc.m as i64,
339 fill_ptr,
340 self.desc.act_scale,
341 core::ptr::null_mut(),
342 0,
343 stream_ptr,
344 )
345 },
346 ElementKind::F64 => unsafe {
347 baracuda_kernels_sys::baracuda_kernels_fill_f64_run(
348 self.desc.m as i64,
349 fill_ptr,
350 self.desc.act_scale as f64,
351 core::ptr::null_mut(),
352 0,
353 stream_ptr,
354 )
355 },
356 _ => {
357 return Err(Error::Unsupported(
358 "SmoothQuantLinearPlan::run reached unsupported TIn at \
359 act-scale broadcast (select should have caught)",
360 ))
361 }
362 };
363 map_status(fill_status)?;
364
365 // ---- Pass 2: fused quantized-linear (int8 GEMM + dequant + FP store). ----
366 let weight_ptr = args.weight_q.data.as_raw().0 as *const c_void;
367 let act_q_ptr = args.act_q.data.as_raw().0 as *const c_void;
368 let act_scale_const = args.act_scale_scratch.data.as_raw().0 as *const c_void;
369 let w_scale_ptr = args.weight_scale.data.as_raw().0 as *const c_void;
370 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
371 let ql_status = match TIn::KIND {
372 ElementKind::F32 => unsafe {
373 baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f32_run(
374 self.desc.m,
375 self.desc.n,
376 self.desc.k,
377 weight_ptr,
378 act_q_ptr,
379 act_scale_const,
380 w_scale_ptr,
381 out_ptr,
382 core::ptr::null_mut(),
383 0,
384 stream_ptr,
385 )
386 },
387 ElementKind::F64 => unsafe {
388 baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f64_run(
389 self.desc.m,
390 self.desc.n,
391 self.desc.k,
392 weight_ptr,
393 act_q_ptr,
394 act_scale_const,
395 w_scale_ptr,
396 out_ptr,
397 core::ptr::null_mut(),
398 0,
399 stream_ptr,
400 )
401 },
402 _ => {
403 return Err(Error::Unsupported(
404 "SmoothQuantLinearPlan::run reached unsupported TIn at \
405 quantized-linear pass (select should have caught)",
406 ))
407 }
408 };
409 map_status(ql_status)
410 }
411}