1use core::ffi::c_void;
8use core::marker::PhantomData;
9
10use baracuda_cutlass::{Error, Result};
11use baracuda_driver::Stream;
12use baracuda_kernels_types::{
13 Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
14 TensorMut, TensorRef, Workspace,
15};
16
17use super::map_status;
18use super::per_group::build_sku_group;
19use super::{validate_input_element, validate_output_element};
20
21#[derive(Copy, Clone, Debug)]
23pub struct DequantizePerGroupDescriptor {
24 pub outer_size: i32,
26 pub axis_size: i32,
28 pub group_size: i32,
30 pub input_element: ElementKind,
32 pub output_element: ElementKind,
34}
35
36impl DequantizePerGroupDescriptor {
37 #[inline]
39 pub fn num_groups(&self) -> i32 {
40 if self.group_size <= 0 {
41 0
42 } else {
43 self.axis_size / self.group_size
44 }
45 }
46}
47
48pub struct DequantizePerGroupArgs<'a, TIn: Element, TOut: IntElement> {
50 pub input: TensorRef<'a, TOut, 2>,
52 pub scale: TensorRef<'a, TIn, 2>,
54 pub zero_point: TensorRef<'a, i32, 2>,
56 pub output: TensorMut<'a, TIn, 2>,
58}
59
60pub struct DequantizePerGroupPlan<TIn: Element, TOut: IntElement> {
79 desc: DequantizePerGroupDescriptor,
80 sku: KernelSku,
81 _marker: PhantomData<(TIn, TOut)>,
82}
83
84impl<TIn: Element, TOut: IntElement> DequantizePerGroupPlan<TIn, TOut> {
85 pub fn select(
87 _stream: &Stream,
88 desc: &DequantizePerGroupDescriptor,
89 _pref: PlanPreference,
90 ) -> Result<Self> {
91 if desc.input_element != TIn::KIND {
92 return Err(Error::Unsupported(
93 "DequantizePerGroupPlan: descriptor input_element != TIn",
94 ));
95 }
96 if desc.output_element != TOut::KIND {
97 return Err(Error::Unsupported(
98 "DequantizePerGroupPlan: descriptor output_element != TOut",
99 ));
100 }
101 validate_input_element(TIn::KIND, "DequantizePerGroupPlan: unsupported TIn dtype")?;
102 validate_output_element(TOut::KIND, "DequantizePerGroupPlan: unsupported TOut dtype")?;
103 if desc.outer_size < 0 || desc.axis_size < 0 {
104 return Err(Error::InvalidProblem(
105 "DequantizePerGroupPlan: outer_size and axis_size must be non-negative",
106 ));
107 }
108 if desc.group_size <= 0 {
109 return Err(Error::InvalidProblem(
110 "DequantizePerGroupPlan: group_size must be > 0",
111 ));
112 }
113 if desc.axis_size % desc.group_size != 0 {
114 return Err(Error::InvalidProblem(
115 "DequantizePerGroupPlan: axis_size must be a multiple of group_size",
116 ));
117 }
118 let sku = build_sku_group::<TIn, TOut>(QuantizeKind::DequantizePerGroup);
119 Ok(Self {
120 desc: *desc,
121 sku,
122 _marker: PhantomData,
123 })
124 }
125
126 pub fn can_implement(&self, args: &DequantizePerGroupArgs<'_, TIn, TOut>) -> Result<()> {
128 let expect_io = [self.desc.outer_size, self.desc.axis_size];
129 if args.input.shape != expect_io || args.output.shape != expect_io {
130 return Err(Error::InvalidProblem(
131 "DequantizePerGroupPlan: I/O tensor shape != [outer, axis_size]",
132 ));
133 }
134 let expect_sg = [self.desc.outer_size, self.desc.num_groups()];
135 if args.scale.shape != expect_sg || args.zero_point.shape != expect_sg {
136 return Err(Error::InvalidProblem(
137 "DequantizePerGroupPlan: scale / zp shape != [outer, num_groups]",
138 ));
139 }
140 Ok(())
141 }
142
143 #[inline]
145 pub fn workspace_size(&self) -> usize {
146 0
147 }
148
149 #[inline]
151 pub fn sku(&self) -> KernelSku {
152 self.sku
153 }
154
155 #[inline]
157 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
158 self.sku.precision_guarantee
159 }
160
161 pub fn run(
163 &self,
164 stream: &Stream,
165 _workspace: Workspace<'_>,
166 args: DequantizePerGroupArgs<'_, TIn, TOut>,
167 ) -> Result<()> {
168 self.can_implement(&args)?;
169 let total = (self.desc.outer_size as i64) * (self.desc.axis_size as i64);
170 if total == 0 {
171 return Ok(());
172 }
173 let in_ptr = args.input.data.as_raw().0 as *const c_void;
174 let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
175 let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
176 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
177 let stream_ptr = stream.as_raw() as *mut c_void;
178 let (outer, axis, g) = (
179 self.desc.outer_size,
180 self.desc.axis_size,
181 self.desc.group_size,
182 );
183 let status = match (TIn::KIND, TOut::KIND) {
184 (ElementKind::F32, ElementKind::S8) => unsafe {
185 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_f32_s8_run(
186 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
187 core::ptr::null_mut(), 0, stream_ptr,
188 )
189 },
190 (ElementKind::F32, ElementKind::U8) => unsafe {
191 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_f32_u8_run(
192 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
193 core::ptr::null_mut(), 0, stream_ptr,
194 )
195 },
196 (ElementKind::F64, ElementKind::S8) => unsafe {
197 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_f64_s8_run(
198 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
199 core::ptr::null_mut(), 0, stream_ptr,
200 )
201 },
202 (ElementKind::F64, ElementKind::U8) => unsafe {
203 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_f64_u8_run(
204 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
205 core::ptr::null_mut(), 0, stream_ptr,
206 )
207 },
208 (ElementKind::F16, ElementKind::S8) => unsafe {
209 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_f16_s8_run(
210 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
211 core::ptr::null_mut(), 0, stream_ptr,
212 )
213 },
214 (ElementKind::F16, ElementKind::U8) => unsafe {
215 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_f16_u8_run(
216 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
217 core::ptr::null_mut(), 0, stream_ptr,
218 )
219 },
220 (ElementKind::Bf16, ElementKind::S8) => unsafe {
221 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_bf16_s8_run(
222 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
223 core::ptr::null_mut(), 0, stream_ptr,
224 )
225 },
226 (ElementKind::Bf16, ElementKind::U8) => unsafe {
227 baracuda_kernels_sys::baracuda_kernels_dequantize_per_group_bf16_u8_run(
228 outer, axis, g, in_ptr, sc_ptr, zp_ptr, out_ptr,
229 core::ptr::null_mut(), 0, stream_ptr,
230 )
231 },
232 _ => {
233 return Err(Error::Unsupported(
234 "DequantizePerGroupPlan::run unsupported (TIn, TOut)",
235 ))
236 }
237 };
238 map_status(status)
239 }
240}