Skip to main content

baracuda_kernels/quantize/
smoothquant.rs

1//! `SmoothQuantLinearPlan` — Phase 45 zero-new-CUDA composition.
2//!
3//! Implements the inference-time linear pass of **SmoothQuant**
4//! (Xiao et al. ICML 2023, MIT;
5//! [mit-han-lab/smoothquant](https://github.com/mit-han-lab/smoothquant)).
6//!
7//! SmoothQuant is an **offline algorithmic recipe** — a Python
8//! preprocessing pass that migrates outlier difficulty from
9//! activations to weights via a per-channel divisor `s[K]` so that
10//! both the smoothed activation `A_smooth = A / diag(s)` and the
11//! smoothed weight `W_smooth = diag(s) · W` quantize cleanly under a
12//! single per-tensor activation scale + a per-output-channel weight
13//! scale. The smoothing itself is **not** a CUDA kernel; it lives in
14//! the Python flow at training-prep time. baracuda only needs to
15//! consume the already-smoothed-and-quantized tensors.
16//!
17//! The inference-time math is the standard W8A8 dequant:
18//!
19//! ```text
20//! y[m, n] = act_scale · weight_scale[n] · Σ_k a_q[m, k] · w_q[n, k]
21//! ```
22//!
23//! Differences from [`super::QuantizedLinearPlan`]:
24//!
25//! - **Per-tensor activation scale** (single `f32`) vs per-token
26//!   `[M]` dynamic-range scale. The whole point of SmoothQuant is
27//!   that one static scalar suffices once outliers are migrated.
28//! - **Caller-pre-quantized int8 activation** — no internal
29//!   `dynamic_range_quantize_per_token_sym` pass.
30//!
31//! Composition strategy: the bespoke `quantized_linear_w8a8_*`
32//! kernel (vendored Milestone 8.3) consumes `scale_a: [M]` — we
33//! reuse it verbatim by having the caller supply an `[M]` scratch
34//! buffer that we fill with the constant `act_scale` via
35//! [`super::super::FillPlan`] before the matmul. Zero new CUDA.
36//!
37//! ## Layout
38//!
39//! - `act_q`        : `[M, K]` int8 (row-major).
40//! - `weight_q`     : `[N, K]` int8 (row-major — one row per output channel,
41//!   matching PyTorch `nn.Linear.weight` layout, same as
42//!   [`super::QuantizedLinearPlan`]).
43//! - `weight_scale` : `[N]` FP (per-output-channel; saved alongside
44//!   the smoothed-then-quantized weights from the offline flow).
45//! - `act_scale`    : single `f32` scalar in the descriptor (per-tensor).
46//! - `output`       : `[M, N]` FP.
47//!
48//! ## Trailblazer scope
49//!
50//! - `TIn ∈ {f32, f64}` activation-scale + weight-scale + output;
51//!   `TWQ = S8` weight; activation = `S8`.
52//!   (f16 / bf16 / u8-weight follow the same matrix as
53//!   [`QuantizedLinearPlan`]; deferred until the underlying bespoke
54//!   `quantized_linear_w8a8` kernel grows the dtypes.)
55//! - **Inference-only** — no backward. The W8A8 path is forward-only
56//!   by convention; if a downstream needs gradients, they should use
57//!   [`super::FakeQuantizePlan`] for QAT.
58//! - **No bias fusion in trailblazer** — bias addition is a separate
59//!   downstream op (Affine / Binary Add). The underlying
60//!   `quantized_linear_w8a8` kernel doesn't take a bias today and
61//!   we don't synthesize one here.
62
63use core::ffi::c_void;
64use core::marker::PhantomData;
65
66use baracuda_cutlass::{Error, Result};
67use baracuda_driver::Stream;
68use baracuda_kernels_types::{
69    ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
70    PlanPreference, PrecisionGuarantee, QuantizeKind, S8, TensorMut, TensorRef, Workspace,
71};
72
73use super::map_status;
74
75/// Descriptor for a `SmoothQuant` linear op.
76///
77/// The per-tensor activation scale lives in the descriptor (not the
78/// args) because in the SmoothQuant flow it's part of the model's
79/// frozen quantization metadata — it doesn't change between launches
80/// for the same layer.
81#[derive(Copy, Clone, Debug)]
82#[non_exhaustive]
83pub struct SmoothQuantLinearDescriptor {
84    /// Number of token rows in the activation (and rows of the output).
85    pub m: i32,
86    /// Number of output channels (rows of `weight_q`, cols of output).
87    pub n: i32,
88    /// Inner reduction dim (cols of `act_q` and `weight_q`).
89    pub k: i32,
90    /// Per-tensor activation scale produced by the offline SmoothQuant
91    /// Python flow. Always `f32` regardless of `TIn` — the underlying
92    /// `quantized_linear_w8a8` kernel does the scale multiply in float
93    /// space irrespective of output dtype.
94    pub act_scale: f32,
95    /// Activation int element kind. Today wired only for `S8`.
96    pub activation_element: ElementKind,
97    /// Weight int element kind. Today wired only for `S8`.
98    pub weight_element: ElementKind,
99    /// Output FP element kind. Must match `TIn::KIND`.
100    pub output_element: ElementKind,
101}
102
103impl SmoothQuantLinearDescriptor {
104    /// Construct a `SmoothQuantLinearDescriptor` for the given problem
105    /// shape and per-tensor activation scale. Defaults `S8` for both
106    /// activation and weight; output element matches `TIn::KIND`.
107    pub fn new<TIn: Element>(m: i32, n: i32, k: i32, act_scale: f32) -> Self {
108        Self {
109            m,
110            n,
111            k,
112            act_scale,
113            activation_element: ElementKind::S8,
114            weight_element: ElementKind::S8,
115            output_element: TIn::KIND,
116        }
117    }
118}
119
120/// Args bundle for a `SmoothQuant` linear launch.
121///
122/// `act_scale_scratch` is a caller-owned `[M]` FP scratch buffer used
123/// to broadcast the descriptor's per-tensor `act_scale` into the
124/// per-row form the underlying `quantized_linear_w8a8` kernel
125/// consumes. Caller-owned so it can be reused across launches without
126/// re-allocation — the Plan's `workspace_size()` returns 0.
127pub struct SmoothQuantLinearArgs<'a, TIn: Element, TWQ: IntElement> {
128    /// Pre-quantized int8 activation `[M, K]`.
129    pub act_q: TensorRef<'a, S8, 2>,
130    /// Pre-smoothed-then-quantized int8 weight `[N, K]`.
131    pub weight_q: TensorRef<'a, TWQ, 2>,
132    /// Per-output-channel weight scale `[N]` in FP.
133    pub weight_scale: TensorRef<'a, TIn, 1>,
134    /// FP output `[M, N]`.
135    pub output: TensorMut<'a, TIn, 2>,
136    /// Scratch for the per-row broadcast of `act_scale`. `[M]` FP.
137    /// Caller-owned; reused across launches. Populated by the plan
138    /// before the matmul launch.
139    pub act_scale_scratch: TensorMut<'a, TIn, 1>,
140}
141
142/// `SmoothQuant` linear plan — pure Rust composition over the
143/// bespoke `quantized_linear_w8a8` kernel.
144///
145/// **When to use**: SmoothQuant inference matmul. Activation has
146/// already been smoothed (divided by per-channel `s[K]`) and
147/// quantized per-tensor to int8; weight has already been smoothed
148/// (multiplied by `s[K]`) and quantized per-output-channel to int8;
149/// caller passes both, plus the static per-tensor act-scale + per-N
150/// weight-scale, to this plan.
151///
152/// **Dtypes (trailblazer)**: `TIn (scales/out) ∈ {f32, f64}`;
153/// `TWQ = S8` weight; activation is fixed at `S8`. f16 / bf16 / u8
154/// weight follow once the underlying `quantized_linear_w8a8` kernel
155/// grows those dtypes (same matrix as
156/// [`super::QuantizedLinearPlan`]).
157///
158/// **Shape limits**: `act_q` `[M, K]`; `weight_q` `[N, K]`;
159/// `weight_scale` `[N]`; `output` `[M, N]`. `[N, K]` weight layout
160/// matches `y = x · W^T` (PyTorch `nn.Linear.weight`).
161///
162/// **Workspace**: zero in [`Workspace`]. Caller supplies
163/// `act_scale_scratch` `[M]` (FP) in [`SmoothQuantLinearArgs`] for
164/// the act-scale broadcast.
165///
166/// **Precision guarantee**: deterministic, bit-stable on the same
167/// hardware (inherits from the underlying `quantized_linear_w8a8`
168/// kernel — register-only int32 accumulator + serial FP scale
169/// multiply, no atomics).
170pub struct SmoothQuantLinearPlan<TIn: Element, TWQ: IntElement> {
171    desc: SmoothQuantLinearDescriptor,
172    sku: KernelSku,
173    _marker: PhantomData<(TIn, TWQ)>,
174}
175
176impl<TIn: Element, TWQ: IntElement> SmoothQuantLinearPlan<TIn, TWQ> {
177    /// Pick a kernel for `desc`.
178    pub fn select(
179        _stream: &Stream,
180        desc: &SmoothQuantLinearDescriptor,
181        _pref: PlanPreference,
182    ) -> Result<Self> {
183        if desc.output_element != TIn::KIND {
184            return Err(Error::Unsupported(
185                "SmoothQuantLinearPlan: descriptor output_element != TIn",
186            ));
187        }
188        if desc.weight_element != TWQ::KIND {
189            return Err(Error::Unsupported(
190                "SmoothQuantLinearPlan: descriptor weight_element != TWQ",
191            ));
192        }
193        if desc.activation_element != ElementKind::S8 {
194            return Err(Error::Unsupported(
195                "SmoothQuantLinearPlan: trailblazer only wires S8 activation \
196                 (matches underlying quantized_linear_w8a8 kernel)",
197            ));
198        }
199        // Trailblazer dtype matrix mirrors QuantizedLinearPlan exactly
200        // (same underlying kernel): TIn ∈ {f32, f64}, TWQ = S8.
201        if !matches!(TIn::KIND, ElementKind::F32 | ElementKind::F64) {
202            return Err(Error::Unsupported(
203                "SmoothQuantLinearPlan: trailblazer only wires f32 / f64 \
204                 output (f16 / bf16 follow when quantized_linear_w8a8 grows them)",
205            ));
206        }
207        if TWQ::KIND != ElementKind::S8 {
208            return Err(Error::Unsupported(
209                "SmoothQuantLinearPlan: trailblazer only wires S8 weight (U8 deferred)",
210            ));
211        }
212        if desc.m < 0 || desc.n < 0 || desc.k < 0 {
213            return Err(Error::InvalidProblem(
214                "SmoothQuantLinearPlan: m, n, k must be non-negative",
215            ));
216        }
217        if !desc.act_scale.is_finite() {
218            return Err(Error::InvalidProblem(
219                "SmoothQuantLinearPlan: act_scale must be finite",
220            ));
221        }
222        // We don't require act_scale > 0 strictly — a zero scale produces
223        // a zero output (degenerate but well-defined). Negative scales
224        // are unusual but mathematically valid (SmoothQuant's offline
225        // flow always yields positive scales; we don't enforce here).
226        let precision_guarantee = PrecisionGuarantee {
227            math_precision: if TIn::KIND == ElementKind::F64 {
228                MathPrecision::F64
229            } else {
230                MathPrecision::F32
231            },
232            accumulator: ElementKind::F32,
233            bit_stable_on_same_hardware: true,
234            deterministic: true,
235        };
236        let sku = KernelSku {
237            category: OpCategory::Quantization,
238            // SmoothQuant rides on the same op-kind discriminant as the
239            // existing W8A8 path — they're variants of the same logical
240            // op (W8A8 fused matmul), just with different
241            // activation-scale provenance.
242            op: QuantizeKind::QuantizedLinear as u16,
243            element: TIn::KIND,
244            aux_element: Some(TWQ::KIND),
245            layout: None,
246            epilogue: None,
247            arch: ArchSku::Sm80,
248            backend: BackendKind::Bespoke,
249            precision_guarantee,
250        };
251        Ok(Self {
252            desc: *desc,
253            sku,
254            _marker: PhantomData,
255        })
256    }
257
258    /// Validate args at run time.
259    pub fn can_implement(&self, args: &SmoothQuantLinearArgs<'_, TIn, TWQ>) -> Result<()> {
260        if args.act_q.shape != [self.desc.m, self.desc.k] {
261            return Err(Error::InvalidProblem(
262                "SmoothQuantLinearPlan: act_q shape != [M, K]",
263            ));
264        }
265        if args.weight_q.shape != [self.desc.n, self.desc.k] {
266            return Err(Error::InvalidProblem(
267                "SmoothQuantLinearPlan: weight_q shape != [N, K]",
268            ));
269        }
270        if args.weight_scale.shape != [self.desc.n] {
271            return Err(Error::InvalidProblem(
272                "SmoothQuantLinearPlan: weight_scale shape != [N]",
273            ));
274        }
275        if args.output.shape != [self.desc.m, self.desc.n] {
276            return Err(Error::InvalidProblem(
277                "SmoothQuantLinearPlan: output shape != [M, N]",
278            ));
279        }
280        if args.act_scale_scratch.shape != [self.desc.m] {
281            return Err(Error::InvalidProblem(
282                "SmoothQuantLinearPlan: act_scale_scratch shape != [M]",
283            ));
284        }
285        Ok(())
286    }
287
288    /// Workspace bytes — none. The act-scale `[M]` broadcast buffer is
289    /// caller-owned via the args bundle (`act_scale_scratch`),
290    /// allowing reuse across launches.
291    #[inline]
292    pub fn workspace_size(&self) -> usize {
293        0
294    }
295
296    /// Identity of the selected kernel.
297    #[inline]
298    pub fn sku(&self) -> KernelSku {
299        self.sku
300    }
301
302    /// Numerical guarantees.
303    #[inline]
304    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
305        self.sku.precision_guarantee
306    }
307
308    /// Launch.
309    ///
310    /// Two-pass: (1) fill the `[M]` scratch with the descriptor's
311    /// `act_scale`; (2) launch the `quantized_linear_w8a8` kernel
312    /// directly via the FFI (skips the dynamic-range-quantize pass
313    /// that [`super::QuantizedLinearPlan`] does).
314    pub fn run(
315        &self,
316        stream: &Stream,
317        _workspace: Workspace<'_>,
318        args: SmoothQuantLinearArgs<'_, TIn, TWQ>,
319    ) -> Result<()> {
320        self.can_implement(&args)?;
321        if (self.desc.m as i64) * (self.desc.n as i64) == 0 || self.desc.k == 0 {
322            return Ok(());
323        }
324
325        let stream_ptr = stream.as_raw() as *mut c_void;
326
327        // ---- Pass 1: broadcast `act_scale` across [M]. ---------------
328        //
329        // We invoke the fill FFI directly rather than constructing a
330        // FillPlan — both paths land on the same underlying kernel,
331        // and the direct call sidesteps the FillPlan's Element trait
332        // bound which would force us through transmute_copy gymnastics
333        // (the descriptor knows the scale as f32; the scratch is TIn).
334        let fill_ptr = args.act_scale_scratch.data.as_raw().0 as *mut c_void;
335        let fill_status = match TIn::KIND {
336            ElementKind::F32 => unsafe {
337                baracuda_kernels_sys::baracuda_kernels_fill_f32_run(
338                    self.desc.m as i64,
339                    fill_ptr,
340                    self.desc.act_scale,
341                    core::ptr::null_mut(),
342                    0,
343                    stream_ptr,
344                )
345            },
346            ElementKind::F64 => unsafe {
347                baracuda_kernels_sys::baracuda_kernels_fill_f64_run(
348                    self.desc.m as i64,
349                    fill_ptr,
350                    self.desc.act_scale as f64,
351                    core::ptr::null_mut(),
352                    0,
353                    stream_ptr,
354                )
355            },
356            _ => {
357                return Err(Error::Unsupported(
358                    "SmoothQuantLinearPlan::run reached unsupported TIn at \
359                     act-scale broadcast (select should have caught)",
360                ))
361            }
362        };
363        map_status(fill_status)?;
364
365        // ---- Pass 2: fused quantized-linear (int8 GEMM + dequant + FP store). ----
366        let weight_ptr = args.weight_q.data.as_raw().0 as *const c_void;
367        let act_q_ptr = args.act_q.data.as_raw().0 as *const c_void;
368        let act_scale_const = args.act_scale_scratch.data.as_raw().0 as *const c_void;
369        let w_scale_ptr = args.weight_scale.data.as_raw().0 as *const c_void;
370        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
371        let ql_status = match TIn::KIND {
372            ElementKind::F32 => unsafe {
373                baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f32_run(
374                    self.desc.m,
375                    self.desc.n,
376                    self.desc.k,
377                    weight_ptr,
378                    act_q_ptr,
379                    act_scale_const,
380                    w_scale_ptr,
381                    out_ptr,
382                    core::ptr::null_mut(),
383                    0,
384                    stream_ptr,
385                )
386            },
387            ElementKind::F64 => unsafe {
388                baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f64_run(
389                    self.desc.m,
390                    self.desc.n,
391                    self.desc.k,
392                    weight_ptr,
393                    act_q_ptr,
394                    act_scale_const,
395                    w_scale_ptr,
396                    out_ptr,
397                    core::ptr::null_mut(),
398                    0,
399                    stream_ptr,
400                )
401            },
402            _ => {
403                return Err(Error::Unsupported(
404                    "SmoothQuantLinearPlan::run reached unsupported TIn at \
405                     quantized-linear pass (select should have caught)",
406                ))
407            }
408        };
409        map_status(ql_status)
410    }
411}