burn_cubecl/kernel/quantization/
quantize.rs1use crate::tensor::CubeTensor;
2use crate::{CubeElement, CubeRuntime, IntElement};
3use burn_tensor::Shape;
4use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType};
5use cubecl::calculate_cube_count_elemwise;
6use cubecl::prelude::*;
7
8#[cube]
9fn pack_i8s_to_u32s(value: Line<u32>) -> u32 {
10 let line_size = value.size();
12 let mut v_packed = 0;
13
14 #[unroll]
15 for i in 0..line_size {
16 v_packed |= (value[i] & 0xFF) << (8 * i);
18 }
19 v_packed
20}
21
22#[cube]
23fn quantize_symmetric_int8<F: Float>(
24 value: Line<F>,
25 scale: f32,
26 range_min: F,
27 range_max: F,
28) -> Line<u32> {
29 Line::cast_from(
32 Line::clamp(
33 Line::round(value / Line::cast_from(scale)),
34 Line::new(range_min),
35 Line::new(range_max),
36 ) + Line::cast_from(comptime!(256f32)),
37 )
38}
39
40#[cube]
41fn quantize_symmetric_int8_packed(
42 input: Line<f32>,
43 scale: f32,
44 range_min: f32,
45 range_max: f32,
46) -> u32 {
47 let value = quantize_symmetric_int8::<f32>(input, scale, range_min, range_max);
49 pack_i8s_to_u32s(value)
51}
52
53#[cube(launch_unchecked)]
54fn quantize_per_tensor_symmetric_int8_kernel(
55 input: &Tensor<Line<f32>>,
56 scale: &Tensor<f32>,
57 range_min: f32,
58 range_max: f32,
59 output: &mut Array<u32>,
60) {
61 if ABSOLUTE_POS >= output.len() {
62 terminate!();
63 }
64
65 let scale = scale[0];
66
67 if ABSOLUTE_POS == output.len() - 1 {
69 output[ABSOLUTE_POS] = u32::reinterpret(scale);
70 terminate!();
71 }
72
73 if comptime!(input.line_size() == 4) {
74 output[ABSOLUTE_POS] =
75 quantize_symmetric_int8_packed(input[ABSOLUTE_POS], scale, range_min, range_max);
76 } else {
77 let num_packed = comptime!(4);
79 let mut values = Line::<f32>::empty(num_packed);
80 #[unroll]
81 for i in 0..num_packed {
82 values[i] = input[ABSOLUTE_POS * num_packed + i][0];
83 }
84 output[ABSOLUTE_POS] = quantize_symmetric_int8_packed(values, scale, range_min, range_max);
85 }
86}
87
88fn create_quantized_output<R: CubeRuntime>(
89 client: ComputeClient<R::Server, R::Channel>,
90 num_input_elems: usize,
91 device: R::Device,
92 shape: Shape,
93 scheme: QuantizationScheme,
94) -> CubeTensor<R> {
95 let output_elems_size = usize::div_ceil(num_input_elems, 4) * core::mem::size_of::<u32>();
97
98 let qparams_size = match &scheme {
100 QuantizationScheme::PerTensor(mode, ..) => match mode {
101 QuantizationMode::Symmetric => core::mem::size_of::<f32>(),
102 },
103 };
104
105 let handle = client.empty(output_elems_size + qparams_size);
106 CubeTensor::new_contiguous(
107 client,
108 device,
109 shape,
110 handle,
111 burn_tensor::DType::QFloat(scheme),
112 )
113}
114
115pub fn quantize<R, F, I>(
117 tensor: CubeTensor<R>,
118 scheme: &QuantizationScheme,
119 scale: CubeTensor<R>,
120) -> CubeTensor<R>
121where
122 R: CubeRuntime,
123 F: CubeElement,
124 I: IntElement,
125{
126 let client = tensor.client.clone();
127 let num_elems = tensor.shape.num_elements();
129
130 let line_size: u8 = 1;
132 let cube_dim = CubeDim::default();
133 let cube_count =
134 calculate_cube_count_elemwise(num_elems.div_ceil(line_size as usize), cube_dim);
135
136 let output = create_quantized_output(
137 client.clone(),
138 num_elems,
139 tensor.device.clone(),
140 tensor.shape.clone(),
141 *scheme,
142 );
143
144 match scheme {
145 QuantizationScheme::PerTensor(mode, QuantizationType::QInt8) => {
146 let ndims = tensor.shape.num_dims();
147 let dummy_array = vec![1; ndims];
148
149 match mode {
150 QuantizationMode::Symmetric => {
151 unsafe {
152 quantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
153 &client,
154 cube_count,
155 cube_dim,
156 tensor.as_tensor_arg::<F>(line_size),
157 TensorArg::from_raw_parts::<F>(
159 &scale.handle,
160 &dummy_array,
161 &dummy_array,
162 1,
163 ),
164 ScalarArg::new(-i8::MAX as f32),
165 ScalarArg::new(i8::MAX as f32),
166 output.as_array_arg::<u32>(1),
167 )
168 };
169 }
170 }
171 }
172 }
173
174 output
175}