burn_cubecl/tensor/
quantization.rs

1use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor};
2use cubecl::quant::scheme::{QuantStore, QuantValue};
3use cubecl::{client::ComputeClient, server::Handle};
4
5use crate::CubeRuntime;
6
7use super::CubeTensor;
8
9/// Runtime parameters for quantization. Can be used to construct a scales handle from the base
10/// tensor handle.
11pub type QParams = burn_backend::quantization::QParams<QParamTensor>;
12
13impl<R: CubeRuntime> CubeTensor<R> {
14    /// Create a new quantized tensor
15    pub fn new_quantized(
16        client: ComputeClient<R>,
17        handle: Handle,
18        shape: Shape,
19        device: R::Device,
20        strides: Vec<usize>,
21        dtype: DType,
22        qparams: QParams,
23    ) -> Self {
24        CubeTensor {
25            client,
26            handle,
27            shape,
28            device,
29            strides,
30            dtype,
31            qparams: Some(qparams),
32        }
33    }
34
35    /// Returns the two tensors: (values, params) for a quantized tensor.
36    /// For the values, native types that aren't supported as a normal `DType` will be returned
37    /// as an unsigned integer tensor representing the bits. Should be reconstructed using `from_bits`
38    /// in kernels.
39    pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {
40        let params = self.scales()?;
41        let scheme = match self.dtype {
42            DType::QFloat(sc) => sc,
43            _ => return None,
44        };
45        let values = match scheme.store {
46            QuantStore::Native => match scheme.value {
47                QuantValue::Q8F | QuantValue::Q8S => CubeTensor {
48                    client: self.client.clone(),
49                    handle: self.handle.clone(),
50                    shape: self.shape.clone(),
51                    device: self.device.clone(),
52                    strides: self.strides.clone(),
53                    dtype: DType::I8,
54                    qparams: None,
55                },
56                QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor {
57                    client: self.client.clone(),
58                    handle: self.handle.clone(),
59                    shape: self.shape.clone(),
60                    device: self.device.clone(),
61                    strides: self.strides.clone(),
62                    dtype: DType::U8,
63                    qparams: None,
64                },
65                QuantValue::Q4F
66                | QuantValue::Q4S
67                | QuantValue::Q2F
68                | QuantValue::Q2S
69                | QuantValue::E2M1 => {
70                    panic!("Can't store native sub-byte values")
71                }
72            },
73            QuantStore::PackedU32(packed_dim) => {
74                let packed_dim = self.rank() - packed_dim - 1;
75                let mut shape = self.shape.clone();
76                shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
77
78                CubeTensor {
79                    client: self.client.clone(),
80                    handle: self.handle.clone(),
81                    shape,
82                    device: self.device.clone(),
83                    strides: self.strides.clone(),
84                    dtype: DType::U32,
85                    qparams: None,
86                }
87            }
88            QuantStore::PackedNative(packed_dim) => match scheme.value {
89                QuantValue::E2M1 => {
90                    let packed_dim = self.rank() - packed_dim - 1;
91                    let mut shape = self.shape.clone();
92                    shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
93
94                    CubeTensor {
95                        client: self.client.clone(),
96                        handle: self.handle.clone(),
97                        shape,
98                        device: self.device.clone(),
99                        strides: self.strides.clone(),
100                        dtype: DType::U8,
101                        qparams: None,
102                    }
103                }
104                other => panic!("{other:?} doesn't support native packing"),
105            },
106        };
107
108        Some((values, params))
109    }
110
111    /// Construct a separate tensor for the quantization scales, if present
112    pub fn scales(&self) -> Option<CubeTensor<R>> {
113        let qparams = self.qparams.as_ref()?;
114        let mut handle = self.handle.clone();
115        handle.offset_start = Some(qparams.scales.offset_start as u64);
116        handle.offset_end = Some(qparams.scales.offset_end as u64);
117
118        Some(CubeTensor::new(
119            self.client.clone(),
120            handle,
121            qparams.scales.shape.clone(),
122            self.device.clone(),
123            qparams.scales.strides.clone(),
124            qparams.scales.dtype,
125        ))
126    }
127}