burn_cubecl/tensor/
quantization.rs1use burn_backend::{DType, Shape, TensorMetadata as _, quantization::QParamTensor};
2use burn_std::{Metadata, Strides};
3use cubecl::quant::scheme::{QuantStore, QuantValue};
4use cubecl::{client::ComputeClient, server::Handle};
5
6use crate::CubeRuntime;
7
8use super::CubeTensor;
9
10pub type QParams = burn_backend::quantization::QParams<QParamTensor>;
13
14impl<R: CubeRuntime> CubeTensor<R> {
15 pub fn new_quantized(
17 client: ComputeClient<R>,
18 handle: Handle,
19 shape: Shape,
20 device: R::Device,
21 strides: Strides,
22 dtype: DType,
23 qparams: QParams,
24 ) -> Self {
25 CubeTensor {
26 client,
27 handle,
28 meta: Box::new(Metadata::new(shape, strides)),
29 device,
30 dtype,
31 qparams: Some(qparams),
32 }
33 }
34
35 pub fn quantized_handles(&self) -> Option<(CubeTensor<R>, CubeTensor<R>)> {
40 let params = self.scales()?;
41 let scheme = match self.dtype {
42 DType::QFloat(sc) => sc,
43 _ => return None,
44 };
45 let values = match scheme.store {
46 QuantStore::Native => match scheme.value {
47 QuantValue::Q8F | QuantValue::Q8S => CubeTensor {
48 client: self.client.clone(),
49 handle: self.handle.clone(),
50 meta: self.meta.clone(),
51 device: self.device.clone(),
52 dtype: DType::I8,
53 qparams: None,
54 },
55 QuantValue::E4M3 | QuantValue::E5M2 => CubeTensor {
56 client: self.client.clone(),
57 handle: self.handle.clone(),
58 meta: self.meta.clone(),
59 device: self.device.clone(),
60 dtype: DType::U8,
61 qparams: None,
62 },
63 QuantValue::Q4F
64 | QuantValue::Q4S
65 | QuantValue::Q2F
66 | QuantValue::Q2S
67 | QuantValue::E2M1 => {
68 panic!("Can't store native sub-byte values")
69 }
70 },
71 QuantStore::PackedU32(packed_dim) => {
72 let packed_dim = self.rank() - packed_dim - 1;
73 let mut shape = self.shape();
74 shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
75
76 CubeTensor {
77 client: self.client.clone(),
78 handle: self.handle.clone(),
79 meta: Box::new(Metadata::new(shape, self.meta.strides.clone())),
80 device: self.device.clone(),
81 dtype: DType::U32,
82 qparams: None,
83 }
84 }
85 QuantStore::PackedNative(packed_dim) => match scheme.value {
86 QuantValue::E2M1 => {
87 let packed_dim = self.rank() - packed_dim - 1;
88 let mut shape = self.shape();
89 shape[packed_dim] = shape[packed_dim].div_ceil(scheme.num_quants());
90
91 CubeTensor {
92 client: self.client.clone(),
93 handle: self.handle.clone(),
94 meta: Box::new(Metadata::new(shape, self.meta.strides.clone())),
95 device: self.device.clone(),
96 dtype: DType::U8,
97 qparams: None,
98 }
99 }
100 other => panic!("{other:?} doesn't support native packing"),
101 },
102 };
103
104 Some((values, params))
105 }
106
107 pub fn scales(&self) -> Option<CubeTensor<R>> {
109 let qparams = self.qparams.as_ref()?;
110 let mut handle = self.handle.clone();
111 handle.offset_start = Some(qparams.scales.offset_start as u64);
112 handle.offset_end = Some(qparams.scales.offset_end as u64);
113
114 Some(CubeTensor::new(
115 self.client.clone(),
116 handle,
117 qparams.scales.metadata.clone(),
118 self.device.clone(),
119 qparams.scales.dtype,
120 ))
121 }
122}