burn_jit/kernel/quantization/
dequantize.rs1use crate::tensor::JitTensor;
2use crate::FloatElement;
3use crate::{JitElement, JitRuntime};
4use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
5use burn_tensor::DType;
6use cubecl::calculate_cube_count_elemwise;
7use cubecl::prelude::*;
8
9use super::{QParams, QTensor};
10
11#[cube]
12pub(crate) fn dequantize_affine_int8<F: Float>(
13 value: Line<i32>,
14 scale: f32,
15 offset: i32,
16) -> Line<F> {
17 Line::cast_from(scale) * Line::cast_from(value - Line::cast_from(offset))
19}
20
21#[cube]
22pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 {
23 let value = (value >> offset) & 0xFF;
25 let sub = i32::cast_from(value & 0x80 != 0) * 256;
28 i32::cast_from(value) - sub
29}
30
31#[cube]
32pub(crate) fn extract_i8s(value: u32) -> Line<i32> {
33 let mut line = Line::empty(4);
34 line[0] = extract_i8(value, 0);
36 line[1] = extract_i8(value, 8);
37 line[2] = extract_i8(value, 16);
38 line[3] = extract_i8(value, 24);
39
40 line
41}
42
43#[cube(launch_unchecked)]
44pub(crate) fn dequantize_per_tensor_affine_int8_kernel(
45 input: &QTensor,
46 output: &mut Tensor<Line<f32>>,
47 #[comptime] scheme: QuantizationScheme,
48) {
49 if ABSOLUTE_POS >= input.len() - 2 {
51 return;
52 }
53
54 let qparams = QParams::new(scheme);
55 let (scale, offset) = qparams.values(input);
56
57 let value = input[ABSOLUTE_POS];
58
59 if comptime!(output.line_size() == 4) {
61 output[ABSOLUTE_POS] = dequantize_affine_int8(extract_i8s(value[0]), scale, offset);
62 } else {
63 let out = dequantize_affine_int8::<f32>(extract_i8s(value[0]), scale, offset);
65
66 #[unroll]
67 for j in 0..out.size() {
68 output[ABSOLUTE_POS + j] = Line::cast_from(out[j]);
69 }
70 }
71}
72
73#[cube]
74pub(crate) fn dequantize_symmetric_int8<F: Float>(value: Line<i32>, scale: f32) -> Line<F> {
75 Line::cast_from(scale) * Line::cast_from(value)
77}
78
79#[cube(launch_unchecked)]
81pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel(
82 input: &QTensor,
83 output: &mut Tensor<Line<f32>>,
84 #[comptime] scheme: QuantizationScheme,
85) {
86 if ABSOLUTE_POS >= input.len() - 1 {
88 return;
89 }
90
91 let qparams = QParams::new(scheme);
92 let (scale, _) = qparams.values(input);
93
94 let value = input[ABSOLUTE_POS];
95
96 if comptime!(output.line_size() == 4) {
98 output[ABSOLUTE_POS] = dequantize_symmetric_int8(extract_i8s(value[0]), scale);
99 } else {
100 let out = dequantize_symmetric_int8::<f32>(extract_i8s(value[0]), scale);
102
103 #[unroll]
104 for j in 0..out.size() {
105 output[ABSOLUTE_POS + j] = Line::cast_from(out[j]);
106 }
107 }
108}
109
110pub(crate) fn dequantize_per_tensor<R, F>(tensor: JitTensor<R>) -> JitTensor<R>
111where
112 R: JitRuntime,
113 F: JitElement,
114{
115 let num_out_elems = tensor.shape.num_elements();
118 let num_elems = usize::div_ceil(num_out_elems, 4);
119 let line_size_in = 1;
120 let line_size_out = if num_out_elems < 4 { 1 } else { 4 };
121 let cube_dim = CubeDim::default();
122 let cube_count = calculate_cube_count_elemwise(num_elems / line_size_in as usize, cube_dim);
123
124 let client = tensor.client.clone();
125 let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
126
127 let output = JitTensor::new_contiguous(
128 client.clone(),
129 tensor.device.clone(),
130 tensor.shape.clone(),
131 handle,
132 F::dtype(),
133 );
134
135 if let DType::QFloat(scheme) = tensor.dtype {
136 match scheme {
137 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
138 unsafe {
139 dequantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
140 &client,
141 cube_count,
142 cube_dim,
143 tensor.as_array_arg::<u32>(line_size_in),
144 output.as_tensor_arg::<F>(line_size_out),
145 scheme,
146 )
147 };
148 }
149 QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
150 unsafe {
151 dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
152 &client,
153 cube_count,
154 cube_dim,
155 tensor.as_array_arg::<u32>(line_size_in),
156 output.as_tensor_arg::<F>(line_size_out),
157 scheme,
158 )
159 };
160 }
161 }
162 }
163
164 output
165}
166
167pub fn dequantize<R, F>(tensor: JitTensor<R>) -> JitTensor<R>
169where
170 R: JitRuntime,
171 F: FloatElement,
172{
173 dequantize_per_tensor::<R, F>(tensor)
174}