1use burn_backend::{
2 Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorMetadata,
3 TensorPrimitive,
4 ops::QTensorOps,
5 quantization::{
6 QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue,
7 QuantizationParametersPrimitive, params_shape,
8 },
9 tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
10};
11use burn_std::{FloatDType, Metadata};
12use cubecl::server::{MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutStrategy};
13use cubecl::{e2m1x2, quant::scheme::QuantStore};
14
15use crate::{
16 CubeBackend, CubeRuntime, FloatElement, IntElement,
17 element::BoolElement,
18 kernel::{self, matmul::MatmulStrategy},
19 tensor::{CubeTensor, QParams},
20};
21
22use super::{into_data, permute, swap_dims};
23
24fn new_qtensor_optimized<R: CubeRuntime>(
26 data: Bytes,
27 shape: impl Into<Shape>,
28 scheme: QuantScheme,
29 device: &R::Device,
30) -> CubeTensor<R> {
31 new_qtensor(data, shape, scheme, device, MemoryLayoutStrategy::Optimized)
32}
33
34fn new_qtensor<R: CubeRuntime>(
36 data: Bytes,
37 shape: impl Into<Shape>,
38 scheme: QuantScheme,
39 device: &R::Device,
40 kind: MemoryLayoutStrategy,
41) -> CubeTensor<R> {
42 new_quantized(shape, scheme, device, Some(data), kind)
43}
44
45pub fn empty_qtensor_optimized<R: CubeRuntime>(
47 shape: impl Into<Shape>,
48 scheme: QuantScheme,
49 device: &R::Device,
50) -> CubeTensor<R> {
51 empty_qtensor(shape, scheme, device, MemoryLayoutStrategy::Optimized)
52}
53
54pub fn empty_qtensor<R: CubeRuntime>(
56 shape: impl Into<Shape>,
57 scheme: QuantScheme,
58 device: &R::Device,
59 kind: MemoryLayoutStrategy,
60) -> CubeTensor<R> {
61 new_quantized(shape, scheme, device, None, kind)
62}
63
64fn new_quantized<R: CubeRuntime>(
65 shape: impl Into<Shape>,
66 scheme: QuantScheme,
67 device: &R::Device,
68 data: Option<Bytes>,
69 alloc_kind: MemoryLayoutStrategy,
70) -> CubeTensor<R> {
71 let client = R::client(device);
72 let shape: Shape = shape.into();
73 let mut shape_value: Shape = shape.clone();
74
75 let rank = shape.rank();
76 let shape_last = shape[rank - 1];
77 let num_quants = scheme.num_quants();
78
79 let data_size = match scheme.store {
80 QuantStore::PackedU32(_) => {
81 if !shape_last.is_multiple_of(num_quants) {
82 panic!("Can't store in u32")
83 }
84 shape_value[rank - 1] = shape_last.div_ceil(num_quants);
85 size_of::<u32>()
86 }
87 QuantStore::Native => match scheme.value {
88 QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {
89 size_of::<i8>()
90 }
91 QuantValue::Q4F
92 | QuantValue::Q4S
93 | QuantValue::Q2F
94 | QuantValue::Q2S
95 | QuantValue::E2M1 => {
96 panic!("Can't store native sub-byte values")
97 }
98 },
99 QuantStore::PackedNative(_) => match scheme.value {
100 QuantValue::E2M1 => size_of::<e2m1x2>(),
101 other => panic!("{other:?} doesn't support native packing"),
102 },
103 };
104
105 let scales_dtype = match scheme.param {
106 QuantParam::F32 => DType::F32,
107 QuantParam::F16 => DType::F16,
108 QuantParam::BF16 => DType::BF16,
109 QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,
111 };
112
113 let scales_shape = params_shape(&shape, scheme.level);
114 let data_desc = MemoryLayoutDescriptor::new(alloc_kind, shape_value.clone(), data_size);
115 let scales_desc =
116 MemoryLayoutDescriptor::new(alloc_kind, scales_shape.clone(), scales_dtype.size());
117
118 let mut tensors = match data {
119 Some(data) => {
120 let num_bytes = shape_value.num_elements() * data_size;
121
122 match data.split(num_bytes) {
123 Ok((bytes_data, bytes_scales)) => client
124 .create_tensors(vec![(data_desc, bytes_data), (scales_desc, bytes_scales)]),
125 Err((data, _)) => client.create_tensors_from_slices(vec![
126 (data_desc, &data[..num_bytes]),
127 (scales_desc, &data[num_bytes..]),
128 ]),
129 }
130 }
131 None => client.empty_tensors(vec![data_desc, scales_desc]),
132 };
133 let MemoryLayout {
134 memory: scales_handle,
135 strides: scales_strides,
136 } = tensors.remove(1);
137 let MemoryLayout { memory, strides } = tensors.remove(0);
138
139 let scales = QParamTensor {
140 offset_start: scales_handle.offset_start.unwrap_or(0) as usize,
141 offset_end: scales_handle.offset_end.unwrap_or(0) as usize,
142 metadata: Metadata::new(scales_shape, scales_strides),
143 dtype: scales_dtype,
144 };
145 let qparams = QParams { scales };
146
147 CubeTensor::new_quantized(
148 client,
149 memory,
150 shape,
151 device.clone(),
152 strides,
153 DType::QFloat(scheme),
154 qparams,
155 )
156}
157
158impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
159where
160 R: CubeRuntime,
161 F: FloatElement,
162 I: IntElement,
163 BT: BoolElement,
164{
165 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
166 match data.dtype {
167 DType::QFloat(scheme) => match scheme {
168 QuantScheme {
169 level: QuantLevel::Tensor | QuantLevel::Block(_),
170 mode: QuantMode::Symmetric,
171 value:
172 QuantValue::Q8F
173 | QuantValue::Q8S
174 | QuantValue::Q4F
175 | QuantValue::Q4S
176 | QuantValue::Q2F
177 | QuantValue::Q2S
178 | QuantValue::E4M3
179 | QuantValue::E5M2
180 | QuantValue::E2M1,
181 ..
182 } => {
183 new_qtensor_optimized(data.bytes, data.shape.clone(), scheme, device)
186 }
187 },
188 _ => panic!(
189 "Invalid dtype (expected DType::QFloat, got {:?})",
190 data.dtype
191 ),
192 }
193 }
194
195 fn quantize(
198 tensor: FloatTensor<Self>,
199 scheme: &QuantScheme,
200 qparams: QuantizationParametersPrimitive<Self>,
201 ) -> QuantizedTensor<Self> {
202 kernel::quantization::quantize(tensor, scheme, qparams.scales)
203 }
204
205 fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
206 kernel::quantization::dequantize(tensor, dtype.into())
207 }
208
209 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
210 tensor.device.clone()
211 }
212
213 fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
214 super::to_device(tensor, device)
215 }
216
217 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
218 super::q_reshape(tensor, shape)
219 }
220
221 async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
222 if tensor.qparams.is_none() {
223 return into_data(tensor).await;
224 }
225
226 let (shape, dtype) = (tensor.shape(), tensor.dtype);
227 let (values, params) = tensor.quantized_handles().unwrap();
228
229 let mut data_values = into_data(values).await?;
230 let data_params = into_data(params).await?;
231
232 data_values.bytes.extend_from_byte_slice(&data_params.bytes);
233
234 Ok(TensorData {
235 bytes: data_values.bytes,
236 shape,
237 dtype,
238 })
239 }
240
241 fn q_swap_dims(
242 tensor: QuantizedTensor<Self>,
243 dim1: usize,
244 dim2: usize,
245 ) -> QuantizedTensor<Self> {
246 swap_dims(tensor, dim1, dim2)
247 }
248
249 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
250 permute(tensor, axes)
251 }
252
253 fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
254 unimplemented!()
255 }
256
257 fn q_gather(
258 _dim: usize,
259 _tensor: QuantizedTensor<Self>,
260 _indices: IntTensor<Self>,
261 ) -> QuantizedTensor<Self> {
262 unimplemented!()
263 }
264
265 fn q_select(
266 _tensor: QuantizedTensor<Self>,
267 _dim: usize,
268 _indices: IntTensor<Self>,
269 ) -> QuantizedTensor<Self> {
270 unimplemented!()
271 }
272
273 fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
274 unimplemented!()
275 }
276
277 fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
278 unimplemented!()
279 }
280
281 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
282 let (propagation, scheme) = match (&lhs, &rhs) {
283 (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),
284 (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),
285 _ => unreachable!(),
286 };
287
288 let out_dtype = match (&lhs, &rhs) {
290 (TensorPrimitive::Float(lhs), _) => lhs.dtype,
291 (_, TensorPrimitive::Float(rhs)) => rhs.dtype,
292 _ => F::dtype(),
293 };
294
295 let (_lhs_dtype, lhs) = match lhs {
296 TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),
297 TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),
298 };
299 let (_rhs_dtype, rhs) = match rhs {
300 TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),
301 TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),
302 };
303
304 let out =
305 kernel::matmul::matmul(lhs, rhs, None, MatmulStrategy::default(), out_dtype).unwrap();
306
307 match propagation {
308 QuantPropagation::Propagate => {
309 TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))
310 }
311 QuantPropagation::Inhibit => TensorPrimitive::Float(out),
312 }
313 }
314}