1use core::ffi::c_void;
14use core::marker::PhantomData;
15
16use baracuda_cutlass::{Error, Result};
17use baracuda_driver::Stream;
18use baracuda_kernels_types::{
19 ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
20 PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace,
21};
22
23use super::{map_status, validate_input_element, validate_output_element};
24
25#[derive(Copy, Clone, Debug)]
27pub struct QuantizePerTokenDescriptor {
28 pub n: i32,
30 pub d: i32,
32 pub q_min: i32,
34 pub q_max: i32,
36 pub input_element: ElementKind,
38 pub output_element: ElementKind,
40}
41
42pub struct QuantizePerTokenArgs<'a, TIn: Element, TOut: IntElement> {
44 pub input: TensorRef<'a, TIn, 2>,
46 pub scale: TensorRef<'a, TIn, 1>,
48 pub zero_point: TensorRef<'a, i32, 1>,
50 pub output: TensorMut<'a, TOut, 2>,
52}
53
54pub struct QuantizePerTokenPlan<TIn: Element, TOut: IntElement> {
78 desc: QuantizePerTokenDescriptor,
79 sku: KernelSku,
80 _marker: PhantomData<(TIn, TOut)>,
81}
82
83impl<TIn: Element, TOut: IntElement> QuantizePerTokenPlan<TIn, TOut> {
84 pub fn select(
86 _stream: &Stream,
87 desc: &QuantizePerTokenDescriptor,
88 _pref: PlanPreference,
89 ) -> Result<Self> {
90 if desc.input_element != TIn::KIND {
91 return Err(Error::Unsupported(
92 "QuantizePerTokenPlan: descriptor input_element != type parameter TIn",
93 ));
94 }
95 if desc.output_element != TOut::KIND {
96 return Err(Error::Unsupported(
97 "QuantizePerTokenPlan: descriptor output_element != type parameter TOut",
98 ));
99 }
100 validate_input_element(TIn::KIND, "QuantizePerTokenPlan: unsupported TIn dtype")?;
101 validate_output_element(TOut::KIND, "QuantizePerTokenPlan: unsupported TOut dtype")?;
102 if desc.n < 0 || desc.d < 0 {
103 return Err(Error::InvalidProblem(
104 "QuantizePerTokenPlan: n and d must be non-negative",
105 ));
106 }
107 if desc.q_max < desc.q_min {
108 return Err(Error::InvalidProblem(
109 "QuantizePerTokenPlan: q_max < q_min",
110 ));
111 }
112 let sku = build_sku::<TIn, TOut>(QuantizeKind::PerToken);
113 Ok(Self {
114 desc: *desc,
115 sku,
116 _marker: PhantomData,
117 })
118 }
119
120 pub fn can_implement(&self, args: &QuantizePerTokenArgs<'_, TIn, TOut>) -> Result<()> {
122 if args.input.shape != [self.desc.n, self.desc.d] {
123 return Err(Error::InvalidProblem(
124 "QuantizePerTokenPlan: input shape != [n, d]",
125 ));
126 }
127 if args.output.shape != [self.desc.n, self.desc.d] {
128 return Err(Error::InvalidProblem(
129 "QuantizePerTokenPlan: output shape != [n, d]",
130 ));
131 }
132 if args.scale.shape != [self.desc.n] {
133 return Err(Error::InvalidProblem(
134 "QuantizePerTokenPlan: scale shape != [n]",
135 ));
136 }
137 if args.zero_point.shape != [self.desc.n] {
138 return Err(Error::InvalidProblem(
139 "QuantizePerTokenPlan: zero_point shape != [n]",
140 ));
141 }
142 Ok(())
143 }
144
145 #[inline]
147 pub fn workspace_size(&self) -> usize {
148 0
149 }
150
151 #[inline]
153 pub fn sku(&self) -> KernelSku {
154 self.sku
155 }
156
157 #[inline]
159 pub fn precision_guarantee(&self) -> PrecisionGuarantee {
160 self.sku.precision_guarantee
161 }
162
163 pub fn run(
165 &self,
166 stream: &Stream,
167 _workspace: Workspace<'_>,
168 args: QuantizePerTokenArgs<'_, TIn, TOut>,
169 ) -> Result<()> {
170 self.can_implement(&args)?;
171 let total = (self.desc.n as i64) * (self.desc.d as i64);
172 if total == 0 {
173 return Ok(());
174 }
175 let in_ptr = args.input.data.as_raw().0 as *const c_void;
176 let sc_ptr = args.scale.data.as_raw().0 as *const c_void;
177 let zp_ptr = args.zero_point.data.as_raw().0 as *const c_void;
178 let out_ptr = args.output.data.as_raw().0 as *mut c_void;
179 let stream_ptr = stream.as_raw() as *mut c_void;
180
181 let status = match (TIn::KIND, TOut::KIND) {
182 (ElementKind::F32, ElementKind::S8) => unsafe {
183 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f32_s8_run(
184 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
185 in_ptr, sc_ptr, zp_ptr, out_ptr,
186 core::ptr::null_mut(), 0, stream_ptr,
187 )
188 },
189 (ElementKind::F32, ElementKind::U8) => unsafe {
190 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f32_u8_run(
191 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
192 in_ptr, sc_ptr, zp_ptr, out_ptr,
193 core::ptr::null_mut(), 0, stream_ptr,
194 )
195 },
196 (ElementKind::F64, ElementKind::S8) => unsafe {
197 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f64_s8_run(
198 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
199 in_ptr, sc_ptr, zp_ptr, out_ptr,
200 core::ptr::null_mut(), 0, stream_ptr,
201 )
202 },
203 (ElementKind::F64, ElementKind::U8) => unsafe {
204 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f64_u8_run(
205 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
206 in_ptr, sc_ptr, zp_ptr, out_ptr,
207 core::ptr::null_mut(), 0, stream_ptr,
208 )
209 },
210 (ElementKind::F16, ElementKind::S8) => unsafe {
211 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f16_s8_run(
212 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
213 in_ptr, sc_ptr, zp_ptr, out_ptr,
214 core::ptr::null_mut(), 0, stream_ptr,
215 )
216 },
217 (ElementKind::F16, ElementKind::U8) => unsafe {
218 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_f16_u8_run(
219 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
220 in_ptr, sc_ptr, zp_ptr, out_ptr,
221 core::ptr::null_mut(), 0, stream_ptr,
222 )
223 },
224 (ElementKind::Bf16, ElementKind::S8) => unsafe {
225 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_bf16_s8_run(
226 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
227 in_ptr, sc_ptr, zp_ptr, out_ptr,
228 core::ptr::null_mut(), 0, stream_ptr,
229 )
230 },
231 (ElementKind::Bf16, ElementKind::U8) => unsafe {
232 baracuda_kernels_sys::baracuda_kernels_quantize_per_token_bf16_u8_run(
233 self.desc.n, self.desc.d, self.desc.q_min, self.desc.q_max,
234 in_ptr, sc_ptr, zp_ptr, out_ptr,
235 core::ptr::null_mut(), 0, stream_ptr,
236 )
237 },
238 _ => {
239 return Err(Error::Unsupported(
240 "QuantizePerTokenPlan::run reached unsupported (TIn, TOut) combination",
241 ))
242 }
243 };
244 map_status(status)
245 }
246}
247
248pub(crate) fn build_sku<TIn: Element, TOut: IntElement>(op: QuantizeKind) -> KernelSku {
250 let precision_guarantee = PrecisionGuarantee {
251 math_precision: if TIn::KIND == ElementKind::F64 {
252 MathPrecision::F64
253 } else {
254 MathPrecision::F32
255 },
256 accumulator: ElementKind::F32,
257 bit_stable_on_same_hardware: true,
259 deterministic: true,
260 };
261 KernelSku {
262 category: OpCategory::Quantization,
263 op: op as u16,
264 element: TIn::KIND,
265 aux_element: Some(TOut::KIND),
266 layout: None,
267 epilogue: None,
268 arch: ArchSku::Sm80,
269 backend: BackendKind::Bespoke,
270 precision_guarantee,
271 }
272}