burn_jit/kernel/quantization/
qtensor.rs

1#![allow(missing_docs)] // cube derive macros
2
3use burn_tensor::quantization::QuantizationScheme;
4use cubecl::prelude::*;
5
6/// Quantization parameters.
7#[derive(CubeLaunch)]
8pub struct QParams {
9    #[cube(comptime)]
10    scheme: QuantizationScheme,
11}
12
13/// Quantized tensor representation.
14pub type QTensor = Array<Line<u32>>;
15
16#[cube]
17impl QParams {
18    /// Create a new quantization parameters instance.
19    pub fn new(scheme: QuantizationScheme) -> Self {
20        QParams { scheme }
21    }
22
23    /// Get the quantization parameters values.
24    pub fn values(&self, tensor: &QTensor) -> (f32, i32) {
25        let len = tensor.len();
26        match comptime!(self.scheme) {
27            QuantizationScheme::PerTensorAffine(_) => match comptime!(tensor.line_size()) {
28                // For line size of 1, scale is the last value in the buffer
29                1 => (
30                    f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]),
31                    i32::cast_from(tensor[len - 2][tensor.line_size() - 1]),
32                ),
33                // For any other line size > 1, scale and zero-point offset are the last two elements
34                _ => {
35                    let values = tensor[len - 1];
36                    (
37                        f32::bitcast_from(values[tensor.line_size() - 1]),
38                        i32::cast_from(values[tensor.line_size() - 2]),
39                    )
40                }
41            },
42            // Symmetric quantization only contains the scaling factor as the last element
43            QuantizationScheme::PerTensorSymmetric(_) => (
44                f32::bitcast_from(tensor[len - 1][tensor.line_size() - 1]),
45                0,
46            ),
47        }
48    }
49}