burn_jit/ops/
qtensor.rs

1use std::ops::Range;
2
3use burn_tensor::{
4    ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5    quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
6    DType, Device, Shape, TensorData,
7};
8
9use crate::{
10    element::BoolElement, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend,
11    JitRuntime,
12};
13
14/// Create a quantized tensor with packed values (u32).
15fn new_qtensor<R: JitRuntime, S: Into<Shape>>(
16    data: &[u8],
17    shape: S,
18    scheme: QuantizationScheme,
19    device: &R::Device,
20) -> JitTensor<R> {
21    let client = R::client(device);
22    let buffer = client.create(data);
23
24    JitTensor::new_contiguous(
25        client,
26        device.clone(),
27        shape.into(),
28        buffer,
29        DType::QFloat(scheme),
30    )
31}
32
33impl<R, F, I, BT> QTensorOps<Self> for JitBackend<R, F, I, BT>
34where
35    R: JitRuntime,
36    F: FloatElement,
37    I: IntElement,
38    BT: BoolElement,
39{
40    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
41        match data.dtype {
42            DType::QFloat(scheme) => match scheme {
43                QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
44                | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
45                    // TensorData quantized representation is the same, with multiple quantized values
46                    // packed into u32 and quantization parameters appended to the bytes
47                    new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device)
48                }
49            },
50            _ => panic!(
51                "Invalid dtype (expected DType::QFloat, got {:?})",
52                data.dtype
53            ),
54        }
55    }
56
57    fn quantize(
58        tensor: FloatTensor<Self>,
59        scheme: &QuantizationScheme,
60        qparams: QuantizationParametersPrimitive<Self>,
61    ) -> QuantizedTensor<Self> {
62        kernel::quantization::quantize::<R, F, I>(tensor, scheme, qparams.scale, qparams.offset)
63    }
64
65    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
66        kernel::quantization::dequantize::<R, F>(tensor)
67    }
68
69    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
70        tensor.device.clone()
71    }
72
73    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
74        super::to_device(tensor, device)
75    }
76
77    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
78        super::reshape(tensor, shape)
79    }
80
81    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
82        let tensor = kernel::into_contiguous(tensor);
83        let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
84
85        TensorData::from_bytes(bytes, tensor.shape, tensor.dtype)
86    }
87
88    fn q_swap_dims(
89        _tensor: QuantizedTensor<Self>,
90        _dim1: usize,
91        _dim2: usize,
92    ) -> QuantizedTensor<Self> {
93        unimplemented!()
94    }
95
96    fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
97        unimplemented!()
98    }
99
100    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
101        unimplemented!()
102    }
103
104    fn q_gather(
105        _dim: usize,
106        _tensor: QuantizedTensor<Self>,
107        _indices: IntTensor<Self>,
108    ) -> QuantizedTensor<Self> {
109        unimplemented!()
110    }
111
112    fn q_select(
113        _tensor: QuantizedTensor<Self>,
114        _dim: usize,
115        _indices: IntTensor<Self>,
116    ) -> QuantizedTensor<Self> {
117        unimplemented!()
118    }
119
120    fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
121        unimplemented!()
122    }
123
124    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
125        unimplemented!()
126    }
127}