1use std::ops::Range;
2
3use burn_tensor::{
4 DType, Device, Shape, TensorData,
5 ops::{FloatTensor, FloatTensorOps, IntTensor, QTensorOps, QuantizedTensor},
6 quantization::{
7 QTensorPrimitive, QuantizationMode, QuantizationParametersPrimitive, QuantizationScheme,
8 QuantizationType,
9 },
10};
11use cubecl::{
12 Feature, Runtime,
13 client::ComputeClient,
14 ir::{Elem, IntKind},
15};
16
17use crate::{
18 CubeBackend, CubeRuntime, FloatElement, IntElement,
19 element::BoolElement,
20 kernel::{self, matmul::MatmulStrategy},
21 tensor::CubeTensor,
22};
23
24use super::{permute, swap_dims};
25
26fn new_qtensor<R: CubeRuntime, S: Into<Shape>>(
28 data: &[u8],
29 shape: S,
30 scheme: QuantizationScheme,
31 device: &R::Device,
32) -> CubeTensor<R> {
33 let client = R::client(device);
34 let buffer = client.create(data);
35
36 CubeTensor::new_contiguous(
37 client,
38 device.clone(),
39 shape.into(),
40 buffer,
41 DType::QFloat(scheme),
42 )
43}
44
45impl<R, F, I, BT> QTensorOps<Self> for CubeBackend<R, F, I, BT>
46where
47 R: CubeRuntime,
48 F: FloatElement,
49 I: IntElement,
50 BT: BoolElement,
51{
52 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
53 match data.dtype {
54 DType::QFloat(scheme) => match scheme {
55 QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
56 new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device)
59 }
60 },
61 _ => panic!(
62 "Invalid dtype (expected DType::QFloat, got {:?})",
63 data.dtype
64 ),
65 }
66 }
67
68 fn quantize(
71 tensor: FloatTensor<Self>,
72 scheme: &QuantizationScheme,
73 qparams: QuantizationParametersPrimitive<Self>,
74 ) -> QuantizedTensor<Self> {
75 kernel::quantization::quantize::<R, F, I>(tensor, scheme, qparams.scale)
76 }
77
78 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
79 kernel::quantization::dequantize::<R, F>(tensor)
80 }
81
82 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
83 tensor.device.clone()
84 }
85
86 fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
87 super::to_device(tensor, device)
88 }
89
90 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
91 super::reshape(tensor, shape)
92 }
93
94 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
95 let tensor = kernel::into_contiguous(tensor);
96 let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
97
98 TensorData::from_bytes(bytes, tensor.shape, tensor.dtype)
100 }
101
102 fn q_swap_dims(
103 tensor: QuantizedTensor<Self>,
104 dim1: usize,
105 dim2: usize,
106 ) -> QuantizedTensor<Self> {
107 swap_dims(tensor, dim1, dim2)
108 }
109
110 fn q_permute(tensor: QuantizedTensor<Self>, axes: &[usize]) -> QuantizedTensor<Self> {
111 permute(tensor, axes)
112 }
113
114 fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
115 unimplemented!()
116 }
117
118 fn q_gather(
119 _dim: usize,
120 _tensor: QuantizedTensor<Self>,
121 _indices: IntTensor<Self>,
122 ) -> QuantizedTensor<Self> {
123 unimplemented!()
124 }
125
126 fn q_select(
127 _tensor: QuantizedTensor<Self>,
128 _dim: usize,
129 _indices: IntTensor<Self>,
130 ) -> QuantizedTensor<Self> {
131 unimplemented!()
132 }
133
134 fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
135 unimplemented!()
136 }
137
138 fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
139 unimplemented!()
140 }
141
142 fn q_matmul(lhs: QuantizedTensor<Self>, rhs: QuantizedTensor<Self>) -> QuantizedTensor<Self> {
143 if features_enabled::<R>(&lhs.client)
144 && both_matches_symmetric_qint8(lhs.scheme(), rhs.scheme())
145 {
146 let out =
147 kernel::matmul::q_matmul(lhs.clone(), rhs.clone(), None, MatmulStrategy::default());
148 if let Ok(out) = out {
149 return out;
150 }
151 }
152
153 let t1_f = <Self>::dequantize(lhs);
155 let t2_f = <Self>::dequantize(rhs);
156 Self::float_matmul(t1_f, t2_f)
157 }
158}
159
160fn both_matches_symmetric_qint8(lhs: &QuantizationScheme, rhs: &QuantizationScheme) -> bool {
161 [lhs, rhs].iter().all(|scheme| {
162 matches!(
163 scheme,
164 QuantizationScheme::PerTensor(QuantizationMode::Symmetric, QuantizationType::QInt8),
165 )
166 })
167}
168
169fn features_enabled<R: Runtime>(client: &ComputeClient<R::Server, R::Channel>) -> bool {
170 client
171 .properties()
172 .feature_enabled(Feature::Type(Elem::Int(IntKind::I8)))
173 && client
174 .properties()
175 .feature_enabled(Feature::DynamicLineSize)
176}