burn_jit/kernel/quantization/
dequantize.rs

1use crate::tensor::JitTensor;
2use crate::FloatElement;
3use crate::{JitElement, JitRuntime};
4use burn_tensor::quantization::{QuantizationScheme, QuantizationType};
5use burn_tensor::DType;
6use cubecl::calculate_cube_count_elemwise;
7use cubecl::prelude::*;
8
9use super::{QParams, QTensor};
10
11#[cube]
12pub(crate) fn dequantize_affine_int8<F: Float>(
13    value: Line<i32>,
14    scale: f32,
15    offset: i32,
16) -> Line<F> {
17    // x = scale * (x_q - offset)
18    Line::cast_from(scale) * Line::cast_from(value - Line::cast_from(offset))
19}
20
21#[cube]
22pub(crate) fn extract_i8(value: u32, offset: u32) -> i32 {
23    // Extract 8-bit segment
24    let value = (value >> offset) & 0xFF;
25    // Check if the value is negative by inspecting the MSB and subtract 256 if it is
26    // Subtract 0 or 256 to circumvent unsupported conditional assignment (let x = if {} else {};)
27    let sub = i32::cast_from(value & 0x80 != 0) * 256;
28    i32::cast_from(value) - sub
29}
30
31#[cube]
32pub(crate) fn extract_i8s(value: u32) -> Line<i32> {
33    let mut line = Line::empty(4);
34    // Extract each 8-bit segment
35    line[0] = extract_i8(value, 0);
36    line[1] = extract_i8(value, 8);
37    line[2] = extract_i8(value, 16);
38    line[3] = extract_i8(value, 24);
39
40    line
41}
42
43#[cube(launch_unchecked)]
44pub(crate) fn dequantize_per_tensor_affine_int8_kernel(
45    input: &QTensor,
46    output: &mut Tensor<Line<f32>>,
47    #[comptime] scheme: QuantizationScheme,
48) {
49    // Last two positions contain the qparams
50    if ABSOLUTE_POS >= input.len() - 2 {
51        return;
52    }
53
54    let qparams = QParams::new(scheme);
55    let (scale, offset) = qparams.values(input);
56
57    let value = input[ABSOLUTE_POS];
58
59    // Input line size is fixed to 1
60    if comptime!(output.line_size() == 4) {
61        output[ABSOLUTE_POS] = dequantize_affine_int8(extract_i8s(value[0]), scale, offset);
62    } else {
63        // For very small inputs where number of elements < 4, the output line size is 1
64        let out = dequantize_affine_int8::<f32>(extract_i8s(value[0]), scale, offset);
65
66        #[unroll]
67        for j in 0..out.size() {
68            output[ABSOLUTE_POS + j] = Line::cast_from(out[j]);
69        }
70    }
71}
72
73#[cube]
74pub(crate) fn dequantize_symmetric_int8<F: Float>(value: Line<i32>, scale: f32) -> Line<F> {
75    // x = scale * x_q
76    Line::cast_from(scale) * Line::cast_from(value)
77}
78
79// Would have wrapped symmetric with the same affine kernel but cube doesn't support Option<Tensor> for offset.
80#[cube(launch_unchecked)]
81pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel(
82    input: &QTensor,
83    output: &mut Tensor<Line<f32>>,
84    #[comptime] scheme: QuantizationScheme,
85) {
86    // Last position contains the qparam
87    if ABSOLUTE_POS >= input.len() - 1 {
88        return;
89    }
90
91    let qparams = QParams::new(scheme);
92    let (scale, _) = qparams.values(input);
93
94    let value = input[ABSOLUTE_POS];
95
96    // Input line size is fixed to 1
97    if comptime!(output.line_size() == 4) {
98        output[ABSOLUTE_POS] = dequantize_symmetric_int8(extract_i8s(value[0]), scale);
99    } else {
100        // For very small inputs where number of elements < 4, the output line size is 1
101        let out = dequantize_symmetric_int8::<f32>(extract_i8s(value[0]), scale);
102
103        #[unroll]
104        for j in 0..out.size() {
105            output[ABSOLUTE_POS + j] = Line::cast_from(out[j]);
106        }
107    }
108}
109
110pub(crate) fn dequantize_per_tensor<R, F>(tensor: JitTensor<R>) -> JitTensor<R>
111where
112    R: JitRuntime,
113    F: JitElement,
114{
115    // The actual number of elements is 1/4 (four int8 values packed in a single u32)
116    // so we choose a line size to match a valid input binding size.
117    let num_out_elems = tensor.shape.num_elements();
118    let num_elems = usize::div_ceil(num_out_elems, 4);
119    let line_size_in = 1;
120    let line_size_out = if num_out_elems < 4 { 1 } else { 4 };
121    let cube_dim = CubeDim::default();
122    let cube_count = calculate_cube_count_elemwise(num_elems / line_size_in as usize, cube_dim);
123
124    let client = tensor.client.clone();
125    let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
126
127    let output = JitTensor::new_contiguous(
128        client.clone(),
129        tensor.device.clone(),
130        tensor.shape.clone(),
131        handle,
132        F::dtype(),
133    );
134
135    if let DType::QFloat(scheme) = tensor.dtype {
136        match scheme {
137            QuantizationScheme::PerTensorAffine(QuantizationType::QInt8) => {
138                unsafe {
139                    dequantize_per_tensor_affine_int8_kernel::launch_unchecked::<R>(
140                        &client,
141                        cube_count,
142                        cube_dim,
143                        tensor.as_array_arg::<u32>(line_size_in),
144                        output.as_tensor_arg::<F>(line_size_out),
145                        scheme,
146                    )
147                };
148            }
149            QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
150                unsafe {
151                    dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
152                        &client,
153                        cube_count,
154                        cube_dim,
155                        tensor.as_array_arg::<u32>(line_size_in),
156                        output.as_tensor_arg::<F>(line_size_out),
157                        scheme,
158                    )
159                };
160            }
161        }
162    }
163
164    output
165}
166
167/// Convert the tensor back to a higher precision data type.
168pub fn dequantize<R, F>(tensor: JitTensor<R>) -> JitTensor<R>
169where
170    R: JitRuntime,
171    F: FloatElement,
172{
173    dequantize_per_tensor::<R, F>(tensor)
174}