1use std::ops::Range;
2
3use burn_tensor::{
4 ops::{FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
5 quantization::{QuantizationParametersPrimitive, QuantizationScheme, QuantizationType},
6 DType, Device, Shape, TensorData,
7};
8
9use crate::{
10 element::BoolElement, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend,
11 JitRuntime,
12};
13
14fn new_qtensor<R: JitRuntime, S: Into<Shape>>(
16 data: &[u8],
17 shape: S,
18 scheme: QuantizationScheme,
19 device: &R::Device,
20) -> JitTensor<R> {
21 let client = R::client(device);
22 let buffer = client.create(data);
23
24 JitTensor::new_contiguous(
25 client,
26 device.clone(),
27 shape.into(),
28 buffer,
29 DType::QFloat(scheme),
30 )
31}
32
33impl<R, F, I, BT> QTensorOps<Self> for JitBackend<R, F, I, BT>
34where
35 R: JitRuntime,
36 F: FloatElement,
37 I: IntElement,
38 BT: BoolElement,
39{
40 fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
41 match data.dtype {
42 DType::QFloat(scheme) => match scheme {
43 QuantizationScheme::PerTensorAffine(QuantizationType::QInt8)
44 | QuantizationScheme::PerTensorSymmetric(QuantizationType::QInt8) => {
45 new_qtensor(data.as_bytes(), data.shape.clone(), scheme, device)
48 }
49 },
50 _ => panic!(
51 "Invalid dtype (expected DType::QFloat, got {:?})",
52 data.dtype
53 ),
54 }
55 }
56
57 fn quantize(
58 tensor: FloatTensor<Self>,
59 scheme: &QuantizationScheme,
60 qparams: QuantizationParametersPrimitive<Self>,
61 ) -> QuantizedTensor<Self> {
62 kernel::quantization::quantize::<R, F, I>(tensor, scheme, qparams.scale, qparams.offset)
63 }
64
65 fn dequantize(tensor: QuantizedTensor<Self>) -> FloatTensor<Self> {
66 kernel::quantization::dequantize::<R, F>(tensor)
67 }
68
69 fn q_device(tensor: &QuantizedTensor<Self>) -> Device<Self> {
70 tensor.device.clone()
71 }
72
73 fn q_to_device(tensor: QuantizedTensor<Self>, device: &Device<Self>) -> QuantizedTensor<Self> {
74 super::to_device(tensor, device)
75 }
76
77 fn q_reshape(tensor: QuantizedTensor<Self>, shape: Shape) -> QuantizedTensor<Self> {
78 super::reshape(tensor, shape)
79 }
80
81 async fn q_into_data(tensor: QuantizedTensor<Self>) -> TensorData {
82 let tensor = kernel::into_contiguous(tensor);
83 let bytes = tensor.client.read_one_async(tensor.handle.binding()).await;
84
85 TensorData::from_bytes(bytes, tensor.shape, tensor.dtype)
86 }
87
88 fn q_swap_dims(
89 _tensor: QuantizedTensor<Self>,
90 _dim1: usize,
91 _dim2: usize,
92 ) -> QuantizedTensor<Self> {
93 unimplemented!()
94 }
95
96 fn q_permute(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
97 unimplemented!()
98 }
99
100 fn q_flip(_tensor: QuantizedTensor<Self>, _axes: &[usize]) -> QuantizedTensor<Self> {
101 unimplemented!()
102 }
103
104 fn q_gather(
105 _dim: usize,
106 _tensor: QuantizedTensor<Self>,
107 _indices: IntTensor<Self>,
108 ) -> QuantizedTensor<Self> {
109 unimplemented!()
110 }
111
112 fn q_select(
113 _tensor: QuantizedTensor<Self>,
114 _dim: usize,
115 _indices: IntTensor<Self>,
116 ) -> QuantizedTensor<Self> {
117 unimplemented!()
118 }
119
120 fn q_slice(_tensor: QuantizedTensor<Self>, _ranges: &[Range<usize>]) -> QuantizedTensor<Self> {
121 unimplemented!()
122 }
123
124 fn q_expand(_tensor: QuantizedTensor<Self>, _shape: Shape) -> QuantizedTensor<Self> {
125 unimplemented!()
126 }
127}