cubek_quant/
quantize.rs

1use cubecl::calculate_cube_count_elemwise;
2use cubecl::features::TypeUsage;
3use cubecl::ir::ElemType;
4use cubecl::prelude::*;
5use cubecl::std::tensor::layout::linear::LinearView;
6use cubecl::std::tensor::{View, layout::linear::linear_view};
7use cubecl::tensor_line_size_parallel;
8
9use crate::{
10    layout::{ScalesLayout, scales_view},
11    utils::check_block_size_compat,
12};
13use crate::{
14    layout::{ScalesView, scales_layout},
15    scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
16};
17
18#[cube]
19fn quantize_symmetric<F: Float, FS: CubePrimitive>(
20    value: Line<F>,
21    scale: FS,
22    range_min: F,
23    range_max: F,
24) -> Line<F> {
25    clamp(
26        Line::round(value / Line::cast_from(scale)),
27        Line::new(range_min),
28        Line::new(range_max),
29    )
30}
31
32#[cube]
33fn quantize_symmetric_q<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
34    value: Line<F>,
35    scale: FS,
36    range_min: F,
37    range_max: F,
38) -> Line<Q> {
39    Line::cast_from(quantize_symmetric::<F, FS>(
40        value, scale, range_min, range_max,
41    ))
42}
43
44#[cube]
45fn quantize_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
46    value: Line<F>,
47    scale: FS,
48    range_min: F,
49    range_max: F,
50    #[comptime] scheme: QuantScheme,
51) -> QS {
52    let value = quantize_symmetric::<F, FS>(value, scale, range_min, range_max);
53    pack_q::<F, QS>(value, scheme.value)
54}
55
56/// Pack a line of quantized floating-point values into a single integer (the stored quantization type),
57/// according to the specified quantization input type.
58#[allow(clippy::explicit_counter_loop)]
59#[cube]
60fn pack_q<F: Float, QS: Int>(value: Line<F>, #[comptime] quant: QuantValue) -> QS {
61    let size_quant = quant.size_bits();
62
63    let size_store = QS::type_size_bits().comptime();
64    let num_quants = size_store / size_quant;
65
66    let mask = (1 << size_quant) - 1;
67    let mut packed = QS::from_int(0);
68
69    // Shift and combine into QS (using i32 for sign extension)
70    #[unroll]
71    for position in 0..num_quants {
72        let offset = QS::cast_from(position * size_quant);
73        let shifted = QS::cast_from(i32::cast_from(value[position]) & mask) << offset;
74        packed |= shifted;
75    }
76
77    packed
78}
79
80#[cube]
81fn write_scale<F: Float, FS: CubePrimitive>(
82    in_pos: usize,
83    scale: &View<F, usize>,
84    out_scale: &mut View<FS, usize, ReadWrite>,
85    scales_layout: ScalesLayout,
86) -> FS {
87    let scale = FS::cast_from(scale[in_pos]);
88
89    // Write the scale into the output buffer
90    if scales_layout.is_block_start(in_pos) {
91        out_scale[in_pos] = scale;
92    }
93
94    scale
95}
96
97#[cube(launch_unchecked)]
98fn quantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
99    input: &LinearView<Line<F>>,
100    scale: &ScalesView<F>,
101    range_min: InputScalar,
102    range_max: InputScalar,
103    output: &mut LinearView<Line<Q>, ReadWrite>,
104    out_scale: &mut ScalesView<FS, ReadWrite>,
105    scales_layout: ScalesLayout,
106    #[define(F, FS, Q)] _dtypes: [StorageType; 3],
107) {
108    if !output.is_in_bounds(ABSOLUTE_POS) {
109        terminate!();
110    }
111
112    let native_packing = Q::packing_factor();
113    let in_pos = ABSOLUTE_POS * input.line_size() * native_packing;
114    let scale = write_scale(in_pos, scale, out_scale, scales_layout);
115
116    output[ABSOLUTE_POS] = quantize_symmetric_q::<F, FS, Q>(
117        input[ABSOLUTE_POS],
118        scale,
119        range_min.get::<F>(),
120        range_max.get::<F>(),
121    );
122    sync_cube();
123}
124
125#[cube(launch_unchecked)]
126fn quantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
127    input: &LinearView<Line<F>>,
128    scale: &ScalesView<F>,
129    range_min: InputScalar,
130    range_max: InputScalar,
131    output: &mut LinearView<Line<u32>, ReadWrite>,
132    out_scale: &mut ScalesView<FS, ReadWrite>,
133    scales_layout: ScalesLayout,
134    #[comptime] scheme: QuantScheme,
135    #[define(F, FS)] _dtypes: [StorageType; 2],
136) {
137    if !output.is_in_bounds(ABSOLUTE_POS) {
138        terminate!();
139    }
140
141    let num_quants = scheme.num_quants();
142    let packed_pos = ABSOLUTE_POS * num_quants;
143    let scale = write_scale(packed_pos, scale, out_scale, scales_layout);
144
145    if input.line_size().comptime() == num_quants {
146        output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
147            input[ABSOLUTE_POS],
148            scale,
149            range_min.get::<F>(),
150            range_max.get::<F>(),
151            scheme,
152        ));
153    } else {
154        // Input line size = 1
155        let mut values = Line::<F>::empty(num_quants);
156        #[unroll]
157        for i in 0..num_quants {
158            values[i] = input[packed_pos + i][0];
159        }
160        output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
161            values,
162            scale,
163            range_min.get::<F>(),
164            range_max.get::<F>(),
165            scheme,
166        ));
167    }
168}
169
170#[allow(clippy::result_large_err)]
171pub fn launch_ref<R: Runtime>(
172    client: &ComputeClient<R>,
173    input: &TensorHandleRef<R>,
174    output: &TensorHandleRef<R>,
175    scale: &TensorHandleRef<'_, R>,
176    out_scale: &TensorHandleRef<'_, R>,
177    scheme: &QuantScheme,
178    input_elem: ElemType,
179) -> Result<(), LaunchError> {
180    let param_elem = ElemType::from_quant_param(scheme.param);
181
182    match scheme {
183        QuantScheme {
184            store: QuantStore::PackedU32(_),
185            ..
186        } => quantize_packed(
187            client, input, scheme, scale, out_scale, output, input_elem, param_elem,
188        ),
189        QuantScheme {
190            value: QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2,
191            store: QuantStore::Native,
192            ..
193        }
194        | QuantScheme {
195            value: QuantValue::E2M1,
196            store: QuantStore::PackedNative(_),
197            ..
198        } => {
199            if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
200                panic!(
201                    "{:?} is not supported for native quantization",
202                    scheme.value
203                );
204            }
205
206            quantize_native(
207                client, input, scheme, scale, out_scale, output, input_elem, param_elem,
208            )
209        }
210        QuantScheme {
211            store: QuantStore::Native | QuantStore::PackedNative(_),
212            value,
213            ..
214        } => {
215            panic!("{value:?} is not supported for native quantization");
216        }
217    }
218}
219
220#[allow(clippy::too_many_arguments)]
221fn quantize_native<R: Runtime>(
222    client: &ComputeClient<R>,
223    input: &TensorHandleRef<R>,
224    scheme: &QuantScheme,
225    scale: &TensorHandleRef<'_, R>,
226    out_scale: &TensorHandleRef<'_, R>,
227    output: &TensorHandleRef<R>,
228    input_dtype: ElemType,
229    scale_dtype: ElemType,
230) -> Result<(), LaunchError> {
231    let num_elems: usize = input.shape.iter().product();
232    let line_size = tensor_line_size_parallel(
233        client.io_optimized_line_sizes_unchecked(input.elem_size),
234        input.shape,
235        input.strides,
236        input.shape.len() - 1,
237    );
238    let working_units = num_elems / line_size as usize;
239    let cube_dim = CubeDim::new(client, working_units);
240    let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
241    let (range_min, range_max) = scheme.value.range();
242
243    match scheme {
244        QuantScheme {
245            level: QuantLevel::Tensor | QuantLevel::Block(_),
246            mode: QuantMode::Symmetric,
247            store: QuantStore::Native,
248            ..
249        } => {
250            // We could use line_size = block_size if it's in the supported line sizes.. but let's keep it simple
251            check_block_size_compat(scheme, line_size as usize);
252            let quant_type = ElemType::from_quant_value(scheme.value);
253
254            unsafe {
255                quantize_symmetric_native_kernel::launch_unchecked(
256                    client,
257                    cube_count,
258                    cube_dim,
259                    linear_view(client, input, line_size),
260                    // scale is computed based on input float dtype, but stored based on qparams precision
261                    scales_view(client, output, scale, 1, scheme),
262                    InputScalar::new(range_min, input_dtype),
263                    InputScalar::new(range_max, input_dtype),
264                    linear_view(client, output, line_size),
265                    scales_view(client, output, out_scale, 1, scheme),
266                    scales_layout(client, output, scale, 1, scheme),
267                    [input_dtype.into(), scale_dtype.into(), quant_type.into()],
268                )
269            }
270        }
271        _ => panic!("Unsupported quantization scheme {scheme:?}"),
272    }
273}
274
275#[allow(clippy::too_many_arguments)]
276fn quantize_packed<R: Runtime>(
277    client: &ComputeClient<R>,
278    input: &TensorHandleRef<R>,
279    scheme: &QuantScheme,
280    scale: &TensorHandleRef<'_, R>,
281    out_scale: &TensorHandleRef<'_, R>,
282    output: &TensorHandleRef<R>,
283    dtype_input: ElemType,
284    dtype_param: ElemType,
285) -> Result<(), LaunchError> {
286    let num_elems: usize = input.shape.iter().product();
287
288    let num_quants = scheme.num_quants();
289    let line_size = num_quants;
290
291    let working_units = num_elems.div_ceil(line_size);
292    let cube_dim = CubeDim::new(client, working_units);
293    let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
294    let (range_min, range_max) = scheme.value.range();
295
296    match scheme {
297        QuantScheme {
298            level: QuantLevel::Tensor | QuantLevel::Block(_),
299            mode: QuantMode::Symmetric,
300            store: QuantStore::PackedU32(_),
301            ..
302        } => {
303            check_block_size_compat(scheme, num_quants); // 32 / 8 = 4
304            unsafe {
305                quantize_symmetric_packed_kernel::launch_unchecked(
306                    client,
307                    cube_count,
308                    cube_dim,
309                    linear_view(client, input, line_size),
310                    // scale is computed based on input float dtype, but stored based on qparams precision
311                    scales_view(client, output, scale, 1, scheme),
312                    InputScalar::new(range_min, dtype_input),
313                    InputScalar::new(range_max, dtype_input),
314                    linear_view(client, output, 1),
315                    scales_view(client, output, out_scale, 1, scheme),
316                    scales_layout(client, output, scale, 1, scheme),
317                    *scheme,
318                    [dtype_input.into(), dtype_param.into()],
319                )
320            }
321        }
322        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
323    }
324}