burn_cubecl/ops/
qtensor.rs

1use burn_backend::{
2    Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorPrimitive,
3    ops::QTensorOps,
4    quantization::{
5        QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue,
6        QuantizationParametersPrimitive, params_shape,
7    },
8    tensor::{Device, FloatElem, FloatTensor, IntTensor, QuantizedTensor},
9};
10use cubecl::server::{Allocation, AllocationDescriptor, AllocationKind};
11use cubecl::{e2m1x2, quant::scheme::QuantStore};
12
13use crate::{
14    CubeBackend, CubeRuntime, FloatElement, IntElement,
15    element::BoolElement,
16    kernel::{self, matmul::MatmulStrategy},
17    tensor::{CubeTensor, QParams},
18};
19
20use super::{into_data, permute, swap_dims};
21
22/// Create a quantized tensor with packed values (u32).
23fn new_qtensor_optimized<R: CubeRuntime>(
24    data: Bytes,
25    shape: impl Into<Shape>,
26    scheme: QuantScheme,
27    device: &R::Device,
28) -> CubeTensor<R> {
29    new_qtensor(data, shape, scheme, device, AllocationKind::Optimized)
30}
31
32/// Create a quantized tensor with packed values (u32).
33fn new_qtensor<R: CubeRuntime>(
34    data: Bytes,
35    shape: impl Into<Shape>,
36    scheme: QuantScheme,
37    device: &R::Device,
38    kind: AllocationKind,
39) -> CubeTensor<R> {
40    new_quantized(shape, scheme, device, Some(data), kind)
41}
42
43/// Create an empty quantized tensor.
44pub fn empty_qtensor_optimized<R: CubeRuntime>(
45    shape: impl Into<Shape>,
46    scheme: QuantScheme,
47    device: &R::Device,
48) -> CubeTensor<R> {
49    empty_qtensor(shape, scheme, device, AllocationKind::Optimized)
50}
51
52/// Create an empty quantized tensor.
53pub fn empty_qtensor<R: CubeRuntime>(
54    shape: impl Into<Shape>,
55    scheme: QuantScheme,
56    device: &R::Device,
57    kind: AllocationKind,
58) -> CubeTensor<R> {
59    new_quantized(shape, scheme, device, None, kind)
60}
61
62fn new_quantized<R: CubeRuntime>(
63    shape: impl Into<Shape>,
64    scheme: QuantScheme,
65    device: &R::Device,
66    data: Option<Bytes>,
67    alloc_kind: AllocationKind,
68) -> CubeTensor<R> {
69    let client = R::client(device);
70    let shape: Shape = shape.into();
71    let mut shape_value: Shape = shape.clone();
72
73    let rank = shape.rank();
74    let shape_last = shape[rank - 1];
75    let num_quants = scheme.num_quants();
76
77    let data_size = match scheme.store {
78        QuantStore::PackedU32(_) => {
79            if !shape_last.is_multiple_of(num_quants) {
80                panic!("Can't store in u32")
81            }
82            shape_value.dims[rank - 1] = shape_last.div_ceil(num_quants);
83            size_of::<u32>()
84        }
85        QuantStore::Native => match scheme.value {
86            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {
87                size_of::<i8>()
88            }
89            QuantValue::Q4F
90            | QuantValue::Q4S
91            | QuantValue::Q2F
92            | QuantValue::Q2S
93            | QuantValue::E2M1 => {
94                panic!("Can't store native sub-byte values")
95            }
96        },
97        QuantStore::PackedNative(_) => match scheme.value {
98            QuantValue::E2M1 => size_of::<e2m1x2>(),
99            other => panic!("{other:?} doesn't support native packing"),
100        },
101    };
102
103    let scales_dtype = match scheme.param {
104        QuantParam::F32 => DType::F32,
105        QuantParam::F16 => DType::F16,
106        QuantParam::BF16 => DType::BF16,
107        // Represented by U8 and reinterpreted in the kernel
108        QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,
109    };
110
111    let scales_shape = params_shape(&shape, scheme.level);
112    let data_desc = AllocationDescriptor::new(alloc_kind, &shape_value.dims, data_size);
113    let scales_desc =
114        AllocationDescriptor::new(alloc_kind, &scales_shape.dims, scales_dtype.size());
115
116    let mut tensors = match data {
117        Some(data) => {
118            let num_bytes = shape_value.num_elements() * data_size;
119
120            match data.split(num_bytes) {
121                Ok((bytes_data, bytes_scales)) => client
122                    .create_tensors(vec![(data_desc, bytes_data), (scales_desc, bytes_scales)]),
123                Err((data, _)) => client.create_tensors_from_slices(vec![
124                    (data_desc, &data[..num_bytes]),
125                    (scales_desc, &data[num_bytes..]),
126                ]),
127            }
128        }
129        None => client.empty_tensors(vec![data_desc, scales_desc]),
130    };
131    let Allocation {
132        handle: scales_handle,
133        strides: scales_strides,
134    } = tensors.remove(1);
135    let Allocation { handle, strides } = tensors.remove(0);
136
137    let scales = QParamTensor {
138        offset_start: scales_handle.offset_start.unwrap_or(0) as usize,
139        offset_end: scales_handle.offset_end.unwrap_or(0) as usize,
140        shape: scales_shape,
141        strides: scales_strides,
142        dtype: scales_dtype,
143    };
144    let qparams = QParams { scales };
145
146    CubeTensor::new_quantized(
147        client,
148        handle,
149        shape,
150        device.clone(),
151        strides,
152        DType::QFloat(scheme),
153        qparams,
154    )
155}
156
157impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
158where
159    R: CubeRuntime,
160    F: FloatElement,
161    I: IntElement,
162    BT: BoolElement,
163{
164    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
165        match data.dtype {
166            DType::QFloat(scheme) => match scheme {
167                QuantScheme {
168                    level: QuantLevel::Tensor | QuantLevel::Block(_),
169                    mode: QuantMode::Symmetric,
170                    value:
171                        QuantValue::Q8F
172                        | QuantValue::Q8S
173                        | QuantValue::Q4F
174                        | QuantValue::Q4S
175                        | QuantValue::Q2F
176                        | QuantValue::Q2S
177                        | QuantValue::E4M3
178                        | QuantValue::E5M2
179                        | QuantValue::E2M1,
180                    ..
181                } => {
182                    // TensorData quantized representation is the same, with multiple quantized values
183                    // packed into u32 and quantization parameters appended to the bytes
184                    new_qtensor_optimized(data.bytes, data.shape.clone(), scheme, device)
185                }
186            },
187            _ => panic!(
188                "Invalid dtype (expected DType::QFloat, got {:?})",
189                data.dtype
190            ),
191        }
192    }
193
194    // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor)
195
196    fn quantize(
197        tensor: FloatTensor<Self>,
198        scheme: &QuantScheme,
199        qparams: QuantizationParametersPrimitive<Self>,
200    ) -> QuantizedTensor<Self> {
201        kernel::quantization::quantize(tensor, scheme, qparams.scales)
202    }
203
204    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
205        kernel::quantization::dequantize(tensor, FloatElem::<Self>::dtype())
206    }
207
208    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
209        tensor.device.clone()
210    }
211
212    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
213        super::to_device(tensor, device)
214    }
215
216    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
217        super::q_reshape(tensor, shape)
218    }
219
220    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
221        if tensor.qparams.is_none() {
222            return into_data(tensor).await;
223        }
224
225        let (shape, dtype) = (tensor.shape.dims.clone(), tensor.dtype);
226        let (values, params) = tensor.quantized_handles().unwrap();
227
228        let mut data_values = into_data(values).await?;
229        let data_params = into_data(params).await?;
230
231        data_values.bytes.extend_from_byte_slice(&data_params.bytes);
232
233        Ok(TensorData {
234            bytes: data_values.bytes,
235            shape,
236            dtype,
237        })
238    }
239
240    fn q_swap_dims(
241        tensor: QuantizedTensor<Self>,
242        dim1: usize,
243        dim2: usize,
244    ) -> QuantizedTensor<Self> {
245        swap_dims(tensor, dim1, dim2)
246    }
247
248    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
249        permute(tensor, axes)
250    }
251
252    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
253        unimplemented!()
254    }
255
256    fn q_gather(
257        _dim: usize,
258        _tensor: QuantizedTensor<Self>,
259        _indices: IntTensor<Self>,
260    ) -> QuantizedTensor<Self> {
261        unimplemented!()
262    }
263
264    fn q_select(
265        _tensor: QuantizedTensor<Self>,
266        _dim: usize,
267        _indices: IntTensor<Self>,
268    ) -> QuantizedTensor<Self> {
269        unimplemented!()
270    }
271
272    fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
273        unimplemented!()
274    }
275
276    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
277        unimplemented!()
278    }
279
280    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
281        let (propagation, scheme) = match (&lhs, &rhs) {
282            (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),
283            (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),
284            _ => unreachable!(),
285        };
286
287        // Inherit precision for mixed inputs, default to `FloatElem` for fully quantized.
288        let out_dtype = match (&lhs, &rhs) {
289            (TensorPrimitive::Float(lhs), _) => lhs.dtype,
290            (_, TensorPrimitive::Float(rhs)) => rhs.dtype,
291            _ => F::dtype(),
292        };
293
294        let (_lhs_dtype, lhs) = match lhs {
295            TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),
296            TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),
297        };
298        let (_rhs_dtype, rhs) = match rhs {
299            TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),
300            TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),
301        };
302
303        let out =
304            kernel::matmul::matmul(lhs, rhs, None, MatmulStrategy::default(), out_dtype).unwrap();
305
306        match propagation {
307            QuantPropagation::Propagate => {
308                TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))
309            }
310            QuantPropagation::Inhibit => TensorPrimitive::Float(out),
311        }
312    }
313}