1use std::marker::PhantomData;
2
3use super::*;
4use crate::tensor::{
5 View, ViewExpand, ViewOperations, ViewOperationsExpand,
6 launch::{ViewArg, ViewCompilationArg},
7 layout::Coordinates,
8};
9use cubecl::prelude::*;
10use cubecl_common::{
11 e2m1x2, e4m3, e5m2,
12 quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue},
13 ue8m0,
14};
15use cubecl_core::{
16 self as cubecl, define_size,
17 ir::{ElemType, FloatKind, StorageType, VectorSize},
18 prelude::barrier::BarrierExpand,
19 unexpanded,
20};
21use half::{bf16, f16};
22
23#[expect(dead_code, reason = "only used in expand")]
31#[derive(CubeType, CubeLaunch, Clone, Copy)]
32pub struct QuantizedView<
33 Q: Scalar,
34 NQ: Size,
35 S: Scalar,
36 F: Numeric,
37 NF: Size,
38 C: Coordinates + 'static,
39> {
40 values: View<Vector<Q, NQ>, C>,
41 scales: View<S, C>,
42 #[cube(comptime)]
43 scheme: QuantScheme,
44 #[cube(comptime)]
45 _ty: PhantomData<(F, NF)>,
46}
47
48#[cube]
49impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
50 QuantizedView<Q, NQ, S, F, NF, C>
51{
52 pub fn new(
53 values: View<Vector<Q, NQ>, C>,
54 scales: View<S, C>,
55 #[comptime] scheme: QuantScheme,
56 ) -> Self {
57 QuantizedView::<Q, NQ, S, F, NF, C> {
58 values,
59 scales,
60 scheme,
61 _ty: PhantomData,
62 }
63 }
64}
65
66impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
67 QuantizedView<Q, NQ, S, F, NF, C>
68{
69 pub fn view(self) -> View<Vector<F, NF>, C> {
70 unexpanded!()
71 }
72
73 pub fn __expand_view(
74 scope: &mut Scope,
75 this: QuantizedViewExpand<Q, NQ, S, F, NF, C>,
76 ) -> ViewExpand<Vector<F, NF>, C, ReadOnly> {
77 this.__expand_view_method(scope)
78 }
79}
80
81impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
82 QuantizedViewExpand<Q, NQ, S, F, NF, C>
83{
84 pub fn new(
85 values: ViewExpand<Vector<Q, NQ>, C>,
86 scales: ViewExpand<S, C>,
87 scheme: QuantScheme,
88 ) -> Self {
89 QuantizedViewExpand::<Q, NQ, S, F, NF, C> {
90 values,
91 scales,
92 scheme,
93 _ty: PhantomData,
94 }
95 }
96
97 pub fn __expand_view_method(
98 self,
99 _scope: &mut Scope,
100 ) -> ViewExpand<Vector<F, NF>, C, ReadOnly> {
101 ViewExpand::new(self)
102 }
103}
104
105impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static> Vectorized
106 for QuantizedView<Q, NQ, S, F, NF, C>
107{
108}
109impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
110 VectorizedExpand for QuantizedViewExpand<Q, NQ, S, F, NF, C>
111{
112 fn vector_size(&self) -> VectorSize {
113 self.values.vector_size() * self.scheme.num_quants()
114 }
115}
116
117impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
118 ViewOperations<Vector<F, NF>, C> for QuantizedView<Q, NQ, S, F, NF, C>
119{
120}
121
122impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
123 ViewOperationsExpand<Vector<F, NF>, C> for QuantizedViewExpand<Q, NQ, S, F, NF, C>
124{
125 fn __expand_read_method(
126 &self,
127 scope: &mut Scope,
128 pos: <C>::ExpandType,
129 ) -> NativeExpand<Vector<F, NF>> {
130 let value = self.values.clone().__expand_read_method(scope, pos.clone());
131 let scale = self.scales.clone().__expand_read_method(scope, pos);
132
133 dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme)
134 }
135
136 fn __expand_read_checked_method(
137 &self,
138 scope: &mut Scope,
139 pos: <C>::ExpandType,
140 ) -> NativeExpand<Vector<F, NF>> {
141 let value = self
142 .values
143 .clone()
144 .__expand_read_checked_method(scope, pos.clone());
145 let scale = self
146 .scales
147 .clone()
148 .__expand_read_checked_method(scope, pos.clone());
149
150 dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme)
151 }
152
153 fn __expand_read_masked_method(
154 &self,
155 scope: &mut Scope,
156 pos: <C>::ExpandType,
157 mask_value: NativeExpand<Vector<F, NF>>,
158 ) -> NativeExpand<Vector<F, NF>> {
159 let value = self
160 .values
161 .clone()
162 .__expand_read_checked_method(scope, pos.clone());
163 let scale = self
164 .scales
165 .clone()
166 .__expand_read_checked_method(scope, pos.clone());
167 let in_bounds = self.__expand_is_in_bounds_method(scope, pos);
168
169 let value = dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme);
170 select::expand::<Vector<F, NF>>(scope, in_bounds, value, mask_value)
171 }
172
173 fn __expand_read_unchecked_method(
174 &self,
175 scope: &mut Scope,
176 pos: <C>::ExpandType,
177 ) -> NativeExpand<Vector<F, NF>> {
178 let value = self
179 .values
180 .clone()
181 .__expand_read_unchecked_method(scope, pos.clone());
182 let scale = self
183 .scales
184 .clone()
185 .__expand_read_unchecked_method(scope, pos);
186
187 dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme)
188 }
189
190 fn __expand_to_linear_slice_method(
191 &self,
192 _scope: &mut Scope,
193 _pos: <C>::ExpandType,
194 _end: <C>::ExpandType,
195 ) -> SliceExpand<Vector<F, NF>, ReadOnly> {
196 panic!("Can't create raw slice for quantized view")
197 }
198
199 fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
200 self.values.clone().__expand_shape_method(scope)
201 }
202
203 fn __expand_is_in_bounds_method(
204 &self,
205 scope: &mut Scope,
206 pos: C::ExpandType,
207 ) -> NativeExpand<bool> {
208 self.values.clone().__expand_is_in_bounds_method(scope, pos)
209 }
210
211 fn __expand_tensor_map_load_method(
212 &self,
213 _scope: &mut Scope,
214 _barrier: BarrierExpand,
215 _shared_memory: SliceExpand<Vector<F, NF>, ReadWrite>,
216 _pos: C::ExpandType,
217 ) {
218 panic!("Can't use tensor map functions on quantized view");
219 }
220}
221
222struct ExpandDynamic<'a, E: Numeric, N: Size, C: Coordinates + 'static> {
223 values: &'a ViewCompilationArg<C>,
224 scales: &'a ViewCompilationArg<C>,
225 scheme: QuantScheme,
226 builder: &'a mut KernelBuilder,
227 _ty: PhantomData<(E, N)>,
228}
229
230impl<'a, E: Numeric, N: Size, C: Coordinates + 'static> RunWithQuantType
231 for ExpandDynamic<'a, E, N, C>
232{
233 type Output = ViewExpand<Vector<E, N>, C>;
234
235 fn execute<Q: Scalar, S: Scalar>(self) -> Self::Output {
236 define_size!(NQ);
237
238 let vector_size = N::__expand_value(&self.builder.scope);
239 let vector_size_q = vector_size / self.scheme.num_quants();
240 self.builder.scope.register_size::<NQ>(vector_size_q);
241
242 let values = View::<Vector<Q, NQ>, C>::expand(self.values, self.builder);
243 let scales = View::<S, C>::expand(self.scales, self.builder);
244 let view = QuantizedViewExpand::new(values, scales, self.scheme);
245 ViewExpand::new(view)
246 }
247}
248
249pub(crate) struct RegisterDynamic<'a, E: CubePrimitive, C: Coordinates + 'static, R: Runtime> {
250 pub values: ViewArg<C, R>,
251 pub scales: ViewArg<C, R>,
252 pub scheme: QuantScheme,
253 pub launcher: &'a mut KernelLauncher<R>,
254 pub _ty: PhantomData<E>,
255}
256
257impl<'a, E: CubePrimitive, C: Coordinates + 'static, R: Runtime> RunWithQuantType
258 for RegisterDynamic<'a, E, C, R>
259{
260 type Output = ViewCompilationArg<C>;
261
262 fn execute<Q: Scalar, S: Scalar>(self) -> Self::Output {
263 define_size!(NQ);
264
265 self.launcher.with_scope(|scope| {
266 let vector_size_q = E::__expand_vector_size(scope) / self.scheme.num_quants();
267 scope.register_size::<NQ>(vector_size_q);
268 });
269
270 let values = View::<Vector<Q, NQ>, C>::register(self.values, self.launcher);
271 let scales = View::<S, C>::register(self.scales, self.launcher);
272 ViewCompilationArg::Quantized {
273 values: Box::new(values),
274 scales: Box::new(scales),
275 scheme: self.scheme,
276 }
277 }
278}
279
280pub fn run_with_quant_type<F: RunWithQuantType>(func: F, scheme: QuantScheme) -> F::Output {
283 fn run_with_q<F: RunWithQuantType, Q: Scalar>(func: F, scheme: QuantScheme) -> F::Output {
284 match scheme.param {
285 QuantParam::F32 => func.execute::<Q, f32>(),
286 QuantParam::F16 => func.execute::<Q, f16>(),
287 QuantParam::BF16 => func.execute::<Q, bf16>(),
288 QuantParam::UE8M0 => func.execute::<Q, ue8m0>(),
289 QuantParam::UE4M3 => func.execute::<Q, e4m3>(),
290 }
291 }
292
293 let run_q = match scheme.store {
294 QuantStore::Native => match scheme.value {
295 QuantValue::Q8F => run_with_q::<F, i8>,
296 QuantValue::Q8S => run_with_q::<F, i8>,
297 QuantValue::E5M2 => run_with_q::<F, e5m2>,
298 QuantValue::E4M3 => run_with_q::<F, e4m3>,
299 QuantValue::Q4F
300 | QuantValue::Q4S
301 | QuantValue::Q2F
302 | QuantValue::Q2S
303 | QuantValue::E2M1 => {
304 panic!("Sub-byte quantization can't be native")
305 }
306 },
307 QuantStore::PackedU32(_) => run_with_q::<F, u32>,
308 QuantStore::PackedNative(_) => run_with_q::<F, e2m1x2>,
309 };
310 run_q(func, scheme)
311}
312
313pub(crate) fn expand_dynamic<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility>(
316 values: &ViewCompilationArg<C>,
317 scales: &ViewCompilationArg<C>,
318 scheme: QuantScheme,
319 builder: &mut KernelBuilder,
320) -> ViewExpand<E, C, IO> {
321 use core::mem::transmute as t;
322
323 fn expand_dynamic_f<F: Numeric, NF: Size, C: Coordinates + 'static>(
325 values: &ViewCompilationArg<C>,
326 scales: &ViewCompilationArg<C>,
327 scheme: QuantScheme,
328 builder: &mut KernelBuilder,
329 ) -> ViewExpand<Vector<F, NF>, C> {
330 let func = ExpandDynamic {
331 values,
332 scales,
333 scheme,
334 builder,
335 _ty: PhantomData::<(F, NF)>,
336 };
337 run_with_quant_type(func, scheme)
338 }
339
340 define_size!(NF);
341
342 let vector_size = E::__expand_vector_size(&builder.scope);
343
344 builder.scope.register_size::<NF>(vector_size);
345
346 #[allow(clippy::missing_transmute_annotations)]
347 unsafe {
348 match E::as_type(&builder.scope).storage_type() {
349 StorageType::Scalar(ElemType::Float(ty)) => match ty {
350 FloatKind::F16 => t(expand_dynamic_f::<f16, NF, C>(
351 values, scales, scheme, builder,
352 )),
353 FloatKind::BF16 => t(expand_dynamic_f::<bf16, NF, C>(
354 values, scales, scheme, builder,
355 )),
356 FloatKind::Flex32 => t(expand_dynamic_f::<flex32, NF, C>(
357 values, scales, scheme, builder,
358 )),
359 FloatKind::F32 => t(expand_dynamic_f::<f32, NF, C>(
360 values, scales, scheme, builder,
361 )),
362 FloatKind::TF32 => t(expand_dynamic_f::<tf32, NF, C>(
363 values, scales, scheme, builder,
364 )),
365 FloatKind::F64 => t(expand_dynamic_f::<f64, NF, C>(
366 values, scales, scheme, builder,
367 )),
368 FloatKind::E2M1
369 | FloatKind::E2M3
370 | FloatKind::E3M2
371 | FloatKind::E4M3
372 | FloatKind::E5M2
373 | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"),
374 },
375 _ => unreachable!("Quantized view should only be used with floats"),
376 }
377 }
378}