burn_cubecl/kernel/quantization/
quantize.rs

1use crate::tensor::CubeTensor;
2use crate::{CubeElement, CubeRuntime, IntElement};
3use burn_tensor::Shape;
4use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType};
5use cubecl::calculate_cube_count_elemwise;
6use cubecl::prelude::*;
7
8#[cube]
9fn pack_i8s_to_u32s(value: Line<u32>) -> u32 {
10    // NOTE: assuming line size of 4
11    let line_size = value.size();
12    let mut v_packed = 0;
13
14    #[unroll]
15    for i in 0..line_size {
16        // Shift and combine into u32
17        v_packed |= (value[i] & 0xFF) << (8 * i);
18    }
19    v_packed
20}
21
22#[cube]
23fn quantize_symmetric_int8<F: Float>(
24    value: Line<F>,
25    scale: f32,
26    range_min: F,
27    range_max: F,
28) -> Line<u32> {
29    // x_q = clamp(round(x / scale), a, b)
30    // NOTE: we add 256 before casting to unsigned to correctly represent negative values
31    Line::cast_from(
32        Line::clamp(
33            Line::round(value / Line::cast_from(scale)),
34            Line::new(range_min),
35            Line::new(range_max),
36        ) + Line::cast_from(comptime!(256f32)),
37    )
38}
39
40#[cube]
41fn quantize_symmetric_int8_packed(
42    input: Line<f32>,
43    scale: f32,
44    range_min: f32,
45    range_max: f32,
46) -> u32 {
47    // Assuming a line size of 4 (equal to the number of values packed)
48    let value = quantize_symmetric_int8::<f32>(input, scale, range_min, range_max);
49    // Shift and combine into u32
50    pack_i8s_to_u32s(value)
51}
52
53#[cube(launch_unchecked)]
54fn quantize_per_tensor_symmetric_int8_kernel(
55    input: &Tensor<Line<f32>>,
56    scale: &Tensor<f32>,
57    range_min: f32,
58    range_max: f32,
59    output: &mut Array<u32>,
60) {
61    if ABSOLUTE_POS >= output.len() {
62        terminate!();
63    }
64
65    let scale = scale[0];
66
67    // Cast the scale to u32 and write the value in the output
68    if ABSOLUTE_POS == output.len() - 1 {
69        output[ABSOLUTE_POS] = u32::reinterpret(scale);
70        terminate!();
71    }
72
73    if comptime!(input.line_size() == 4) {
74        output[ABSOLUTE_POS] =
75            quantize_symmetric_int8_packed(input[ABSOLUTE_POS], scale, range_min, range_max);
76    } else {
77        // line size 1
78        let num_packed = comptime!(4);
79        let mut values = Line::<f32>::empty(num_packed);
80        #[unroll]
81        for i in 0..num_packed {
82            values[i] = input[ABSOLUTE_POS * num_packed + i][0];
83        }
84        output[ABSOLUTE_POS] = quantize_symmetric_int8_packed(values, scale, range_min, range_max);
85    }
86}
87
88fn create_quantized_output<R: CubeRuntime>(
89    client: ComputeClient<R::Server, R::Channel>,
90    num_input_elems: usize,
91    device: R::Device,
92    shape: Shape,
93    scheme: QuantizationScheme,
94) -> CubeTensor<R> {
95    // Output tensor contains 4x less elements (four int8 values packed in a single u32)
96    let output_elems_size = usize::div_ceil(num_input_elems, 4) * core::mem::size_of::<u32>();
97
98    // Scale and offset (optional) qparams are also packed in the tensor data
99    let qparams_size = match &scheme {
100        QuantizationScheme::PerTensor(mode, ..) => match mode {
101            QuantizationMode::Symmetric => core::mem::size_of::<f32>(),
102        },
103    };
104
105    let handle = client.empty(output_elems_size + qparams_size);
106    CubeTensor::new_contiguous(
107        client,
108        device,
109        shape,
110        handle,
111        burn_tensor::DType::QFloat(scheme),
112    )
113}
114
115/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
116pub fn quantize<R, F, I>(
117    tensor: CubeTensor<R>,
118    scheme: &QuantizationScheme,
119    scale: CubeTensor<R>,
120) -> CubeTensor<R>
121where
122    R: CubeRuntime,
123    F: CubeElement,
124    I: IntElement,
125{
126    let client = tensor.client.clone();
127    // Output tensor contains 4x less elements (four int8 values packed in a single u32)
128    let num_elems = tensor.shape.num_elements();
129
130    // Force vectorization to process 4 quantized values packed for 1 output value
131    let line_size: u8 = 1;
132    let cube_dim = CubeDim::default();
133    let cube_count =
134        calculate_cube_count_elemwise(num_elems.div_ceil(line_size as usize), cube_dim);
135
136    let output = create_quantized_output(
137        client.clone(),
138        num_elems,
139        tensor.device.clone(),
140        tensor.shape.clone(),
141        *scheme,
142    );
143
144    match scheme {
145        QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) => {
146            let ndims = tensor.shape.num_dims();
147            let dummy_array = vec![1; ndims];
148
149            match mode {
150                QuantizationMode::Symmetric => {
151                    unsafe {
152                        quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
153                            &client,
154                            cube_count,
155                            cube_dim,
156                            tensor.as_tensor_arg::<F>(line_size),
157                            // Ignore shape and stride
158                            TensorArg::from_raw_parts::<F>(
159                                &scale.handle,
160                                &dummy_array,
161                                &dummy_array,
162                                1,
163                            ),
164                            ScalarArg::new(-i8::MAX as f32),
165                            ScalarArg::new(i8::MAX as f32),
166                            output.as_array_arg::<u32>(1),
167                        )
168                    };
169                }
170            }
171        }
172    }
173
174    output
175}