burn_tch/
tensor.rs

1use crate::{LibTorchDevice, TchElement};
2use burn_tensor::{
3    quantization::{
4        AffineQuantization, QTensorPrimitive, QuantizationScheme, QuantizationStrategy,
5        QuantizationType, SymmetricQuantization,
6    },
7    DType, Shape, TensorData, TensorMetadata,
8};
9use libc::c_void;
10use std::sync::Arc;
11
12/// A reference to a tensor storage.
13///
14/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
15#[allow(clippy::arc_with_non_send_sync)]
16pub type StorageRef = Arc<*mut c_void>;
17
18/// A reference to a tensor storage.
19#[derive(PartialEq, Debug, Clone)]
20pub enum Storage {
21    /// When a tensor is a partial view of another tensor.
22    View {
23        /// Storage reference for the whole buffer.
24        buffer_ref: StorageRef,
25        /// Storage reference for the partial buffer.
26        view_ref: StorageRef,
27    },
28    /// When a tensor use all of its buffer.
29    Owned {
30        /// Storage reference for the whole buffer.
31        buffer_ref: StorageRef,
32    },
33}
34
35impl Storage {
36    /// Check if the storage can be used inplace.
37    pub fn can_mut(&self) -> bool {
38        match self {
39            Storage::View {
40                buffer_ref: start_ref,
41                view_ref,
42            } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
43            Storage::Owned {
44                buffer_ref: start_ref,
45            } => Arc::strong_count(start_ref) == 1,
46        }
47    }
48
49    /// Get the whole buffer reference.
50    pub fn buffer_ref(&self) -> &StorageRef {
51        match self {
52            Storage::View {
53                buffer_ref: start_ref,
54                view_ref: _,
55            } => start_ref,
56            Storage::Owned {
57                buffer_ref: start_ref,
58            } => start_ref,
59        }
60    }
61}
62
63/// A tensor using the tch backend.
64#[derive(Debug, PartialEq)]
65pub struct TchTensor {
66    /// Handle to the tensor. Call methods on this field.
67    pub tensor: tch::Tensor,
68
69    /// The tensor's storage
70    pub storage: Storage,
71}
72
73impl TensorMetadata for TchTensor {
74    fn dtype(&self) -> DType {
75        match self.tensor.kind() {
76            tch::Kind::Uint8 => DType::U8,
77            tch::Kind::Int8 => DType::I8,
78            tch::Kind::Int16 => DType::I16,
79            tch::Kind::Int => DType::I32,
80            tch::Kind::Int64 => DType::I64,
81            tch::Kind::Half => DType::F16,
82            tch::Kind::Float => DType::F32,
83            tch::Kind::Double => DType::F64,
84            tch::Kind::Bool => DType::Bool,
85            tch::Kind::QUInt8 => DType::U8,
86            tch::Kind::BFloat16 => DType::BF16,
87            // Complex and quantization types are not valid/implemented.
88            _ => unimplemented!(),
89        }
90    }
91
92    fn shape(&self) -> Shape {
93        Shape::from(self.tensor.size())
94    }
95}
96
97impl TchTensor {
98    /// Create a new tensor.
99    ///
100    /// Note that if the tensor was created from an operation that may reuse the same tensor
101    /// storage as the parent, you should use [from_existing](TchTensor::from_existing)
102    /// instead.
103    pub fn new(tensor: tch::Tensor) -> Self {
104        #[allow(clippy::arc_with_non_send_sync)]
105        let storage = Storage::Owned {
106            buffer_ref: Arc::new(tensor.data_ptr()),
107        };
108
109        Self { tensor, storage }
110    }
111
112    /// Create a tensor that was created from an operation executed on a parent tensor.
113    ///
114    /// If the child tensor shared the same storage as its parent, it will be cloned, effectively
115    /// tracking how much tensors point to the same memory space.
116    pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
117        let storage_child = tensor.data_ptr();
118        let mut is_a_new_tensor = true;
119
120        match &storage_parent {
121            Storage::View {
122                buffer_ref: start_ref,
123                view_ref,
124            } => {
125                if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
126                    is_a_new_tensor = false;
127                }
128            }
129            Storage::Owned {
130                buffer_ref: start_ref,
131            } => {
132                if storage_child == *start_ref.as_ref() {
133                    is_a_new_tensor = false;
134                }
135            }
136        };
137
138        let storage = match is_a_new_tensor {
139            true => Storage::Owned {
140                #[allow(clippy::arc_with_non_send_sync)]
141                buffer_ref: Arc::new(storage_child),
142            },
143            false => storage_parent.clone(),
144        };
145
146        Self { tensor, storage }
147    }
148
149    /// Create a tensor that uses a part of its parent tensor such as slice and narrow.
150    pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
151        let storage = Storage::View {
152            buffer_ref: storage_parent.buffer_ref().clone(),
153            #[allow(clippy::arc_with_non_send_sync)]
154            view_ref: Arc::new(tensor.data_ptr()),
155        };
156        Self { tensor, storage }
157    }
158}
159
160// This is safe since we don't use autodiff from LibTorch.
161// Also, atomic reference counting is used to know if the tensor's data can be reused.
162// If there are multiple reference on the same tensor, it becomes read only.
163unsafe impl Send for TchTensor {}
164unsafe impl Sync for TchTensor {}
165
166impl TchTensor {
167    /// Checks if the tensor can be mutated in-place.
168    ///
169    /// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
170    /// and the storage can be mutated.
171    pub fn can_mut(&self) -> bool {
172        let stride_contains_zero = self.tensor.stride().iter().any(|&s| s == 0);
173
174        !stride_contains_zero && self.storage.can_mut()
175    }
176
177    /// Executes an operation on a tensor if the data can be reused.
178    pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
179        &mut self,
180        func: F,
181    ) -> Option<TchTensor> {
182        if !self.can_mut() {
183            return None;
184        }
185
186        let data = self.storage.clone();
187        Some(TchTensor::from_existing(func(&mut self.tensor), data))
188    }
189
190    /// Executes a unary operation, reusing the tensor data if possible.
191    pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
192    where
193        FOwn: Fn(tch::Tensor) -> tch::Tensor,
194        FRef: Fn(&tch::Tensor) -> tch::Tensor,
195    {
196        if !self.can_mut() {
197            return TchTensor::from_existing(fref(&self.tensor), self.storage);
198        }
199
200        TchTensor::from_existing(fown(self.tensor), self.storage)
201    }
202
203    /// Executes a binary operation, reusing the tensor data if possible.
204    pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
205        mut lhs: Self,
206        mut rhs: Self,
207        flmut: FLMut,
208        frmut: FRMut,
209        fref: FRef,
210    ) -> TchTensor
211    where
212        FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
213        FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
214        FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
215    {
216        let lhs_shape = lhs.shape();
217        let rhs_shape = rhs.shape();
218
219        // Both lhs and rhs are expected to have the same rank
220        let d_out = lhs_shape.num_dims();
221        let mut out_shape = Shape::from(vec![1usize; d_out]);
222
223        for i in 0..d_out {
224            out_shape.dims[i] = usize::max(lhs_shape.dims[i], rhs_shape.dims[i]);
225        }
226
227        let num_elements_out = out_shape.num_elements();
228
229        // Attempt to mutate lhs tensor
230        if lhs_shape.num_elements() == num_elements_out {
231            if let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor)) {
232                return output;
233            }
234        }
235
236        // Attempt to mutate rhs tensor
237        if rhs_shape.num_elements() == num_elements_out {
238            if let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs)) {
239                return output;
240            }
241        }
242
243        let storage = lhs.storage;
244        let tensor = fref(&lhs.tensor, &rhs.tensor);
245
246        TchTensor::from_existing(tensor, storage)
247    }
248}
249
250impl Clone for TchTensor {
251    fn clone(&self) -> Self {
252        Self {
253            tensor: self.tensor.shallow_clone(),
254            storage: self.storage.clone(),
255        }
256    }
257}
258
259/// A shape that can be used by LibTorch.
260#[derive(Debug)]
261pub struct TchShape {
262    /// The shape's dimensions.
263    pub dims: Vec<i64>,
264}
265
266impl From<Shape> for TchShape {
267    fn from(shape: Shape) -> Self {
268        TchShape {
269            dims: shape.dims.into_iter().map(|d| d as i64).collect(),
270        }
271    }
272}
273
274impl From<&[usize]> for TchShape {
275    fn from(shape: &[usize]) -> Self {
276        TchShape {
277            dims: shape.iter().map(|d| *d as i64).collect(),
278        }
279    }
280}
281
282impl TchTensor {
283    /// Creates a new tensor from a shape and a device.
284    ///
285    /// # Arguments
286    ///
287    /// * `data` - The tensor's data.
288    /// * `device` - The device on which the tensor will be allocated.
289    ///
290    /// # Returns
291    ///
292    /// A new tensor.
293    pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
294        let shape_tch = TchShape::from(data.shape.as_slice());
295        let tensor =
296            tch::Tensor::from_slice(data.convert::<E>().as_slice::<E>().unwrap()).to(device);
297        let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
298
299        Self::new(tensor)
300    }
301}
302
303impl TchTensor {
304    /// Creates an empty tensor from a shape and a device.
305    ///
306    /// # Arguments
307    ///
308    /// * `shape` - The shape of the tensor.
309    /// * `device` - The device to create the tensor on.
310    ///
311    /// # Returns
312    ///
313    /// A new empty tensor.
314    pub fn empty<E: tch::kind::Element>(shape: Shape, device: LibTorchDevice) -> Self {
315        let shape_tch = TchShape::from(shape);
316        let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
317
318        Self::new(tensor)
319    }
320}
321
322/// A quantized tensor for the tch backend.
323#[derive(Clone, Debug)]
324pub struct TchQTensor {
325    /// The quantized tensor.
326    pub qtensor: TchTensor,
327    /// The quantization scheme.
328    pub scheme: QuantizationScheme,
329}
330
331impl TchQTensor {
332    /// Returns the quantization strategy, including quantization parameters, for the given tensor.
333    pub fn strategy(&self) -> QuantizationStrategy {
334        match &self.scheme {
335            QuantizationScheme::PerTensorAffine(dtype) => match dtype {
336                QuantizationType::QInt8 => {
337                    let scale = self.qtensor.tensor.q_scale();
338                    let offset = self.qtensor.tensor.q_zero_point();
339                    QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(
340                        scale as f32,
341                        offset as i8,
342                    ))
343                }
344            },
345            QuantizationScheme::PerTensorSymmetric(dtype) => match dtype {
346                QuantizationType::QInt8 => {
347                    let scale = self.qtensor.tensor.q_scale();
348                    QuantizationStrategy::PerTensorSymmetricInt8(SymmetricQuantization::init(
349                        scale as f32,
350                    ))
351                }
352            },
353        }
354    }
355}
356
357impl TensorMetadata for TchQTensor {
358    fn dtype(&self) -> DType {
359        DType::QFloat(self.scheme)
360    }
361
362    fn shape(&self) -> Shape {
363        self.qtensor.shape()
364    }
365}
366
367impl QTensorPrimitive for TchQTensor {
368    fn scheme(&self) -> &QuantizationScheme {
369        &self.scheme
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use crate::LibTorch;
376
377    use super::*;
378    use burn_tensor::ops::QTensorOps;
379    use burn_tensor::quantization::QuantizationParametersPrimitive;
380    use burn_tensor::{Distribution, Tensor, TensorPrimitive};
381    use rand::prelude::StdRng;
382    use rand::SeedableRng;
383
384    #[test]
385    fn should_support_into_and_from_data_1d() {
386        let data_expected = TensorData::random::<f32, _, _>(
387            Shape::new([3]),
388            Distribution::Default,
389            &mut StdRng::from_entropy(),
390        );
391        let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
392
393        let data_actual =
394            Tensor::<LibTorch<f32>, 1>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
395
396        assert_eq!(data_expected, data_actual);
397    }
398
399    #[test]
400    fn should_support_into_and_from_data_2d() {
401        let data_expected = TensorData::random::<f32, _, _>(
402            Shape::new([2, 3]),
403            Distribution::Default,
404            &mut StdRng::from_entropy(),
405        );
406        let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
407
408        let data_actual =
409            Tensor::<LibTorch<f32>, 2>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
410
411        assert_eq!(data_expected, data_actual);
412    }
413
414    #[test]
415    fn should_not_update_inplace_after_reshape() {
416        let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
417        let tensor_2 = tensor_1.clone();
418
419        let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0);
420
421        assert_ne!(
422            tensor_3.to_data().as_slice::<f32>().unwrap(),
423            tensor_1.to_data().as_slice::<f32>().unwrap()
424        );
425    }
426
427    #[test]
428    fn should_not_update_inplace_after_slice() {
429        let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
430        let tensor_2 = tensor_1.clone();
431
432        let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
433
434        assert_ne!(
435            tensor_3.to_data().as_slice::<f32>().unwrap(),
436            tensor_1.to_data().as_slice::<f32>().unwrap()
437        );
438    }
439
440    #[test]
441    fn should_support_qtensor_strategy() {
442        let tensor =
443            TchTensor::from_data::<f32>(TensorData::from([-1.8, -1.0, 0.0, 0.5]), tch::Device::Cpu);
444        let scheme = QuantizationScheme::PerTensorAffine(QuantizationType::QInt8);
445        let qparams = QuantizationParametersPrimitive::<LibTorch<f32, i8>> {
446            scale: TchTensor::from_data::<f32>(TensorData::from([0.009_019_608]), tch::Device::Cpu),
447            offset: Some(TchTensor::from_data::<i8>(
448                TensorData::from([72]),
449                tch::Device::Cpu,
450            )),
451        };
452        let qtensor: TchQTensor = LibTorch::quantize(tensor, &scheme, qparams);
453
454        assert_eq!(qtensor.scheme(), &scheme);
455        assert_eq!(
456            qtensor.strategy(),
457            QuantizationStrategy::PerTensorAffineInt8(AffineQuantization::init(0.009_019_608, 72))
458        );
459    }
460}