Skip to main content

baracuda_kernels/quantize/
per_token.rs

1//! `quantize_per_token` forward plan.
2//!
3//! Per-row quantization for 2-D activations: input `[N, D]`; one
4//! `(scale, zero_point)` pair per token row. Used by W8A8 LLM
5//! activation quantization at inference time (the caller computes
6//! `scale[n]` from each row's max-abs dynamic range).
7//!
8//! FW: `q[n, d] = clamp(round(x[n, d] / scale[n]) + zero_point[n],
9//!                     qmin, qmax)`.
10//!
11//! BW: see [`crate::quantize::QuantizePerTokenBackwardPlan`].
12
13use core::ffi::c_void;
14use core::marker::PhantomData;
15
16use baracuda_cutlass::{Error, Result};
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19    ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
20    PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace,
21};
22
23use super::{map_status, validate_input_element, validate_output_element};
24
25/// Descriptor for a `quantize_per_token` forward op.
26#[derive(Copy, Clone, Debug)]
27pub struct QuantizePerTokenDescriptor {
28    /// Number of token rows (first axis of input/output).
29    pub n: i32,
30    /// Feature dim (second axis of input/output).
31    pub d: i32,
32    /// Quantization range lower bound (e.g. `-128` for s8 symmetric).
33    pub q_min: i32,
34    /// Quantization range upper bound (e.g. `127` for s8 symmetric).
35    pub q_max: i32,
36    /// Input FP element kind. Must match `TIn::KIND`.
37    pub input_element: ElementKind,
38    /// Output int element kind (s8 or u8). Must match `TOut::KIND`.
39    pub output_element: ElementKind,
40}
41
42/// Args bundle for a `quantize_per_token` forward launch.
43pub struct QuantizePerTokenArgs<'a, TIn: Element, TOut: IntElement> {
44    /// Input `[N, D]` in FP.
45    pub input: TensorRef<'a, TIn, 2>,
46    /// Per-row scale `[N]` in FP.
47    pub scale: TensorRef<'a, TIn, 1>,
48    /// Per-row zero-point `[N]` in i32.
49    pub zero_point: TensorRef<'a, i32, 1>,
50    /// Output `[N, D]` in int.
51    pub output: TensorMut<'a, TOut, 2>,
52}
53
54/// `quantize_per_token` forward plan.
55///
56/// `q[n, d] = clamp(round(x[n, d] / scale[n]) + zero_point[n], qmin, qmax)`.
57/// Per-row quantization for 2-D activations (W8A8 LLM-style).
58///
59/// **When to use**: forward activation quantization at inference (one
60/// `(scale, zp)` pair per token row, computed from the row's max-abs
61/// range upstream). For weight quantization use
62/// [`QuantizePerChannelPlan`](crate::QuantizePerChannelPlan); for
63/// global scale use [`QuantizePerTensorPlan`](crate::QuantizePerTensorPlan).
64/// Pair with [`QuantizePerTokenBackwardPlan`](crate::QuantizePerTokenBackwardPlan)
65/// for STE.
66///
67/// **Dtypes**: input FP `{f32, f64, f16, bf16}` × output int
68/// `{s8, u8}`. `scale[]` is input dtype; `zero_point[]` is `i32`.
69///
70/// **Shape limits**: rank-2 `[N, D]`; `scale` and `zero_point` are
71/// `[N]`. `q_max ≥ q_min`.
72///
73/// **Workspace**: none.
74///
75/// **Precision guarantee**: deterministic, bit-stable. One thread
76/// per output cell, no atomics. Round-ties-even.
77pub struct QuantizePerTokenPlan<TIn: Element, TOut: IntElement> {
78    desc: QuantizePerTokenDescriptor,
79    sku: KernelSku,
80    _marker: PhantomData<(TIn, TOut)>,
81}
82
83impl<TIn: Element, TOut: IntElement> QuantizePerTokenPlan<TIn, TOut> {
84    /// Pick a kernel for `desc`.
85    pub fn select(
86        _stream: &Stream,
87        desc: &QuantizePerTokenDescriptor,
88        _pref: PlanPreference,
89    ) -> Result<Self> {
90        if desc.input_element != TIn::KIND {
91            return Err(Error::Unsupported(
92                "QuantizePerTokenPlan: descriptor input_element != type parameter TIn",
93            ));
94        }
95        if desc.output_element != TOut::KIND {
96            return Err(Error::Unsupported(
97                "QuantizePerTokenPlan: descriptor output_element != type parameter TOut",
98            ));
99        }
100        validate_input_element(TIn::KIND, "QuantizePerTokenPlan: unsupported TIn dtype")?;
101        validate_output_element(TOut::KIND, "QuantizePerTokenPlan: unsupported TOut dtype")?;
102        if desc.n < 0 || desc.d < 0 {
103            return Err(Error::InvalidProblem(
104                "QuantizePerTokenPlan: n and d must be non-negative",
105            ));
106        }
107        if desc.q_max < desc.q_min {
108            return Err(Error::InvalidProblem(
109                "QuantizePerTokenPlan: q_max < q_min",
110            ));
111        }
112        let sku = build_sku::<TIn, TOut>(QuantizeKind::PerToken);
113        Ok(Self {
114            desc: *desc,
115            sku,
116            _marker: PhantomData,
117        })
118    }
119
120    /// Validate args at run time.
121    pub fn can_implement(&self, args: &QuantizePerTokenArgs<'_, TIn, TOut>) -> Result<()> {
122        if args.input.shape != [self.desc.n, self.desc.d] {
123            return Err(Error::InvalidProblem(
124                "QuantizePerTokenPlan: input shape != [n, d]",
125            ));
126        }
127        if args.output.shape != [self.desc.n, self.desc.d] {
128            return Err(Error::InvalidProblem(
129                "QuantizePerTokenPlan: output shape != [n, d]",
130            ));
131        }
132        if args.scale.shape != [self.desc.n] {
133            return Err(Error::InvalidProblem(
134                "QuantizePerTokenPlan: scale shape != [n]",
135            ));
136        }
137        if args.zero_point.shape != [self.desc.n] {
138            return Err(Error::InvalidProblem(
139                "QuantizePerTokenPlan: zero_point shape != [n]",
140            ));
141        }
142        Ok(())
143    }
144
145    /// Workspace bytes — none.
146    #[inline]
147    pub fn workspace_size(&self) -> usize {
148        0
149    }
150
151    /// Identity of the selected kernel.
152    #[inline]
153    pub fn sku(&self) -> KernelSku {
154        self.sku
155    }
156
157    /// Numerical guarantees.
158    #[inline]
159    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
160        self.sku.precision_guarantee
161    }
162
163    /// Launch.
164    pub fn run(
165        &self,
166        stream: &Stream,
167        _workspace: Workspace<'_>,
168        args: QuantizePerTokenArgs<'_, TIn, TOut>,
169    ) -> Result<()> {
170        self.can_implement(&args)?;
171        let total = (self.desc.n as i64) * (self.desc.d as i64);
172        if total == 0 {
173            return Ok(());
174        }
175        let in_ptr = args.input.data.as_raw().0 as *const c_void;
176        let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
177        let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
178        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
179        let stream_ptr = stream.as_raw() as *mut c_void;
180
181        let status = match (TIn::KIND, TOut::KIND) {
182            (ElementKind::F32, ElementKind::S8) => unsafe {
183                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f32_s8_run(
184                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
185                    in_ptr, sc_ptr, zp_ptr, out_ptr,
186                    core::ptr::null_mut(), 0, stream_ptr,
187                )
188            },
189            (ElementKind::F32, ElementKind::U8) => unsafe {
190                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f32_u8_run(
191                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
192                    in_ptr, sc_ptr, zp_ptr, out_ptr,
193                    core::ptr::null_mut(), 0, stream_ptr,
194                )
195            },
196            (ElementKind::F64, ElementKind::S8) => unsafe {
197                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f64_s8_run(
198                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
199                    in_ptr, sc_ptr, zp_ptr, out_ptr,
200                    core::ptr::null_mut(), 0, stream_ptr,
201                )
202            },
203            (ElementKind::F64, ElementKind::U8) => unsafe {
204                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f64_u8_run(
205                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
206                    in_ptr, sc_ptr, zp_ptr, out_ptr,
207                    core::ptr::null_mut(), 0, stream_ptr,
208                )
209            },
210            (ElementKind::F16, ElementKind::S8) => unsafe {
211                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f16_s8_run(
212                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
213                    in_ptr, sc_ptr, zp_ptr, out_ptr,
214                    core::ptr::null_mut(), 0, stream_ptr,
215                )
216            },
217            (ElementKind::F16, ElementKind::U8) => unsafe {
218                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f16_u8_run(
219                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
220                    in_ptr, sc_ptr, zp_ptr, out_ptr,
221                    core::ptr::null_mut(), 0, stream_ptr,
222                )
223            },
224            (ElementKind::Bf16, ElementKind::S8) => unsafe {
225                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_bf16_s8_run(
226                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
227                    in_ptr, sc_ptr, zp_ptr, out_ptr,
228                    core::ptr::null_mut(), 0, stream_ptr,
229                )
230            },
231            (ElementKind::Bf16, ElementKind::U8) => unsafe {
232                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_bf16_u8_run(
233                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
234                    in_ptr, sc_ptr, zp_ptr, out_ptr,
235                    core::ptr::null_mut(), 0, stream_ptr,
236                )
237            },
238            _ => {
239                return Err(Error::Unsupported(
240                    "QuantizePerTokenPlan::run reached unsupported (TIn, TOut) combination",
241                ))
242            }
243        };
244        map_status(status)
245    }
246}
247
248/// Build the [`KernelSku`] for a quantize-per-token-family plan.
249pub(crate) fn build_sku<TIn: Element, TOut: IntElement>(op: QuantizeKind) -> KernelSku {
250    let precision_guarantee = PrecisionGuarantee {
251        math_precision: if TIn::KIND == ElementKind::F64 {
252            MathPrecision::F64
253        } else {
254            MathPrecision::F32
255        },
256        accumulator: ElementKind::F32,
257        // Deterministic — one thread per output cell, no atomics.
258        bit_stable_on_same_hardware: true,
259        deterministic: true,
260    };
261    KernelSku {
262        category: OpCategory::Quantization,
263        op: op as u16,
264        element: TIn::KIND,
265        aux_element: Some(TOut::KIND),
266        layout: None,
267        epilogue: None,
268        arch: ArchSku::Sm80,
269        backend: BackendKind::Bespoke,
270        precision_guarantee,
271    }
272}