1use burn_tensor::{
2 DType, Device, Shape, TensorData, TensorPrimitive,
3 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
4 quantization::{
5 QParamTensor, QTensorPrimitive, QuantLevel, QuantMode, QuantParam, QuantPropagation,
6 QuantScheme, QuantValue, QuantizationParametersPrimitive, params_shape,
7 },
8};
9use cubecl::{
10 matmul::components::{AccG, AccS, LhsG, LhsS, RhsG, RhsS},
11 server::{Allocation, AllocationDescriptor, AllocationKind},
12};
13use cubecl_quant::scheme::QuantStore;
14
15use crate::{
16 CubeBackend, CubeRuntime, FloatElement, IntElement,
17 element::BoolElement,
18 execute_with_dtype,
19 kernel::{self, matmul::MatmulStrategy},
20 tensor::{CubeTensor, QParams},
21};
22
23use super::{into_data, permute, swap_dims};
24
25fn new_qtensor_optimized<R: CubeRuntime>(
27 data: &[u8],
28 shape: impl Into<Shape>,
29 scheme: QuantScheme,
30 device: &R::Device,
31) -> CubeTensor<R> {
32 new_qtensor(data, shape, scheme, device, AllocationKind::Optimized)
33}
34
35fn new_qtensor<R: CubeRuntime>(
37 data: &[u8],
38 shape: impl Into<Shape>,
39 scheme: QuantScheme,
40 device: &R::Device,
41 kind: AllocationKind,
42) -> CubeTensor<R> {
43 new_quantized(shape, scheme, device, Some(data), kind)
44}
45
46pub fn empty_qtensor_optimized<R: CubeRuntime>(
48 shape: impl Into<Shape>,
49 scheme: QuantScheme,
50 device: &R::Device,
51) -> CubeTensor<R> {
52 empty_qtensor(shape, scheme, device, AllocationKind::Optimized)
53}
54
55pub fn empty_qtensor<R: CubeRuntime>(
57 shape: impl Into<Shape>,
58 scheme: QuantScheme,
59 device: &R::Device,
60 kind: AllocationKind,
61) -> CubeTensor<R> {
62 new_quantized(shape, scheme, device, None, kind)
63}
64
65fn new_quantized<R: CubeRuntime>(
66 shape: impl Into<Shape>,
67 scheme: QuantScheme,
68 device: &R::Device,
69 data: Option<&[u8]>,
70 alloc_kind: AllocationKind,
71) -> CubeTensor<R> {
72 let client = R::client(device);
73 let shape: Shape = shape.into();
74 let mut shape_value: Shape = shape.clone();
75
76 let rank = shape.rank();
77 let shape_last = shape[rank - 1];
78 let num_quants = scheme.num_quants();
79
80 let data_size = match scheme.store {
81 QuantStore::U32 => {
82 if !shape_last.is_multiple_of(num_quants) {
83 panic!("Can't store in u32")
84 }
85 shape_value.dims[rank - 1] = shape_last.div_ceil(num_quants);
86 size_of::<u32>()
87 }
88 QuantStore::Native => match scheme.value {
89 QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {
90 size_of::<i8>()
91 }
92 QuantValue::E2M1 => size_of::<u8>(),
94 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
95 panic!("Can't store native sub-byte values")
96 }
97 },
98 };
99
100 let scales_dtype = match scheme.param {
101 QuantParam::F32 => DType::F32,
102 QuantParam::F16 => DType::F16,
103 QuantParam::BF16 => DType::BF16,
104 QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,
106 };
107
108 let scales_shape = params_shape(&shape, scheme.level);
109 let data_desc = AllocationDescriptor::new(alloc_kind, &shape_value.dims, data_size);
110 let scales_desc =
111 AllocationDescriptor::new(alloc_kind, &scales_shape.dims, scales_dtype.size());
112
113 let mut tensors = match data {
114 Some(data) => {
115 let num_bytes = shape_value.num_elements() * data_size;
116 client.create_tensors(vec![
117 (data_desc, &data[..num_bytes]),
118 (scales_desc, &data[num_bytes..]),
119 ])
120 }
121 None => client.empty_tensors(vec![data_desc, scales_desc]),
122 };
123 let Allocation {
124 handle: scales_handle,
125 strides: scales_strides,
126 } = tensors.remove(1);
127 let Allocation { handle, strides } = tensors.remove(0);
128
129 let scales = QParamTensor {
130 offset_start: scales_handle.offset_start.unwrap_or(0) as usize,
131 offset_end: scales_handle.offset_end.unwrap_or(0) as usize,
132 shape: scales_shape,
133 strides: scales_strides,
134 dtype: scales_dtype,
135 };
136 let qparams = QParams { scales };
137
138 CubeTensor::new_quantized(
139 client,
140 handle,
141 shape,
142 device.clone(),
143 strides,
144 DType::QFloat(scheme),
145 qparams,
146 )
147}
148
149impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
150where
151 R: CubeRuntime,
152 F: FloatElement,
153 I: IntElement,
154 BT: BoolElement,
155{
156 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
157 match data.dtype {
158 DType::QFloat(scheme) => match scheme {
159 QuantScheme {
160 level: QuantLevel::Tensor | QuantLevel::Block(_),
161 mode: QuantMode::Symmetric,
162 value:
163 QuantValue::Q8F
164 | QuantValue::Q8S
165 | QuantValue::Q4F
166 | QuantValue::Q4S
167 | QuantValue::Q2F
168 | QuantValue::Q2S
169 | QuantValue::E4M3
170 | QuantValue::E5M2
171 | QuantValue::E2M1,
172 ..
173 } => {
174 new_qtensor_optimized(data.as_bytes(), data.shape.clone(), scheme, device)
177 }
178 },
179 _ => panic!(
180 "Invalid dtype (expected DType::QFloat, got {:?})",
181 data.dtype
182 ),
183 }
184 }
185
186 fn quantize(
189 tensor: FloatTensor<Self>,
190 scheme: &QuantScheme,
191 qparams: QuantizationParametersPrimitive<Self>,
192 ) -> QuantizedTensor<Self> {
193 kernel::quantization::quantize::<R, F>(tensor, scheme, qparams.scales)
194 }
195
196 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
197 kernel::quantization::dequantize::<R, F>(tensor)
198 }
199
200 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
201 tensor.device.clone()
202 }
203
204 fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
205 super::to_device(tensor, device)
206 }
207
208 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
209 super::q_reshape(tensor, shape)
210 }
211
212 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
213 if tensor.qparams.is_none() {
214 return execute_with_dtype!(tensor.dtype, E, into_data::<R, E>(tensor).await);
215 }
216
217 let (shape, dtype) = (tensor.shape.dims.clone(), tensor.dtype);
218 let scheme = match dtype {
219 DType::QFloat(val) => val,
220 _ => unreachable!("Already checked if quantized."),
221 };
222 let (values, params) = tensor.quantized_handles().unwrap();
223
224 let mut data_values = match scheme.store {
225 QuantStore::Native => match scheme.value {
226 QuantValue::Q8F | QuantValue::Q8S => into_data::<R, i8>(values).await,
227 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
228 into_data::<R, u8>(values).await
229 }
230 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
231 panic!("Can't store native sub-byte values")
232 }
233 },
234 QuantStore::U32 => into_data::<R, u32>(values).await,
235 };
236 let data_params = match scheme.param {
237 QuantParam::UE8M0 | QuantParam::UE4M3 => into_data::<R, u8>(params).await,
238 QuantParam::F16 => into_data::<R, half::f16>(params).await,
239 QuantParam::BF16 => into_data::<R, half::bf16>(params).await,
240 QuantParam::F32 => into_data::<R, f32>(params).await,
241 };
242
243 data_values.bytes.extend_from_byte_slice(&data_params.bytes);
244
245 TensorData {
246 bytes: data_values.bytes,
247 shape,
248 dtype,
249 }
250 }
251
252 fn q_swap_dims(
253 tensor: QuantizedTensor<Self>,
254 dim1: usize,
255 dim2: usize,
256 ) -> QuantizedTensor<Self> {
257 swap_dims(tensor, dim1, dim2)
258 }
259
260 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
261 permute(tensor, axes)
262 }
263
264 fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
265 unimplemented!()
266 }
267
268 fn q_gather(
269 _dim: usize,
270 _tensor: QuantizedTensor<Self>,
271 _indices: IntTensor<Self>,
272 ) -> QuantizedTensor<Self> {
273 unimplemented!()
274 }
275
276 fn q_select(
277 _tensor: QuantizedTensor<Self>,
278 _dim: usize,
279 _indices: IntTensor<Self>,
280 ) -> QuantizedTensor<Self> {
281 unimplemented!()
282 }
283
284 fn q_slice(
285 _tensor: QuantizedTensor<Self>,
286 _slices: &[burn_tensor::Slice],
287 ) -> QuantizedTensor<Self> {
288 unimplemented!()
289 }
290
291 fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
292 unimplemented!()
293 }
294
295 fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
296 let (propagation, scheme) = match (&lhs, &rhs) {
297 (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),
298 (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),
299 _ => unreachable!(),
300 };
301
302 let out_dtype = match (&lhs, &rhs) {
304 (TensorPrimitive::Float(lhs), _) => lhs.dtype,
305 (_, TensorPrimitive::Float(rhs)) => rhs.dtype,
306 _ => F::dtype(),
307 };
308
309 let (lhs_dtype, lhs) = match lhs {
310 TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),
311 TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),
312 };
313 let (rhs_dtype, rhs) = match rhs {
314 TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),
315 TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),
316 };
317
318 let out = execute_with_dtype!(float(lhs_dtype), LP, {
319 execute_with_dtype!(float(rhs_dtype), RP, {
320 execute_with_dtype!(float(out_dtype), OP, {
321 type MP = (LhsG<LP>, RhsG<RP>, AccG<OP>, LhsS<LP>, RhsS<RP>, AccS<OP>);
322
323 kernel::matmul::matmul::<R, MP>(lhs, rhs, None, MatmulStrategy::default())
324 .unwrap()
325 })
326 })
327 });
328
329 match propagation {
330 QuantPropagation::Propagate => {
331 TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))
332 }
333 QuantPropagation::Inhibit => TensorPrimitive::Float(out),
334 }
335 }
336}