cubecl_std/quant/
view.rs

1use std::marker::PhantomData;
2
3use super::*;
4use crate::tensor::{
5    View, ViewExpand, ViewOperations, ViewOperationsExpand, launch::ViewCompilationArg,
6    layout::Coordinates,
7};
8use cubecl::prelude::*;
9use cubecl_common::{
10    e2m1x2, e4m3, e5m2,
11    quant::scheme::{QuantParam, QuantScheme, QuantStore, QuantValue},
12    ue8m0,
13};
14use cubecl_core::{
15    self as cubecl,
16    ir::{ElemType, FloatKind, StorageType},
17    prelude::barrier::BarrierExpand,
18    unexpanded,
19};
20use half::{bf16, f16};
21
22/// View that dequantizes after loads. Scales layout should take values coordinates and map them
23/// to the corresponding scale.
24///
25/// # Warning
26/// Assumes only one scale maps to a single load. Adjust line size of values or block size to ensure
27/// this.
28/// Must ensure `block_size.is_multiple_of(line_size * scheme.num_quants())`.
29#[expect(dead_code, reason = "only used in expand")]
30#[derive(CubeType, CubeLaunch, Clone, Copy)]
31pub struct QuantizedView<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> {
32    values: View<Line<Q>, C>,
33    scales: View<S, C>,
34    #[cube(comptime)]
35    scheme: QuantScheme,
36    #[cube(comptime)]
37    _ty: PhantomData<F>,
38}
39
40#[cube]
41impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
42    QuantizedView<Q, S, F, C>
43{
44    pub fn new(
45        values: View<Line<Q>, C>,
46        scales: View<S, C>,
47        #[comptime] scheme: QuantScheme,
48    ) -> Self {
49        QuantizedView::<Q, S, F, C> {
50            values,
51            scales,
52            scheme,
53            _ty: PhantomData,
54        }
55    }
56}
57
58impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
59    QuantizedView<Q, S, F, C>
60{
61    pub fn view(self) -> View<Line<F>, C> {
62        unexpanded!()
63    }
64
65    pub fn __expand_view(
66        scope: &mut Scope,
67        this: QuantizedViewExpand<Q, S, F, C>,
68    ) -> ViewExpand<Line<F>, C, ReadOnly> {
69        this.__expand_view_method(scope)
70    }
71}
72
73impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
74    QuantizedViewExpand<Q, S, F, C>
75{
76    pub fn new(
77        values: ViewExpand<Line<Q>, C>,
78        scales: ViewExpand<S, C>,
79        scheme: QuantScheme,
80    ) -> Self {
81        QuantizedViewExpand::<Q, S, F, C> {
82            values,
83            scales,
84            scheme,
85            _ty: PhantomData,
86        }
87    }
88
89    pub fn __expand_view_method(self, _scope: &mut Scope) -> ViewExpand<Line<F>, C, ReadOnly> {
90        ViewExpand::new(self)
91    }
92}
93
94impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> Lined
95    for QuantizedView<Q, S, F, C>
96{
97}
98impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static> LinedExpand
99    for QuantizedViewExpand<Q, S, F, C>
100{
101    fn line_size(&self) -> u32 {
102        self.values.line_size() * self.scheme.num_quants() as u32
103    }
104}
105
106impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
107    ViewOperations<Line<F>, C> for QuantizedView<Q, S, F, C>
108{
109}
110
111impl<Q: CubePrimitive, S: CubePrimitive, F: Numeric, C: Coordinates + 'static>
112    ViewOperationsExpand<Line<F>, C> for QuantizedViewExpand<Q, S, F, C>
113{
114    fn __expand_read_method(
115        &self,
116        scope: &mut Scope,
117        pos: <C>::ExpandType,
118    ) -> ExpandElementTyped<Line<F>> {
119        let value = self.values.clone().__expand_read_method(scope, pos.clone());
120        let scale = self.scales.clone().__expand_read_method(scope, pos);
121
122        dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
123    }
124
125    fn __expand_read_checked_method(
126        &self,
127        scope: &mut Scope,
128        pos: <C>::ExpandType,
129    ) -> ExpandElementTyped<Line<F>> {
130        let value = self
131            .values
132            .clone()
133            .__expand_read_checked_method(scope, pos.clone());
134        let scale = self
135            .scales
136            .clone()
137            .__expand_read_checked_method(scope, pos.clone());
138
139        dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
140    }
141
142    fn __expand_read_masked_method(
143        &self,
144        scope: &mut Scope,
145        pos: <C>::ExpandType,
146        mask_value: ExpandElementTyped<Line<F>>,
147    ) -> ExpandElementTyped<Line<F>> {
148        let value = self
149            .values
150            .clone()
151            .__expand_read_checked_method(scope, pos.clone());
152        let scale = self
153            .scales
154            .clone()
155            .__expand_read_checked_method(scope, pos.clone());
156        let in_bounds = self.__expand_is_in_bounds_method(scope, pos);
157
158        let value = dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme);
159        select::expand::<Line<F>>(scope, in_bounds, value, mask_value)
160    }
161
162    fn __expand_read_unchecked_method(
163        &self,
164        scope: &mut Scope,
165        pos: <C>::ExpandType,
166    ) -> ExpandElementTyped<Line<F>> {
167        let value = self
168            .values
169            .clone()
170            .__expand_read_unchecked_method(scope, pos.clone());
171        let scale = self
172            .scales
173            .clone()
174            .__expand_read_unchecked_method(scope, pos);
175
176        dequantize_aligned::expand::<Q, S, F>(scope, value, scale, self.scheme)
177    }
178
179    fn __expand_to_linear_slice_method(
180        &self,
181        _scope: &mut Scope,
182        _pos: <C>::ExpandType,
183        _end: <C>::ExpandType,
184    ) -> SliceExpand<Line<F>, ReadOnly> {
185        panic!("Can't create raw slice for quantized view")
186    }
187
188    fn __expand_shape_method(&self, scope: &mut Scope) -> <C>::ExpandType {
189        self.values.clone().__expand_shape_method(scope)
190    }
191
192    fn __expand_is_in_bounds_method(
193        &self,
194        scope: &mut Scope,
195        pos: C::ExpandType,
196    ) -> ExpandElementTyped<bool> {
197        self.values.clone().__expand_is_in_bounds_method(scope, pos)
198    }
199
200    fn __expand_tensor_map_load_method(
201        &self,
202        _scope: &mut Scope,
203        _barrier: BarrierExpand,
204        _shared_memory: SliceExpand<Line<F>, ReadWrite>,
205        _pos: C::ExpandType,
206    ) {
207        panic!("Can't use tensor map functions on quantized view");
208    }
209}
210
211struct ExpandDynamic<'a, E: Numeric, C: Coordinates + 'static> {
212    values: &'a ViewCompilationArg<C>,
213    scales: &'a ViewCompilationArg<C>,
214    scheme: QuantScheme,
215    builder: &'a mut KernelBuilder,
216    _ty: PhantomData<E>,
217}
218
219impl<'a, E: Numeric, C: Coordinates + 'static> RunWithQuantType for ExpandDynamic<'a, E, C> {
220    type Output = ViewExpand<Line<E>, C>;
221
222    fn execute<Q: CubePrimitive, S: CubePrimitive>(self) -> Self::Output {
223        let values = View::<Line<Q>, C>::expand(self.values, self.builder);
224        let scales = View::<S, C>::expand(self.scales, self.builder);
225        let view = QuantizedViewExpand::new(values, scales, self.scheme);
226        ViewExpand::new(view)
227    }
228}
229
230/// Run a function with the quantization storage type and scale. Useful when concrete types are
231/// required but aren't available, and only the dynamic schema is known.
232pub fn run_with_quant_type<F: RunWithQuantType>(func: F, scheme: QuantScheme) -> F::Output {
233    fn run_with_q<F: RunWithQuantType, Q: CubePrimitive>(
234        func: F,
235        scheme: QuantScheme,
236    ) -> F::Output {
237        match scheme.param {
238            QuantParam::F32 => func.execute::<Q, f32>(),
239            QuantParam::F16 => func.execute::<Q, f16>(),
240            QuantParam::BF16 => func.execute::<Q, bf16>(),
241            QuantParam::UE8M0 => func.execute::<Q, ue8m0>(),
242            QuantParam::UE4M3 => func.execute::<Q, e4m3>(),
243        }
244    }
245
246    let run_q = match scheme.store {
247        QuantStore::Native => match scheme.value {
248            QuantValue::Q8F => run_with_q::<F, i8>,
249            QuantValue::Q8S => run_with_q::<F, i8>,
250            QuantValue::E5M2 => run_with_q::<F, e5m2>,
251            QuantValue::E4M3 => run_with_q::<F, e4m3>,
252            QuantValue::E2M1 => run_with_q::<F, e2m1x2>,
253            QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
254                panic!("Sub-byte quantization can't be native")
255            }
256        },
257        QuantStore::U32 => run_with_q::<F, u32>,
258    };
259    run_q(func, scheme)
260}
261
262/// Dynamically expand based on the quantization scheme. Ugly, but the only way to fully hide the
263/// quantization from the kernel using the view.
264pub(crate) fn expand_dynamic<E: CubePrimitive, C: Coordinates + 'static, IO: SliceVisibility>(
265    values: &ViewCompilationArg<C>,
266    scales: &ViewCompilationArg<C>,
267    scheme: QuantScheme,
268    builder: &mut KernelBuilder,
269) -> ViewExpand<E, C, IO> {
270    use core::mem::transmute as t;
271
272    // To specify tighter trait bounds
273    fn expand_dynamic_f<F: Numeric, C: Coordinates + 'static>(
274        values: &ViewCompilationArg<C>,
275        scales: &ViewCompilationArg<C>,
276        scheme: QuantScheme,
277        builder: &mut KernelBuilder,
278    ) -> ViewExpand<Line<F>, C> {
279        let func = ExpandDynamic {
280            values,
281            scales,
282            scheme,
283            builder,
284            _ty: PhantomData::<F>,
285        };
286        run_with_quant_type(func, scheme)
287    }
288
289    #[allow(clippy::missing_transmute_annotations)]
290    unsafe {
291        match E::as_type(&builder.scope) {
292            StorageType::Scalar(ElemType::Float(ty)) => match ty {
293                FloatKind::F16 => t(expand_dynamic_f::<f16, C>(values, scales, scheme, builder)),
294                FloatKind::BF16 => t(expand_dynamic_f::<bf16, C>(values, scales, scheme, builder)),
295                FloatKind::Flex32 => t(expand_dynamic_f::<flex32, C>(
296                    values, scales, scheme, builder,
297                )),
298                FloatKind::F32 => t(expand_dynamic_f::<f32, C>(values, scales, scheme, builder)),
299                FloatKind::TF32 => t(expand_dynamic_f::<tf32, C>(values, scales, scheme, builder)),
300                FloatKind::F64 => t(expand_dynamic_f::<f64, C>(values, scales, scheme, builder)),
301                FloatKind::E2M1
302                | FloatKind::E2M3
303                | FloatKind::E3M2
304                | FloatKind::E4M3
305                | FloatKind::E5M2
306                | FloatKind::UE8M0 => unreachable!("Minifloats don't implement `Float` ops"),
307            },
308            _ => unreachable!("Quantized view should only be used with floats"),
309        }
310    }
311}