burn_cubecl/kernel/quantization/
dequantize.rs

1use crate::tensor::CubeTensor;
2use crate::{CubeElement, CubeRuntime};
3use burn_tensor::DType;
4use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType};
5use cubecl::calculate_cube_count_elemwise;
6use cubecl::prelude::*;
7
8use super::{QParams, QTensor};
9
10#[cube]
11fn dequantize_symmetric_int8<F: Float>(value: Line<i32>, scale: f32) -> Line<F> {
12    // x = scale * x_q
13    Line::cast_from(scale) * Line::cast_from(value)
14}
15
16#[cube]
17fn extract_i8(value: u32, offset: u32) -> i32 {
18    // Extract 8-bit segment
19    let value = (value >> offset) & 0xFF;
20    // Check if the value is negative by inspecting the MSB and subtract 256 if it is
21    // Subtract 0 or 256 to circumvent unsupported conditional assignment (let x = if {} else {};)
22    let sub = i32::cast_from(value & 0x80 != 0) * 256;
23    i32::cast_from(value) - sub
24}
25
26#[cube]
27fn unpack_i8s(value: u32) -> Line<i32> {
28    let mut line = Line::empty(4_u32);
29    // Extract each 8-bit segment
30    line[0] = extract_i8(value, 0);
31    line[1] = extract_i8(value, 8);
32    line[2] = extract_i8(value, 16);
33    line[3] = extract_i8(value, 24);
34
35    line
36}
37
38#[cube(launch_unchecked)]
39fn dequantize_per_tensor_symmetric_int8_kernel(
40    input: &QTensor,
41    output: &mut Tensor<Line<f32>>,
42    #[comptime] scheme: QuantizationScheme,
43) {
44    // Last position contains the qparam
45    if ABSOLUTE_POS >= input.len() - 1 {
46        terminate!();
47    }
48
49    let qparams = QParams::new(scheme);
50    let (scale, _) = qparams.values(input);
51
52    let value = input[ABSOLUTE_POS];
53
54    // Input line size is fixed to 1
55    if comptime!(output.line_size() == 4) {
56        output[ABSOLUTE_POS] = dequantize_symmetric_int8(unpack_i8s(value[0]), scale);
57    } else {
58        // For very small inputs where number of elements < 4, the output line size is 1
59        let out = dequantize_symmetric_int8::<f32>(unpack_i8s(value[0]), scale);
60
61        #[unroll]
62        for j in 0..out.size() {
63            output[ABSOLUTE_POS * out.size() + j] = Line::cast_from(out[j]);
64        }
65    }
66}
67
68/// Convert the tensor back to a higher precision data type.
69pub fn dequantize<R, F>(tensor: CubeTensor<R>) -> CubeTensor<R>
70where
71    R: CubeRuntime,
72    F: CubeElement,
73{
74    // The actual number of elements is 1/4 (four int8 values packed in a single u32)
75    // so we choose a line size to match a valid input binding size.
76    let num_out_elems = tensor.shape.num_elements();
77    let num_elems = usize::div_ceil(num_out_elems, 4);
78    let line_size_in = 1;
79    let line_size_out = 1;
80    let cube_dim = CubeDim::default();
81    let cube_count = calculate_cube_count_elemwise(num_elems / line_size_in as usize, cube_dim);
82
83    let client = tensor.client.clone();
84    let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
85
86    let output = CubeTensor::new_contiguous(
87        client.clone(),
88        tensor.device.clone(),
89        tensor.shape.clone(),
90        handle,
91        F::dtype(),
92    );
93
94    if let DType::QFloat(scheme) = tensor.dtype {
95        match scheme {
96            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
97                unsafe {
98                    dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
99                        &client,
100                        cube_count,
101                        cube_dim,
102                        tensor.as_array_arg::<u32>(line_size_in),
103                        output.as_tensor_arg::<F>(line_size_out),
104                        scheme,
105                    )
106                };
107            }
108        }
109    }
110
111    output
112}