1use burn_backend::{
2 Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,
3 ops::QTensorOps,
4 quantization::{
5 QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue,
6 QuantizationParametersPrimitive, params_shape,
7 },
8 tensor::{Device, FloatElem, FloatTensor, IntTensor, QuantizedTensor},
9};
10use cubecl::server::{Allocation, AllocationDescriptor, AllocationKind};
11use cubecl::{e2m1x2, quant::scheme::QuantStore};
12
13use crate::{
14 CubeBackend, CubeRuntime, FloatElement, IntElement,
15 element::BoolElement,
16 kernel::{self, matmul::MatmulStrategy},
17 tensor::{CubeTensor, QParams},
18};
19
20use super::{into_data, permute, swap_dims};
21
22fn new_qtensor_optimized<R: CubeRuntime>(
24 data: Bytes,
25 shape: impl Into<Shape>,
26 scheme: QuantScheme,
27 device: &R::Device,
28) -> CubeTensor<R> {
29 new_qtensor(data, shape, scheme, device, AllocationKind::Optimized)
30}
31
32fn new_qtensor<R: CubeRuntime>(
34 data: Bytes,
35 shape: impl Into<Shape>,
36 scheme: QuantScheme,
37 device: &R::Device,
38 kind: AllocationKind,
39) -> CubeTensor<R> {
40 new_quantized(shape, scheme, device, Some(data), kind)
41}
42
43pub fn empty_qtensor_optimized<R: CubeRuntime>(
45 shape: impl Into<Shape>,
46 scheme: QuantScheme,
47 device: &R::Device,
48) -> CubeTensor<R> {
49 empty_qtensor(shape, scheme, device, AllocationKind::Optimized)
50}
51
52pub fn empty_qtensor<R: CubeRuntime>(
54 shape: impl Into<Shape>,
55 scheme: QuantScheme,
56 device: &R::Device,
57 kind: AllocationKind,
58) -> CubeTensor<R> {
59 new_quantized(shape, scheme, device, None, kind)
60}
61
62fn new_quantized<R: CubeRuntime>(
63 shape: impl Into<Shape>,
64 scheme: QuantScheme,
65 device: &R::Device,
66 data: Option<Bytes>,
67 alloc_kind: AllocationKind,
68) -> CubeTensor<R> {
69 let client = R::client(device);
70 let shape: Shape = shape.into();
71 let mut shape_value: Shape = shape.clone();
72
73 let rank = shape.rank();
74 let shape_last = shape[rank - 1];
75 let num_quants = scheme.num_quants();
76
77 let data_size = match scheme.store {
78 QuantStore::PackedU32(_) => {
79 if !shape_last.is_multiple_of(num_quants) {
80 panic!("Can't store in u32")
81 }
82 shape_value.dims[rank - 1] = shape_last.div_ceil(num_quants);
83 size_of::<u32>()
84 }
85 QuantStore::Native => match scheme.value {
86 QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {
87 size_of::<i8>()
88 }
89 QuantValue::Q4F
90 | QuantValue::Q4S
91 | QuantValue::Q2F
92 | QuantValue::Q2S
93 | QuantValue::E2M1 => {
94 panic!("Can't store native sub-byte values")
95 }
96 },
97 QuantStore::PackedNative(_) => match scheme.value {
98 QuantValue::E2M1 => size_of::<e2m1x2>(),
99 other => panic!("{other:?} doesn't support native packing"),
100 },
101 };
102
103 let scales_dtype = match scheme.param {
104 QuantParam::F32 => DType::F32,
105 QuantParam::F16 => DType::F16,
106 QuantParam::BF16 => DType::BF16,
107 QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,
109 };
110
111 let scales_shape = params_shape(&shape, scheme.level);
112 let data_desc = AllocationDescriptor::new(alloc_kind, &shape_value.dims, data_size);
113 let scales_desc =
114 AllocationDescriptor::new(alloc_kind, &scales_shape.dims, scales_dtype.size());
115
116 let mut tensors = match data {
117 Some(data) => {
118 let num_bytes = shape_value.num_elements() * data_size;
119
120 match data.split(num_bytes) {
121 Ok((bytes_data, bytes_scales)) => client
122 .create_tensors(vec![(data_desc, bytes_data), (scales_desc, bytes_scales)]),
123 Err((data, _)) => client.create_tensors_from_slices(vec![
124 (data_desc, &data[..num_bytes]),
125 (scales_desc, &data[num_bytes..]),
126 ]),
127 }
128 }
129 None => client.empty_tensors(vec![data_desc, scales_desc]),
130 };
131 let Allocation {
132 handle: scales_handle,
133 strides: scales_strides,
134 } = tensors.remove(1);
135 let Allocation { handle, strides } = tensors.remove(0);
136
137 let scales = QParamTensor {
138 offset_start: scales_handle.offset_start.unwrap_or(0) as usize,
139 offset_end: scales_handle.offset_end.unwrap_or(0) as usize,
140 shape: scales_shape,
141 strides: scales_strides,
142 dtype: scales_dtype,
143 };
144 let qparams = QParams { scales };
145
146 CubeTensor::new_quantized(
147 client,
148 handle,
149 shape,
150 device.clone(),
151 strides,
152 DType::QFloat(scheme),
153 qparams,
154 )
155}
156
157impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
158where
159 R: CubeRuntime,
160 F: FloatElement,
161 I: IntElement,
162 BT: BoolElement,
163{
164 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
165 match data.dtype {
166 DType::QFloat(scheme) => match scheme {
167 QuantScheme {
168 level: QuantLevel::Tensor | QuantLevel::Block(_),
169 mode: QuantMode::Symmetric,
170 value:
171 QuantValue::Q8F
172 | QuantValue::Q8S
173 | QuantValue::Q4F
174 | QuantValue::Q4S
175 | QuantValue::Q2F
176 | QuantValue::Q2S
177 | QuantValue::E4M3
178 | QuantValue::E5M2
179 | QuantValue::E2M1,
180 ..
181 } => {
182 new_qtensor_optimized(data.bytes, data.shape.clone(), scheme, device)
185 }
186 },
187 _ => panic!(
188 "Invalid dtype (expected DType::QFloat, got {:?})",
189 data.dtype
190 ),
191 }
192 }
193
194 fn quantize(
197 tensor: FloatTensor<Self>,
198 scheme: &QuantScheme,
199 qparams: QuantizationParametersPrimitive<Self>,
200 ) -> QuantizedTensor<Self> {
201 kernel::quantization::quantize(tensor, scheme, qparams.scales)
202 }
203
204 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
205 kernel::quantization::dequantize(tensor, FloatElem::<Self>::dtype())
206 }
207
208 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
209 tensor.device.clone()
210 }
211
212 fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
213 super::to_device(tensor, device)
214 }
215
216 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
217 super::q_reshape(tensor, shape)
218 }
219
220 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
221 if tensor.qparams.is_none() {
222 return into_data(tensor).await;
223 }
224
225 let (shape, dtype) = (tensor.shape.dims.clone(), tensor.dtype);
226 let (values, params) = tensor.quantized_handles().unwrap();
227
228 let mut data_values = into_data(values).await?;
229 let data_params = into_data(params).await?;
230
231 data_values.bytes.extend_from_byte_slice(&data_params.bytes);
232
233 Ok(TensorData {
234 bytes: data_values.bytes,
235 shape,
236 dtype,
237 })
238 }
239
240 fn q_swap_dims(
241 tensor: QuantizedTensor<Self>,
242 dim1: usize,
243 dim2: usize,
244 ) -> QuantizedTensor<Self> {
245 swap_dims(tensor, dim1, dim2)
246 }
247
248 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
249 permute(tensor, axes)
250 }
251
252 fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
253 unimplemented!()
254 }
255
256 fn q_gather(
257 _dim: usize,
258 _tensor: QuantizedTensor<Self>,
259 _indices: IntTensor<Self>,
260 ) -> QuantizedTensor<Self> {
261 unimplemented!()
262 }
263
264 fn q_select(
265 _tensor: QuantizedTensor<Self>,
266 _dim: usize,
267 _indices: IntTensor<Self>,
268 ) -> QuantizedTensor<Self> {
269 unimplemented!()
270 }
271
272 fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
273 unimplemented!()
274 }
275
276 fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
277 unimplemented!()
278 }
279
280 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
281 let (propagation, scheme) = match (&lhs, &rhs) {
282 (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),
283 (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),
284 _ => unreachable!(),
285 };
286
287 let out_dtype = match (&lhs, &rhs) {
289 (TensorPrimitive::Float(lhs), _) => lhs.dtype,
290 (_, TensorPrimitive::Float(rhs)) => rhs.dtype,
291 _ => F::dtype(),
292 };
293
294 let (_lhs_dtype, lhs) = match lhs {
295 TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),
296 TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),
297 };
298 let (_rhs_dtype, rhs) = match rhs {
299 TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),
300 TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),
301 };
302
303 let out =
304 kernel::matmul::matmul(lhs, rhs, None, MatmulStrategy::default(), out_dtype).unwrap();
305
306 match propagation {
307 QuantPropagation::Propagate => {
308 TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))
309 }
310 QuantPropagation::Inhibit => TensorPrimitive::Float(out),
311 }
312 }
313}