Skip to main content

baracuda_kernels/quantize/
fake_quantize_backward.rs

1//! `fake_quantize` backward plan via STE.
2//!
3//! `dx = dy * 1[qmin <= round(x/scale)+zp <= qmax]`. **No `1/scale`
4//! factor** — the FW's dequant-side multiply by `scale` exactly cancels
5//! the STE's `1/scale`. This is the key difference from
6//! [`super::QuantizePerTensorBackwardPlan`], which DOES include `1/scale`.
7//!
8//! The in-range mask is recomputed in the kernel from the saved input
9//! `x`. Callers must retain `x` from the FW pass.
10
11use core::ffi::c_void;
12use core::marker::PhantomData;
13
14use baracuda_cutlass::{Error, Result};
15use baracuda_driver::Stream;
16use baracuda_kernels_types::{
17    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind, ScalarType,
18    TensorMut, TensorRef, Workspace,
19};
20
21use super::fake_quantize::build_sku;
22use super::{map_status, validate_input_element};
23
24/// Descriptor for a `fake_quantize` backward op.
25#[derive(Copy, Clone, Debug)]
26pub struct FakeQuantizeBackwardDescriptor {
27    /// Total element count.
28    pub numel: i32,
29    /// Lower clip bound from FW.
30    pub q_min: i32,
31    /// Upper clip bound from FW.
32    pub q_max: i32,
33    /// Input FP element kind.
34    pub input_element: ElementKind,
35}
36
37/// Args bundle for a `fake_quantize` backward launch.
38pub struct FakeQuantizeBackwardArgs<'a, TIn: Element> {
39    /// Saved FW input `[numel]` — required for mask recomputation.
40    pub input: TensorRef<'a, TIn, 1>,
41    /// Scalar scale (same value used in FW).
42    pub scale: <TIn as Element>::Scalar,
43    /// Scalar zero point (same value used in FW).
44    pub zero_point: i32,
45    /// Upstream gradient `[numel]` in FP.
46    pub d_output: TensorRef<'a, TIn, 1>,
47    /// Output `[numel]` in FP.
48    pub d_input: TensorMut<'a, TIn, 1>,
49}
50
51/// `fake_quantize` backward plan.
52///
53/// STE: `dx = dy * 1[qmin ≤ round(x/scale)+zp ≤ qmax]`. **No
54/// `1/scale` factor** — the FW dequant-side multiply by `scale`
55/// exactly cancels the STE's `1/scale`. This is the key difference
56/// vs [`QuantizePerTensorBackwardPlan`](crate::QuantizePerTensorBackwardPlan),
57/// which DOES include `1/scale`. Mask recomputed in-kernel.
58///
59/// **When to use**: backward for
60/// [`FakeQuantizePlan`](crate::FakeQuantizePlan).
61///
62/// **Dtypes**: `{f32, f64, f16, bf16}`.
63///
64/// **Shape limits**: flat `[numel]`.
65///
66/// **Workspace**: none.
67///
68/// **Precision guarantee**: deterministic, bit-stable.
69pub struct FakeQuantizeBackwardPlan<TIn: Element> {
70    desc: FakeQuantizeBackwardDescriptor,
71    sku: KernelSku,
72    _marker: PhantomData<TIn>,
73}
74
75impl<TIn: Element> FakeQuantizeBackwardPlan<TIn> {
76    /// Pick a kernel.
77    pub fn select(
78        _stream: &Stream,
79        desc: &FakeQuantizeBackwardDescriptor,
80        _pref: PlanPreference,
81    ) -> Result<Self> {
82        if desc.input_element != TIn::KIND {
83            return Err(Error::Unsupported(
84                "FakeQuantizeBackwardPlan: descriptor input_element != TIn",
85            ));
86        }
87        validate_input_element(TIn::KIND, "FakeQuantizeBackwardPlan: unsupported TIn dtype")?;
88        if desc.numel < 0 {
89            return Err(Error::InvalidProblem(
90                "FakeQuantizeBackwardPlan: numel must be non-negative",
91            ));
92        }
93        let sku = build_sku::<TIn>(QuantizeKind::FakeQuantizeBackward);
94        Ok(Self {
95            desc: *desc,
96            sku,
97            _marker: PhantomData,
98        })
99    }
100
101    /// Validate args.
102    pub fn can_implement(&self, args: &FakeQuantizeBackwardArgs<'_, TIn>) -> Result<()> {
103        let expected = [self.desc.numel];
104        if args.input.shape != expected
105            || args.d_output.shape != expected
106            || args.d_input.shape != expected
107        {
108            return Err(Error::InvalidProblem(
109                "FakeQuantizeBackwardPlan: tensor shape != [numel]",
110            ));
111        }
112        Ok(())
113    }
114
115    /// Workspace bytes.
116    #[inline]
117    pub fn workspace_size(&self) -> usize {
118        0
119    }
120
121    /// Identity.
122    #[inline]
123    pub fn sku(&self) -> KernelSku {
124        self.sku
125    }
126
127    /// Numerical guarantees.
128    #[inline]
129    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
130        self.sku.precision_guarantee
131    }
132
133    /// Launch.
134    pub fn run(
135        &self,
136        stream: &Stream,
137        _workspace: Workspace<'_>,
138        args: FakeQuantizeBackwardArgs<'_, TIn>,
139    ) -> Result<()> {
140        self.can_implement(&args)?;
141        let numel = self.desc.numel as i64;
142        if numel == 0 {
143            return Ok(());
144        }
145        let x_ptr = args.input.data.as_raw().0 as *const c_void;
146        let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
147        let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
148        let stream_ptr = stream.as_raw() as *mut c_void;
149        let zp = args.zero_point;
150        let qmin = self.desc.q_min;
151        let qmax = self.desc.q_max;
152
153        let status = if <TIn::Scalar as ScalarType>::IS_F64 {
154            let scale_f64 = args.scale.to_f64();
155            unsafe {
156                baracuda_kernels_sys::baracuda_kernels_fake_quantize_backward_f64_run(
157                    numel, scale_f64, zp, qmin, qmax,
158                    x_ptr, dy_ptr, dx_ptr,
159                    core::ptr::null_mut(), 0, stream_ptr,
160                )
161            }
162        } else {
163            let scale_f32 = args.scale.to_f32();
164            match TIn::KIND {
165                ElementKind::F32 => unsafe {
166                    baracuda_kernels_sys::baracuda_kernels_fake_quantize_backward_f32_run(
167                        numel, scale_f32, zp, qmin, qmax,
168                        x_ptr, dy_ptr, dx_ptr,
169                        core::ptr::null_mut(), 0, stream_ptr,
170                    )
171                },
172                ElementKind::F16 => unsafe {
173                    baracuda_kernels_sys::baracuda_kernels_fake_quantize_backward_f16_run(
174                        numel, scale_f32, zp, qmin, qmax,
175                        x_ptr, dy_ptr, dx_ptr,
176                        core::ptr::null_mut(), 0, stream_ptr,
177                    )
178                },
179                ElementKind::Bf16 => unsafe {
180                    baracuda_kernels_sys::baracuda_kernels_fake_quantize_backward_bf16_run(
181                        numel, scale_f32, zp, qmin, qmax,
182                        x_ptr, dy_ptr, dx_ptr,
183                        core::ptr::null_mut(), 0, stream_ptr,
184                    )
185                },
186                _ => return Err(Error::Unsupported(
187                    "FakeQuantizeBackwardPlan: unsupported TIn at run()",
188                )),
189            }
190        };
191        map_status(status)
192    }
193}