Skip to main content

cubecl_std/quant/
view.rs

1use std::marker::PhantomData;
2
3use super::*;
4use crate::tensor::{
5    View, ViewExpand, ViewOperations, ViewOperationsExpand,
6    launch::{ViewArg, ViewCompilationArg},
7    layout::Coordinates,
8};
9use cubecl::prelude::*;
10use cubecl_common::{
11    e2m1x2, e4m3, e5m2,
12    quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue},
13    ue8m0,
14};
15use cubecl_core::{
16    self as cubecl, define_size,
17    ir::{ElemType, FloatKind, StorageType, VectorSize},
18    prelude::barrier::BarrierExpand,
19    unexpanded,
20};
21use half::{bf16, f16};
22
23/// View that dequantizes after loads. Scales layout should take values coordinates and map them
24/// to the corresponding scale.
25///
26/// # Warning
27/// Assumes only one scale maps to a single load. Adjust vector size of values or block size to ensure
28/// this.
29/// Must ensure `block_size.is_multiple_of(vector_size * scheme.num_quants())`.
30#[expect(dead_code, reason = "only used in expand")]
31#[derive(CubeType, CubeLaunch, Clone, Copy)]
32pub struct QuantizedView<
33    Q: Scalar,
34    NQ: Size,
35    S: Scalar,
36    F: Numeric,
37    NF: Size,
38    C: Coordinates + 'static,
39> {
40    values: View<Vector<Q, NQ>, C>,
41    scales: View<S, C>,
42    #[cube(comptime)]
43    scheme: QuantScheme,
44    #[cube(comptime)]
45    _ty: PhantomData<(F, NF)>,
46}
47
48#[cube]
49impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
50    QuantizedView<Q, NQ, S, F, NF, C>
51{
52    pub fn new(
53        values: View<Vector<Q, NQ>, C>,
54        scales: View<S, C>,
55        #[comptime] scheme: QuantScheme,
56    ) -> Self {
57        QuantizedView::<Q, NQ, S, F, NF, C> {
58            values,
59            scales,
60            scheme,
61            _ty: PhantomData,
62        }
63    }
64}
65
66impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
67    QuantizedView<Q, NQ, S, F, NF, C>
68{
69    pub fn view(self) -> View<Vector<F, NF>, C> {
70        unexpanded!()
71    }
72
73    pub fn __expand_view(
74        scope: &mut Scope,
75        this: QuantizedViewExpand<Q, NQ, S, F, NF, C>,
76    ) -> ViewExpand<Vector<F, NF>, C, ReadOnly> {
77        this.__expand_view_method(scope)
78    }
79}
80
81impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
82    QuantizedViewExpand<Q, NQ, S, F, NF, C>
83{
84    pub fn new(
85        values: ViewExpand<Vector<Q, NQ>, C>,
86        scales: ViewExpand<S, C>,
87        scheme: QuantScheme,
88    ) -> Self {
89        QuantizedViewExpand::<Q, NQ, S, F, NF, C> {
90            values,
91            scales,
92            scheme,
93            _ty: PhantomData,
94        }
95    }
96
97    pub fn __expand_view_method(
98        self,
99        _scope: &mut Scope,
100    ) -> ViewExpand<Vector<F, NF>, C, ReadOnly> {
101        ViewExpand::new(self)
102    }
103}
104
105impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static> Vectorized
106    for QuantizedView<Q, NQ, S, F, NF, C>
107{
108}
109impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
110    VectorizedExpand for QuantizedViewExpand<Q, NQ, S, F, NF, C>
111{
112    fn vector_size(&self) -> VectorSize {
113        self.values.vector_size() * self.scheme.num_quants()
114    }
115}
116
117impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
118    ViewOperations<Vector<F, NF>, C> for QuantizedView<Q, NQ, S, F, NF, C>
119{
120}
121
122impl<Q: Scalar, NQ: Size, S: Scalar, F: Numeric, NF: Size, C: Coordinates + 'static>
123    ViewOperationsExpand<Vector<F, NF>, C> for QuantizedViewExpand<Q, NQ, S, F, NF, C>
124{
125    fn __expand_read_method(
126        &self,
127        scope: &mut Scope,
128        pos: <C>::ExpandType,
129    ) -> NativeExpand<Vector<F, NF>> {
130        let value = self.values.clone().__expand_read_method(scope, pos.clone());
131        let scale = self.scales.clone().__expand_read_method(scope, pos);
132
133        dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme)
134    }
135
136    fn __expand_read_checked_method(
137        &self,
138        scope: &mut Scope,
139        pos: <C>::ExpandType,
140    ) -> NativeExpand<Vector<F, NF>> {
141        let value = self
142            .values
143            .clone()
144            .__expand_read_checked_method(scope, pos.clone());
145        let scale = self
146            .scales
147            .clone()
148            .__expand_read_checked_method(scope, pos.clone());
149
150        dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme)
151    }
152
153    fn __expand_read_masked_method(
154        &self,
155        scope: &mut Scope,
156        pos: <C>::ExpandType,
157        mask_value: NativeExpand<Vector<F, NF>>,
158    ) -> NativeExpand<Vector<F, NF>> {
159        let value = self
160            .values
161            .clone()
162            .__expand_read_checked_method(scope, pos.clone());
163        let scale = self
164            .scales
165            .clone()
166            .__expand_read_checked_method(scope, pos.clone());
167        let in_bounds = self.__expand_is_in_bounds_method(scope, pos);
168
169        let value = dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme);
170        select::expand::<Vector<F, NF>>(scope, in_bounds, value, mask_value)
171    }
172
173    fn __expand_read_unchecked_method(
174        &self,
175        scope: &mut Scope,
176        pos: <C>::ExpandType,
177    ) -> NativeExpand<Vector<F, NF>> {
178        let value = self
179            .values
180            .clone()
181            .__expand_read_unchecked_method(scope, pos.clone());
182        let scale = self
183            .scales
184            .clone()
185            .__expand_read_unchecked_method(scope, pos);
186
187        dequantize_aligned::expand::<Q, S, F, NQ, NF>(scope, value, scale, self.scheme)
188    }
189
190    fn __expand_to_linear_slice_method(
191        &self,
192        _scope: &mut Scope,
193        _pos: <C>::ExpandType,
194        _end: <C>::ExpandType,
195    ) -> SliceExpand<Vector<F, NF>, ReadOnly> {
196        panic!("Can't create raw slice for quantized view")
197    }
198
199    fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
200        self.values.clone().__expand_shape_method(scope)
201    }
202
203    fn __expand_is_in_bounds_method(
204        &self,
205        scope: &mut Scope,
206        pos: C::ExpandType,
207    ) -> NativeExpand<bool> {
208        self.values.clone().__expand_is_in_bounds_method(scope, pos)
209    }
210
211    fn __expand_tensor_map_load_method(
212        &self,
213        _scope: &mut Scope,
214        _barrier: BarrierExpand,
215        _shared_memory: SliceExpand<Vector<F, NF>, ReadWrite>,
216        _pos: C::ExpandType,
217    ) {
218        panic!("Can't use tensor map functions on quantized view");
219    }
220}
221
222struct ExpandDynamic<'a, E: Numeric, N: Size, C: Coordinates + 'static> {
223    values: &'a ViewCompilationArg<C>,
224    scales: &'a ViewCompilationArg<C>,
225    scheme: QuantScheme,
226    builder: &'a mut KernelBuilder,
227    _ty: PhantomData<(E, N)>,
228}
229
230impl<'a, E: Numeric, N: Size, C: Coordinates + 'static> RunWithQuantType
231    for ExpandDynamic<'a, E, N, C>
232{
233    type Output = ViewExpand<Vector<E, N>, C>;
234
235    fn execute<Q: Scalar, S: Scalar>(self) -> Self::Output {
236        define_size!(NQ);
237
238        let vector_size = N::__expand_value(&self.builder.scope);
239        let vector_size_q = vector_size / self.scheme.num_quants();
240        self.builder.scope.register_size::<NQ>(vector_size_q);
241
242        let values = View::<Vector<Q, NQ>, C>::expand(self.values, self.builder);
243        let scales = View::<S, C>::expand(self.scales, self.builder);
244        let view = QuantizedViewExpand::new(values, scales, self.scheme);
245        ViewExpand::new(view)
246    }
247}
248
249pub(crate) struct RegisterDynamic<'a, E: CubePrimitive, C: Coordinates + 'static, R: Runtime> {
250    pub values: ViewArg<C, R>,
251    pub scales: ViewArg<C, R>,
252    pub scheme: QuantScheme,
253    pub launcher: &'a mut KernelLauncher<R>,
254    pub _ty: PhantomData<E>,
255}
256
257impl<'a, E: CubePrimitive, C: Coordinates + 'static, R: Runtime> RunWithQuantType
258    for RegisterDynamic<'a, E, C, R>
259{
260    type Output = ViewCompilationArg<C>;
261
262    fn execute<Q: Scalar, S: Scalar>(self) -> Self::Output {
263        define_size!(NQ);
264
265        self.launcher.with_scope(|scope| {
266            let vector_size_q = E::__expand_vector_size(scope) / self.scheme.num_quants();
267            scope.register_size::<NQ>(vector_size_q);
268        });
269
270        let values = View::<Vector<Q, NQ>, C>::register(self.values, self.launcher);
271        let scales = View::<S, C>::register(self.scales, self.launcher);
272        ViewCompilationArg::Quantized {
273            values: Box::new(values),
274            scales: Box::new(scales),
275            scheme: self.scheme,
276        }
277    }
278}
279
280/// Run a function with the quantization storage type and scale. Useful when concrete types are
281/// required but aren't available, and only the dynamic schema is known.
282pub fn run_with_quant_type<F: RunWithQuantType>(func: F, scheme: QuantScheme) -> F::Output {
283    fn run_with_q<F: RunWithQuantType, Q: Scalar>(func: F, scheme: QuantScheme) -> F::Output {
284        match scheme.param {
285            QuantParam::F32 => func.execute::<Q, f32>(),
286            QuantParam::F16 => func.execute::<Q, f16>(),
287            QuantParam::BF16 => func.execute::<Q, bf16>(),
288            QuantParam::UE8M0 => func.execute::<Q, ue8m0>(),
289            QuantParam::UE4M3 => func.execute::<Q, e4m3>(),
290        }
291    }
292
293    let run_q = match scheme.store {
294        QuantStore::Native => match scheme.value {
295            QuantValue::Q8F => run_with_q::<F, i8>,
296            QuantValue::Q8S => run_with_q::<F, i8>,
297            QuantValue::E5M2 => run_with_q::<F, e5m2>,
298            QuantValue::E4M3 => run_with_q::<F, e4m3>,
299            QuantValue::Q4F
300            | QuantValue::Q4S
301            | QuantValue::Q2F
302            | QuantValue::Q2S
303            | QuantValue::E2M1 => {
304                panic!("Sub-byte quantization can't be native")
305            }
306        },
307        QuantStore::PackedU32(_) => run_with_q::<F, u32>,
308        QuantStore::PackedNative(_) => run_with_q::<F, e2m1x2>,
309    };
310    run_q(func, scheme)
311}
312
313/// Dynamically expand based on the quantization scheme. Ugly, but the only way to fully hide the
314/// quantization from the kernel using the view.
315pub(crate) fn expand_dynamic<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility>(
316    values: &ViewCompilationArg<C>,
317    scales: &ViewCompilationArg<C>,
318    scheme: QuantScheme,
319    builder: &mut KernelBuilder,
320) -> ViewExpand<E, C, IO> {
321    use core::mem::transmute as t;
322
323    // To specify tighter trait bounds
324    fn expand_dynamic_f<F: Numeric, NF: Size, C: Coordinates + 'static>(
325        values: &ViewCompilationArg<C>,
326        scales: &ViewCompilationArg<C>,
327        scheme: QuantScheme,
328        builder: &mut KernelBuilder,
329    ) -> ViewExpand<Vector<F, NF>, C> {
330        let func = ExpandDynamic {
331            values,
332            scales,
333            scheme,
334            builder,
335            _ty: PhantomData::<(F, NF)>,
336        };
337        run_with_quant_type(func, scheme)
338    }
339
340    define_size!(NF);
341
342    let vector_size = E::__expand_vector_size(&builder.scope);
343
344    builder.scope.register_size::<NF>(vector_size);
345
346    #[allow(clippy::missing_transmute_annotations)]
347    unsafe {
348        match E::as_type(&builder.scope).storage_type() {
349            StorageType::Scalar(ElemType::Float(ty)) => match ty {
350                FloatKind::F16 => t(expand_dynamic_f::<f16, NF, C>(
351                    values, scales, scheme, builder,
352                )),
353                FloatKind::BF16 => t(expand_dynamic_f::<bf16, NF, C>(
354                    values, scales, scheme, builder,
355                )),
356                FloatKind::Flex32 => t(expand_dynamic_f::<flex32, NF, C>(
357                    values, scales, scheme, builder,
358                )),
359                FloatKind::F32 => t(expand_dynamic_f::<f32, NF, C>(
360                    values, scales, scheme, builder,
361                )),
362                FloatKind::TF32 => t(expand_dynamic_f::<tf32, NF, C>(
363                    values, scales, scheme, builder,
364                )),
365                FloatKind::F64 => t(expand_dynamic_f::<f64, NF, C>(
366                    values, scales, scheme, builder,
367                )),
368                FloatKind::E2M1
369                | FloatKind::E2M3
370                | FloatKind::E3M2
371                | FloatKind::E4M3
372                | FloatKind::E5M2
373                | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"),
374            },
375            _ => unreachable!("Quantized view should only be used with floats"),
376        }
377    }
378}