cubecl_quant/
quantize.rs

1use cubecl::calculate_cube_count_elemwise;
2use cubecl::prelude::*;
3use cubecl_core::ir::ElemType;
4use cubecl_core::tensor_line_size_parallel;
5use cubecl_core::{self as cubecl};
6use cubecl_runtime::TypeUsage;
7use cubecl_std::scalar::InputScalar;
8use cubecl_std::tensor::layout::linear::LinearView;
9use cubecl_std::tensor::{View, layout::linear::linear_view};
10
11use crate::{
12    layout::{ScalesLayout, scales_view},
13    utils::check_block_size_compat,
14};
15use crate::{
16    layout::{ScalesView, scales_layout},
17    scheme::{QuantLevel, QuantMode, QuantScheme, QuantStore, QuantValue},
18};
19
20#[cube]
21fn quantize_symmetric<F: Float, FS: CubePrimitive>(
22    value: Line<F>,
23    scale: FS,
24    range_min: F,
25    range_max: F,
26) -> Line<F> {
27    Line::clamp(
28        Line::round(value / Line::cast_from(scale)),
29        Line::new(range_min),
30        Line::new(range_max),
31    )
32}
33
34#[cube]
35fn quantize_symmetric_q<F: Float, FS: CubePrimitive, Q: CubePrimitive>(
36    value: Line<F>,
37    scale: FS,
38    range_min: F,
39    range_max: F,
40) -> Line<Q> {
41    Line::cast_from(quantize_symmetric::<F, FS>(
42        value, scale, range_min, range_max,
43    ))
44}
45
46#[cube]
47fn quantize_packed_value<F: Float, FS: CubePrimitive, QS: Int>(
48    value: Line<F>,
49    scale: FS,
50    range_min: F,
51    range_max: F,
52    #[comptime] scheme: QuantScheme,
53) -> QS {
54    let value = quantize_symmetric::<F, FS>(value, scale, range_min, range_max);
55    pack_q::<F, QS>(value, scheme.value)
56}
57
58/// Pack a line of quantized floating-point values into a single integer (the stored quantization type),
59/// according to the specified quantization input type.
60#[allow(clippy::explicit_counter_loop)]
61#[cube]
62fn pack_q<F: Float, QS: Int>(value: Line<F>, #[comptime] quant: QuantValue) -> QS {
63    let size_quant = comptime!(quant.size_bits() as u32);
64
65    let size_store = comptime!(QS::size_bits().unwrap() as u32);
66    let num_quants = comptime!(size_store / size_quant);
67
68    let mask = i32::cast_from(comptime!((1 << size_quant) - 1));
69    let mut position = comptime!(0);
70    let mut packed = QS::cast_from(0);
71
72    // Shift and combine into QS (using i32 for sign extension)
73    #[unroll]
74    for _ in 0..num_quants {
75        let offset = QS::cast_from(comptime!(position * size_quant));
76        let shifted = QS::cast_from(i32::cast_from(value[position]) & mask) << offset;
77        packed |= shifted;
78        comptime!(position += 1);
79    }
80
81    packed
82}
83
84#[cube]
85fn write_scale<F: Float, FS: CubePrimitive>(
86    in_pos: u32,
87    scale: &View<F, u32>,
88    out_scale: &mut View<FS, u32, ReadWrite>,
89    scales_layout: ScalesLayout,
90) -> FS {
91    let scale = FS::cast_from(scale[in_pos]);
92
93    // Write the scale into the output buffer
94    if scales_layout.is_block_start(in_pos) {
95        out_scale[in_pos] = scale;
96    }
97
98    scale
99}
100
101#[cube(launch_unchecked)]
102fn quantize_symmetric_native_kernel<F: Float, FS: Numeric, Q: Numeric>(
103    input: &LinearView<Line<F>>,
104    scale: &ScalesView<F>,
105    range_min: InputScalar,
106    range_max: InputScalar,
107    output: &mut LinearView<Line<Q>, ReadWrite>,
108    out_scale: &mut ScalesView<FS, ReadWrite>,
109    scales_layout: ScalesLayout,
110    #[define(F, FS, Q)] _dtypes: [StorageType; 3],
111) {
112    if !output.is_in_bounds(ABSOLUTE_POS) {
113        terminate!();
114    }
115
116    let native_packing = Q::packing_factor();
117    let in_pos = ABSOLUTE_POS * input.line_size() * native_packing;
118    let scale = write_scale(in_pos, scale, out_scale, scales_layout);
119
120    output[ABSOLUTE_POS] = quantize_symmetric_q::<F, FS, Q>(
121        input[ABSOLUTE_POS],
122        scale,
123        range_min.get::<F>(),
124        range_max.get::<F>(),
125    );
126    sync_cube();
127}
128
129#[cube(launch_unchecked)]
130fn quantize_symmetric_packed_kernel<F: Float, FS: Numeric>(
131    input: &LinearView<Line<F>>,
132    scale: &ScalesView<F>,
133    range_min: InputScalar,
134    range_max: InputScalar,
135    output: &mut LinearView<Line<u32>, ReadWrite>,
136    out_scale: &mut ScalesView<FS, ReadWrite>,
137    scales_layout: ScalesLayout,
138    #[comptime] scheme: QuantScheme,
139    #[define(F, FS)] _dtypes: [StorageType; 2],
140) {
141    if !output.is_in_bounds(ABSOLUTE_POS) {
142        terminate!();
143    }
144
145    let num_quants = comptime!(scheme.num_quants() as u32);
146    let packed_pos = ABSOLUTE_POS * num_quants;
147    let scale = write_scale(packed_pos, scale, out_scale, scales_layout);
148
149    if comptime!(input.line_size() == num_quants) {
150        output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
151            input[ABSOLUTE_POS],
152            scale,
153            range_min.get::<F>(),
154            range_max.get::<F>(),
155            scheme,
156        ));
157    } else {
158        // Input line size = 1
159        let mut values = Line::<F>::empty(num_quants);
160        #[unroll]
161        for i in 0..num_quants {
162            values[i] = input[packed_pos + i][0];
163        }
164        output[ABSOLUTE_POS] = Line::cast_from(quantize_packed_value::<F, FS, u32>(
165            values,
166            scale,
167            range_min.get::<F>(),
168            range_max.get::<F>(),
169            scheme,
170        ));
171    }
172}
173
174#[allow(clippy::result_large_err)]
175pub fn launch_ref<R: Runtime>(
176    client: &ComputeClient<R>,
177    input: &TensorHandleRef<R>,
178    output: &TensorHandleRef<R>,
179    scale: &TensorHandleRef<'_, R>,
180    out_scale: &TensorHandleRef<'_, R>,
181    scheme: &QuantScheme,
182    input_elem: ElemType,
183) -> Result<(), LaunchError> {
184    let param_elem = ElemType::from_quant_param(scheme.param);
185
186    match scheme {
187        QuantScheme {
188            store: QuantStore::U32,
189            ..
190        } => quantize_packed(
191            client, input, scheme, scale, out_scale, output, input_elem, param_elem,
192        ),
193        QuantScheme {
194            value:
195                QuantValue::Q8F
196                | QuantValue::Q8S
197                | QuantValue::E4M3
198                | QuantValue::E5M2
199                | QuantValue::E2M1,
200            store: QuantStore::Native,
201            ..
202        } => {
203            if !i8::supported_uses(client).contains(TypeUsage::Conversion) {
204                panic!(
205                    "{:?} is not supported for native quantization",
206                    scheme.value
207                );
208            }
209
210            quantize_native(
211                client, input, scheme, scale, out_scale, output, input_elem, param_elem,
212            )
213        }
214        QuantScheme {
215            store: QuantStore::Native,
216            value,
217            ..
218        } => {
219            panic!("{value:?} is not supported for native quantization");
220        }
221    }
222}
223
224#[allow(clippy::too_many_arguments)]
225fn quantize_native<R: Runtime>(
226    client: &ComputeClient<R>,
227    input: &TensorHandleRef<R>,
228    scheme: &QuantScheme,
229    scale: &TensorHandleRef<'_, R>,
230    out_scale: &TensorHandleRef<'_, R>,
231    output: &TensorHandleRef<R>,
232    input_dtype: ElemType,
233    scale_dtype: ElemType,
234) -> Result<(), LaunchError> {
235    let num_elems: usize = input.shape.iter().product();
236    let line_size = tensor_line_size_parallel(
237        client.io_optimized_line_sizes_unchecked(input.elem_size),
238        input.shape,
239        input.strides,
240        input.shape.len() - 1,
241    );
242    let cube_dim = CubeDim::default();
243    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
244    let (range_min, range_max) = scheme.value.range();
245
246    match scheme {
247        QuantScheme {
248            level: QuantLevel::Tensor | QuantLevel::Block(_),
249            mode: QuantMode::Symmetric,
250            store: QuantStore::Native,
251            ..
252        } => {
253            // We could use line_size = block_size if it's in the supported line sizes.. but let's keep it simple
254            check_block_size_compat(scheme, line_size as usize);
255            let quant_type = ElemType::from_quant_value(scheme.value);
256
257            unsafe {
258                quantize_symmetric_native_kernel::launch_unchecked(
259                    client,
260                    cube_count,
261                    cube_dim,
262                    linear_view(client, input, line_size),
263                    // scale is computed based on input float dtype, but stored based on qparams precision
264                    scales_view(client, output, scale, 1, scheme),
265                    InputScalar::new(range_min, input_dtype),
266                    InputScalar::new(range_max, input_dtype),
267                    linear_view(client, output, line_size),
268                    scales_view(client, output, out_scale, 1, scheme),
269                    scales_layout(client, output, scale, 1, scheme),
270                    [input_dtype.into(), scale_dtype.into(), quant_type.into()],
271                )
272            }
273        }
274        _ => panic!("Unsupported quantization scheme {scheme:?}"),
275    }
276}
277
278#[allow(clippy::too_many_arguments)]
279fn quantize_packed<R: Runtime>(
280    client: &ComputeClient<R>,
281    input: &TensorHandleRef<R>,
282    scheme: &QuantScheme,
283    scale: &TensorHandleRef<'_, R>,
284    out_scale: &TensorHandleRef<'_, R>,
285    output: &TensorHandleRef<R>,
286    dtype_input: ElemType,
287    dtype_param: ElemType,
288) -> Result<(), LaunchError> {
289    let num_elems: usize = input.shape.iter().product();
290
291    let num_quants = scheme.num_quants() as u8;
292    let line_size = num_quants;
293
294    let cube_dim = CubeDim::default();
295    let cube_count =
296        calculate_cube_count_elemwise(num_elems.div_ceil(line_size as usize), cube_dim);
297    let (range_min, range_max) = scheme.value.range();
298
299    match scheme {
300        QuantScheme {
301            level: QuantLevel::Tensor | QuantLevel::Block(_),
302            mode: QuantMode::Symmetric,
303            store: QuantStore::U32,
304            ..
305        } => {
306            check_block_size_compat(scheme, num_quants as usize); // 32 / 8 = 4
307            unsafe {
308                quantize_symmetric_packed_kernel::launch_unchecked(
309                    client,
310                    cube_count,
311                    cube_dim,
312                    linear_view(client, input, line_size),
313                    // scale is computed based on input float dtype, but stored based on qparams precision
314                    scales_view(client, output, scale, 1, scheme),
315                    InputScalar::new(range_min, dtype_input),
316                    InputScalar::new(range_max, dtype_input),
317                    linear_view(client, output, 1),
318                    scales_view(client, output, out_scale, 1, scheme),
319                    scales_layout(client, output, scale, 1, scheme),
320                    *scheme,
321                    [dtype_input.into(), dtype_param.into()],
322                )
323            }
324        }
325        QuantScheme { .. } => panic!("Unsupported quantization scheme {scheme:?}"),
326    }
327}