Skip to main content

burn_cubecl/ops/
qtensor.rs

1use burn_backend::{
2    Bytes, DType, ExecutionError, QTensorPrimitive, Shape, Slice, TensorData, TensorMetadata,
3    TensorPrimitive,
4    ops::QTensorOps,
5    quantization::{
6        QParamTensor, QuantLevel, QuantMode, QuantParam, QuantPropagation, QuantScheme, QuantValue,
7        QuantizationParametersPrimitive, params_shape,
8    },
9    tensor::{Device, FloatTensor, IntTensor, QuantizedTensor},
10};
11use burn_std::{FloatDType, Metadata};
12use cubecl::server::{MemoryLayout, MemoryLayoutDescriptor, MemoryLayoutStrategy};
13use cubecl::{e2m1x2, quant::scheme::QuantStore};
14
15use crate::{
16    CubeBackend, CubeRuntime, FloatElement, IntElement,
17    element::BoolElement,
18    kernel::{self, matmul::MatmulStrategy},
19    tensor::{CubeTensor, QParams},
20};
21
22use super::{into_data, permute, swap_dims};
23
24/// Create a quantized tensor with packed values (u32).
25fn new_qtensor_optimized<R: CubeRuntime>(
26    data: Bytes,
27    shape: impl Into<Shape>,
28    scheme: QuantScheme,
29    device: &R::Device,
30) -> CubeTensor<R> {
31    new_qtensor(data, shape, scheme, device, MemoryLayoutStrategy::Optimized)
32}
33
34/// Create a quantized tensor with packed values (u32).
35fn new_qtensor<R: CubeRuntime>(
36    data: Bytes,
37    shape: impl Into<Shape>,
38    scheme: QuantScheme,
39    device: &R::Device,
40    kind: MemoryLayoutStrategy,
41) -> CubeTensor<R> {
42    new_quantized(shape, scheme, device, Some(data), kind)
43}
44
45/// Create an empty quantized tensor.
46pub fn empty_qtensor_optimized<R: CubeRuntime>(
47    shape: impl Into<Shape>,
48    scheme: QuantScheme,
49    device: &R::Device,
50) -> CubeTensor<R> {
51    empty_qtensor(shape, scheme, device, MemoryLayoutStrategy::Optimized)
52}
53
54/// Create an empty quantized tensor.
55pub fn empty_qtensor<R: CubeRuntime>(
56    shape: impl Into<Shape>,
57    scheme: QuantScheme,
58    device: &R::Device,
59    kind: MemoryLayoutStrategy,
60) -> CubeTensor<R> {
61    new_quantized(shape, scheme, device, None, kind)
62}
63
64fn new_quantized<R: CubeRuntime>(
65    shape: impl Into<Shape>,
66    scheme: QuantScheme,
67    device: &R::Device,
68    data: Option<Bytes>,
69    alloc_kind: MemoryLayoutStrategy,
70) -> CubeTensor<R> {
71    let client = R::client(device);
72    let shape: Shape = shape.into();
73    let mut shape_value: Shape = shape.clone();
74
75    let rank = shape.rank();
76    let shape_last = shape[rank - 1];
77    let num_quants = scheme.num_quants();
78
79    let data_size = match scheme.store {
80        QuantStore::PackedU32(_) => {
81            if !shape_last.is_multiple_of(num_quants) {
82                panic!("Can't store in u32")
83            }
84            shape_value[rank - 1] = shape_last.div_ceil(num_quants);
85            size_of::<u32>()
86        }
87        QuantStore::Native => match scheme.value {
88            QuantValue::Q8F | QuantValue::Q8S | QuantValue::E4M3 | QuantValue::E5M2 => {
89                size_of::<i8>()
90            }
91            QuantValue::Q4F
92            | QuantValue::Q4S
93            | QuantValue::Q2F
94            | QuantValue::Q2S
95            | QuantValue::E2M1 => {
96                panic!("Can't store native sub-byte values")
97            }
98        },
99        QuantStore::PackedNative(_) => match scheme.value {
100            QuantValue::E2M1 => size_of::<e2m1x2>(),
101            other => panic!("{other:?} doesn't support native packing"),
102        },
103    };
104
105    let scales_dtype = match scheme.param {
106        QuantParam::F32 => DType::F32,
107        QuantParam::F16 => DType::F16,
108        QuantParam::BF16 => DType::BF16,
109        // Represented by U8 and reinterpreted in the kernel
110        QuantParam::UE8M0 | QuantParam::UE4M3 => DType::U8,
111    };
112
113    let scales_shape = params_shape(&shape, scheme.level);
114    let data_desc = MemoryLayoutDescriptor::new(alloc_kind, shape_value.clone(), data_size);
115    let scales_desc =
116        MemoryLayoutDescriptor::new(alloc_kind, scales_shape.clone(), scales_dtype.size());
117
118    let mut tensors = match data {
119        Some(data) => {
120            let num_bytes = shape_value.num_elements() * data_size;
121
122            match data.split(num_bytes) {
123                Ok((bytes_data, bytes_scales)) => client
124                    .create_tensors(vec![(data_desc, bytes_data), (scales_desc, bytes_scales)]),
125                Err((data, _)) => client.create_tensors_from_slices(vec![
126                    (data_desc, &data[..num_bytes]),
127                    (scales_desc, &data[num_bytes..]),
128                ]),
129            }
130        }
131        None => client.empty_tensors(vec![data_desc, scales_desc]),
132    };
133    let MemoryLayout {
134        memory: scales_handle,
135        strides: scales_strides,
136    } = tensors.remove(1);
137    let MemoryLayout { memory, strides } = tensors.remove(0);
138
139    let scales = QParamTensor {
140        offset_start: scales_handle.offset_start.unwrap_or(0) as usize,
141        offset_end: scales_handle.offset_end.unwrap_or(0) as usize,
142        metadata: Metadata::new(scales_shape, scales_strides),
143        dtype: scales_dtype,
144    };
145    let qparams = QParams { scales };
146
147    CubeTensor::new_quantized(
148        client,
149        memory,
150        shape,
151        device.clone(),
152        strides,
153        DType::QFloat(scheme),
154        qparams,
155    )
156}
157
158impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
159where
160    R: CubeRuntime,
161    F: FloatElement,
162    I: IntElement,
163    BT: BoolElement,
164{
165    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
166        match data.dtype {
167            DType::QFloat(scheme) => match scheme {
168                QuantScheme {
169                    level: QuantLevel::Tensor | QuantLevel::Block(_),
170                    mode: QuantMode::Symmetric,
171                    value:
172                        QuantValue::Q8F
173                        | QuantValue::Q8S
174                        | QuantValue::Q4F
175                        | QuantValue::Q4S
176                        | QuantValue::Q2F
177                        | QuantValue::Q2S
178                        | QuantValue::E4M3
179                        | QuantValue::E5M2
180                        | QuantValue::E2M1,
181                    ..
182                } => {
183                    // TensorData quantized representation is the same, with multiple quantized values
184                    // packed into u32 and quantization parameters appended to the bytes
185                    new_qtensor_optimized(data.bytes, data.shape.clone(), scheme, device)
186                }
187            },
188            _ => panic!(
189                "Invalid dtype (expected DType::QFloat, got {:?})",
190                data.dtype
191            ),
192        }
193    }
194
195    // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor)
196
197    fn quantize(
198        tensor: FloatTensor<Self>,
199        scheme: &QuantScheme,
200        qparams: QuantizationParametersPrimitive<Self>,
201    ) -> QuantizedTensor<Self> {
202        kernel::quantization::quantize(tensor, scheme, qparams.scales)
203    }
204
205    fn dequantize(tensor: QuantizedTensor<Self>, dtype: FloatDType) -> FloatTensor<Self> {
206        kernel::quantization::dequantize(tensor, dtype.into())
207    }
208
209    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
210        tensor.device.clone()
211    }
212
213    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
214        super::to_device(tensor, device)
215    }
216
217    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
218        super::q_reshape(tensor, shape)
219    }
220
221    async fn q_into_data(tensor: QuantizedTensor<Self>) -> Result<TensorData, ExecutionError> {
222        if tensor.qparams.is_none() {
223            return into_data(tensor).await;
224        }
225
226        let (shape, dtype) = (tensor.shape(), tensor.dtype);
227        let (values, params) = tensor.quantized_handles().unwrap();
228
229        let mut data_values = into_data(values).await?;
230        let data_params = into_data(params).await?;
231
232        data_values.bytes.extend_from_byte_slice(&data_params.bytes);
233
234        Ok(TensorData {
235            bytes: data_values.bytes,
236            shape,
237            dtype,
238        })
239    }
240
241    fn q_swap_dims(
242        tensor: QuantizedTensor<Self>,
243        dim1: usize,
244        dim2: usize,
245    ) -> QuantizedTensor<Self> {
246        swap_dims(tensor, dim1, dim2)
247    }
248
249    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
250        permute(tensor, axes)
251    }
252
253    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
254        unimplemented!()
255    }
256
257    fn q_gather(
258        _dim: usize,
259        _tensor: QuantizedTensor<Self>,
260        _indices: IntTensor<Self>,
261    ) -> QuantizedTensor<Self> {
262        unimplemented!()
263    }
264
265    fn q_select(
266        _tensor: QuantizedTensor<Self>,
267        _dim: usize,
268        _indices: IntTensor<Self>,
269    ) -> QuantizedTensor<Self> {
270        unimplemented!()
271    }
272
273    fn q_slice(_tensor: QuantizedTensor<Self>, _slices: &[Slice]) -> QuantizedTensor<Self> {
274        unimplemented!()
275    }
276
277    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
278        unimplemented!()
279    }
280
281    fn q_matmul(lhs: TensorPrimitive<Self>, rhs: TensorPrimitive<Self>) -> TensorPrimitive<Self> {
282        let (propagation, scheme) = match (&lhs, &rhs) {
283            (TensorPrimitive::QFloat(lhs), _) => (lhs.propagation(), *lhs.scheme()),
284            (_, TensorPrimitive::QFloat(rhs)) => (rhs.propagation(), *rhs.scheme()),
285            _ => unreachable!(),
286        };
287
288        // Inherit precision for mixed inputs, default to `FloatElem` for fully quantized.
289        let out_dtype = match (&lhs, &rhs) {
290            (TensorPrimitive::Float(lhs), _) => lhs.dtype,
291            (_, TensorPrimitive::Float(rhs)) => rhs.dtype,
292            _ => F::dtype(),
293        };
294
295        let (_lhs_dtype, lhs) = match lhs {
296            TensorPrimitive::Float(lhs) => (lhs.dtype, lhs),
297            TensorPrimitive::QFloat(lhs) => (out_dtype, lhs),
298        };
299        let (_rhs_dtype, rhs) = match rhs {
300            TensorPrimitive::Float(rhs) => (rhs.dtype, rhs),
301            TensorPrimitive::QFloat(rhs) => (out_dtype, rhs),
302        };
303
304        let out =
305            kernel::matmul::matmul(lhs, rhs, None, MatmulStrategy::default(), out_dtype).unwrap();
306
307        match propagation {
308            QuantPropagation::Propagate => {
309                TensorPrimitive::QFloat(Self::quantize_dynamic(out, &scheme))
310            }
311            QuantPropagation::Inhibit => TensorPrimitive::Float(out),
312        }
313    }
314}