baracuda_kernels/quantize/
dequantize_per_token.rs1use core::ffi::c_void;
7use core::marker::PhantomData;
8
9use baracuda_cutlass::{Error, Result};
10use baracuda_driver::Stream;
11use baracuda_kernels_types::{
12 Element, ElementKind, IntElement, KernelSku, PlanPreference, PrecisionGuarantee, QuantizeKind,
13 TensorMut, TensorRef, Workspace,
14};
15
16use super::map_status;
17use super::per_token::build_sku;
18use super::{validate_input_element, validate_output_element};
19
20#[derive(Copy, Clone, Debug)]
22pub struct DequantizePerTokenDescriptor {
23 pub n: i32,
25 pub d: i32,
27 pub input_element: ElementKind,
30 pub output_element: ElementKind,
32}
33
34pub struct DequantizePerTokenArgs<'a, TIn: Element, TOut: IntElement> {
41 pub input: TensorRef<'a, TOut, 2>,
43 pub scale: TensorRef<'a, TIn, 1>,
45 pub zero_point: TensorRef<'a, i32, 1>,
47 pub output: TensorMut<'a, TIn, 2>,
49}
50
51pub struct DequantizePerTokenPlan<TIn: Element, TOut: IntElement> {
69 desc: DequantizePerTokenDescriptor,
70 sku: KernelSku,
71 _marker: PhantomData<(TIn, TOut)>,
72}
73
74impl<TIn: Element, TOut: IntElement> DequantizePerTokenPlan<TIn, TOut> {
75 pub fn select(
77 _stream: &Stream,
78 desc: &DequantizePerTokenDescriptor,
79 _pref: PlanPreference,
80 ) -> Result<Self> {
81 if desc.input_element != TIn::KIND {
82 return Err(Error::Unsupported(
83 "DequantizePerTokenPlan: descriptor input_element != TIn",
84 ));
85 }
86 if desc.output_element != TOut::KIND {
87 return Err(Error::Unsupported(
88 "DequantizePerTokenPlan: descriptor output_element != TOut",
89 ));
90 }
91 validate_input_element(TIn::KIND, "DequantizePerTokenPlan: unsupported TIn dtype")?;
92 validate_output_element(TOut::KIND, "DequantizePerTokenPlan: unsupported TOut dtype")?;
93 if desc.n < 0 || desc.d < 0 {
94 return Err(Error::InvalidProblem(
95 "DequantizePerTokenPlan: n and d must be non-negative",
96 ));
97 }
98 let sku = build_sku::<TIn, TOut>(QuantizeKind::DequantizePerToken);
99 Ok(Self {
100 desc: *desc,
101 sku,
102 _marker: PhantomData,
103 })
104 }
105
106 pub fn can_implement(&self, args: &DequantizePerTokenArgs<'_, TIn, TOut>) -> Result<()> {
108 if args.input.shape != [self.desc.n, self.desc.d] {
109 return Err(Error::InvalidProblem(
110 "DequantizePerTokenPlan: input shape != [n, d]",
111 ));
112 }
113 if args.output.shape != [self.desc.n, self.desc.d] {
114 return Err(Error::InvalidProblem(
115 "DequantizePerTokenPlan: output shape != [n, d]",
116 ));
117 }
118 if args.scale.shape != [self.desc.n] {
119 return Err(Error::InvalidProblem(
120 "DequantizePerTokenPlan: scale shape != [n]",
121 ));
122 }
123 if args.zero_point.shape != [self.desc.n] {
124 return Err(Error::InvalidProblem(
125 "DequantizePerTokenPlan: zero_point shape != [n]",
126 ));
127 }
128 Ok(())
129 }
130
131 #[inline]
133 pub fn workspace_size(&self) -> usize {
134 0
135 }
136
137 #[inline]
139 pub fn sku(&self) -> KernelSku {
140 self.sku
141 }
142
143 #[inline]
145 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
146 self.sku.precision_guarantee
147 }
148
149 pub fn run(
151 &self,
152 stream: &Stream,
153 _workspace: Workspace<'_>,
154 args: DequantizePerTokenArgs<'_, TIn, TOut>,
155 ) -> Result<()> {
156 self.can_implement(&args)?;
157 let total = (self.desc.n as i64) * (self.desc.d as i64);
158 if total == 0 {
159 return Ok(());
160 }
161 let in_ptr = args.input.data.as_raw().0 as *const c_void;
162 let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
163 let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
164 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
165 let stream_ptr = stream.as_raw() as *mut c_void;
166
167 let status = match (TIn::KIND, TOut::KIND) {
168 (ElementKind::F32, ElementKind::S8) => unsafe {
169 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_f32_s8_run(
170 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
171 core::ptr::null_mut(), 0, stream_ptr,
172 )
173 },
174 (ElementKind::F32, ElementKind::U8) => unsafe {
175 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_f32_u8_run(
176 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
177 core::ptr::null_mut(), 0, stream_ptr,
178 )
179 },
180 (ElementKind::F64, ElementKind::S8) => unsafe {
181 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_f64_s8_run(
182 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
183 core::ptr::null_mut(), 0, stream_ptr,
184 )
185 },
186 (ElementKind::F64, ElementKind::U8) => unsafe {
187 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_f64_u8_run(
188 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
189 core::ptr::null_mut(), 0, stream_ptr,
190 )
191 },
192 (ElementKind::F16, ElementKind::S8) => unsafe {
193 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_f16_s8_run(
194 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
195 core::ptr::null_mut(), 0, stream_ptr,
196 )
197 },
198 (ElementKind::F16, ElementKind::U8) => unsafe {
199 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_f16_u8_run(
200 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
201 core::ptr::null_mut(), 0, stream_ptr,
202 )
203 },
204 (ElementKind::Bf16, ElementKind::S8) => unsafe {
205 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_bf16_s8_run(
206 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
207 core::ptr::null_mut(), 0, stream_ptr,
208 )
209 },
210 (ElementKind::Bf16, ElementKind::U8) => unsafe {
211 baracuda_kernels_sys::baracuda_kernels_dequantize_per_token_bf16_u8_run(
212 self.desc.n, self.desc.d, in_ptr, sc_ptr, zp_ptr, out_ptr,
213 core::ptr::null_mut(), 0, stream_ptr,
214 )
215 },
216 _ => {
217 return Err(Error::Unsupported(
218 "DequantizePerTokenPlan::run unsupported (TIn, TOut)",
219 ))
220 }
221 };
222 map_status(status)
223 }
224}