cubecl_quant/
quantize.rs

1use cubecl::calculate_cube_count_elemwise;
2use cubecl::prelude::*;
3use cubecl_common::{e2m1x2, e4m3, e5m2, ue8m0};
4use cubecl_core::tensor_line_size_parallel;
5use cubecl_core::{self as cubecl};
6use cubecl_runtime::TypeUsage;
7use cubecl_std::tensor::layout::linear::LinearView;
8use cubecl_std::tensor::{View, layout::linear::linear_view};
9
10use crate::{
11    layout::{ScalesLayout, scales_view},
12    utils::check_block_size_compat,
13};
14use crate::{
15    layout::{ScalesView, scales_layout},
16    scheme::{QuantLevel, QuantMode, QuantParam, QuantScheme, QuantStore, QuantValue},
17};
18use half::{bf16, f16};
19
20#[cube]
21fn quantize_symmetric<F: Float, FS: CubePrimitive>(
22    value: Line<F>,
23    scale: FS,
24    range_min: F,
25    range_max: F,
26) -> Line<F> {
27    Line::clamp(
28        Line::round(value / Line::cast_from(scale)),
29        Line::new(range_min),
30        Line::new(range_max),
31    )
32}
33
34#[cube]
35fn quantize_symmetric_q<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
36    value: Line<F>,
37    scale: FS,
38    range_min: F,
39    range_max: F,
40) -> Line<Q> {
41    Line::cast_from(quantize_symmetric::<F, FS>(
42        value, scale, range_min, range_max,
43    ))
44}
45
46#[cube]
47fn quantize_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
48    value: Line<F>,
49    scale: FS,
50    range_min: F,
51    range_max: F,
52    #[comptime] scheme: QuantScheme,
53) -> QS {
54    let value = quantize_symmetric::<F, FS>(value, scale, range_min, range_max);
55    pack_q::<F, QS>(value, scheme.value)
56}
57
58/// Pack a line of quantized floating-point values into a single integer (the stored quantization type),
59/// according to the specified quantization input type.
60#[allow(clippy::explicit_counter_loop)]
61#[cube]
62fn pack_q<F: Float, QS: Int>(value: Line<F>, #[comptime] quant: QuantValue) -> QS {
63    let size_quant = comptime!(quant.size_bits() as u32);
64
65    let size_store = comptime!(QS::size_bits().unwrap() as u32);
66    let num_quants = comptime!(size_store / size_quant);
67
68    let mask = i32::cast_from(comptime!((1 << size_quant) - 1));
69    let mut position = comptime!(0);
70    let mut packed = QS::cast_from(0);
71
72    // Shift and combine into QS (using i32 for sign extension)
73    #[unroll]
74    for _ in 0..num_quants {
75        let offset = QS::cast_from(comptime!(position * size_quant));
76        let shifted = QS::cast_from(i32::cast_from(value[position]) & mask) << offset;
77        packed |= shifted;
78        comptime!(position += 1);
79    }
80
81    packed
82}
83
84#[cube]
85fn write_scale<F: Float, FS: CubePrimitive>(
86    in_pos: u32,
87    scale: &View<F, u32>,
88    out_scale: &mut View<FS, u32, ReadWrite>,
89    scales_layout: ScalesLayout,
90) -> FS {
91    let scale = FS::cast_from(scale[in_pos]);
92
93    // Write the scale into the output buffer
94    if scales_layout.is_block_start(in_pos) {
95        out_scale[in_pos] = scale;
96    }
97
98    scale
99}
100
101#[cube(launch_unchecked)]
102fn quantize_symmetric_native_kernel<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
103    input: &LinearView<Line<F>>,
104    scale: &ScalesView<F>,
105    range_min: F,
106    range_max: F,
107    output: &mut LinearView<Line<Q>, ReadWrite>,
108    out_scale: &mut ScalesView<FS, ReadWrite>,
109    scales_layout: ScalesLayout,
110) {
111    if !output.is_in_bounds(ABSOLUTE_POS) {
112        terminate!();
113    }
114
115    let native_packing = Q::packing_factor();
116    let in_pos = ABSOLUTE_POS * input.line_size() * native_packing;
117    let scale = write_scale(in_pos, scale, out_scale, scales_layout);
118
119    output[ABSOLUTE_POS] =
120        quantize_symmetric_q::<F, FS, Q>(input[ABSOLUTE_POS], scale, range_min, range_max);
121    sync_cube();
122}
123
124#[cube(launch_unchecked)]
125fn quantize_symmetric_packed_kernel<F: Float, FS: CubePrimitive>(
126    input: &LinearView<Line<F>>,
127    scale: &ScalesView<F>,
128    range_min: F,
129    range_max: F,
130    output: &mut LinearView<Line<u32>, ReadWrite>,
131    out_scale: &mut ScalesView<FS, ReadWrite>,
132    scales_layout: ScalesLayout,
133    #[comptime] scheme: QuantScheme,
134) {
135    if !output.is_in_bounds(ABSOLUTE_POS) {
136        terminate!();
137    }
138
139    let num_quants = comptime!(scheme.num_quants() as u32);
140    let packed_pos = ABSOLUTE_POS * num_quants;
141    let scale = write_scale(packed_pos, scale, out_scale, scales_layout);
142
143    if comptime!(input.line_size() == num_quants) {
144        output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
145            input[ABSOLUTE_POS],
146            scale,
147            range_min,
148            range_max,
149            scheme,
150        ));
151    } else {
152        // Input line size = 1
153        let mut values = Line::<F>::empty(num_quants);
154        #[unroll]
155        for i in 0..num_quants {
156            values[i] = input[packed_pos + i][0];
157        }
158        output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
159            values, scale, range_min, range_max, scheme,
160        ));
161    }
162}
163
164#[allow(clippy::result_large_err)]
165pub fn launch_ref<R: Runtime, F: Float + CubeElement>(
166    client: &ComputeClient<R::Server>,
167    input: &TensorHandleRef<R>,
168    output: &TensorHandleRef<R>,
169    scale: &TensorHandleRef<'_, R>,
170    out_scale: &TensorHandleRef<'_, R>,
171    scheme: &QuantScheme,
172) {
173    match scheme {
174        QuantScheme {
175            store: QuantStore::U32,
176            ..
177        } => match scheme.param {
178            QuantParam::F32 => {
179                quantize_packed::<R, F, f32>(client, input, scheme, scale, out_scale, output)
180            }
181            QuantParam::F16 => {
182                quantize_packed::<R, F, f16>(client, input, scheme, scale, out_scale, output)
183            }
184            QuantParam::BF16 => {
185                quantize_packed::<R, F, bf16>(client, input, scheme, scale, out_scale, output)
186            }
187            QuantParam::UE8M0 => {
188                quantize_packed::<R, F, ue8m0>(client, input, scheme, scale, out_scale, output)
189            }
190            QuantParam::UE4M3 => {
191                quantize_packed::<R, F, e4m3>(client, input, scheme, scale, out_scale, output)
192            }
193        },
194        QuantScheme {
195            value:
196                QuantValue::Q8F
197                | QuantValue::Q8S
198                | QuantValue::E4M3
199                | QuantValue::E5M2
200                | QuantValue::E2M1,
201            store: QuantStore::Native,
202            ..
203        } => {
204            if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
205                panic!(
206                    "{:?} is not supported for native quantization",
207                    scheme.value
208                );
209            }
210
211            match scheme.param {
212                QuantParam::F32 => {
213                    quantize_native::<R, F, f32>(client, input, scheme, scale, out_scale, output)
214                }
215                QuantParam::F16 => {
216                    quantize_native::<R, F, f16>(client, input, scheme, scale, out_scale, output)
217                }
218                QuantParam::BF16 => {
219                    quantize_native::<R, F, bf16>(client, input, scheme, scale, out_scale, output)
220                }
221                QuantParam::UE8M0 => {
222                    quantize_native::<R, F, ue8m0>(client, input, scheme, scale, out_scale, output)
223                }
224                QuantParam::UE4M3 => {
225                    quantize_native::<R, F, e4m3>(client, input, scheme, scale, out_scale, output)
226                }
227            }
228        }
229        QuantScheme {
230            store: QuantStore::Native,
231            value,
232            ..
233        } => {
234            panic!("{value:?} is not supported for native quantization");
235        }
236    }
237}
238
239fn quantize_native<R: Runtime, F: Float, FS: CubePrimitive>(
240    client: &ComputeClient<R::Server>,
241    input: &TensorHandleRef<R>,
242    scheme: &QuantScheme,
243    scale: &TensorHandleRef<'_, R>,
244    out_scale: &TensorHandleRef<'_, R>,
245    output: &TensorHandleRef<R>,
246) {
247    let num_elems: usize = input.shape.iter().product();
248    let line_size = tensor_line_size_parallel(
249        R::io_optimized_line_sizes_unchecked(size_of::<F>()),
250        input.shape,
251        input.strides,
252        input.shape.len() - 1,
253    );
254    let cube_dim = CubeDim::default();
255    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
256    let (range_min, range_max) = scheme.value.range();
257
258    match scheme {
259        QuantScheme {
260            level: QuantLevel::Tensor | QuantLevel::Block(_),
261            mode: QuantMode::Symmetric,
262            value,
263            store: QuantStore::Native,
264            ..
265        } => {
266            // We could use line_size = block_size if it's in the supported line sizes.. but let's keep it simple
267            check_block_size_compat(scheme, line_size as usize);
268
269            let launch = match value {
270                QuantValue::Q8F | QuantValue::Q8S => {
271                    quantize_symmetric_native_kernel::launch_unchecked::<F, FS, i8, R>
272                }
273                QuantValue::E4M3 => {
274                    quantize_symmetric_native_kernel::launch_unchecked::<F, FS, e4m3, R>
275                }
276                QuantValue::E5M2 => {
277                    quantize_symmetric_native_kernel::launch_unchecked::<F, FS, e5m2, R>
278                }
279                QuantValue::E2M1 => {
280                    quantize_symmetric_native_kernel::launch_unchecked::<F, FS, e2m1x2, R>
281                }
282                other => panic!("Unsupported quant value {other:?}"),
283            };
284
285            unsafe {
286                launch(
287                    client,
288                    cube_count,
289                    cube_dim,
290                    linear_view(client, input, line_size),
291                    // scale is computed based on input float dtype, but stored based on qparams precision
292                    scales_view(client, output, scale, 1, scheme),
293                    ScalarArg::new(F::from_int(range_min as i64)),
294                    ScalarArg::new(F::from_int(range_max as i64)),
295                    linear_view(client, output, line_size),
296                    scales_view(client, output, out_scale, 1, scheme),
297                    scales_layout(client, output, scale, 1, scheme),
298                )
299            };
300        }
301        _ => panic!("Unsupported quantization scheme {scheme:?}"),
302    }
303}
304
305fn quantize_packed<R: Runtime, F: Float, FS: CubePrimitive>(
306    client: &ComputeClient<R::Server>,
307    input: &TensorHandleRef<R>,
308    scheme: &QuantScheme,
309    scale: &TensorHandleRef<'_, R>,
310    out_scale: &TensorHandleRef<'_, R>,
311    output: &TensorHandleRef<R>,
312) {
313    let num_elems: usize = input.shape.iter().product();
314
315    let num_quants = scheme.num_quants() as u8;
316    let line_size = num_quants;
317
318    let cube_dim = CubeDim::default();
319    let cube_count =
320        calculate_cube_count_elemwise(num_elems.div_ceil(line_size as usize), cube_dim);
321    let (range_min, range_max) = scheme.value.range();
322
323    match scheme {
324        QuantScheme {
325            level: QuantLevel::Tensor | QuantLevel::Block(_),
326            mode: QuantMode::Symmetric,
327            store: QuantStore::U32,
328            ..
329        } => {
330            check_block_size_compat(scheme, num_quants as usize); // 32 / 8 = 4
331            unsafe {
332                quantize_symmetric_packed_kernel::launch_unchecked::<F, FS, R>(
333                    client,
334                    cube_count,
335                    cube_dim,
336                    linear_view(client, input, line_size),
337                    // scale is computed based on input float dtype, but stored based on qparams precision
338                    scales_view(client, output, scale, 1, scheme),
339                    ScalarArg::new(F::from_int(range_min as i64)),
340                    ScalarArg::new(F::from_int(range_max as i64)),
341                    linear_view(client, output, 1),
342                    scales_view(client, output, out_scale, 1, scheme),
343                    scales_layout(client, output, scale, 1, scheme),
344                    *scheme,
345                )
346            };
347        }
348        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
349    }
350}