Skip to main content

baracuda_kernels/quantize/
fake_quantize.rs

1//! `fake_quantize` forward plan — per-tensor, FP roundtrip.
2//!
3//! `y = scale * (clamp(round(x / scale) + zp, q_min, q_max) - zp)`. The
4//! roundtrip of `quantize` followed by `dequantize`, executed entirely
5//! in FP — produces a lossy FP output of the same dtype as the input.
6//! No integer storage involved. PyTorch
7//! `torch.fake_quantize_per_tensor_affine`.
8//!
9//! The descriptor carries the int range (`q_min` / `q_max`) but not an
10//! output dtype — the int range is what defines the lossy precision
11//! step. Caller picks the int range matching their downstream `qint`
12//! storage (`[-128, 127]` for s8, `[0, 255]` for u8) but no `TOut` plan
13//! parameter is needed.
14
15use core::ffi::c_void;
16use core::marker::PhantomData;
17
18use baracuda_cutlass::{Error, Result};
19use baracuda_driver::Stream;
20use baracuda_kernels_types::{
21    ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
22    PlanPreference, PrecisionGuarantee, QuantizeKind, ScalarType, TensorMut, TensorRef, Workspace,
23};
24
25use super::{map_status, validate_input_element};
26
27/// Descriptor for a `fake_quantize` forward op.
28#[derive(Copy, Clone, Debug)]
29pub struct FakeQuantizeDescriptor {
30    /// Total element count.
31    pub numel: i32,
32    /// Lower clip bound.
33    pub q_min: i32,
34    /// Upper clip bound.
35    pub q_max: i32,
36    /// Input FP element kind.
37    pub input_element: ElementKind,
38}
39
40/// Args bundle for a `fake_quantize` forward launch.
41pub struct FakeQuantizeArgs<'a, TIn: Element> {
42    /// Input FP tensor `[numel]`.
43    pub input: TensorRef<'a, TIn, 1>,
44    /// Scalar scale (FP).
45    pub scale: <TIn as Element>::Scalar,
46    /// Scalar zero point.
47    pub zero_point: i32,
48    /// Output FP tensor `[numel]` — same dtype as input.
49    pub output: TensorMut<'a, TIn, 1>,
50}
51
52/// `fake_quantize` forward plan.
53///
54/// `y = scale * (clamp(round(x / scale) + zp, q_min, q_max) - zp)`.
55/// The FP-only roundtrip of quantize-then-dequantize, no integer
56/// storage (PyTorch `torch.fake_quantize_per_tensor_affine`).
57///
58/// **When to use**: QAT (quantization-aware training) — produces a
59/// lossy FP output of the same dtype as the input. Pair with
60/// [`FakeQuantizeBackwardPlan`](crate::FakeQuantizeBackwardPlan) for
61/// STE autograd.
62///
63/// **Dtypes**: `{f32, f64, f16, bf16}` in and out (same dtype).
64///
65/// **Shape limits**: flat `[numel]`.
66///
67/// **Workspace**: none.
68///
69/// **Precision guarantee**: deterministic, bit-stable. Round-ties-
70/// even matches FW quantize.
71pub struct FakeQuantizePlan<TIn: Element> {
72    desc: FakeQuantizeDescriptor,
73    sku: KernelSku,
74    _marker: PhantomData<TIn>,
75}
76
77impl<TIn: Element> FakeQuantizePlan<TIn> {
78    /// Pick a kernel.
79    pub fn select(
80        _stream: &Stream,
81        desc: &FakeQuantizeDescriptor,
82        _pref: PlanPreference,
83    ) -> Result<Self> {
84        if desc.input_element != TIn::KIND {
85            return Err(Error::Unsupported(
86                "FakeQuantizePlan: descriptor input_element != TIn",
87            ));
88        }
89        validate_input_element(TIn::KIND, "FakeQuantizePlan: unsupported TIn dtype")?;
90        if desc.numel < 0 {
91            return Err(Error::InvalidProblem(
92                "FakeQuantizePlan: numel must be non-negative",
93            ));
94        }
95        if desc.q_max < desc.q_min {
96            return Err(Error::InvalidProblem("FakeQuantizePlan: q_max < q_min"));
97        }
98        let sku = build_sku::<TIn>(QuantizeKind::FakeQuantize);
99        Ok(Self {
100            desc: *desc,
101            sku,
102            _marker: PhantomData,
103        })
104    }
105
106    /// Validate args.
107    pub fn can_implement(&self, args: &FakeQuantizeArgs<'_, TIn>) -> Result<()> {
108        let expected = [self.desc.numel];
109        if args.input.shape != expected || args.output.shape != expected {
110            return Err(Error::InvalidProblem(
111                "FakeQuantizePlan: tensor shape != [numel]",
112            ));
113        }
114        Ok(())
115    }
116
117    /// Workspace bytes.
118    #[inline]
119    pub fn workspace_size(&self) -> usize {
120        0
121    }
122
123    /// Identity.
124    #[inline]
125    pub fn sku(&self) -> KernelSku {
126        self.sku
127    }
128
129    /// Numerical guarantees.
130    #[inline]
131    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
132        self.sku.precision_guarantee
133    }
134
135    /// Launch.
136    pub fn run(
137        &self,
138        stream: &Stream,
139        _workspace: Workspace<'_>,
140        args: FakeQuantizeArgs<'_, TIn>,
141    ) -> Result<()> {
142        self.can_implement(&args)?;
143        let numel = self.desc.numel as i64;
144        if numel == 0 {
145            return Ok(());
146        }
147        let x_ptr = args.input.data.as_raw().0 as *const c_void;
148        let y_ptr = args.output.data.as_raw().0 as *mut c_void;
149        let stream_ptr = stream.as_raw() as *mut c_void;
150        let zp = args.zero_point;
151        let qmin = self.desc.q_min;
152        let qmax = self.desc.q_max;
153
154        let status = if <TIn::Scalar as ScalarType>::IS_F64 {
155            let scale_f64 = args.scale.to_f64();
156            unsafe {
157                baracuda_kernels_sys::baracuda_kernels_fake_quantize_f64_run(
158                    numel, scale_f64, zp, qmin, qmax, x_ptr, y_ptr,
159                    core::ptr::null_mut(), 0, stream_ptr,
160                )
161            }
162        } else {
163            let scale_f32 = args.scale.to_f32();
164            match TIn::KIND {
165                ElementKind::F32 => unsafe {
166                    baracuda_kernels_sys::baracuda_kernels_fake_quantize_f32_run(
167                        numel, scale_f32, zp, qmin, qmax, x_ptr, y_ptr,
168                        core::ptr::null_mut(), 0, stream_ptr,
169                    )
170                },
171                ElementKind::F16 => unsafe {
172                    baracuda_kernels_sys::baracuda_kernels_fake_quantize_f16_run(
173                        numel, scale_f32, zp, qmin, qmax, x_ptr, y_ptr,
174                        core::ptr::null_mut(), 0, stream_ptr,
175                    )
176                },
177                ElementKind::Bf16 => unsafe {
178                    baracuda_kernels_sys::baracuda_kernels_fake_quantize_bf16_run(
179                        numel, scale_f32, zp, qmin, qmax, x_ptr, y_ptr,
180                        core::ptr::null_mut(), 0, stream_ptr,
181                    )
182                },
183                _ => return Err(Error::Unsupported(
184                    "FakeQuantizePlan: unsupported TIn at run()",
185                )),
186            }
187        };
188        map_status(status)
189    }
190}
191
192/// Build the [`KernelSku`] for a fake-quantize-family plan. Sibling of
193/// [`super::per_tensor::build_sku`]; no TOut surfaces in the SKU because
194/// fake_quantize stays in FP space.
195pub(crate) fn build_sku<TIn: Element>(op: QuantizeKind) -> KernelSku {
196    let precision_guarantee = PrecisionGuarantee {
197        math_precision: if TIn::KIND == ElementKind::F64 {
198            MathPrecision::F64
199        } else {
200            MathPrecision::F32
201        },
202        accumulator: ElementKind::F32,
203        bit_stable_on_same_hardware: true,
204        deterministic: true,
205    };
206    KernelSku {
207        category: OpCategory::Quantization,
208        op: op as u16,
209        element: TIn::KIND,
210        aux_element: None,
211        layout: None,
212        epilogue: None,
213        arch: ArchSku::Sm80,
214        backend: BackendKind::Bespoke,
215        precision_guarantee,
216    }
217}