baracuda_kernels/quantize/
dequantize_per_tensor.rs1use core::ffi::c_void;
9use core::marker::PhantomData;
10
11use baracuda_cutlass::{Error, Result};
12use baracuda_driver::Stream;
13use baracuda_kernels_types::{
14 Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
15 ScalarType, TensorMut, TensorRef, Workspace,
16};
17
18use super::map_status;
19use super::per_tensor::build_sku;
20use super::{validate_input_element, validate_output_element};
21
22#[derive(Copy, Clone, Debug)]
24pub struct DequantizePerTensorDescriptor {
25 pub numel: i32,
27 pub input_element: ElementKind,
29 pub output_element: ElementKind,
31}
32
33pub struct DequantizePerTensorArgs<'a, TIn: Element, TOut: IntElement> {
35 pub input: TensorRef<'a, TOut, 1>,
37 pub scale: <TIn as Element>::Scalar,
39 pub zero_point: i32,
41 pub output: TensorMut<'a, TIn, 1>,
43}
44
45pub struct DequantizePerTensorPlan<TIn: Element, TOut: IntElement> {
63 desc: DequantizePerTensorDescriptor,
64 sku: KernelSku,
65 _marker: PhantomData<(TIn, TOut)>,
66}
67
68impl<TIn: Element, TOut: IntElement> DequantizePerTensorPlan<TIn, TOut> {
69 pub fn select(
71 _stream: &Stream,
72 desc: &DequantizePerTensorDescriptor,
73 _pref: PlanPreference,
74 ) -> Result<Self> {
75 if desc.input_element != TIn::KIND {
76 return Err(Error::Unsupported(
77 "DequantizePerTensorPlan: descriptor input_element != TIn",
78 ));
79 }
80 if desc.output_element != TOut::KIND {
81 return Err(Error::Unsupported(
82 "DequantizePerTensorPlan: descriptor output_element != TOut",
83 ));
84 }
85 validate_input_element(TIn::KIND, "DequantizePerTensorPlan: unsupported TIn dtype")?;
86 validate_output_element(TOut::KIND, "DequantizePerTensorPlan: unsupported TOut dtype")?;
87 if desc.numel < 0 {
88 return Err(Error::InvalidProblem(
89 "DequantizePerTensorPlan: numel must be non-negative",
90 ));
91 }
92 let sku = build_sku::<TIn, TOut>(QuantizeKind::DequantizePerTensor);
93 Ok(Self {
94 desc: *desc,
95 sku,
96 _marker: PhantomData,
97 })
98 }
99
100 pub fn can_implement(&self, args: &DequantizePerTensorArgs<'_, TIn, TOut>) -> Result<()> {
102 let expected = [self.desc.numel];
103 if args.input.shape != expected || args.output.shape != expected {
104 return Err(Error::InvalidProblem(
105 "DequantizePerTensorPlan: tensor shape != [numel]",
106 ));
107 }
108 Ok(())
109 }
110
111 #[inline]
113 pub fn workspace_size(&self) -> usize {
114 0
115 }
116
117 #[inline]
119 pub fn sku(&self) -> KernelSku {
120 self.sku
121 }
122
123 #[inline]
125 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
126 self.sku.precision_guarantee
127 }
128
129 pub fn run(
131 &self,
132 stream: &Stream,
133 _workspace: Workspace<'_>,
134 args: DequantizePerTensorArgs<'_, TIn, TOut>,
135 ) -> Result<()> {
136 self.can_implement(&args)?;
137 let numel = self.desc.numel as i64;
138 if numel == 0 {
139 return Ok(());
140 }
141 let q_ptr = args.input.data.as_raw().0 as *const c_void;
142 let x_ptr = args.output.data.as_raw().0 as *mut c_void;
143 let stream_ptr = stream.as_raw() as *mut c_void;
144 let zp = args.zero_point;
145
146 let status = if <TIn::Scalar as ScalarType>::IS_F64 {
147 let scale_f64 = args.scale.to_f64();
148 match TOut::KIND {
149 ElementKind::S8 => unsafe {
150 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f64_s8_run(
151 numel, scale_f64, zp, q_ptr, x_ptr,
152 core::ptr::null_mut(), 0, stream_ptr,
153 )
154 },
155 ElementKind::U8 => unsafe {
156 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f64_u8_run(
157 numel, scale_f64, zp, q_ptr, x_ptr,
158 core::ptr::null_mut(), 0, stream_ptr,
159 )
160 },
161 _ => return Err(Error::Unsupported(
162 "DequantizePerTensorPlan: unsupported TOut at run()",
163 )),
164 }
165 } else {
166 let scale_f32 = args.scale.to_f32();
167 match (TIn::KIND, TOut::KIND) {
168 (ElementKind::F32, ElementKind::S8) => unsafe {
169 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f32_s8_run(
170 numel, scale_f32, zp, q_ptr, x_ptr,
171 core::ptr::null_mut(), 0, stream_ptr,
172 )
173 },
174 (ElementKind::F32, ElementKind::U8) => unsafe {
175 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f32_u8_run(
176 numel, scale_f32, zp, q_ptr, x_ptr,
177 core::ptr::null_mut(), 0, stream_ptr,
178 )
179 },
180 (ElementKind::F16, ElementKind::S8) => unsafe {
181 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f16_s8_run(
182 numel, scale_f32, zp, q_ptr, x_ptr,
183 core::ptr::null_mut(), 0, stream_ptr,
184 )
185 },
186 (ElementKind::F16, ElementKind::U8) => unsafe {
187 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_f16_u8_run(
188 numel, scale_f32, zp, q_ptr, x_ptr,
189 core::ptr::null_mut(), 0, stream_ptr,
190 )
191 },
192 (ElementKind::Bf16, ElementKind::S8) => unsafe {
193 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_bf16_s8_run(
194 numel, scale_f32, zp, q_ptr, x_ptr,
195 core::ptr::null_mut(), 0, stream_ptr,
196 )
197 },
198 (ElementKind::Bf16, ElementKind::U8) => unsafe {
199 baracuda_kernels_sys::baracuda_kernels_dequantize_per_tensor_bf16_u8_run(
200 numel, scale_f32, zp, q_ptr, x_ptr,
201 core::ptr::null_mut(), 0, stream_ptr,
202 )
203 },
204 _ => return Err(Error::Unsupported(
205 "DequantizePerTensorPlan: unsupported (TIn, TOut) at run()",
206 )),
207 }
208 };
209 map_status(status)
210 }
211}