Skip to main content

baracuda_kernels/quantize/
dequantize_per_token_backward.rs

1//! `dequantize_per_token` backward plan — straight-through
2//! (`dq = dy * scale[n]`).
3//!
4//! Phase 8 Milestone 8.2. The plan is parameterized on both `TIn` (the
5//! FP element type the gradient flows in) and `TOut` (the int storage
6//! type the FW would have produced). `TOut` is unused by the BW kernel
7//! itself — the gradient continues in FP — but lives in the type
8//! signature for parity with [`super::DequantizePerTokenPlan`], so a
9//! caller can ascribe an autograd node's BW Plan with the same
10//! `(TIn, TOut)` tuple it used for FW.
11
12use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18    Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
19    TensorMut, TensorRef, Workspace,
20};
21
22use super::map_status;
23use super::per_token::build_sku;
24use super::validate_input_element;
25
26/// Descriptor for a `dequantize_per_token` backward op.
27#[derive(Copy, Clone, Debug)]
28pub struct DequantizePerTokenBackwardDescriptor {
29    /// Number of token rows.
30    pub n: i32,
31    /// Feature dim.
32    pub d: i32,
33}
34
35/// Args bundle for a `dequantize_per_token` backward launch.
36pub struct DequantizePerTokenBackwardArgs<'a, TIn: Element, TOut: IntElement> {
37    /// Per-row scale `[N]` in FP. Used to scale dy.
38    pub scale: TensorRef<'a, TIn, 1>,
39    /// Upstream gradient `[N, D]` in FP.
40    pub d_output: TensorRef<'a, TIn, 2>,
41    /// Output `[N, D]` in FP — same dtype as `d_output` (the q-input is
42    /// integer but the gradient continues in FP).
43    pub d_input: TensorMut<'a, TIn, 2>,
44    /// Phantom for the int output dtype carried by the plan type
45    /// parameter (needed so the plan can be parametric the same way the
46    /// sibling FW plan is).
47    pub _phantom: PhantomData<TOut>,
48}
49
50/// `dequantize_per_token` backward plan.
51///
52/// Straight-through linear:
53/// `dq_FP[n, d] = scale[n] * dy[n, d]`. Int input is non-differentiable;
54/// FP gradient flows through.
55///
56/// **When to use**: backward for
57/// [`DequantizePerTokenPlan`](crate::DequantizePerTokenPlan).
58///
59/// **Dtypes**: gradients in `{f32, f64, f16, bf16}`.
60///
61/// **Shape limits**: rank-2 `[N, D]`.
62///
63/// **Workspace**: none.
64///
65/// **Precision guarantee**: deterministic, bit-stable.
66pub struct DequantizePerTokenBackwardPlan<TIn: Element, TOut: IntElement> {
67    desc: DequantizePerTokenBackwardDescriptor,
68    sku: KernelSku,
69    _marker: PhantomData<(TIn, TOut)>,
70}
71
72impl<TIn: Element, TOut: IntElement> DequantizePerTokenBackwardPlan<TIn, TOut> {
73    /// Pick a kernel for `desc`.
74    pub fn select(
75        _stream: &Stream,
76        desc: &DequantizePerTokenBackwardDescriptor,
77        _pref: PlanPreference,
78    ) -> Result<Self> {
79        validate_input_element(
80            TIn::KIND,
81            "DequantizePerTokenBackwardPlan: unsupported TIn dtype",
82        )?;
83        if !matches!(TOut::KIND, ElementKind::S8 | ElementKind::U8) {
84            return Err(Error::Unsupported(
85                "DequantizePerTokenBackwardPlan: TOut must be S8 or U8",
86            ));
87        }
88        if desc.n < 0 || desc.d < 0 {
89            return Err(Error::InvalidProblem(
90                "DequantizePerTokenBackwardPlan: n and d must be non-negative",
91            ));
92        }
93        let sku = build_sku::<TIn, TOut>(QuantizeKind::DequantizePerTokenBackward);
94        Ok(Self {
95            desc: *desc,
96            sku,
97            _marker: PhantomData,
98        })
99    }
100
101    /// Validate args.
102    pub fn can_implement(
103        &self,
104        args: &DequantizePerTokenBackwardArgs<'_, TIn, TOut>,
105    ) -> Result<()> {
106        let expect = [self.desc.n, self.desc.d];
107        if args.d_output.shape != expect || args.d_input.shape != expect {
108            return Err(Error::InvalidProblem(
109                "DequantizePerTokenBackwardPlan: tensor shape != [n, d]",
110            ));
111        }
112        if args.scale.shape != [self.desc.n] {
113            return Err(Error::InvalidProblem(
114                "DequantizePerTokenBackwardPlan: scale shape != [n]",
115            ));
116        }
117        Ok(())
118    }
119
120    /// Workspace bytes — none.
121    #[inline]
122    pub fn workspace_size(&self) -> usize {
123        0
124    }
125
126    /// Identity.
127    #[inline]
128    pub fn sku(&self) -> KernelSku {
129        self.sku
130    }
131
132    /// Numerical guarantees.
133    #[inline]
134    pub fn precision_guarantee(&self) -> PrecisionGuarantee {
135        self.sku.precision_guarantee
136    }
137
138    /// Launch.
139    pub fn run(
140        &self,
141        stream: &Stream,
142        _workspace: Workspace<'_>,
143        args: DequantizePerTokenBackwardArgs<'_, TIn, TOut>,
144    ) -> Result<()> {
145        self.can_implement(&args)?;
146        let total = (self.desc.n as i64) * (self.desc.d as i64);
147        if total == 0 {
148            return Ok(());
149        }
150        let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
151        let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
152        let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
153        let stream_ptr = stream.as_raw() as *mut c_void;
154
155        let status = match TIn::KIND {
156            ElementKind::F32 => unsafe {
157                baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_f32_run(
158                    self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
159                    core::ptr::null_mut(), 0, stream_ptr,
160                )
161            },
162            ElementKind::F64 => unsafe {
163                baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_f64_run(
164                    self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
165                    core::ptr::null_mut(), 0, stream_ptr,
166                )
167            },
168            ElementKind::F16 => unsafe {
169                baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_f16_run(
170                    self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
171                    core::ptr::null_mut(), 0, stream_ptr,
172                )
173            },
174            ElementKind::Bf16 => unsafe {
175                baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_bf16_run(
176                    self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
177                    core::ptr::null_mut(), 0, stream_ptr,
178                )
179            },
180            _ => {
181                return Err(Error::Unsupported(
182                    "DequantizePerTokenBackwardPlan::run unsupported TIn dtype",
183                ))
184            }
185        };
186        map_status(status)
187    }
188}