burn_cubecl/kernel/quantization/
qtensor.rs

1#![allow(missing_docs)] // cube derive macros
2
3use burn_tensor::quantization::{QuantizationMode, QuantizationScheme};
4use cubecl::prelude::*;
5
6/// Quantization parameters.
7#[derive(CubeLaunch, CubeType)]
8pub struct QParams {
9    #[cube(comptime)]
10    scheme: QuantizationScheme,
11}
12
13/// Quantized tensor representation.
14pub type QTensor = Array<Line<u32>>;
15
16#[cube]
17impl QParams {
18    /// Create a new quantization parameters instance.
19    pub fn new(scheme: QuantizationScheme) -> Self {
20        QParams { scheme }
21    }
22
23    /// Get the quantization parameters values.
24    pub fn values(&self, tensor: &QTensor) -> (f32, i32) {
25        let len = tensor.len();
26        match comptime!(self.scheme) {
27            // Symmetric quantization only contains the scaling factor as the last element
28            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, _) => {
29                (f32::reinterpret(tensor[len - 1][tensor.line_size() - 1]), 0)
30            }
31        }
32    }
33}