burn_jit/kernel/quantization/
quantize.rs

1use crate::tensor::JitTensor;
2use crate::FloatElement;
3use crate::{IntElement, JitElement, JitRuntime};
4use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
5use cubecl::calculate_cube_count_elemwise;
6use cubecl::prelude::*;
7
8#[cube]
9pub(crate) fn quantize_affine_int8<F: Float>(
10    value: Line<F>,
11    scale: f32,
12    offset: i32,
13    range_min: f32,
14    range_max: f32,
15) -> Line<u32> {
16    // x_q = clamp(round(x / scale + offset), a, b)
17    // NOTE: we add 256 before casting to unsigned to correctly represent negative values
18    Line::cast_from(
19        Line::clamp(
20            Line::round((value / Line::cast_from(scale)) + Line::cast_from(offset)),
21            Line::cast_from(range_min),
22            Line::cast_from(range_max),
23        ) + Line::cast_from(comptime!(256f32)),
24    )
25}
26
27#[cube(launch_unchecked)]
28pub(crate) fn quantize_per_tensor_affine_int8_kernel(
29    input: &Tensor<Line<f32>>,
30    scale: &Tensor<f32>,
31    offset: &Tensor<i32>,
32    range_min: f32,
33    range_max: f32,
34    output: &mut Array<u32>,
35) {
36    if ABSOLUTE_POS >= output.len() {
37        return;
38    }
39
40    let scale = scale[0];
41    let offset = offset[0];
42
43    // Cast the scale to u32 and write the value in the output
44    if ABSOLUTE_POS == output.len() - 1 {
45        output[ABSOLUTE_POS] = u32::bitcast_from(scale);
46        return;
47    }
48
49    // Cast the offset to u32 and write the value in the output
50    if ABSOLUTE_POS == output.len() - 2 {
51        output[ABSOLUTE_POS] = u32::bitcast_from(offset);
52        return;
53    }
54
55    let line_size = comptime!(input.line_size());
56    if comptime!(line_size == 4) {
57        // Assuming a line size of 4 (equal to the number of values packed)
58        let value =
59            quantize_affine_int8::<f32>(input[ABSOLUTE_POS], scale, offset, range_min, range_max);
60        // Shift and combine into u32
61        output[ABSOLUTE_POS] = pack_i8s_to_u32s(value);
62    } else {
63        let mut v_packed = 0;
64        let num_packed = comptime!(4);
65        #[unroll]
66        for i in 0..num_packed {
67            let v = quantize_affine_int8::<f32>(
68                input[ABSOLUTE_POS + i],
69                scale,
70                offset,
71                range_min,
72                range_max,
73            );
74            // Shift and combine into u32
75            v_packed |= (v[0] & 0xFF) << (8 * i);
76        }
77        output[ABSOLUTE_POS] = v_packed;
78    }
79}
80
81#[cube]
82pub(crate) fn quantize_symmetric_int8<F: Float>(
83    value: Line<F>,
84    scale: f32,
85    range_min: F,
86    range_max: F,
87) -> Line<u32> {
88    // x_q = clamp(round(x / scale), a, b)
89    // NOTE: we add 256 before casting to unsigned to correctly represent negative values
90    Line::cast_from(
91        Line::clamp(
92            Line::round(value / Line::cast_from(scale)),
93            Line::new(range_min),
94            Line::new(range_max),
95        ) + Line::cast_from(comptime!(256f32)),
96    )
97}
98
99#[cube]
100pub(crate) fn pack_i8s_to_u32s(value: Line<u32>) -> u32 {
101    // NOTE: assuming line size of 4
102    let line_size = value.size();
103    let mut v_packed = 0;
104
105    #[unroll]
106    for i in 0..line_size {
107        // Shift and combine into u32
108        v_packed |= (value[i] & 0xFF) << (8 * i);
109    }
110    v_packed
111}
112
113// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
114#[cube(launch_unchecked)]
115pub(crate) fn quantize_per_tensor_symmetric_int8_kernel(
116    input: &Tensor<Line<f32>>,
117    scale: &Tensor<f32>,
118    range_min: f32,
119    range_max: f32,
120    output: &mut Array<u32>,
121) {
122    if ABSOLUTE_POS >= output.len() {
123        return;
124    }
125
126    let scale = scale[0];
127
128    // Cast the scale to u32 and write the value in the output
129    if ABSOLUTE_POS == output.len() - 1 {
130        output[ABSOLUTE_POS] = u32::bitcast_from(scale);
131        return;
132    }
133
134    let line_size = comptime!(input.line_size());
135    if comptime!(line_size == 4) {
136        // Assuming a vectorization factor of 4 (equal to the number of values packed)
137        let value =
138            quantize_symmetric_int8::<f32>(input[ABSOLUTE_POS], scale, range_min, range_max);
139        // Shift and combine into u32
140        output[ABSOLUTE_POS] = pack_i8s_to_u32s(value);
141    } else {
142        let num_packed = comptime!(4);
143        let mut v_packed = 0;
144        #[unroll]
145        for i in 0..num_packed {
146            let v = quantize_symmetric_int8::<f32>(
147                input[ABSOLUTE_POS + i],
148                scale,
149                range_min,
150                range_max,
151            );
152            // Shift and combine into u32
153            v_packed |= (v[0] & 0xFF) << (8 * i);
154        }
155        output[ABSOLUTE_POS] = v_packed;
156    }
157}
158
159pub(crate) fn quantize_per_tensor<R, F, I>(
160    tensor: JitTensor<R>,
161    scale: JitTensor<R>,
162    offset: Option<JitTensor<R>>,
163    scheme: QuantizationScheme,
164) -> JitTensor<R>
165where
166    R: JitRuntime,
167    F: JitElement,
168    I: IntElement,
169{
170    let ndims = tensor.shape.num_dims();
171    let num_elems = tensor.shape.num_elements();
172    let client = tensor.client.clone();
173    // Output tensor contains 4x less elements (four int8 values packed in a single u32)
174    let output_num_elems = usize::div_ceil(num_elems, 4) * core::mem::size_of::<u32>();
175
176    // Force vectorization to process 4 quantized values packed for 1 output value
177    let line_size: u8 = if num_elems < 4 { 1 } else { 4 };
178    let cube_dim = CubeDim::default();
179    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
180
181    let dummy_array = vec![1; ndims];
182    if let Some(offset) = offset {
183        // Scale and offset qparams are also packed in the tensor dat
184        let handle = client
185            .empty(output_num_elems + core::mem::size_of::<f32>() + core::mem::size_of::<i32>());
186        let output = JitTensor::new_contiguous(
187            client.clone(),
188            tensor.device.clone(),
189            tensor.shape.clone(),
190            handle,
191            burn_tensor::DType::QFloat(scheme),
192        );
193
194        unsafe {
195            quantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
196                &client,
197                cube_count,
198                cube_dim,
199                tensor.as_tensor_arg::<F>(line_size),
200                // Ignore shape and stride
201                TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
202                TensorArg::from_raw_parts::<I>(&offset.handle, &dummy_array, &dummy_array, 1),
203                ScalarArg::new(i8::MIN as f32),
204                ScalarArg::new(i8::MAX as f32),
205                output.as_array_arg::<u32>(1),
206            )
207        };
208        output
209    } else {
210        // Scale qparam is also packed in the tensor data
211        let handle = client.empty(output_num_elems + core::mem::size_of::<f32>());
212        let output = JitTensor::new_contiguous(
213            client.clone(),
214            tensor.device.clone(),
215            tensor.shape.clone(),
216            handle,
217            burn_tensor::DType::QFloat(scheme),
218        );
219
220        unsafe {
221            quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
222                &client,
223                cube_count,
224                cube_dim,
225                tensor.as_tensor_arg::<F>(line_size),
226                // Ignore shape and stride
227                TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
228                ScalarArg::new(-i8::MAX as f32),
229                ScalarArg::new(i8::MAX as f32),
230                output.as_array_arg::<u32>(1),
231            )
232        };
233
234        output
235    }
236}
237
238/// Convert the tensor to a lower precision data type based on the quantization scheme and parameters.
239pub fn quantize<R, F, I>(
240    tensor: JitTensor<R>,
241    scheme: &QuantizationScheme,
242    scale: JitTensor<R>,
243    offset: Option<JitTensor<R>>,
244) -> JitTensor<R>
245where
246    R: JitRuntime,
247    F: FloatElement,
248    I: IntElement,
249{
250    match scheme {
251        QuantizationScheme::PerTensorAffine(dtype)
252        | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
253            QuantizationType::QInt8 => {
254                quantize_per_tensor::<R, F, I>(tensor, scale, offset, *scheme)
255            }
256        },
257    }
258}