Skip to main content

burn_cubecl/tensor/
quantization.rs

1use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor};
2use burn_std::{Metadata, Strides};
3use cubecl::quant::scheme::{QuantStore, QuantValue};
4use cubecl::{client::ComputeClient, server::Handle};
5
6use crate::CubeRuntime;
7
8use super::CubeTensor;
9
10/// Runtime parameters for quantization. Can be used to construct a scales handle from the base
11/// tensor handle.
12pub type QParams = burn_backend::quantization::QParams<QParamTensor>;
13
14impl<R: CubeRuntime> CubeTensor<R> {
15    /// Create a new quantized tensor
16    pub fn new_quantized(
17        client: ComputeClient<R>,
18        handle: Handle,
19        shape: Shape,
20        device: R::Device,
21        strides: Strides,
22        dtype: DType,
23        qparams: QParams,
24    ) -> Self {
25        CubeTensor {
26            client,
27            handle,
28            meta: Box::new(Metadata::new(shape, strides)),
29            device,
30            dtype,
31            qparams: Some(qparams),
32        }
33    }
34
35    /// Returns the two tensors: (values, params) for a quantized tensor.
36    /// For the values, native types that aren't supported as a normal `DType` will be returned
37    /// as an unsigned integer tensor representing the bits. Should be reconstructed using `from_bits`
38    /// in kernels.
39    pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {
40        let params = self.scales()?;
41        let scheme = match self.dtype {
42            DType::QFloat(sc) => sc,
43            _ => return None,
44        };
45        let values = match scheme.store {
46            QuantStore::Native => match scheme.value {
47                QuantValue::Q8F | QuantValue::Q8S => CubeTensor {
48                    client: self.client.clone(),
49                    handle: self.handle.clone(),
50                    meta: self.meta.clone(),
51                    device: self.device.clone(),
52                    dtype: DType::I8,
53                    qparams: None,
54                },
55                QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor {
56                    client: self.client.clone(),
57                    handle: self.handle.clone(),
58                    meta: self.meta.clone(),
59                    device: self.device.clone(),
60                    dtype: DType::U8,
61                    qparams: None,
62                },
63                QuantValue::Q4F
64                | QuantValue::Q4S
65                | QuantValue::Q2F
66                | QuantValue::Q2S
67                | QuantValue::E2M1 => {
68                    panic!("Can't store native sub-byte values")
69                }
70            },
71            QuantStore::PackedU32(packed_dim) => {
72                let packed_dim = self.rank() - packed_dim - 1;
73                let mut shape = self.shape();
74                shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
75
76                CubeTensor {
77                    client: self.client.clone(),
78                    handle: self.handle.clone(),
79                    meta: Box::new(Metadata::new(shape, self.meta.strides.clone())),
80                    device: self.device.clone(),
81                    dtype: DType::U32,
82                    qparams: None,
83                }
84            }
85            QuantStore::PackedNative(packed_dim) => match scheme.value {
86                QuantValue::E2M1 => {
87                    let packed_dim = self.rank() - packed_dim - 1;
88                    let mut shape = self.shape();
89                    shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
90
91                    CubeTensor {
92                        client: self.client.clone(),
93                        handle: self.handle.clone(),
94                        meta: Box::new(Metadata::new(shape, self.meta.strides.clone())),
95                        device: self.device.clone(),
96                        dtype: DType::U8,
97                        qparams: None,
98                    }
99                }
100                other => panic!("{other:?} doesn't support native packing"),
101            },
102        };
103
104        Some((values, params))
105    }
106
107    /// Construct a separate tensor for the quantization scales, if present
108    pub fn scales(&self) -> Option<CubeTensor<R>> {
109        let qparams = self.qparams.as_ref()?;
110        let mut handle = self.handle.clone();
111        handle.offset_start = Some(qparams.scales.offset_start as u64);
112        handle.offset_end = Some(qparams.scales.offset_end as u64);
113
114        Some(CubeTensor::new(
115            self.client.clone(),
116            handle,
117            qparams.scales.metadata.clone(),
118            self.device.clone(),
119            qparams.scales.dtype,
120        ))
121    }
122}