baracuda_kernels/quantize/
per_token_backward.rs1use core::ffi::c_void;
8use core::marker::PhantomData;
9
10use baracuda_cutlass::{Error, Result};
11use baracuda_driver::Stream;
12use baracuda_kernels_types::{
13 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut,
14 TensorRef, Workspace,
15};
16
17use super::map_status;
18use super::per_token::build_sku;
19use super::validate_input_element;
20
21#[derive(Copy, Clone, Debug)]
23pub struct QuantizePerTokenBackwardDescriptor {
24 pub n: i32,
26 pub d: i32,
28 pub q_min: i32,
30 pub q_max: i32,
32 pub input_element: ElementKind,
34}
35
36pub struct QuantizePerTokenBackwardArgs<'a, TIn: Element> {
38 pub d_output: TensorRef<'a, TIn, 2>,
40 pub input: TensorRef<'a, TIn, 2>,
42 pub scale: TensorRef<'a, TIn, 1>,
44 pub zero_point: TensorRef<'a, i32, 1>,
46 pub d_input: TensorMut<'a, TIn, 2>,
48}
49
50pub struct QuantizePerTokenBackwardPlan<TIn: Element> {
69 desc: QuantizePerTokenBackwardDescriptor,
70 sku: KernelSku,
71 _marker: PhantomData<TIn>,
72}
73
74impl<TIn: Element> QuantizePerTokenBackwardPlan<TIn> {
78 pub fn select(
80 _stream: &Stream,
81 desc: &QuantizePerTokenBackwardDescriptor,
82 _pref: PlanPreference,
83 ) -> Result<Self> {
84 if desc.input_element != TIn::KIND {
85 return Err(Error::Unsupported(
86 "QuantizePerTokenBackwardPlan: descriptor input_element != type parameter TIn",
87 ));
88 }
89 validate_input_element(
90 TIn::KIND,
91 "QuantizePerTokenBackwardPlan: unsupported TIn dtype",
92 )?;
93 if desc.n < 0 || desc.d < 0 {
94 return Err(Error::InvalidProblem(
95 "QuantizePerTokenBackwardPlan: n and d must be non-negative",
96 ));
97 }
98 if desc.q_max < desc.q_min {
99 return Err(Error::InvalidProblem(
100 "QuantizePerTokenBackwardPlan: q_max < q_min",
101 ));
102 }
103 let sku = build_sku::<TIn, baracuda_kernels_types::S8>(QuantizeKind::PerTokenBackward);
108 Ok(Self {
109 desc: *desc,
110 sku,
111 _marker: PhantomData,
112 })
113 }
114
115 pub fn can_implement(&self, args: &QuantizePerTokenBackwardArgs<'_, TIn>) -> Result<()> {
117 let expect = [self.desc.n, self.desc.d];
118 if args.d_output.shape != expect
119 || args.input.shape != expect
120 || args.d_input.shape != expect
121 {
122 return Err(Error::InvalidProblem(
123 "QuantizePerTokenBackwardPlan: tensor shape != [n, d]",
124 ));
125 }
126 if args.scale.shape != [self.desc.n] {
127 return Err(Error::InvalidProblem(
128 "QuantizePerTokenBackwardPlan: scale shape != [n]",
129 ));
130 }
131 if args.zero_point.shape != [self.desc.n] {
132 return Err(Error::InvalidProblem(
133 "QuantizePerTokenBackwardPlan: zero_point shape != [n]",
134 ));
135 }
136 Ok(())
137 }
138
139 #[inline]
141 pub fn workspace_size(&self) -> usize {
142 0
143 }
144
145 #[inline]
147 pub fn sku(&self) -> KernelSku {
148 self.sku
149 }
150
151 #[inline]
153 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
154 self.sku.precision_guarantee
155 }
156
157 pub fn run(
159 &self,
160 stream: &Stream,
161 _workspace: Workspace<'_>,
162 args: QuantizePerTokenBackwardArgs<'_, TIn>,
163 ) -> Result<()> {
164 self.can_implement(&args)?;
165 let total = (self.desc.n as i64) * (self.desc.d as i64);
166 if total == 0 {
167 return Ok(());
168 }
169 let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
170 let x_ptr = args.input.data.as_raw().0 as *const c_void;
171 let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
172 let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
173 let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
174 let stream_ptr = stream.as_raw() as *mut c_void;
175
176 let status = match TIn::KIND {
177 ElementKind::F32 => unsafe {
178 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_f32_run(
179 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
180 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
181 core::ptr::null_mut(), 0, stream_ptr,
182 )
183 },
184 ElementKind::F64 => unsafe {
185 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_f64_run(
186 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
187 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
188 core::ptr::null_mut(), 0, stream_ptr,
189 )
190 },
191 ElementKind::F16 => unsafe {
192 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_f16_run(
193 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
194 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
195 core::ptr::null_mut(), 0, stream_ptr,
196 )
197 },
198 ElementKind::Bf16 => unsafe {
199 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_backward_bf16_run(
200 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
201 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
202 core::ptr::null_mut(), 0, stream_ptr,
203 )
204 },
205 _ => {
206 return Err(Error::Unsupported(
207 "QuantizePerTokenBackwardPlan::run reached unsupported TIn dtype",
208 ))
209 }
210 };
211 map_status(status)
212 }
213}