baracuda_kernels/quantize/
dequantize_per_token_backward.rs1use core::ffi::c_void;
13use core::marker::PhantomData;
14
15use baracuda_cutlass::{Error, Result};
16use baracuda_driver::Stream;
17use baracuda_kernels_types::{
18 Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
19 TensorMut, TensorRef, Workspace,
20};
21
22use super::map_status;
23use super::per_token::build_sku;
24use super::validate_input_element;
25
26#[derive(Copy, Clone, Debug)]
28pub struct DequantizePerTokenBackwardDescriptor {
29 pub n: i32,
31 pub d: i32,
33}
34
35pub struct DequantizePerTokenBackwardArgs<'a, TIn: Element, TOut: IntElement> {
37 pub scale: TensorRef<'a, TIn, 1>,
39 pub d_output: TensorRef<'a, TIn, 2>,
41 pub d_input: TensorMut<'a, TIn, 2>,
44 pub _phantom: PhantomData<TOut>,
48}
49
50pub struct DequantizePerTokenBackwardPlan<TIn: Element, TOut: IntElement> {
67 desc: DequantizePerTokenBackwardDescriptor,
68 sku: KernelSku,
69 _marker: PhantomData<(TIn, TOut)>,
70}
71
72impl<TIn: Element, TOut: IntElement> DequantizePerTokenBackwardPlan<TIn, TOut> {
73 pub fn select(
75 _stream: &Stream,
76 desc: &DequantizePerTokenBackwardDescriptor,
77 _pref: PlanPreference,
78 ) -> Result<Self> {
79 validate_input_element(
80 TIn::KIND,
81 "DequantizePerTokenBackwardPlan: unsupported TIn dtype",
82 )?;
83 if !matches!(TOut::KIND, ElementKind::S8 | ElementKind::U8) {
84 return Err(Error::Unsupported(
85 "DequantizePerTokenBackwardPlan: TOut must be S8 or U8",
86 ));
87 }
88 if desc.n < 0 || desc.d < 0 {
89 return Err(Error::InvalidProblem(
90 "DequantizePerTokenBackwardPlan: n and d must be non-negative",
91 ));
92 }
93 let sku = build_sku::<TIn, TOut>(QuantizeKind::DequantizePerTokenBackward);
94 Ok(Self {
95 desc: *desc,
96 sku,
97 _marker: PhantomData,
98 })
99 }
100
101 pub fn can_implement(
103 &self,
104 args: &DequantizePerTokenBackwardArgs<'_, TIn, TOut>,
105 ) -> Result<()> {
106 let expect = [self.desc.n, self.desc.d];
107 if args.d_output.shape != expect || args.d_input.shape != expect {
108 return Err(Error::InvalidProblem(
109 "DequantizePerTokenBackwardPlan: tensor shape != [n, d]",
110 ));
111 }
112 if args.scale.shape != [self.desc.n] {
113 return Err(Error::InvalidProblem(
114 "DequantizePerTokenBackwardPlan: scale shape != [n]",
115 ));
116 }
117 Ok(())
118 }
119
120 #[inline]
122 pub fn workspace_size(&self) -> usize {
123 0
124 }
125
126 #[inline]
128 pub fn sku(&self) -> KernelSku {
129 self.sku
130 }
131
132 #[inline]
134 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
135 self.sku.precision_guarantee
136 }
137
138 pub fn run(
140 &self,
141 stream: &Stream,
142 _workspace: Workspace<'_>,
143 args: DequantizePerTokenBackwardArgs<'_, TIn, TOut>,
144 ) -> Result<()> {
145 self.can_implement(&args)?;
146 let total = (self.desc.n as i64) * (self.desc.d as i64);
147 if total == 0 {
148 return Ok(());
149 }
150 let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
151 let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
152 let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
153 let stream_ptr = stream.as_raw() as *mut c_void;
154
155 let status = match TIn::KIND {
156 ElementKind::F32 => unsafe {
157 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_f32_run(
158 self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
159 core::ptr::null_mut(), 0, stream_ptr,
160 )
161 },
162 ElementKind::F64 => unsafe {
163 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_f64_run(
164 self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
165 core::ptr::null_mut(), 0, stream_ptr,
166 )
167 },
168 ElementKind::F16 => unsafe {
169 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_f16_run(
170 self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
171 core::ptr::null_mut(), 0, stream_ptr,
172 )
173 },
174 ElementKind::Bf16 => unsafe {
175 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_backward_bf16_run(
176 self.desc.n, self.desc.d, dy_ptr, sc_ptr, dx_ptr,
177 core::ptr::null_mut(), 0, stream_ptr,
178 )
179 },
180 _ => {
181 return Err(Error::Unsupported(
182 "DequantizePerTokenBackwardPlan::run unsupported TIn dtype",
183 ))
184 }
185 };
186 map_status(status)
187 }
188}