1use core::ffi::c_void;
23use core::marker::PhantomData;
24
25use baracuda_cutlass::{Error, Result};
26use baracuda_driver::Stream;
27use baracuda_kernels_types::{
28 ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
29 PlanPreference, PrecisionGuarantee, QuantizeKind, ScalarType, TensorMut, TensorRef, Workspace,
30};
31
32use super::{map_status, validate_input_element, validate_output_element};
33
34#[derive(Copy, Clone, Debug)]
36pub struct QuantizePerTensorDescriptor {
37 pub numel: i32,
39 pub q_min: i32,
41 pub q_max: i32,
43 pub input_element: ElementKind,
45 pub output_element: ElementKind,
47}
48
49pub struct QuantizePerTensorArgs<'a, TIn: Element, TOut: IntElement> {
55 pub input: TensorRef<'a, TIn, 1>,
57 pub scale: <TIn as Element>::Scalar,
60 pub zero_point: i32,
62 pub output: TensorMut<'a, TOut, 1>,
64}
65
66pub struct QuantizePerTensorPlan<TIn: Element, TOut: IntElement> {
93 desc: QuantizePerTensorDescriptor,
94 sku: KernelSku,
95 _marker: PhantomData<(TIn, TOut)>,
96}
97
98impl<TIn: Element, TOut: IntElement> QuantizePerTensorPlan<TIn, TOut> {
99 pub fn select(
101 _stream: &Stream,
102 desc: &QuantizePerTensorDescriptor,
103 _pref: PlanPreference,
104 ) -> Result<Self> {
105 if desc.input_element != TIn::KIND {
106 return Err(Error::Unsupported(
107 "QuantizePerTensorPlan: descriptor input_element != type parameter TIn",
108 ));
109 }
110 if desc.output_element != TOut::KIND {
111 return Err(Error::Unsupported(
112 "QuantizePerTensorPlan: descriptor output_element != type parameter TOut",
113 ));
114 }
115 validate_input_element(TIn::KIND, "QuantizePerTensorPlan: unsupported TIn dtype")?;
116 validate_output_element(TOut::KIND, "QuantizePerTensorPlan: unsupported TOut dtype")?;
117 if desc.numel < 0 {
118 return Err(Error::InvalidProblem(
119 "QuantizePerTensorPlan: numel must be non-negative",
120 ));
121 }
122 if desc.q_max < desc.q_min {
123 return Err(Error::InvalidProblem(
124 "QuantizePerTensorPlan: q_max < q_min",
125 ));
126 }
127 let sku = build_sku::<TIn, TOut>(QuantizeKind::PerTensor);
128 Ok(Self {
129 desc: *desc,
130 sku,
131 _marker: PhantomData,
132 })
133 }
134
135 pub fn can_implement(&self, args: &QuantizePerTensorArgs<'_, TIn, TOut>) -> Result<()> {
137 if args.input.shape != [self.desc.numel] {
138 return Err(Error::InvalidProblem(
139 "QuantizePerTensorPlan: input shape != [numel]",
140 ));
141 }
142 if args.output.shape != [self.desc.numel] {
143 return Err(Error::InvalidProblem(
144 "QuantizePerTensorPlan: output shape != [numel]",
145 ));
146 }
147 Ok(())
148 }
149
150 #[inline]
152 pub fn workspace_size(&self) -> usize {
153 0
154 }
155
156 #[inline]
158 pub fn sku(&self) -> KernelSku {
159 self.sku
160 }
161
162 #[inline]
164 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
165 self.sku.precision_guarantee
166 }
167
168 pub fn run(
170 &self,
171 stream: &Stream,
172 _workspace: Workspace<'_>,
173 args: QuantizePerTensorArgs<'_, TIn, TOut>,
174 ) -> Result<()> {
175 self.can_implement(&args)?;
176 let numel = self.desc.numel as i64;
177 if numel == 0 {
178 return Ok(());
179 }
180 let x_ptr = args.input.data.as_raw().0 as *const c_void;
181 let q_ptr = args.output.data.as_raw().0 as *mut c_void;
182 let stream_ptr = stream.as_raw() as *mut c_void;
183 let zp = args.zero_point;
184 let qmin = self.desc.q_min;
185 let qmax = self.desc.q_max;
186
187 let status = if <TIn::Scalar as ScalarType>::IS_F64 {
188 let scale_f64 = args.scale.to_f64();
190 match TOut::KIND {
191 ElementKind::S8 => unsafe {
192 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_f64_s8_run(
193 numel, scale_f64, zp, qmin, qmax,
194 x_ptr, q_ptr,
195 core::ptr::null_mut(), 0, stream_ptr,
196 )
197 },
198 ElementKind::U8 => unsafe {
199 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_f64_u8_run(
200 numel, scale_f64, zp, qmin, qmax,
201 x_ptr, q_ptr,
202 core::ptr::null_mut(), 0, stream_ptr,
203 )
204 },
205 _ => return Err(Error::Unsupported(
206 "QuantizePerTensorPlan: unsupported TOut at run() (select should have caught)",
207 )),
208 }
209 } else {
210 let scale_f32 = args.scale.to_f32();
212 match (TIn::KIND, TOut::KIND) {
213 (ElementKind::F32, ElementKind::S8) => unsafe {
214 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_f32_s8_run(
215 numel, scale_f32, zp, qmin, qmax,
216 x_ptr, q_ptr,
217 core::ptr::null_mut(), 0, stream_ptr,
218 )
219 },
220 (ElementKind::F32, ElementKind::U8) => unsafe {
221 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_f32_u8_run(
222 numel, scale_f32, zp, qmin, qmax,
223 x_ptr, q_ptr,
224 core::ptr::null_mut(), 0, stream_ptr,
225 )
226 },
227 (ElementKind::F16, ElementKind::S8) => unsafe {
228 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_f16_s8_run(
229 numel, scale_f32, zp, qmin, qmax,
230 x_ptr, q_ptr,
231 core::ptr::null_mut(), 0, stream_ptr,
232 )
233 },
234 (ElementKind::F16, ElementKind::U8) => unsafe {
235 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_f16_u8_run(
236 numel, scale_f32, zp, qmin, qmax,
237 x_ptr, q_ptr,
238 core::ptr::null_mut(), 0, stream_ptr,
239 )
240 },
241 (ElementKind::Bf16, ElementKind::S8) => unsafe {
242 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_bf16_s8_run(
243 numel, scale_f32, zp, qmin, qmax,
244 x_ptr, q_ptr,
245 core::ptr::null_mut(), 0, stream_ptr,
246 )
247 },
248 (ElementKind::Bf16, ElementKind::U8) => unsafe {
249 baracuda_kernels_sys::baracuda_kernels_quantize_per_tensor_bf16_u8_run(
250 numel, scale_f32, zp, qmin, qmax,
251 x_ptr, q_ptr,
252 core::ptr::null_mut(), 0, stream_ptr,
253 )
254 },
255 _ => return Err(Error::Unsupported(
256 "QuantizePerTensorPlan: unsupported (TIn, TOut) at run()",
257 )),
258 }
259 };
260 map_status(status)
261 }
262}
263
264pub(crate) fn build_sku<TIn: Element, TOut: IntElement>(op: QuantizeKind) -> KernelSku {
266 let precision_guarantee = PrecisionGuarantee {
267 math_precision: if TIn::KIND == ElementKind::F64 {
268 MathPrecision::F64
269 } else {
270 MathPrecision::F32
271 },
272 accumulator: ElementKind::F32,
273 bit_stable_on_same_hardware: true,
274 deterministic: true,
275 };
276 KernelSku {
277 category: OpCategory::Quantization,
278 op: op as u16,
279 element: TIn::KIND,
280 aux_element: Some(TOut::KIND),
281 layout: None,
282 epilogue: None,
283 arch: ArchSku::Sm80,
284 backend: BackendKind::Bespoke,
285 precision_guarantee,
286 }
287}