Skip to main content

baracuda_kernels/quantize/
per_channel.rs

1//! `quantize_per_channel` forward plan.
2//!
3//! Per-axis-slice quantization: `q[..., c, ...] = clamp(round(x[..., c, ...]
4//! / scale[c]) + zero_point[c], q_min, q_max)`. `scale` and `zero_point`
5//! are 1-D tensors of length `C` (the extent of the channel axis). The
6//! typical use is per-output-channel weight quantization for conv /
7//! linear layers. PyTorch `torch.quantize_per_channel`.
8//!
9//! Trailblazer layout: the input / output are **rank-4 contiguous** —
10//! the caller pads lower-rank tensors to 4 with leading or trailing
11//! extents of `1`. `axis` selects which of the 4 dims indexes
12//! `scale[]` / `zero_point[]`. Strided per-channel is deferred.
13
14use core::ffi::c_void;
15use core::marker::PhantomData;
16
17use baracuda_cutlass::{Error, Result};
18use baracuda_driver::Stream;
19use baracuda_kernels_types::{
20    Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
21    TensorMut, TensorRef, Workspace,
22};
23
24use super::map_status;
25use super::per_tensor::build_sku;
26use super::{validate_input_element, validate_output_element};
27
28/// Max rank supported by the per-channel kernels (matches `MAX_RANK` in
29/// `kernels/include/baracuda_quantize.cuh`).
30pub const MAX_RANK: usize = 4;
31
32/// Descriptor for a `quantize_per_channel` forward op.
33#[derive(Copy, Clone, Debug)]
34pub struct QuantizePerChannelDescriptor {
35    /// 4-D shape (caller pads rank with `1`'s).
36    pub shape: [i32; MAX_RANK],
37    /// Logical rank (used for validation; the kernel always sees rank 4).
38    pub rank: u8,
39    /// Axis index in `[0, 4)` that indexes the per-channel `scale[]` /
40    /// `zero_point[]` vectors.
41    pub axis: u8,
42    /// Quantization range lower bound.
43    pub q_min: i32,
44    /// Quantization range upper bound.
45    pub q_max: i32,
46    /// Input FP element kind.
47    pub input_element: ElementKind,
48    /// Output int element kind.
49    pub output_element: ElementKind,
50}
51
52/// Args bundle for a `quantize_per_channel` forward launch.
53pub struct QuantizePerChannelArgs<'a, TIn: Element, TOut: IntElement> {
54    /// Input `[D0, D1, D2, D3]` in FP.
55    pub input: TensorRef<'a, TIn, 4>,
56    /// Per-channel scale `[C]` in FP, where `C = shape[axis]`.
57    pub scale: TensorRef<'a, TIn, 1>,
58    /// Per-channel zero point `[C]` in i32.
59    pub zero_point: TensorRef<'a, i32, 1>,
60    /// Output `[D0, D1, D2, D3]` in int.
61    pub output: TensorMut<'a, TOut, 4>,
62}
63
64/// `quantize_per_channel` forward plan.
65///
66/// `q[..., c, ...] = clamp(round(x[..., c, ...] / scale[c]) + zero_point[c], q_min, q_max)`.
67/// Per-axis quantization (PyTorch `torch.quantize_per_channel`).
68///
69/// **When to use**: post-training quantization of conv / linear
70/// weights along the output-channel axis. For activations use
71/// [`QuantizePerTokenPlan`](crate::QuantizePerTokenPlan); for whole-
72/// tensor scale use [`QuantizePerTensorPlan`](crate::QuantizePerTensorPlan).
73/// Pair with [`QuantizePerChannelBackwardPlan`](crate::QuantizePerChannelBackwardPlan)
74/// for STE.
75///
76/// **Dtypes**: input FP `{f32, f64, f16, bf16}` × output int
77/// `{s8, u8}`. `scale[]` is input dtype; `zero_point[]` is `i32`.
78/// Sub-byte (`s4` / `u4`) deferred.
79///
80/// **Shape limits**: rank-4 contiguous (caller pads lower-rank
81/// tensors with `1`'s); `axis ∈ [0, 4)`; per-channel vectors have
82/// length `shape[axis]`. `q_max ≥ q_min`. Strided per-channel is
83/// deferred.
84///
85/// **Workspace**: none.
86///
87/// **Precision guarantee**: deterministic, bit-stable on same
88/// hardware. Round-ties-even.
89pub struct QuantizePerChannelPlan<TIn: Element, TOut: IntElement> {
90    desc: QuantizePerChannelDescriptor,
91    sku: KernelSku,
92    _marker: PhantomData<(TIn, TOut)>,
93}
94
95impl<TIn: Element, TOut: IntElement> QuantizePerChannelPlan<TIn, TOut> {
96    /// Pick a kernel.
97    pub fn select(
98        _stream: &Stream,
99        desc: &QuantizePerChannelDescriptor,
100        _pref: PlanPreference,
101    ) -> Result<Self> {
102        if desc.input_element != TIn::KIND {
103            return Err(Error::Unsupported(
104                "QuantizePerChannelPlan: descriptor input_element != TIn",
105            ));
106        }
107        if desc.output_element != TOut::KIND {
108            return Err(Error::Unsupported(
109                "QuantizePerChannelPlan: descriptor output_element != TOut",
110            ));
111        }
112        validate_input_element(TIn::KIND, "QuantizePerChannelPlan: unsupported TIn dtype")?;
113        validate_output_element(TOut::KIND, "QuantizePerChannelPlan: unsupported TOut dtype")?;
114        if (desc.axis as usize) >= MAX_RANK {
115            return Err(Error::InvalidProblem(
116                "QuantizePerChannelPlan: axis out of range [0, MAX_RANK)",
117            ));
118        }
119        if (desc.rank as usize) == 0 || (desc.rank as usize) > MAX_RANK {
120            return Err(Error::InvalidProblem(
121                "QuantizePerChannelPlan: rank must be in [1, MAX_RANK]",
122            ));
123        }
124        for &d in desc.shape.iter() {
125            if d < 0 {
126                return Err(Error::InvalidProblem(
127                    "QuantizePerChannelPlan: shape dims must be non-negative",
128                ));
129            }
130        }
131        if desc.q_max < desc.q_min {
132            return Err(Error::InvalidProblem("QuantizePerChannelPlan: q_max < q_min"));
133        }
134        let sku = build_sku::<TIn, TOut>(QuantizeKind::PerChannel);
135        Ok(Self {
136            desc: *desc,
137            sku,
138            _marker: PhantomData,
139        })
140    }
141
142    /// Validate args.
143    pub fn can_implement(&self, args: &QuantizePerChannelArgs<'_, TIn, TOut>) -> Result<()> {
144        if args.input.shape != self.desc.shape {
145            return Err(Error::InvalidProblem(
146                "QuantizePerChannelPlan: input shape mismatch with descriptor",
147            ));
148        }
149        if args.output.shape != self.desc.shape {
150            return Err(Error::InvalidProblem(
151                "QuantizePerChannelPlan: output shape mismatch with descriptor",
152            ));
153        }
154        let c = self.desc.shape[self.desc.axis as usize];
155        if args.scale.shape != [c] {
156            return Err(Error::InvalidProblem(
157                "QuantizePerChannelPlan: scale shape != [shape[axis]]",
158            ));
159        }
160        if args.zero_point.shape != [c] {
161            return Err(Error::InvalidProblem(
162                "QuantizePerChannelPlan: zero_point shape != [shape[axis]]",
163            ));
164        }
165        Ok(())
166    }
167
168    /// Workspace bytes.
169    #[inline]
170    pub fn workspace_size(&self) -> usize {
171        0
172    }
173
174    /// Identity.
175    #[inline]
176    pub fn sku(&self) -> KernelSku {
177        self.sku
178    }
179
180    /// Numerical guarantees.
181    #[inline]
182    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
183        self.sku.precision_guarantee
184    }
185
186    /// Launch.
187    pub fn run(
188        &self,
189        stream: &Stream,
190        _workspace: Workspace<'_>,
191        args: QuantizePerChannelArgs<'_, TIn, TOut>,
192    ) -> Result<()> {
193        self.can_implement(&args)?;
194        let numel = args.output.numel();
195        if numel == 0 {
196            return Ok(());
197        }
198        let x_ptr = args.input.data.as_raw().0 as *const c_void;
199        let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
200        let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
201        let q_ptr = args.output.data.as_raw().0 as *mut c_void;
202        let stream_ptr = stream.as_raw() as *mut c_void;
203        let shape4 = self.desc.shape.as_ptr();
204        let axis = self.desc.axis as i32;
205        let qmin = self.desc.q_min;
206        let qmax = self.desc.q_max;
207
208        let status = match (TIn::KIND, TOut::KIND) {
209            (ElementKind::F32, ElementKind::S8) => unsafe {
210                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_f32_s8_run(
211                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
212                    core::ptr::null_mut(), 0, stream_ptr,
213                )
214            },
215            (ElementKind::F32, ElementKind::U8) => unsafe {
216                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_f32_u8_run(
217                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
218                    core::ptr::null_mut(), 0, stream_ptr,
219                )
220            },
221            (ElementKind::F16, ElementKind::S8) => unsafe {
222                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_f16_s8_run(
223                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
224                    core::ptr::null_mut(), 0, stream_ptr,
225                )
226            },
227            (ElementKind::F16, ElementKind::U8) => unsafe {
228                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_f16_u8_run(
229                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
230                    core::ptr::null_mut(), 0, stream_ptr,
231                )
232            },
233            (ElementKind::Bf16, ElementKind::S8) => unsafe {
234                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_bf16_s8_run(
235                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
236                    core::ptr::null_mut(), 0, stream_ptr,
237                )
238            },
239            (ElementKind::Bf16, ElementKind::U8) => unsafe {
240                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_bf16_u8_run(
241                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
242                    core::ptr::null_mut(), 0, stream_ptr,
243                )
244            },
245            (ElementKind::F64, ElementKind::S8) => unsafe {
246                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_f64_s8_run(
247                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
248                    core::ptr::null_mut(), 0, stream_ptr,
249                )
250            },
251            (ElementKind::F64, ElementKind::U8) => unsafe {
252                baracuda_kernels_sys::baracuda_kernels_quantize_per_channel_f64_u8_run(
253                    numel, shape4, axis, qmin, qmax, x_ptr, sc_ptr, zp_ptr, q_ptr,
254                    core::ptr::null_mut(), 0, stream_ptr,
255                )
256            },
257            _ => return Err(Error::Unsupported(
258                "QuantizePerChannelPlan: unsupported (TIn, TOut) at run()",
259            )),
260        };
261        map_status(status)
262    }
263}