Skip to main content

baracuda_kernels/quantize/
per_group.rs

1//! `quantize_per_group` forward plan.
2//!
3//! Per-group quantization: input is `[outer, axis_size]` (the higher-
4//! rank tensor is flattened by the caller before this plan), and the
5//! quantization axis (the rightmost dim) is partitioned into contiguous
6//! groups of size `group_size`. Each group gets its own `(scale, zp)`,
7//! so `scale` and `zero_point` have shape `[outer, num_groups]` where
8//! `num_groups = axis_size / group_size`.
9//!
10//! Used by INT4 LLM weight quantization (GPTQ / AWQ / GGML), typically
11//! with `group_size = 128`.
12//!
13//! Trailblazer scope: quant axis must be the **rightmost** axis so the
14//! layout is naturally group-contiguous. Higher-rank tensors with a
15//! non-last quant axis require a permute first (caller's
16//! responsibility).
17
18use core::ffi::c_void;
19use core::marker::PhantomData;
20
21use baracuda_cutlass::{Error, Result};
22use baracuda_driver::Stream;
23use baracuda_kernels_types::{
24    ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
25    PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace,
26};
27
28use super::{map_status, validate_input_element, validate_output_element};
29
30/// Descriptor for a `quantize_per_group` forward op.
31#[derive(Copy, Clone, Debug)]
32pub struct QuantizePerGroupDescriptor {
33    /// Product of all dims except the quant axis (the flattened
34    /// non-quant prefix).
35    pub outer_size: i32,
36    /// Length of the quant axis. Must be `>= 0` and divisible by
37    /// `group_size`.
38    pub axis_size: i32,
39    /// Group size — number of consecutive elements along the quant
40    /// axis that share a `(scale, zp)` pair. Typical: `128` for GPTQ
41    /// INT4 weights.
42    pub group_size: i32,
43    /// Quantization range lower bound.
44    pub q_min: i32,
45    /// Quantization range upper bound.
46    pub q_max: i32,
47    /// Input FP element kind.
48    pub input_element: ElementKind,
49    /// Output int element kind.
50    pub output_element: ElementKind,
51}
52
53impl QuantizePerGroupDescriptor {
54    /// Number of groups along the quant axis. Equals
55    /// `axis_size / group_size` (validated `axis_size % group_size == 0`).
56    #[inline]
57    pub fn num_groups(&self) -> i32 {
58        if self.group_size <= 0 {
59            0
60        } else {
61            self.axis_size / self.group_size
62        }
63    }
64}
65
66/// Args bundle for a `quantize_per_group` forward launch.
67pub struct QuantizePerGroupArgs<'a, TIn: Element, TOut: IntElement> {
68    /// Input `[outer_size, axis_size]` in FP.
69    pub input: TensorRef<'a, TIn, 2>,
70    /// Per-group scale `[outer_size, num_groups]` in FP.
71    pub scale: TensorRef<'a, TIn, 2>,
72    /// Per-group zero-point `[outer_size, num_groups]` in i32.
73    pub zero_point: TensorRef<'a, i32, 2>,
74    /// Output `[outer_size, axis_size]` in int.
75    pub output: TensorMut<'a, TOut, 2>,
76}
77
78/// `quantize_per_group` forward plan.
79///
80/// Per-group quantization along the rightmost axis. Each contiguous
81/// group of `group_size` elements gets its own `(scale, zp)` pair.
82///
83/// **When to use**: INT4 LLM weight quantization (GPTQ / AWQ / GGML),
84/// typically `group_size = 128`. Pair with
85/// [`QuantizePerGroupBackwardPlan`](crate::QuantizePerGroupBackwardPlan)
86/// for STE. For per-channel quant use
87/// [`QuantizePerChannelPlan`](crate::QuantizePerChannelPlan).
88///
89/// **Dtypes**: input FP `{f32, f64, f16, bf16}` × output int
90/// `{s8, u8}`. `scale` is input dtype; `zero_point` is `i32`.
91///
92/// **Shape limits**: rank-2 `[outer_size, axis_size]` (caller
93/// flattens higher-rank inputs); `axis_size % group_size == 0`;
94/// `group_size > 0`; `scale` and `zero_point` are
95/// `[outer_size, num_groups]`. Quant axis must be the **rightmost**
96/// axis (a permute is the caller's responsibility otherwise).
97/// `q_max ≥ q_min`.
98///
99/// **Workspace**: none.
100///
101/// **Precision guarantee**: deterministic, bit-stable. Round-ties-
102/// even.
103pub struct QuantizePerGroupPlan<TIn: Element, TOut: IntElement> {
104    desc: QuantizePerGroupDescriptor,
105    sku: KernelSku,
106    _marker: PhantomData<(TIn, TOut)>,
107}
108
109impl<TIn: Element, TOut: IntElement> QuantizePerGroupPlan<TIn, TOut> {
110    /// Pick a kernel for `desc`.
111    pub fn select(
112        _stream: &Stream,
113        desc: &QuantizePerGroupDescriptor,
114        _pref: PlanPreference,
115    ) -> Result<Self> {
116        if desc.input_element != TIn::KIND {
117            return Err(Error::Unsupported(
118                "QuantizePerGroupPlan: descriptor input_element != TIn",
119            ));
120        }
121        if desc.output_element != TOut::KIND {
122            return Err(Error::Unsupported(
123                "QuantizePerGroupPlan: descriptor output_element != TOut",
124            ));
125        }
126        validate_input_element(TIn::KIND, "QuantizePerGroupPlan: unsupported TIn dtype")?;
127        validate_output_element(TOut::KIND, "QuantizePerGroupPlan: unsupported TOut dtype")?;
128        if desc.outer_size < 0 || desc.axis_size < 0 {
129            return Err(Error::InvalidProblem(
130                "QuantizePerGroupPlan: outer_size and axis_size must be non-negative",
131            ));
132        }
133        if desc.group_size <= 0 {
134            return Err(Error::InvalidProblem(
135                "QuantizePerGroupPlan: group_size must be > 0",
136            ));
137        }
138        if desc.axis_size % desc.group_size != 0 {
139            return Err(Error::InvalidProblem(
140                "QuantizePerGroupPlan: axis_size must be a multiple of group_size",
141            ));
142        }
143        if desc.q_max < desc.q_min {
144            return Err(Error::InvalidProblem(
145                "QuantizePerGroupPlan: q_max < q_min",
146            ));
147        }
148        let sku = build_sku_group::<TIn, TOut>(QuantizeKind::PerGroup);
149        Ok(Self {
150            desc: *desc,
151            sku,
152            _marker: PhantomData,
153        })
154    }
155
156    /// Validate args.
157    pub fn can_implement(&self, args: &QuantizePerGroupArgs<'_, TIn, TOut>) -> Result<()> {
158        let expect_io = [self.desc.outer_size, self.desc.axis_size];
159        if args.input.shape != expect_io || args.output.shape != expect_io {
160            return Err(Error::InvalidProblem(
161                "QuantizePerGroupPlan: I/O tensor shape != [outer, axis_size]",
162            ));
163        }
164        let expect_sg = [self.desc.outer_size, self.desc.num_groups()];
165        if args.scale.shape != expect_sg || args.zero_point.shape != expect_sg {
166            return Err(Error::InvalidProblem(
167                "QuantizePerGroupPlan: scale / zp shape != [outer, num_groups]",
168            ));
169        }
170        Ok(())
171    }
172
173    /// Workspace bytes — none.
174    #[inline]
175    pub fn workspace_size(&self) -> usize {
176        0
177    }
178
179    /// Identity.
180    #[inline]
181    pub fn sku(&self) -> KernelSku {
182        self.sku
183    }
184
185    /// Numerical guarantees.
186    #[inline]
187    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
188        self.sku.precision_guarantee
189    }
190
191    /// Launch.
192    pub fn run(
193        &self,
194        stream: &Stream,
195        _workspace: Workspace<'_>,
196        args: QuantizePerGroupArgs<'_, TIn, TOut>,
197    ) -> Result<()> {
198        self.can_implement(&args)?;
199        let total = (self.desc.outer_size as i64) * (self.desc.axis_size as i64);
200        if total == 0 {
201            return Ok(());
202        }
203        let in_ptr = args.input.data.as_raw().0 as *const c_void;
204        let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
205        let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
206        let out_ptr = args.output.data.as_raw().0 as *mut c_void;
207        let stream_ptr = stream.as_raw() as *mut c_void;
208        let (outer, axis, g, qmin, qmax) = (
209            self.desc.outer_size,
210            self.desc.axis_size,
211            self.desc.group_size,
212            self.desc.q_min,
213            self.desc.q_max,
214        );
215        let status = match (TIn::KIND, TOut::KIND) {
216            (ElementKind::F32, ElementKind::S8) => unsafe {
217                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_f32_s8_run(
218                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
219                    core::ptr::null_mut(), 0, stream_ptr,
220                )
221            },
222            (ElementKind::F32, ElementKind::U8) => unsafe {
223                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_f32_u8_run(
224                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
225                    core::ptr::null_mut(), 0, stream_ptr,
226                )
227            },
228            (ElementKind::F64, ElementKind::S8) => unsafe {
229                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_f64_s8_run(
230                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
231                    core::ptr::null_mut(), 0, stream_ptr,
232                )
233            },
234            (ElementKind::F64, ElementKind::U8) => unsafe {
235                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_f64_u8_run(
236                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
237                    core::ptr::null_mut(), 0, stream_ptr,
238                )
239            },
240            (ElementKind::F16, ElementKind::S8) => unsafe {
241                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_f16_s8_run(
242                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
243                    core::ptr::null_mut(), 0, stream_ptr,
244                )
245            },
246            (ElementKind::F16, ElementKind::U8) => unsafe {
247                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_f16_u8_run(
248                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
249                    core::ptr::null_mut(), 0, stream_ptr,
250                )
251            },
252            (ElementKind::Bf16, ElementKind::S8) => unsafe {
253                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_bf16_s8_run(
254                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
255                    core::ptr::null_mut(), 0, stream_ptr,
256                )
257            },
258            (ElementKind::Bf16, ElementKind::U8) => unsafe {
259                baracuda_kernels_sys::baracuda_kernels_quantize_per_group_bf16_u8_run(
260                    outer, axis, g, qmin, qmax, in_ptr, sc_ptr, zp_ptr, out_ptr,
261                    core::ptr::null_mut(), 0, stream_ptr,
262                )
263            },
264            _ => {
265                return Err(Error::Unsupported(
266                    "QuantizePerGroupPlan::run unsupported (TIn, TOut)",
267                ))
268            }
269        };
270        map_status(status)
271    }
272}
273
274/// Build the [`KernelSku`] for a quantize-per-group-family plan.
275pub(crate) fn build_sku_group<TIn: Element, TOut: IntElement>(op: QuantizeKind) -> KernelSku {
276    let precision_guarantee = PrecisionGuarantee {
277        math_precision: if TIn::KIND == ElementKind::F64 {
278            MathPrecision::F64
279        } else {
280            MathPrecision::F32
281        },
282        accumulator: ElementKind::F32,
283        bit_stable_on_same_hardware: true,
284        deterministic: true,
285    };
286    KernelSku {
287        category: OpCategory::Quantization,
288        op: op as u16,
289        element: TIn::KIND,
290        aux_element: Some(TOut::KIND),
291        layout: None,
292        epilogue: None,
293        arch: ArchSku::Sm80,
294        backend: BackendKind::Bespoke,
295        precision_guarantee,
296    }
297}