cubecl_std/quant/
view.rs

1use std::marker::PhantomData;
2
3use super::*;
4use crate::{
5    CubeOption, CubeOptionExpand,
6    tensor::{
7        View, ViewExpand, ViewOperations, ViewOperationsExpand, launch::ViewCompilationArg,
8        layout::Coordinates,
9    },
10};
11use cubecl::prelude::*;
12use cubecl_common::{
13    e2m1x2, e4m3, e5m2,
14    quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue},
15    ue8m0,
16};
17use cubecl_core::{
18    self as cubecl,
19    ir::{ElemType, FloatKind, StorageType},
20    prelude::barrier::BarrierExpand,
21    unexpanded,
22};
23use half::{bf16, f16};
24
25/// View that dequantizes after loads. Scales layout should take values coordinates and map them
26/// to the corresponding scale.
27///
28/// # Warning
29/// Assumes only one scale maps to a single load. Adjust line size of values or block size to ensure
30/// this.
31/// Must ensure `block_size.is_multiple_of(line_size * scheme.num_quants())`.
32#[expect(dead_code, reason = "only used in expand")]
33#[derive(CubeType, CubeLaunch, Clone, Copy)]
34pub struct QuantizedView<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> {
35    values: View<Line<Q>, C>,
36    scales: View<S, C>,
37    #[cube(comptime)]
38    scheme: QuantScheme,
39    #[cube(comptime)]
40    _ty: PhantomData<F>,
41}
42
43#[cube]
44impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
45    QuantizedView<Q, S, F, C>
46{
47    pub fn new(
48        values: View<Line<Q>, C>,
49        scales: View<S, C>,
50        #[comptime] scheme: QuantScheme,
51    ) -> Self {
52        QuantizedView::<Q, S, F, C> {
53            values,
54            scales,
55            scheme,
56            _ty: PhantomData,
57        }
58    }
59}
60
61impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
62    QuantizedView<Q, S, F, C>
63{
64    pub fn view(self) -> View<Line<F>, C> {
65        unexpanded!()
66    }
67
68    pub fn __expand_view(
69        scope: &mut Scope,
70        this: QuantizedViewExpand<Q, S, F, C>,
71    ) -> ViewExpand<Line<F>, C, ReadOnly> {
72        this.__expand_view_method(scope)
73    }
74}
75
76impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
77    QuantizedViewExpand<Q, S, F, C>
78{
79    pub fn new(
80        values: ViewExpand<Line<Q>, C>,
81        scales: ViewExpand<S, C>,
82        scheme: QuantScheme,
83    ) -> Self {
84        QuantizedViewExpand::<Q, S, F, C> {
85            values,
86            scales,
87            scheme,
88            _ty: PhantomData,
89        }
90    }
91
92    pub fn __expand_view_method(self, _scope: &mut Scope) -> ViewExpand<Line<F>, C, ReadOnly> {
93        ViewExpand::new(self)
94    }
95}
96
97impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> Lined
98    for QuantizedView<Q, S, F, C>
99{
100}
101impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> LinedExpand
102    for QuantizedViewExpand<Q, S, F, C>
103{
104    fn line_size(&self) -> u32 {
105        self.values.line_size() * self.scheme.num_quants() as u32
106    }
107}
108
109impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
110    ViewOperations<Line<F>, C> for QuantizedView<Q, S, F, C>
111{
112}
113
114impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
115    ViewOperationsExpand<Line<F>, C> for QuantizedViewExpand<Q, S, F, C>
116{
117    fn __expand_read_method(
118        &self,
119        scope: &mut Scope,
120        pos: <C>::ExpandType,
121    ) -> ExpandElementTyped<Line<F>> {
122        let value = self.values.clone().__expand_read_method(scope, pos.clone());
123        let scale = self.scales.clone().__expand_read_method(scope, pos);
124
125        dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
126    }
127
128    fn __expand_read_checked_method(
129        &self,
130        scope: &mut Scope,
131        pos: <C>::ExpandType,
132    ) -> ExpandElementTyped<Line<F>> {
133        let value = self
134            .values
135            .clone()
136            .__expand_read_checked_method(scope, pos.clone());
137        let scale = self
138            .scales
139            .clone()
140            .__expand_read_checked_method(scope, pos.clone());
141
142        dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
143    }
144
145    fn __expand_read_masked_method(
146        &self,
147        scope: &mut Scope,
148        pos: <C>::ExpandType,
149        mask_value: ExpandElementTyped<Line<F>>,
150    ) -> ExpandElementTyped<Line<F>> {
151        let value = self
152            .values
153            .clone()
154            .__expand_read_checked_method(scope, pos.clone());
155        let scale = self
156            .scales
157            .clone()
158            .__expand_read_checked_method(scope, pos.clone());
159        let in_bounds = self.__expand_is_in_bounds_method(scope, pos);
160
161        let value = dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme);
162        select::expand::<Line<F>>(scope, in_bounds, value, mask_value)
163    }
164
165    fn __expand_read_unchecked_method(
166        &self,
167        scope: &mut Scope,
168        pos: <C>::ExpandType,
169    ) -> ExpandElementTyped<Line<F>> {
170        let value = self
171            .values
172            .clone()
173            .__expand_read_unchecked_method(scope, pos.clone());
174        let scale = self
175            .scales
176            .clone()
177            .__expand_read_unchecked_method(scope, pos);
178
179        dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
180    }
181
182    fn __expand_to_linear_slice_method(
183        &self,
184        _scope: &mut Scope,
185        _pos: <C>::ExpandType,
186        _end: <C>::ExpandType,
187    ) -> SliceExpand<Line<F>, ReadOnly> {
188        panic!("Can't create raw slice for quantized view")
189    }
190
191    fn __expand_as_tensor_map_method(
192        &self,
193        scope: &mut Scope,
194    ) -> CubeOptionExpand<TensorMap<Line<F>>> {
195        CubeOption::__expand_new_None(scope)
196    }
197
198    fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
199        self.values.clone().__expand_shape_method(scope)
200    }
201
202    fn __expand_is_in_bounds_method(
203        &self,
204        scope: &mut Scope,
205        pos: C::ExpandType,
206    ) -> ExpandElementTyped<bool> {
207        self.values.clone().__expand_is_in_bounds_method(scope, pos)
208    }
209
210    fn __expand_tensor_map_load_method(
211        &self,
212        _scope: &mut Scope,
213        _barrier: BarrierExpand,
214        _shared_memory: SliceExpand<Line<F>, ReadWrite>,
215        _pos: C::ExpandType,
216    ) {
217        panic!("Can't use tensor map functions on quantized view");
218    }
219}
220
221struct ExpandDynamic<'a, E: Numeric, C: Coordinates + 'static> {
222    values: &'a ViewCompilationArg<C>,
223    scales: &'a ViewCompilationArg<C>,
224    scheme: QuantScheme,
225    builder: &'a mut KernelBuilder,
226    _ty: PhantomData<E>,
227}
228
229impl<'a, E: Numeric, C: Coordinates + 'static> RunWithQuantType for ExpandDynamic<'a, E, C> {
230    type Output = ViewExpand<Line<E>, C>;
231
232    fn execute<Q: CubePrimitive, S: CubePrimitive>(self) -> Self::Output {
233        let values = View::<Line<Q>, C>::expand(self.values, self.builder);
234        let scales = View::<S, C>::expand(self.scales, self.builder);
235        let view = QuantizedViewExpand::new(values, scales, self.scheme);
236        ViewExpand::new(view)
237    }
238}
239
240/// Run a function with the quantization storage type and scale. Useful when concrete types are
241/// required but aren't available, and only the dynamic schema is known.
242pub fn run_with_quant_type<F: RunWithQuantType>(func: F, scheme: QuantScheme) -> F::Output {
243    fn run_with_q<F: RunWithQuantType, Q: CubePrimitive>(
244        func: F,
245        scheme: QuantScheme,
246    ) -> F::Output {
247        match scheme.param {
248            QuantParam::F32 => func.execute::<Q, f32>(),
249            QuantParam::F16 => func.execute::<Q, f16>(),
250            QuantParam::BF16 => func.execute::<Q, bf16>(),
251            QuantParam::UE8M0 => func.execute::<Q, ue8m0>(),
252            QuantParam::UE4M3 => func.execute::<Q, e4m3>(),
253        }
254    }
255
256    let run_q = match scheme.store {
257        QuantStore::Native => match scheme.value {
258            QuantValue::Q8F => run_with_q::<F, i8>,
259            QuantValue::Q8S => run_with_q::<F, i8>,
260            QuantValue::E5M2 => run_with_q::<F, e5m2>,
261            QuantValue::E4M3 => run_with_q::<F, e4m3>,
262            QuantValue::E2M1 => run_with_q::<F, e2m1x2>,
263            QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
264                panic!("Sub-byte quantization can't be native")
265            }
266        },
267        QuantStore::U32 => run_with_q::<F, u32>,
268    };
269    run_q(func, scheme)
270}
271
272/// Dynamically expand based on the quantization scheme. Ugly, but the only way to fully hide the
273/// quantization from the kernel using the view.
274pub(crate) fn expand_dynamic<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility>(
275    values: &ViewCompilationArg<C>,
276    scales: &ViewCompilationArg<C>,
277    scheme: QuantScheme,
278    builder: &mut KernelBuilder,
279) -> ViewExpand<E, C, IO> {
280    use core::mem::transmute as t;
281
282    // To specify tighter trait bounds
283    fn expand_dynamic_f<F: Numeric, C: Coordinates + 'static>(
284        values: &ViewCompilationArg<C>,
285        scales: &ViewCompilationArg<C>,
286        scheme: QuantScheme,
287        builder: &mut KernelBuilder,
288    ) -> ViewExpand<Line<F>, C> {
289        let func = ExpandDynamic {
290            values,
291            scales,
292            scheme,
293            builder,
294            _ty: PhantomData::<F>,
295        };
296        run_with_quant_type(func, scheme)
297    }
298
299    #[allow(clippy::missing_transmute_annotations)]
300    unsafe {
301        match E::as_type(&builder.scope) {
302            StorageType::Scalar(ElemType::Float(ty)) => match ty {
303                FloatKind::F16 => t(expand_dynamic_f::<f16, C>(values, scales, scheme, builder)),
304                FloatKind::BF16 => t(expand_dynamic_f::<bf16, C>(values, scales, scheme, builder)),
305                FloatKind::Flex32 => t(expand_dynamic_f::<flex32, C>(
306                    values, scales, scheme, builder,
307                )),
308                FloatKind::F32 => t(expand_dynamic_f::<f32, C>(values, scales, scheme, builder)),
309                FloatKind::TF32 => t(expand_dynamic_f::<tf32, C>(values, scales, scheme, builder)),
310                FloatKind::F64 => t(expand_dynamic_f::<f64, C>(values, scales, scheme, builder)),
311                FloatKind::E2M1
312                | FloatKind::E2M3
313                | FloatKind::E3M2
314                | FloatKind::E4M3
315                | FloatKind::E5M2
316                | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"),
317            },
318            _ => unreachable!("Quantized view should only be used with floats"),
319        }
320    }
321}