burn_cubecl/tensor/
quantization.rs1use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor};
2use cubecl::quant::scheme::{QuantStore, QuantValue};
3use cubecl::{client::ComputeClient, server::Handle};
4
5use crate::CubeRuntime;
6
7use super::CubeTensor;
8
9pub type QParams = burn_backend::quantization::QParams<QParamTensor>;
12
13impl<R: CubeRuntime> CubeTensor<R> {
14 pub fn new_quantized(
16 client: ComputeClient<R>,
17 handle: Handle,
18 shape: Shape,
19 device: R::Device,
20 strides: Vec<usize>,
21 dtype: DType,
22 qparams: QParams,
23 ) -> Self {
24 CubeTensor {
25 client,
26 handle,
27 shape,
28 device,
29 strides,
30 dtype,
31 qparams: Some(qparams),
32 }
33 }
34
35 pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {
40 let params = self.scales()?;
41 let scheme = match self.dtype {
42 DType::QFloat(sc) => sc,
43 _ => return None,
44 };
45 let values = match scheme.store {
46 QuantStore::Native => match scheme.value {
47 QuantValue::Q8F | QuantValue::Q8S => CubeTensor {
48 client: self.client.clone(),
49 handle: self.handle.clone(),
50 shape: self.shape.clone(),
51 device: self.device.clone(),
52 strides: self.strides.clone(),
53 dtype: DType::I8,
54 qparams: None,
55 },
56 QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor {
57 client: self.client.clone(),
58 handle: self.handle.clone(),
59 shape: self.shape.clone(),
60 device: self.device.clone(),
61 strides: self.strides.clone(),
62 dtype: DType::U8,
63 qparams: None,
64 },
65 QuantValue::Q4F
66 | QuantValue::Q4S
67 | QuantValue::Q2F
68 | QuantValue::Q2S
69 | QuantValue::E2M1 => {
70 panic!("Can't store native sub-byte values")
71 }
72 },
73 QuantStore::PackedU32(packed_dim) => {
74 let packed_dim = self.rank() - packed_dim - 1;
75 let mut shape = self.shape.clone();
76 shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
77
78 CubeTensor {
79 client: self.client.clone(),
80 handle: self.handle.clone(),
81 shape,
82 device: self.device.clone(),
83 strides: self.strides.clone(),
84 dtype: DType::U32,
85 qparams: None,
86 }
87 }
88 QuantStore::PackedNative(packed_dim) => match scheme.value {
89 QuantValue::E2M1 => {
90 let packed_dim = self.rank() - packed_dim - 1;
91 let mut shape = self.shape.clone();
92 shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
93
94 CubeTensor {
95 client: self.client.clone(),
96 handle: self.handle.clone(),
97 shape,
98 device: self.device.clone(),
99 strides: self.strides.clone(),
100 dtype: DType::U8,
101 qparams: None,
102 }
103 }
104 other => panic!("{other:?} doesn't support native packing"),
105 },
106 };
107
108 Some((values, params))
109 }
110
111 pub fn scales(&self) -> Option<CubeTensor<R>> {
113 let qparams = self.qparams.as_ref()?;
114 let mut handle = self.handle.clone();
115 handle.offset_start = Some(qparams.scales.offset_start as u64);
116 handle.offset_end = Some(qparams.scales.offset_end as u64);
117
118 Some(CubeTensor::new(
119 self.client.clone(),
120 handle,
121 qparams.scales.shape.clone(),
122 self.device.clone(),
123 qparams.scales.strides.clone(),
124 qparams.scales.dtype,
125 ))
126 }
127}