burn_cubecl/tensor/
quantization.rs

1use burn_tensor::{DType, Shape, quantization::QParamTensor};
2use cubecl::{client::ComputeClient, server::Handle};
3use cubecl_quant::scheme::{QuantStore, QuantValue};
4
5use crate::CubeRuntime;
6
7use super::CubeTensor;
8
9/// Runtime parameters for quantization. Can be used to construct a scales handle from the base
10/// tensor handle.
11pub type QParams = burn_tensor::quantization::QParams<QParamTensor>;
12
13impl<R: CubeRuntime> CubeTensor<R> {
14    /// Create a new quantized tensor
15    pub fn new_quantized(
16        client: ComputeClient<R::Server>,
17        handle: Handle,
18        shape: Shape,
19        device: R::Device,
20        strides: Vec<usize>,
21        dtype: DType,
22        qparams: QParams,
23    ) -> Self {
24        CubeTensor {
25            client,
26            handle,
27            shape,
28            device,
29            strides,
30            dtype,
31            qparams: Some(qparams),
32        }
33    }
34
35    /// Returns the two tensors: (values, params) for a quantized tensor.
36    pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {
37        let params = self.scales()?;
38        let scheme = match self.dtype {
39            DType::QFloat(sc) => sc,
40            _ => return None,
41        };
42        let values = match scheme.store {
43            QuantStore::Native => match scheme.value {
44                QuantValue::Q8F | QuantValue::Q8S => CubeTensor {
45                    client: self.client.clone(),
46                    handle: self.handle.clone(),
47                    shape: self.shape.clone(),
48                    device: self.device.clone(),
49                    strides: self.strides.clone(),
50                    dtype: DType::I8,
51                    qparams: None,
52                },
53                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
54                    unimplemented!("Not yet supported")
55                }
56                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
57                    panic!("Can't store native sub-byte values")
58                }
59            },
60            QuantStore::U32 => {
61                let major_dim = self.major_dim();
62                let mut shape = self.shape.clone();
63                shape[major_dim] = shape[major_dim].div_ceil(scheme.num_quants());
64
65                CubeTensor {
66                    client: self.client.clone(),
67                    handle: self.handle.clone(),
68                    shape,
69                    device: self.device.clone(),
70                    strides: self.strides.clone(),
71                    dtype: DType::U32,
72                    qparams: None,
73                }
74            }
75        };
76
77        Some((values, params))
78    }
79
80    fn major_dim(&self) -> usize {
81        let rank = self.shape.num_dims();
82        self.strides
83            .iter()
84            .enumerate()
85            .rev()
86            .find(|(_, s)| **s == 1)
87            .map(|(i, _)| i)
88            .unwrap_or(rank - 1)
89    }
90
91    /// Construct a separate tensor for the quantization scales, if present
92    pub fn scales(&self) -> Option<CubeTensor<R>> {
93        let qparams = self.qparams.as_ref()?;
94        let mut handle = self.handle.clone();
95        handle.offset_start = Some(qparams.scales.offset_start as u64);
96        handle.offset_end = Some(qparams.scales.offset_end as u64);
97
98        Some(CubeTensor::new(
99            self.client.clone(),
100            handle,
101            qparams.scales.shape.clone(),
102            self.device.clone(),
103            qparams.scales.strides.clone(),
104            qparams.scales.dtype,
105        ))
106    }
107}