burn_cubecl/tensor/
quantization.rs1use burn_tensor::{DType, Shape, quantization::QParamTensor};
2use cubecl::{client::ComputeClient, server::Handle};
3use cubecl_quant::scheme::{QuantStore, QuantValue};
4
5use crate::CubeRuntime;
6
7use super::CubeTensor;
8
9pub type QParams = burn_tensor::quantization::QParams<QParamTensor>;
12
13impl<R: CubeRuntime> CubeTensor<R> {
14 pub fn new_quantized(
16 client: ComputeClient<R::Server>,
17 handle: Handle,
18 shape: Shape,
19 device: R::Device,
20 strides: Vec<usize>,
21 dtype: DType,
22 qparams: QParams,
23 ) -> Self {
24 CubeTensor {
25 client,
26 handle,
27 shape,
28 device,
29 strides,
30 dtype,
31 qparams: Some(qparams),
32 }
33 }
34
35 pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {
37 let params = self.scales()?;
38 let scheme = match self.dtype {
39 DType::QFloat(sc) => sc,
40 _ => return None,
41 };
42 let values = match scheme.store {
43 QuantStore::Native => match scheme.value {
44 QuantValue::Q8F | QuantValue::Q8S => CubeTensor {
45 client: self.client.clone(),
46 handle: self.handle.clone(),
47 shape: self.shape.clone(),
48 device: self.device.clone(),
49 strides: self.strides.clone(),
50 dtype: DType::I8,
51 qparams: None,
52 },
53 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
54 unimplemented!("Not yet supported")
55 }
56 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
57 panic!("Can't store native sub-byte values")
58 }
59 },
60 QuantStore::U32 => {
61 let major_dim = self.major_dim();
62 let mut shape = self.shape.clone();
63 shape[major_dim] = shape[major_dim].div_ceil(scheme.num_quants());
64
65 CubeTensor {
66 client: self.client.clone(),
67 handle: self.handle.clone(),
68 shape,
69 device: self.device.clone(),
70 strides: self.strides.clone(),
71 dtype: DType::U32,
72 qparams: None,
73 }
74 }
75 };
76
77 Some((values, params))
78 }
79
80 fn major_dim(&self) -> usize {
81 let rank = self.shape.num_dims();
82 self.strides
83 .iter()
84 .enumerate()
85 .rev()
86 .find(|(_, s)| **s == 1)
87 .map(|(i, _)| i)
88 .unwrap_or(rank - 1)
89 }
90
91 pub fn scales(&self) -> Option<CubeTensor<R>> {
93 let qparams = self.qparams.as_ref()?;
94 let mut handle = self.handle.clone();
95 handle.offset_start = Some(qparams.scales.offset_start as u64);
96 handle.offset_end = Some(qparams.scales.offset_end as u64);
97
98 Some(CubeTensor::new(
99 self.client.clone(),
100 handle,
101 qparams.scales.shape.clone(),
102 self.device.clone(),
103 qparams.scales.strides.clone(),
104 qparams.scales.dtype,
105 ))
106 }
107}