baracuda_kernels/quantize/
per_group_backward.rs1use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12 Element, ElementKind, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut,
13 TensorRef, Workspace,
14};
15
16use super::map_status;
17use super::per_group::build_sku_group;
18use super::validate_input_element;
19
20#[derive(Copy, Clone, Debug)]
22pub struct QuantizePerGroupBackwardDescriptor {
23 pub outer_size: i32,
25 pub axis_size: i32,
27 pub group_size: i32,
29 pub q_min: i32,
31 pub q_max: i32,
33 pub input_element: ElementKind,
35}
36
37impl QuantizePerGroupBackwardDescriptor {
38 #[inline]
40 pub fn num_groups(&self) -> i32 {
41 if self.group_size <= 0 {
42 0
43 } else {
44 self.axis_size / self.group_size
45 }
46 }
47}
48
49pub struct QuantizePerGroupBackwardArgs<'a, TIn: Element> {
51 pub d_output: TensorRef<'a, TIn, 2>,
53 pub input: TensorRef<'a, TIn, 2>,
55 pub scale: TensorRef<'a, TIn, 2>,
57 pub zero_point: TensorRef<'a, i32, 2>,
59 pub d_input: TensorMut<'a, TIn, 2>,
61}
62
63pub struct QuantizePerGroupBackwardPlan<TIn: Element> {
81 desc: QuantizePerGroupBackwardDescriptor,
82 sku: KernelSku,
83 _marker: PhantomData<TIn>,
84}
85
86impl<TIn: Element> QuantizePerGroupBackwardPlan<TIn> {
87 pub fn select(
89 _stream: &Stream,
90 desc: &QuantizePerGroupBackwardDescriptor,
91 _pref: PlanPreference,
92 ) -> Result<Self> {
93 if desc.input_element != TIn::KIND {
94 return Err(Error::Unsupported(
95 "QuantizePerGroupBackwardPlan: descriptor input_element != TIn",
96 ));
97 }
98 validate_input_element(
99 TIn::KIND,
100 "QuantizePerGroupBackwardPlan: unsupported TIn dtype",
101 )?;
102 if desc.outer_size < 0 || desc.axis_size < 0 {
103 return Err(Error::InvalidProblem(
104 "QuantizePerGroupBackwardPlan: outer_size and axis_size must be non-negative",
105 ));
106 }
107 if desc.group_size <= 0 {
108 return Err(Error::InvalidProblem(
109 "QuantizePerGroupBackwardPlan: group_size must be > 0",
110 ));
111 }
112 if desc.axis_size % desc.group_size != 0 {
113 return Err(Error::InvalidProblem(
114 "QuantizePerGroupBackwardPlan: axis_size must be a multiple of group_size",
115 ));
116 }
117 if desc.q_max < desc.q_min {
118 return Err(Error::InvalidProblem(
119 "QuantizePerGroupBackwardPlan: q_max < q_min",
120 ));
121 }
122 let sku =
123 build_sku_group::<TIn, baracuda_kernels_types::S8>(QuantizeKind::PerGroupBackward);
124 Ok(Self {
125 desc: *desc,
126 sku,
127 _marker: PhantomData,
128 })
129 }
130
131 pub fn can_implement(&self, args: &QuantizePerGroupBackwardArgs<'_, TIn>) -> Result<()> {
133 let expect_io = [self.desc.outer_size, self.desc.axis_size];
134 if args.d_output.shape != expect_io
135 || args.input.shape != expect_io
136 || args.d_input.shape != expect_io
137 {
138 return Err(Error::InvalidProblem(
139 "QuantizePerGroupBackwardPlan: I/O tensor shape != [outer, axis_size]",
140 ));
141 }
142 let expect_sg = [self.desc.outer_size, self.desc.num_groups()];
143 if args.scale.shape != expect_sg || args.zero_point.shape != expect_sg {
144 return Err(Error::InvalidProblem(
145 "QuantizePerGroupBackwardPlan: scale / zp shape != [outer, num_groups]",
146 ));
147 }
148 Ok(())
149 }
150
151 #[inline]
153 pub fn workspace_size(&self) -> usize {
154 0
155 }
156
157 #[inline]
159 pub fn sku(&self) -> KernelSku {
160 self.sku
161 }
162
163 #[inline]
165 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
166 self.sku.precision_guarantee
167 }
168
169 pub fn run(
171 &self,
172 stream: &Stream,
173 _workspace: Workspace<'_>,
174 args: QuantizePerGroupBackwardArgs<'_, TIn>,
175 ) -> Result<()> {
176 self.can_implement(&args)?;
177 let total = (self.desc.outer_size as i64) * (self.desc.axis_size as i64);
178 if total == 0 {
179 return Ok(());
180 }
181 let dy_ptr = args.d_output.data.as_raw().0 as *const c_void;
182 let x_ptr = args.input.data.as_raw().0 as *const c_void;
183 let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
184 let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
185 let dx_ptr = args.d_input.data.as_raw().0 as *mut c_void;
186 let stream_ptr = stream.as_raw() as *mut c_void;
187 let (outer, axis, g, qmin, qmax) = (
188 self.desc.outer_size,
189 self.desc.axis_size,
190 self.desc.group_size,
191 self.desc.q_min,
192 self.desc.q_max,
193 );
194
195 let status = match TIn::KIND {
196 ElementKind::F32 => unsafe {
197 baracuda_kernels_sys::baracuda_kernels_quantize_per_group_backward_f32_run(
198 outer, axis, g, qmin, qmax,
199 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
200 core::ptr::null_mut(), 0, stream_ptr,
201 )
202 },
203 ElementKind::F64 => unsafe {
204 baracuda_kernels_sys::baracuda_kernels_quantize_per_group_backward_f64_run(
205 outer, axis, g, qmin, qmax,
206 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
207 core::ptr::null_mut(), 0, stream_ptr,
208 )
209 },
210 ElementKind::F16 => unsafe {
211 baracuda_kernels_sys::baracuda_kernels_quantize_per_group_backward_f16_run(
212 outer, axis, g, qmin, qmax,
213 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
214 core::ptr::null_mut(), 0, stream_ptr,
215 )
216 },
217 ElementKind::Bf16 => unsafe {
218 baracuda_kernels_sys::baracuda_kernels_quantize_per_group_backward_bf16_run(
219 outer, axis, g, qmin, qmax,
220 dy_ptr, x_ptr, sc_ptr, zp_ptr, dx_ptr,
221 core::ptr::null_mut(), 0, stream_ptr,
222 )
223 },
224 _ => {
225 return Err(Error::Unsupported(
226 "QuantizePerGroupBackwardPlan::run unsupported TIn dtype",
227 ))
228 }
229 };
230 map_status(status)
231 }
232}