1use crate::{LibTorchDevice, TchElement};
2use burn_tensor::{
3 quantization::{
4 AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
5 QuantizationType, SymmetricQuantization,
6 },
7 DType, Shape, TensorData, TensorMetadata,
8};
9use libc::c_void;
10use std::sync::Arc;
11
12#[allow(clippy::arc_with_non_send_sync)]
16pub type StorageRef = Arc<*mut c_void>;
17
18#[derive(PartialEq, Debug, Clone)]
20pub enum Storage {
21 View {
23 buffer_ref: StorageRef,
25 view_ref: StorageRef,
27 },
28 Owned {
30 buffer_ref: StorageRef,
32 },
33}
34
35impl Storage {
36 pub fn can_mut(&self) -> bool {
38 match self {
39 Storage::View {
40 buffer_ref: start_ref,
41 view_ref,
42 } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
43 Storage::Owned {
44 buffer_ref: start_ref,
45 } => Arc::strong_count(start_ref) == 1,
46 }
47 }
48
49 pub fn buffer_ref(&self) -> &StorageRef {
51 match self {
52 Storage::View {
53 buffer_ref: start_ref,
54 view_ref: _,
55 } => start_ref,
56 Storage::Owned {
57 buffer_ref: start_ref,
58 } => start_ref,
59 }
60 }
61}
62
63#[derive(Debug, PartialEq)]
65pub struct TchTensor {
66 pub tensor: tch::Tensor,
68
69 pub storage: Storage,
71}
72
73impl TensorMetadata for TchTensor {
74 fn dtype(&self) -> DType {
75 match self.tensor.kind() {
76 tch::Kind::Uint8 => DType::U8,
77 tch::Kind::Int8 => DType::I8,
78 tch::Kind::Int16 => DType::I16,
79 tch::Kind::Int => DType::I32,
80 tch::Kind::Int64 => DType::I64,
81 tch::Kind::Half => DType::F16,
82 tch::Kind::Float => DType::F32,
83 tch::Kind::Double => DType::F64,
84 tch::Kind::Bool => DType::Bool,
85 tch::Kind::QUInt8 => DType::U8,
86 tch::Kind::BFloat16 => DType::BF16,
87 _ => unimplemented!(),
89 }
90 }
91
92 fn shape(&self) -> Shape {
93 Shape::from(self.tensor.size())
94 }
95}
96
97impl TchTensor {
98 pub fn new(tensor: tch::Tensor) -> Self {
104 #[allow(clippy::arc_with_non_send_sync)]
105 let storage = Storage::Owned {
106 buffer_ref: Arc::new(tensor.data_ptr()),
107 };
108
109 Self { tensor, storage }
110 }
111
112 pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
117 let storage_child = tensor.data_ptr();
118 let mut is_a_new_tensor = true;
119
120 match &storage_parent {
121 Storage::View {
122 buffer_ref: start_ref,
123 view_ref,
124 } => {
125 if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
126 is_a_new_tensor = false;
127 }
128 }
129 Storage::Owned {
130 buffer_ref: start_ref,
131 } => {
132 if storage_child == *start_ref.as_ref() {
133 is_a_new_tensor = false;
134 }
135 }
136 };
137
138 let storage = match is_a_new_tensor {
139 true => Storage::Owned {
140 #[allow(clippy::arc_with_non_send_sync)]
141 buffer_ref: Arc::new(storage_child),
142 },
143 false => storage_parent.clone(),
144 };
145
146 Self { tensor, storage }
147 }
148
149 pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
151 let storage = Storage::View {
152 buffer_ref: storage_parent.buffer_ref().clone(),
153 #[allow(clippy::arc_with_non_send_sync)]
154 view_ref: Arc::new(tensor.data_ptr()),
155 };
156 Self { tensor, storage }
157 }
158}
159
160unsafe impl Send for TchTensor {}
164unsafe impl Sync for TchTensor {}
165
166impl TchTensor {
167 pub fn can_mut(&self) -> bool {
172 let stride_contains_zero = self.tensor.stride().iter().any(|&s| s == 0);
173
174 !stride_contains_zero && self.storage.can_mut()
175 }
176
177 pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
179 &mut self,
180 func: F,
181 ) -> Option<TchTensor> {
182 if !self.can_mut() {
183 return None;
184 }
185
186 let data = self.storage.clone();
187 Some(TchTensor::from_existing(func(&mut self.tensor), data))
188 }
189
190 pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
192 where
193 FOwn: Fn(tch::Tensor) -> tch::Tensor,
194 FRef: Fn(&tch::Tensor) -> tch::Tensor,
195 {
196 if !self.can_mut() {
197 return TchTensor::from_existing(fref(&self.tensor), self.storage);
198 }
199
200 TchTensor::from_existing(fown(self.tensor), self.storage)
201 }
202
203 pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
205 mut lhs: Self,
206 mut rhs: Self,
207 flmut: FLMut,
208 frmut: FRMut,
209 fref: FRef,
210 ) -> TchTensor
211 where
212 FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
213 FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
214 FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
215 {
216 let lhs_shape = lhs.shape();
217 let rhs_shape = rhs.shape();
218
219 let d_out = lhs_shape.num_dims();
221 let mut out_shape = Shape::from(vec![1usize; d_out]);
222
223 for i in 0..d_out {
224 out_shape.dims[i] = usize::max(lhs_shape.dims[i], rhs_shape.dims[i]);
225 }
226
227 let num_elements_out = out_shape.num_elements();
228
229 if lhs_shape.num_elements() == num_elements_out {
231 if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
232 return output;
233 }
234 }
235
236 if rhs_shape.num_elements() == num_elements_out {
238 if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
239 return output;
240 }
241 }
242
243 let storage = lhs.storage;
244 let tensor = fref(&lhs.tensor, &rhs.tensor);
245
246 TchTensor::from_existing(tensor, storage)
247 }
248}
249
250impl Clone for TchTensor {
251 fn clone(&self) -> Self {
252 Self {
253 tensor: self.tensor.shallow_clone(),
254 storage: self.storage.clone(),
255 }
256 }
257}
258
259#[derive(Debug)]
261pub struct TchShape {
262 pub dims: Vec<i64>,
264}
265
266impl From<Shape> for TchShape {
267 fn from(shape: Shape) -> Self {
268 TchShape {
269 dims: shape.dims.into_iter().map(|d| d as i64).collect(),
270 }
271 }
272}
273
274impl From<&[usize]> for TchShape {
275 fn from(shape: &[usize]) -> Self {
276 TchShape {
277 dims: shape.iter().map(|d| *d as i64).collect(),
278 }
279 }
280}
281
282impl TchTensor {
283 pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
294 let shape_tch = TchShape::from(data.shape.as_slice());
295 let tensor =
296 tch::Tensor::from_slice(data.convert::<E>().as_slice::<E>().unwrap()).to(device);
297 let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
298
299 Self::new(tensor)
300 }
301}
302
303impl TchTensor {
304 pub fn empty<E: tch::kind::Element>(shape: Shape, device: LibTorchDevice) -> Self {
315 let shape_tch = TchShape::from(shape);
316 let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
317
318 Self::new(tensor)
319 }
320}
321
322#[derive(Clone, Debug)]
324pub struct TchQTensor {
325 pub qtensor: TchTensor,
327 pub scheme: QuantizationScheme,
329}
330
331impl TchQTensor {
332 pub fn strategy(&self) -> QuantizationStrategy {
334 match &self.scheme {
335 QuantizationScheme::PerTensorAffine(dtype) => match dtype {
336 QuantizationType::QInt8 => {
337 let scale = self.qtensor.tensor.q_scale();
338 let offset = self.qtensor.tensor.q_zero_point();
339 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
340 scale as f32,
341 offset as i8,
342 ))
343 }
344 },
345 QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
346 QuantizationType::QInt8 => {
347 let scale = self.qtensor.tensor.q_scale();
348 QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
349 scale as f32,
350 ))
351 }
352 },
353 }
354 }
355}
356
357impl TensorMetadata for TchQTensor {
358 fn dtype(&self) -> DType {
359 DType::QFloat(self.scheme)
360 }
361
362 fn shape(&self) -> Shape {
363 self.qtensor.shape()
364 }
365}
366
367impl QTensorPrimitive for TchQTensor {
368 fn scheme(&self) -> &QuantizationScheme {
369 &self.scheme
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use crate::LibTorch;
376
377 use super::*;
378 use burn_tensor::ops::QTensorOps;
379 use burn_tensor::quantization::QuantizationParametersPrimitive;
380 use burn_tensor::{Distribution, Tensor, TensorPrimitive};
381 use rand::prelude::StdRng;
382 use rand::SeedableRng;
383
384 #[test]
385 fn should_support_into_and_from_data_1d() {
386 let data_expected = TensorData::random::<f32, _, _>(
387 Shape::new([3]),
388 Distribution::Default,
389 &mut StdRng::from_entropy(),
390 );
391 let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
392
393 let data_actual =
394 Tensor::<LibTorch<f32>, 1>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
395
396 assert_eq!(data_expected, data_actual);
397 }
398
399 #[test]
400 fn should_support_into_and_from_data_2d() {
401 let data_expected = TensorData::random::<f32, _, _>(
402 Shape::new([2, 3]),
403 Distribution::Default,
404 &mut StdRng::from_entropy(),
405 );
406 let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
407
408 let data_actual =
409 Tensor::<LibTorch<f32>, 2>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
410
411 assert_eq!(data_expected, data_actual);
412 }
413
414 #[test]
415 fn should_not_update_inplace_after_reshape() {
416 let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
417 let tensor_2 = tensor_1.clone();
418
419 let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0);
420
421 assert_ne!(
422 tensor_3.to_data().as_slice::<f32>().unwrap(),
423 tensor_1.to_data().as_slice::<f32>().unwrap()
424 );
425 }
426
427 #[test]
428 fn should_not_update_inplace_after_slice() {
429 let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
430 let tensor_2 = tensor_1.clone();
431
432 let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
433
434 assert_ne!(
435 tensor_3.to_data().as_slice::<f32>().unwrap(),
436 tensor_1.to_data().as_slice::<f32>().unwrap()
437 );
438 }
439
440 #[test]
441 fn should_support_qtensor_strategy() {
442 let tensor =
443 TchTensor::from_data::<f32>(TensorData::from([-1.8, -1.0, 0.0, 0.5]), tch::Device::Cpu);
444 let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
445 let qparams = QuantizationParametersPrimitive::<LibTorch<f32, i8>> {
446 scale: TchTensor::from_data::<f32>(TensorData::from([0.009_019_608]), tch::Device::Cpu),
447 offset: Some(TchTensor::from_data::<i8>(
448 TensorData::from([72]),
449 tch::Device::Cpu,
450 )),
451 };
452 let qtensor: TchQTensor = LibTorch::quantize(tensor, &scheme, qparams);
453
454 assert_eq!(qtensor.scheme(), &scheme);
455 assert_eq!(
456 qtensor.strategy(),
457 QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
458 );
459 }
460}