burn_cubecl/ops/
qtensor.rs

1use std::ops::Range;
2
3use burn_tensor::{
4    DType, Device, Shape, TensorData,
5    ops::{FloatTensor, FloatTensorOps, IntTensor, QTensorOps, QuantizedTensor},
6    quantization::{
7        QTensorPrimitive, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme,
8        QuantizationType,
9    },
10};
11use cubecl::{
12    Feature, Runtime,
13    client::ComputeClient,
14    ir::{Elem, IntKind},
15};
16
17use crate::{
18    CubeBackend, CubeRuntime, FloatElement, IntElement,
19    element::BoolElement,
20    kernel::{self, matmul::MatmulStrategy},
21    tensor::CubeTensor,
22};
23
24use super::{permute, swap_dims};
25
26/// Create a quantized tensor with packed values (u32).
27fn new_qtensor<R: CubeRuntime, S: Into<Shape>>(
28    data: &[u8],
29    shape: S,
30    scheme: QuantizationScheme,
31    device: &R::Device,
32) -> CubeTensor<R> {
33    let client = R::client(device);
34    let buffer = client.create(data);
35
36    CubeTensor::new_contiguous(
37        client,
38        device.clone(),
39        shape.into(),
40        buffer,
41        DType::QFloat(scheme),
42    )
43}
44
45impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
46where
47    R: CubeRuntime,
48    F: FloatElement,
49    I: IntElement,
50    BT: BoolElement,
51{
52    fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
53        match data.dtype {
54            DType::QFloat(scheme) => match scheme {
55                QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
56                    // TensorData quantized representation is the same, with multiple quantized values
57                    // packed into u32 and quantization parameters appended to the bytes
58                    new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device)
59                }
60            },
61            _ => panic!(
62                "Invalid dtype (expected DType::QFloat, got {:?})",
63                data.dtype
64            ),
65        }
66    }
67
68    // TODO: quantize_dynamic (we can compute min-max on the fly and scale, especially when not per-tensor)
69
70    fn quantize(
71        tensor: FloatTensor<Self>,
72        scheme: &QuantizationScheme,
73        qparams: QuantizationParametersPrimitive<Self>,
74    ) -> QuantizedTensor<Self> {
75        kernel::quantization::quantize::<R, F, I>(tensor, scheme, qparams.scale)
76    }
77
78    fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
79        kernel::quantization::dequantize::<R, F>(tensor)
80    }
81
82    fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
83        tensor.device.clone()
84    }
85
86    fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
87        super::to_device(tensor, device)
88    }
89
90    fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
91        super::reshape(tensor, shape)
92    }
93
94    async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
95        let tensor = kernel::into_contiguous(tensor);
96        let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
97
98        // We use the same internal representation
99        TensorData::from_bytes(bytes, tensor.shape, tensor.dtype)
100    }
101
102    fn q_swap_dims(
103        tensor: QuantizedTensor<Self>,
104        dim1: usize,
105        dim2: usize,
106    ) -> QuantizedTensor<Self> {
107        swap_dims(tensor, dim1, dim2)
108    }
109
110    fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
111        permute(tensor, axes)
112    }
113
114    fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
115        unimplemented!()
116    }
117
118    fn q_gather(
119        _dim: usize,
120        _tensor: QuantizedTensor<Self>,
121        _indices: IntTensor<Self>,
122    ) -> QuantizedTensor<Self> {
123        unimplemented!()
124    }
125
126    fn q_select(
127        _tensor: QuantizedTensor<Self>,
128        _dim: usize,
129        _indices: IntTensor<Self>,
130    ) -> QuantizedTensor<Self> {
131        unimplemented!()
132    }
133
134    fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
135        unimplemented!()
136    }
137
138    fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
139        unimplemented!()
140    }
141
142    fn q_matmul(lhs: QuantizedTensor<Self>, rhs: QuantizedTensor<Self>) -> QuantizedTensor<Self> {
143        if features_enabled::<R>(&lhs.client)
144            && both_matches_symmetric_qint8(lhs.scheme(), rhs.scheme())
145        {
146            let out =
147                kernel::matmul::q_matmul(lhs.clone(), rhs.clone(), None, MatmulStrategy::default());
148            if let Ok(out) = out {
149                return out;
150            }
151        }
152
153        // If the above quantized matmul fail, we fallback to the dequantize-then-matmul pattern.
154        let t1_f = <Self>::dequantize(lhs);
155        let t2_f = <Self>::dequantize(rhs);
156        Self::float_matmul(t1_f, t2_f)
157    }
158}
159
160fn both_matches_symmetric_qint8(lhs: &QuantizationScheme, rhs: &QuantizationScheme) -> bool {
161    [lhs, rhs].iter().all(|scheme| {
162        matches!(
163            scheme,
164            QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8),
165        )
166    })
167}
168
169fn features_enabled<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> bool {
170    client
171        .properties()
172        .feature_enabled(Feature::Type(Elem::Int(IntKind::I8)))
173        && client
174            .properties()
175            .feature_enabled(Feature::DynamicLineSize)
176}