Skip to main content

baracuda_kernels/quantize/
per_token_backward.rs

1//! `quantize_per_token` backward plan (Straight-Through Estimator).
2//!
3//! `dx[n, d] = (dy[n, d] / scale[n]) * 1[qmin < round(x/s)+zp < qmax]`.
4//! The in-range mask is recomputed in the kernel from the saved input
5//! tensor — no separate "mask" output from FW.
6
7use core::ffi::c_void;
8use core::marker::PhantomData;
9
10use baracuda_cutlass::{Error, Result};
11use baracuda_driver::Stream;
12use baracuda_kernels_types::{
13    Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut,
14    TensorRef, Workspace,
15};
16
17use super::map_status;
18use super::per_token::build_sku;
19use super::validate_input_element;
20
21/// Descriptor for a `quantize_per_token` backward op.
22#[derive(Copy, Clone, Debug)]
23pub struct QuantizePerTokenBackwardDescriptor {
24    /// Number of token rows.
25    pub n: i32,
26    /// Feature dim.
27    pub d: i32,
28    /// Lower clip bound (FW's qmin).
29    pub q_min: i32,
30    /// Upper clip bound (FW's qmax).
31    pub q_max: i32,
32    /// Input FP element kind.
33    pub input_element: ElementKind,
34}
35
36/// Args bundle for the per-token BW launch.
37pub struct QuantizePerTokenBackwardArgs<'a, TIn: Element> {
38    /// Upstream gradient `[N, D]`.
39    pub d_output: TensorRef<'a, TIn, 2>,
40    /// Saved input from FW (needed for the in-range mask) `[N, D]`.
41    pub input: TensorRef<'a, TIn, 2>,
42    /// Saved scale `[N]`.
43    pub scale: TensorRef<'a, TIn, 1>,
44    /// Saved zero-point `[N]`.
45    pub zero_point: TensorRef<'a, i32, 1>,
46    /// Output `dx` `[N, D]`.
47    pub d_input: TensorMut<'a, TIn, 2>,
48}
49
50/// `quantize_per_token` backward plan.
51///
52/// STE: `dx[n, d] = (dy[n, d] / scale[n]) * 1[qmin ≤ round(x[n,d]/scale[n])+zp[n] ≤ qmax]`.
53/// Mask recomputed in-kernel.
54///
55/// **When to use**: backward for
56/// [`QuantizePerTokenPlan`](crate::QuantizePerTokenPlan). Caller
57/// retains FW input, scale, zero_point.
58///
59/// **Dtypes**: gradients in `{f32, f64, f16, bf16}`; no int output —
60/// hence the single-type-parameter signature.
61///
62/// **Shape limits**: rank-2 `[N, D]`; per-row `scale` and `zp` of
63/// length `N`.
64///
65/// **Workspace**: none.
66///
67/// **Precision guarantee**: deterministic, bit-stable.
68pub struct QuantizePerTokenBackwardPlan<TIn: Element> {
69    desc: QuantizePerTokenBackwardDescriptor,
70    sku: KernelSku,
71    _marker: PhantomData<TIn>,
72}
73
74// Phantom 2nd type for the SKU build. We pick S8 as a stand-in for
75// `aux_element` slot — the BW kernel is TOut-agnostic (no integer
76// storage is touched), but the SKU expects a concrete tag.
77impl<TIn: Element> QuantizePerTokenBackwardPlan<TIn> {
78    /// Pick a kernel for `desc`.
79    pub fn select(
80        _stream: &Stream,
81        desc: &QuantizePerTokenBackwardDescriptor,
82        _pref: PlanPreference,
83    ) -> Result<Self> {
84        if desc.input_element != TIn::KIND {
85            return Err(Error::Unsupported(
86                "QuantizePerTokenBackwardPlan: descriptor input_element != type parameter TIn",
87            ));
88        }
89        validate_input_element(
90            TIn::KIND,
91            "QuantizePerTokenBackwardPlan: unsupported TIn dtype",
92        )?;
93        if desc.n < 0 || desc.d < 0 {
94            return Err(Error::InvalidProblem(
95                "QuantizePerTokenBackwardPlan: n and d must be non-negative",
96            ));
97        }
98        if desc.q_max < desc.q_min {
99            return Err(Error::InvalidProblem(
100                "QuantizePerTokenBackwardPlan: q_max < q_min",
101            ));
102        }
103        // SKU's aux_element slot reflects "the output int kind FW would
104        // have used". BW doesn't actually touch int storage, but
105        // selectors / telemetry treat the BW SKU as related to its FW
106        // peer — we publish S8 as the default.
107        let sku = build_sku::<TIn, baracuda_kernels_types::S8>(QuantizeKind::PerTokenBackward);
108        Ok(Self {
109            desc: *desc,
110            sku,
111            _marker: PhantomData,
112        })
113    }
114
115    /// Validate args.
116    pub fn can_implement(&self, args: &QuantizePerTokenBackwardArgs<'_, TIn>) -> Result<()> {
117        let expect = [self.desc.n, self.desc.d];
118        if args.d_output.shape != expect
119            || args.input.shape != expect
120            || args.d_input.shape != expect
121        {
122            return Err(Error::InvalidProblem(
123                "QuantizePerTokenBackwardPlan: tensor shape != [n, d]",
124            ));
125        }
126        if args.scale.shape != [self.desc.n] {
127            return Err(Error::InvalidProblem(
128                "QuantizePerTokenBackwardPlan: scale shape != [n]",
129            ));
130        }
131        if args.zero_point.shape != [self.desc.n] {
132            return Err(Error::InvalidProblem(
133                "QuantizePerTokenBackwardPlan: zero_point shape != [n]",
134            ));
135        }
136        Ok(())
137    }
138
139    /// Workspace bytes — none.
140    #[inline]
141    pub fn workspace_size(&self) -> usize {
142        0
143    }
144
145    /// Identity of the kernel.
146    #[inline]
147    pub fn sku(&self) -> KernelSku {
148        self.sku
149    }
150
151    /// Numerical guarantees.
152    #[inline]
153    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
154        self.sku.precision_guarantee
155    }
156
157    /// Launch.
158    pub fn run(
159        &self,
160        stream: &Stream,
161        _workspace: Workspace<'_>,
162        args: QuantizePerTokenBackwardArgs<'_, TIn>,
163    ) -> Result<()> {
164        self.can_implement(&args)?;
165        let total = (self.desc.n as i64) * (self.desc.d as i64);
166        if total == 0 {
167            return Ok(());
168        }
169        let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
170        let x_ptr = args.input.data.as_raw().0 as *const c_void;
171        let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
172        let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
173        let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
174        let stream_ptr = stream.as_raw() as *mut c_void;
175
176        let status = match TIn::KIND {
177            ElementKind::F32 => unsafe {
178                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_f32_run(
179                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
180                    dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
181                    core::ptr::null_mut(), 0, stream_ptr,
182                )
183            },
184            ElementKind::F64 => unsafe {
185                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_f64_run(
186                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
187                    dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
188                    core::ptr::null_mut(), 0, stream_ptr,
189                )
190            },
191            ElementKind::F16 => unsafe {
192                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_f16_run(
193                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
194                    dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
195                    core::ptr::null_mut(), 0, stream_ptr,
196                )
197            },
198            ElementKind::Bf16 => unsafe {
199                baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_bf16_run(
200                    self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
201                    dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
202                    core::ptr::null_mut(), 0, stream_ptr,
203                )
204            },
205            _ => {
206                return Err(Error::Unsupported(
207                    "QuantizePerTokenBackwardPlan::run reached unsupported TIn dtype",
208                ))
209            }
210        };
211        map_status(status)
212    }
213}