burn_cubecl/kernel/quantization/
dequantize.rs1use crate::tensor::CubeTensor;
2use crate::{CubeElement, CubeRuntime};
3use burn_tensor::DType;
4use burn_tensor::quantization::{QuantizationMode, QuantizationScheme, QuantizationType};
5use cubecl::calculate_cube_count_elemwise;
6use cubecl::prelude::*;
7
8use super::{QParams, QTensor};
9
10#[cube]
11fn dequantize_symmetric_int8<F: Float>(value: Line<i32>, scale: f32) -> Line<F> {
12 Line::cast_from(scale) * Line::cast_from(value)
14}
15
16#[cube]
17fn extract_i8(value: u32, offset: u32) -> i32 {
18 let value = (value >> offset) & 0xFF;
20 let sub = i32::cast_from(value & 0x80 != 0) * 256;
23 i32::cast_from(value) - sub
24}
25
26#[cube]
27fn unpack_i8s(value: u32) -> Line<i32> {
28 let mut line = Line::empty(4_u32);
29 line[0] = extract_i8(value, 0);
31 line[1] = extract_i8(value, 8);
32 line[2] = extract_i8(value, 16);
33 line[3] = extract_i8(value, 24);
34
35 line
36}
37
38#[cube(launch_unchecked)]
39fn dequantize_per_tensor_symmetric_int8_kernel(
40 input: &QTensor,
41 output: &mut Tensor<Line<f32>>,
42 #[comptime] scheme: QuantizationScheme,
43) {
44 if ABSOLUTE_POS >= input.len() - 1 {
46 terminate!();
47 }
48
49 let qparams = QParams::new(scheme);
50 let (scale, _) = qparams.values(input);
51
52 let value = input[ABSOLUTE_POS];
53
54 if comptime!(output.line_size() == 4) {
56 output[ABSOLUTE_POS] = dequantize_symmetric_int8(unpack_i8s(value[0]), scale);
57 } else {
58 let out = dequantize_symmetric_int8::<f32>(unpack_i8s(value[0]), scale);
60
61 #[unroll]
62 for j in 0..out.size() {
63 output[ABSOLUTE_POS * out.size() + j] = Line::cast_from(out[j]);
64 }
65 }
66}
67
68pub fn dequantize<R, F>(tensor: CubeTensor<R>) -> CubeTensor<R>
70where
71 R: CubeRuntime,
72 F: CubeElement,
73{
74 let num_out_elems = tensor.shape.num_elements();
77 let num_elems = usize::div_ceil(num_out_elems, 4);
78 let line_size_in = 1;
79 let line_size_out = 1;
80 let cube_dim = CubeDim::default();
81 let cube_count = calculate_cube_count_elemwise(num_elems / line_size_in as usize, cube_dim);
82
83 let client = tensor.client.clone();
84 let handle = client.empty(num_out_elems * core::mem::size_of::<F>());
85
86 let output = CubeTensor::new_contiguous(
87 client.clone(),
88 tensor.device.clone(),
89 tensor.shape.clone(),
90 handle,
91 F::dtype(),
92 );
93
94 if let DType::QFloat(scheme) = tensor.dtype {
95 match scheme {
96 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8) => {
97 unsafe {
98 dequantize_per_tensor_symmetric_int8_kernel::launch_unchecked::<R>(
99 &client,
100 cube_count,
101 cube_dim,
102 tensor.as_array_arg::<u32>(line_size_in),
103 output.as_tensor_arg::<F>(line_size_out),
104 scheme,
105 )
106 };
107 }
108 }
109 }
110
111 output
112}