burn_cubecl/ops/
qtensor.rs

1use burn_tensor::{
2    DType, Device, Shape, TensorData, TensorPrimitive,
3    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
4    quantization::{
5        QParamTensor, QTensorPrimitive, QuantLevel, QuantMode, QuantParam, QuantPropagation,
6        QuantScheme, QuantValue, QuantizationParametersPrimitive, params_shape,
7    },
8};
9use cubecl::{
10    matmul::components::{AccG, AccS, LhsG, LhsS, RhsG, RhsS},
11    server::{Allocation, AllocationDescriptor, AllocationKind},
12};
13use cubecl_quant::scheme::QuantStore;
14
15use crate::{
16    CubeBackend, CubeRuntime, FloatElement, IntElement,
17    element::BoolElement,
18    execute_with_dtype,
19    kernel::{self, matmul::MatmulStrategy},
20    tensor::{CubeTensor, QParams},
21};
22
23use super::{into_data, permute, swap_dims};
24
25/// Create a quantized tensor with packed values (u32).
26fn new_qtensor_optimized<R: CubeRuntime>(
27    data: &[u8],
28    shape: impl Into<Shape>,
29    scheme: QuantScheme,
30    device: &R::Device,
31) -> CubeTensor<R> {
32    new_qtensor(data, shape, scheme, device, AllocationKind::Optimized)
33}
34
35/// Create a quantized tensor with packed values (u32).
36fn new_qtensor<R: CubeRuntime>(
37    data: &[u8],
38    shape: impl Into<Shape>,
39    scheme: QuantScheme,
40    device: &R::Device,
41    kind: AllocationKind,
42) -> CubeTensor<R> {
43    new_quantized(shape, scheme, device, Some(data), kind)
44}
45
46/// Create an empty quantized tensor.
47pub fn empty_qtensor_optimized<R: CubeRuntime>(
48    shape: impl Into<Shape>,
49    scheme: QuantScheme,
50    device: &R::Device,
51) -> CubeTensor<R> {
52    empty_qtensor(shape, scheme, device, AllocationKind::Optimized)
53}
54
55/// Create an empty quantized tensor.
56pub fn empty_qtensor<R: CubeRuntime>(
57    shape: impl Into<Shape>,
58    scheme: QuantScheme,
59    device: &R::Device,
60    kind: AllocationKind,
61) -> CubeTensor<R> {
62    new_quantized(shape, scheme, device, None, kind)
63}
64
65fn new_quantized<R: CubeRuntime>(
66    shape: impl Into<Shape>,
67    scheme: QuantScheme,
68    device: &R::Device,
69    data: Option<&[u8]>,
70    alloc_kind: AllocationKind,
71) -> CubeTensor<R> {
72    let client = R::client(device);
73    let shape: Shape = shape.into();
74    let mut shape_value: Shape = shape.clone();
75
76    let rank = shape.rank();
77    let shape_last = shape[rank - 1];
78    let num_quants = scheme.num_quants();
79
80    let data_size = match scheme.store {
81        QuantStore::U32 => {
82            if !shape_last.is_multiple_of(num_quants) {
83                panic!("Can't store in u32")
84            }
85            shape_value.dims[rank - 1] = shape_last.div_ceil(num_quants);
86            size_of::<u32>()
87        }
88        QuantStore::Native => match scheme.value {
89            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {
90                size_of::<i8>()
91            }
92            // Native e2m1 is packed in u8
93            QuantValue::E2M1 => size_of::<u8>(),
94            QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
95                panic!("Can't store native sub-byte values")
96            }
97        },
98    };
99
100    let scales_dtype = match scheme.param {
101        QuantParam::F32 => DType::F32,
102        QuantParam::F16 => DType::F16,
103        QuantParam::BF16 => DType::BF16,
104        // Represented by U8 and reinterpreted in the kernel
105        QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,
106    };
107
108    let scales_shape = params_shape(&shape, scheme.level);
109    let data_desc = AllocationDescriptor::new(alloc_kind, &shape_value.dims, data_size);
110    let scales_desc =
111        AllocationDescriptor::new(alloc_kind, &scales_shape.dims, scales_dtype.size());
112
113    let mut tensors = match data {
114        Some(data) => {
115            let num_bytes = shape_value.num_elements() * data_size;
116            client.create_tensors(vec![
117                (data_desc, &data[..num_bytes]),
118                (scales_desc, &data[num_bytes..]),
119            ])
120        }
121        None => client.empty_tensors(vec![data_desc, scales_desc]),
122    };
123    let Allocation {
124        handle: scales_handle,
125        strides: scales_strides,
126    } = tensors.remove(1);
127    let Allocation { handle, strides } = tensors.remove(0);
128
129    let scales = QParamTensor {
130        offset_start: scales_handle.offset_start.unwrap_or(0) as usize,
131        offset_end: scales_handle.offset_end.unwrap_or(0) as usize,
132        shape: scales_shape,
133        strides: scales_strides,
134        dtype: scales_dtype,
135    };
136    let qparams = QParams { scales };
137
138    CubeTensor::new_quantized(
139        client,
140        handle,
141        shape,
142        device.clone(),
143        strides,
144        DType::QFloat(scheme),
145        qparams,
146    )
147}
148
149impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
150where
151    R: CubeRuntime,
152    F: FloatElement,
153    I: IntElement,
154    BT: BoolElement,
155{
156    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
157        match data.dtype {
158            DType::QFloat(scheme) => match scheme {
159                QuantScheme {
160                    level: QuantLevel::Tensor | QuantLevel::Block(_),
161                    mode: QuantMode::Symmetric,
162                    value:
163                        QuantValue::Q8F
164                        | QuantValue::Q8S
165                        | QuantValue::Q4F
166                        | QuantValue::Q4S
167                        | QuantValue::Q2F
168                        | QuantValue::Q2S
169                        | QuantValue::E4M3
170                        | QuantValue::E5M2
171                        | QuantValue::E2M1,
172                    ..
173                } => {
174                    // TensorData quantized representation is the same, with multiple quantized values
175                    // packed into u32 and quantization parameters appended to the bytes
176                    new_qtensor_optimized(data.as_bytes(), data.shape.clone(), scheme, device)
177                }
178            },
179            _ => panic!(
180                "Invalid dtype (expected DType::QFloat, got {:?})",
181                data.dtype
182            ),
183        }
184    }
185
186    // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor)
187
188    fn quantize(
189        tensor: FloatTensor<Self>,
190        scheme: &QuantScheme,
191        qparams: QuantizationParametersPrimitive<Self>,
192    ) -> QuantizedTensor<Self> {
193        kernel::quantization::quantize::<R, F>(tensor, scheme, qparams.scales)
194    }
195
196    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
197        kernel::quantization::dequantize::<R, F>(tensor)
198    }
199
200    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
201        tensor.device.clone()
202    }
203
204    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
205        super::to_device(tensor, device)
206    }
207
208    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
209        super::q_reshape(tensor, shape)
210    }
211
212    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
213        if tensor.qparams.is_none() {
214            return execute_with_dtype!(tensor.dtype, E, into_data::<R, E>(tensor).await);
215        }
216
217        let (shape, dtype) = (tensor.shape.dims.clone(), tensor.dtype);
218        let scheme = match dtype {
219            DType::QFloat(val) => val,
220            _ => unreachable!("Already checked if quantized."),
221        };
222        let (values, params) = tensor.quantized_handles().unwrap();
223
224        let mut data_values = match scheme.store {
225            QuantStore::Native => match scheme.value {
226                QuantValue::Q8F | QuantValue::Q8S => into_data::<R, i8>(values).await,
227                QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
228                    into_data::<R, u8>(values).await
229                }
230                QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
231                    panic!("Can't store native sub-byte values")
232                }
233            },
234            QuantStore::U32 => into_data::<R, u32>(values).await,
235        };
236        let data_params = match scheme.param {
237            QuantParam::UE8M0 | QuantParam::UE4M3 => into_data::<R, u8>(params).await,
238            QuantParam::F16 => into_data::<R, half::f16>(params).await,
239            QuantParam::BF16 => into_data::<R, half::bf16>(params).await,
240            QuantParam::F32 => into_data::<R, f32>(params).await,
241        };
242
243        data_values.bytes.extend_from_byte_slice(&data_params.bytes);
244
245        TensorData {
246            bytes: data_values.bytes,
247            shape,
248            dtype,
249        }
250    }
251
252    fn q_swap_dims(
253        tensor: QuantizedTensor<Self>,
254        dim1: usize,
255        dim2: usize,
256    ) -> QuantizedTensor<Self> {
257        swap_dims(tensor, dim1, dim2)
258    }
259
260    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
261        permute(tensor, axes)
262    }
263
264    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
265        unimplemented!()
266    }
267
268    fn q_gather(
269        _dim: usize,
270        _tensor: QuantizedTensor<Self>,
271        _indices: IntTensor<Self>,
272    ) -> QuantizedTensor<Self> {
273        unimplemented!()
274    }
275
276    fn q_select(
277        _tensor: QuantizedTensor<Self>,
278        _dim: usize,
279        _indices: IntTensor<Self>,
280    ) -> QuantizedTensor<Self> {
281        unimplemented!()
282    }
283
284    fn q_slice(
285        _tensor: QuantizedTensor<Self>,
286        _slices: &[burn_tensor::Slice],
287    ) -> QuantizedTensor<Self> {
288        unimplemented!()
289    }
290
291    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
292        unimplemented!()
293    }
294
295    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
296        let (propagation, scheme) = match (&lhs, &rhs) {
297            (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),
298            (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),
299            _ => unreachable!(),
300        };
301
302        // Inherit precision for mixed inputs, default to `FloatElem` for fully quantized.
303        let out_dtype = match (&lhs, &rhs) {
304            (TensorPrimitive::Float(lhs), _) => lhs.dtype,
305            (_, TensorPrimitive::Float(rhs)) => rhs.dtype,
306            _ => F::dtype(),
307        };
308
309        let (lhs_dtype, lhs) = match lhs {
310            TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),
311            TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),
312        };
313        let (rhs_dtype, rhs) = match rhs {
314            TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),
315            TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),
316        };
317
318        let out = execute_with_dtype!(float(lhs_dtype), LP, {
319            execute_with_dtype!(float(rhs_dtype), RP, {
320                execute_with_dtype!(float(out_dtype), OP, {
321                    type MP = (LhsG<LP>, RhsG<RP>, AccG<OP>, LhsS<LP>, RhsS<RP>, AccS<OP>);
322
323                    kernel::matmul::matmul::<R, MP>(lhs, rhs, None, MatmulStrategy::default())
324                        .unwrap()
325                })
326            })
327        });
328
329        match propagation {
330            QuantPropagation::Propagate => {
331                TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))
332            }
333            QuantPropagation::Inhibit => TensorPrimitive::Float(out),
334        }
335    }
336}