burn_jit/kernel/quantization/
quantize.rs1use crate::tensor::JitTensor;
2use crate::FloatElement;
3use crate::{IntElement, JitElement, JitRuntime};
4use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
5use cubecl::calculate_cube_count_elemwise;
6use cubecl::prelude::*;
7
8#[cube]
9pub(crate) fn quantize_affine_int8<F: Float>(
10 value: Line<F>,
11 scale: f32,
12 offset: i32,
13 range_min: f32,
14 range_max: f32,
15) -> Line<u32> {
16 Line::cast_from(
19 Line::clamp(
20 Line::round((value / Line::cast_from(scale)) + Line::cast_from(offset)),
21 Line::cast_from(range_min),
22 Line::cast_from(range_max),
23 ) + Line::cast_from(comptime!(256f32)),
24 )
25}
26
27#[cube(launch_unchecked)]
28pub(crate) fn quantize_per_tensor_affine_int8_kernel(
29 input: &Tensor<Line<f32>>,
30 scale: &Tensor<f32>,
31 offset: &Tensor<i32>,
32 range_min: f32,
33 range_max: f32,
34 output: &mut Array<u32>,
35) {
36 if ABSOLUTE_POS >= output.len() {
37 return;
38 }
39
40 let scale = scale[0];
41 let offset = offset[0];
42
43 if ABSOLUTE_POS == output.len() - 1 {
45 output[ABSOLUTE_POS] = u32::bitcast_from(scale);
46 return;
47 }
48
49 if ABSOLUTE_POS == output.len() - 2 {
51 output[ABSOLUTE_POS] = u32::bitcast_from(offset);
52 return;
53 }
54
55 let line_size = comptime!(input.line_size());
56 if comptime!(line_size == 4) {
57 let value =
59 quantize_affine_int8::<f32>(input[ABSOLUTE_POS], scale, offset, range_min, range_max);
60 output[ABSOLUTE_POS] = pack_i8s_to_u32s(value);
62 } else {
63 let mut v_packed = 0;
64 let num_packed = comptime!(4);
65 #[unroll]
66 for i in 0..num_packed {
67 let v = quantize_affine_int8::<f32>(
68 input[ABSOLUTE_POS + i],
69 scale,
70 offset,
71 range_min,
72 range_max,
73 );
74 v_packed |= (v[0] & 0xFF) << (8 * i);
76 }
77 output[ABSOLUTE_POS] = v_packed;
78 }
79}
80
81#[cube]
82pub(crate) fn quantize_symmetric_int8<F: Float>(
83 value: Line<F>,
84 scale: f32,
85 range_min: F,
86 range_max: F,
87) -> Line<u32> {
88 Line::cast_from(
91 Line::clamp(
92 Line::round(value / Line::cast_from(scale)),
93 Line::new(range_min),
94 Line::new(range_max),
95 ) + Line::cast_from(comptime!(256f32)),
96 )
97}
98
99#[cube]
100pub(crate) fn pack_i8s_to_u32s(value: Line<u32>) -> u32 {
101 let line_size = value.size();
103 let mut v_packed = 0;
104
105 #[unroll]
106 for i in 0..line_size {
107 v_packed |= (value[i] & 0xFF) << (8 * i);
109 }
110 v_packed
111}
112
113#[cube(launch_unchecked)]
115pub(crate) fn quantize_per_tensor_symmetric_int8_kernel(
116 input: &Tensor<Line<f32>>,
117 scale: &Tensor<f32>,
118 range_min: f32,
119 range_max: f32,
120 output: &mut Array<u32>,
121) {
122 if ABSOLUTE_POS >= output.len() {
123 return;
124 }
125
126 let scale = scale[0];
127
128 if ABSOLUTE_POS == output.len() - 1 {
130 output[ABSOLUTE_POS] = u32::bitcast_from(scale);
131 return;
132 }
133
134 let line_size = comptime!(input.line_size());
135 if comptime!(line_size == 4) {
136 let value =
138 quantize_symmetric_int8::<f32>(input[ABSOLUTE_POS], scale, range_min, range_max);
139 output[ABSOLUTE_POS] = pack_i8s_to_u32s(value);
141 } else {
142 let num_packed = comptime!(4);
143 let mut v_packed = 0;
144 #[unroll]
145 for i in 0..num_packed {
146 let v = quantize_symmetric_int8::<f32>(
147 input[ABSOLUTE_POS + i],
148 scale,
149 range_min,
150 range_max,
151 );
152 v_packed |= (v[0] & 0xFF) << (8 * i);
154 }
155 output[ABSOLUTE_POS] = v_packed;
156 }
157}
158
159pub(crate) fn quantize_per_tensor<R, F, I>(
160 tensor: JitTensor<R>,
161 scale: JitTensor<R>,
162 offset: Option<JitTensor<R>>,
163 scheme: QuantizationScheme,
164) -> JitTensor<R>
165where
166 R: JitRuntime,
167 F: JitElement,
168 I: IntElement,
169{
170 let ndims = tensor.shape.num_dims();
171 let num_elems = tensor.shape.num_elements();
172 let client = tensor.client.clone();
173 let output_num_elems = usize::div_ceil(num_elems, 4) * core::mem::size_of::<u32>();
175
176 let line_size: u8 = if num_elems < 4 { 1 } else { 4 };
178 let cube_dim = CubeDim::default();
179 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
180
181 let dummy_array = vec![1; ndims];
182 if let Some(offset) = offset {
183 let handle = client
185 .empty(output_num_elems + core::mem::size_of::<f32>() + core::mem::size_of::<i32>());
186 let output = JitTensor::new_contiguous(
187 client.clone(),
188 tensor.device.clone(),
189 tensor.shape.clone(),
190 handle,
191 burn_tensor::DType::QFloat(scheme),
192 );
193
194 unsafe {
195 quantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
196 &client,
197 cube_count,
198 cube_dim,
199 tensor.as_tensor_arg::<F>(line_size),
200 TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
202 TensorArg::from_raw_parts::<I>(&offset.handle, &dummy_array, &dummy_array, 1),
203 ScalarArg::new(i8::MIN as f32),
204 ScalarArg::new(i8::MAX as f32),
205 output.as_array_arg::<u32>(1),
206 )
207 };
208 output
209 } else {
210 let handle = client.empty(output_num_elems + core::mem::size_of::<f32>());
212 let output = JitTensor::new_contiguous(
213 client.clone(),
214 tensor.device.clone(),
215 tensor.shape.clone(),
216 handle,
217 burn_tensor::DType::QFloat(scheme),
218 );
219
220 unsafe {
221 quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
222 &client,
223 cube_count,
224 cube_dim,
225 tensor.as_tensor_arg::<F>(line_size),
226 TensorArg::from_raw_parts::<F>(&scale.handle, &dummy_array, &dummy_array, 1),
228 ScalarArg::new(-i8::MAX as f32),
229 ScalarArg::new(i8::MAX as f32),
230 output.as_array_arg::<u32>(1),
231 )
232 };
233
234 output
235 }
236}
237
238pub fn quantize<R, F, I>(
240 tensor: JitTensor<R>,
241 scheme: &QuantizationScheme,
242 scale: JitTensor<R>,
243 offset: Option<JitTensor<R>>,
244) -> JitTensor<R>
245where
246 R: JitRuntime,
247 F: FloatElement,
248 I: IntElement,
249{
250 match scheme {
251 QuantizationScheme::PerTensorAffine(dtype)
252 | QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
253 QuantizationType::QInt8 => {
254 quantize_per_tensor::<R, F, I>(tensor, scale, offset, *scheme)
255 }
256 },
257 }
258}