burn_tch/
tensor.rs

1use crate::{LibTorchDevice, TchElement};
2use burn_tensor::{DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};
3use libc::c_void;
4use std::sync::Arc;
5
6/// A reference to a tensor storage.
7///
8/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
9#[allow(clippy::arc_with_non_send_sync)]
10pub type StorageRef = Arc<*mut c_void>;
11
12/// A reference to a tensor storage.
13#[derive(PartialEq, Debug, Clone)]
14pub enum Storage {
15    /// When a tensor is a partial view of another tensor.
16    View {
17        /// Storage reference for the whole buffer.
18        buffer_ref: StorageRef,
19        /// Storage reference for the partial buffer.
20        view_ref: StorageRef,
21    },
22    /// When a tensor use all of its buffer.
23    Owned {
24        /// Storage reference for the whole buffer.
25        buffer_ref: StorageRef,
26    },
27}
28
29impl Storage {
30    /// Check if the storage can be used inplace.
31    pub fn can_mut(&self) -> bool {
32        match self {
33            Storage::View {
34                buffer_ref: start_ref,
35                view_ref,
36            } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
37            Storage::Owned {
38                buffer_ref: start_ref,
39            } => Arc::strong_count(start_ref) == 1,
40        }
41    }
42
43    /// Get the whole buffer reference.
44    pub fn buffer_ref(&self) -> &StorageRef {
45        match self {
46            Storage::View {
47                buffer_ref: start_ref,
48                view_ref: _,
49            } => start_ref,
50            Storage::Owned {
51                buffer_ref: start_ref,
52            } => start_ref,
53        }
54    }
55}
56
57/// A tensor using the tch backend.
58#[derive(Debug, PartialEq)]
59pub struct TchTensor {
60    /// Handle to the tensor. Call methods on this field.
61    pub tensor: tch::Tensor,
62
63    /// The tensor's storage
64    pub storage: Storage,
65}
66
67impl TensorMetadata for TchTensor {
68    fn dtype(&self) -> DType {
69        match self.tensor.kind() {
70            tch::Kind::Uint8 => DType::U8,
71            tch::Kind::Int8 => DType::I8,
72            tch::Kind::Int16 => DType::I16,
73            tch::Kind::Int => DType::I32,
74            tch::Kind::Int64 => DType::I64,
75            tch::Kind::Half => DType::F16,
76            tch::Kind::Float => DType::F32,
77            tch::Kind::Double => DType::F64,
78            tch::Kind::Bool => DType::Bool,
79            tch::Kind::BFloat16 => DType::BF16,
80            // Complex and quantization types are not valid/implemented.
81            _ => unimplemented!(),
82        }
83    }
84
85    fn shape(&self) -> Shape {
86        Shape::from(self.tensor.size())
87    }
88
89    fn rank(&self) -> usize {
90        self.tensor.dim()
91    }
92}
93
94impl burn_tensor::quantization::QTensorPrimitive for TchTensor {
95    fn scheme(&self) -> &burn_tensor::quantization::QuantScheme {
96        unimplemented!("Quantization is not supported")
97    }
98}
99
100pub(crate) trait IntoKind {
101    fn into_kind(self) -> tch::Kind;
102}
103
104impl IntoKind for FloatDType {
105    fn into_kind(self) -> tch::Kind {
106        match self {
107            FloatDType::F64 => tch::Kind::Double,
108            FloatDType::F32 => tch::Kind::Float,
109            FloatDType::Flex32 => tch::Kind::Float,
110            FloatDType::F16 => tch::Kind::Half,
111            FloatDType::BF16 => tch::Kind::BFloat16,
112        }
113    }
114}
115
116impl IntoKind for IntDType {
117    fn into_kind(self) -> tch::Kind {
118        match self {
119            IntDType::I64 => tch::Kind::Int64,
120            IntDType::I32 => tch::Kind::Int,
121            IntDType::I16 => tch::Kind::Int16,
122            IntDType::I8 => tch::Kind::Int8,
123            IntDType::U64 => tch::Kind::Uint8,
124            other => panic!("Unsupported dtype {other:?}"),
125        }
126    }
127}
128
129impl TchTensor {
130    /// Create a new tensor.
131    ///
132    /// Note that if the tensor was created from an operation that may reuse the same tensor
133    /// storage as the parent, you should use [from_existing](TchTensor::from_existing)
134    /// instead.
135    pub fn new(tensor: tch::Tensor) -> Self {
136        #[allow(clippy::arc_with_non_send_sync)]
137        let storage = Storage::Owned {
138            buffer_ref: Arc::new(tensor.data_ptr()),
139        };
140
141        Self { tensor, storage }
142    }
143
144    /// Create a tensor that was created from an operation executed on a parent tensor.
145    ///
146    /// If the child tensor shared the same storage as its parent, it will be cloned, effectively
147    /// tracking how much tensors point to the same memory space.
148    pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
149        let storage_child = tensor.data_ptr();
150        let mut is_a_new_tensor = true;
151
152        match &storage_parent {
153            Storage::View {
154                buffer_ref: start_ref,
155                view_ref,
156            } => {
157                if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
158                    is_a_new_tensor = false;
159                }
160            }
161            Storage::Owned {
162                buffer_ref: start_ref,
163            } => {
164                if storage_child == *start_ref.as_ref() {
165                    is_a_new_tensor = false;
166                }
167            }
168        };
169
170        let storage = match is_a_new_tensor {
171            true => Storage::Owned {
172                #[allow(clippy::arc_with_non_send_sync)]
173                buffer_ref: Arc::new(storage_child),
174            },
175            false => storage_parent.clone(),
176        };
177
178        Self { tensor, storage }
179    }
180
181    /// Create a tensor that uses a part of its parent tensor such as slice and narrow.
182    pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
183        let storage = Storage::View {
184            buffer_ref: storage_parent.buffer_ref().clone(),
185            #[allow(clippy::arc_with_non_send_sync)]
186            view_ref: Arc::new(tensor.data_ptr()),
187        };
188        Self { tensor, storage }
189    }
190}
191
192// This is safe since we don't use autodiff from LibTorch.
193// Also, atomic reference counting is used to know if the tensor's data can be reused.
194// If there are multiple reference on the same tensor, it becomes read only.
195unsafe impl Send for TchTensor {}
196unsafe impl Sync for TchTensor {}
197
198impl TchTensor {
199    /// Checks if the tensor can be mutated in-place.
200    ///
201    /// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
202    /// and the storage can be mutated.
203    pub fn can_mut(&self) -> bool {
204        let stride_contains_zero = self.tensor.stride().contains(&0);
205
206        !stride_contains_zero && self.storage.can_mut()
207    }
208
209    /// Executes an operation on a tensor if the data can be reused.
210    pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
211        &mut self,
212        func: F,
213    ) -> Option<TchTensor> {
214        if !self.can_mut() {
215            return None;
216        }
217
218        let data = self.storage.clone();
219        Some(TchTensor::from_existing(func(&mut self.tensor), data))
220    }
221
222    /// Executes a unary operation, reusing the tensor data if possible.
223    pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
224    where
225        FOwn: Fn(tch::Tensor) -> tch::Tensor,
226        FRef: Fn(&tch::Tensor) -> tch::Tensor,
227    {
228        if !self.can_mut() {
229            return TchTensor::from_existing(fref(&self.tensor), self.storage);
230        }
231
232        TchTensor::from_existing(fown(self.tensor), self.storage)
233    }
234
235    /// Executes a binary operation, reusing the tensor data if possible.
236    pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
237        mut lhs: Self,
238        mut rhs: Self,
239        flmut: FLMut,
240        frmut: FRMut,
241        fref: FRef,
242    ) -> TchTensor
243    where
244        FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
245        FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
246        FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
247    {
248        let lhs_shape = lhs.shape();
249        let rhs_shape = rhs.shape();
250
251        // Both lhs and rhs are expected to have the same rank
252        let d_out = lhs_shape.num_dims();
253        let mut out_shape = Shape::from(vec![1usize; d_out]);
254
255        for i in 0..d_out {
256            out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);
257        }
258
259        let num_elements_out = out_shape.num_elements();
260
261        // Attempt to mutate lhs tensor
262        if lhs_shape.num_elements() == num_elements_out
263            && let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))
264        {
265            return output;
266        }
267
268        // Attempt to mutate rhs tensor
269        if rhs_shape.num_elements() == num_elements_out
270            && let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))
271        {
272            return output;
273        }
274
275        let storage = lhs.storage;
276        let tensor = fref(&lhs.tensor, &rhs.tensor);
277
278        TchTensor::from_existing(tensor, storage)
279    }
280}
281
282impl Clone for TchTensor {
283    fn clone(&self) -> Self {
284        Self {
285            tensor: self.tensor.shallow_clone(),
286            storage: self.storage.clone(),
287        }
288    }
289}
290
291/// A shape that can be used by LibTorch.
292#[derive(Debug)]
293pub struct TchShape {
294    /// The shape's dimensions.
295    pub dims: Vec<i64>,
296}
297
298impl From<Shape> for TchShape {
299    fn from(shape: Shape) -> Self {
300        TchShape {
301            dims: shape.into_iter().map(|d| d as i64).collect(),
302        }
303    }
304}
305
306impl From<&[usize]> for TchShape {
307    fn from(shape: &[usize]) -> Self {
308        TchShape {
309            dims: shape.iter().map(|d| *d as i64).collect(),
310        }
311    }
312}
313
314impl TchTensor {
315    /// Creates a new tensor from a shape and a device.
316    ///
317    /// # Arguments
318    ///
319    /// * `data` - The tensor's data.
320    /// * `device` - The device on which the tensor will be allocated.
321    ///
322    /// # Returns
323    ///
324    /// A new tensor.
325    pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
326        let shape_tch = TchShape::from(data.shape.as_slice());
327        let tensor = tch::Tensor::from_slice(data.as_slice::<E>().unwrap()).to(device);
328        let tensor = tensor.reshape(shape_tch.dims).to_kind(E::KIND);
329
330        Self::new(tensor)
331    }
332}
333
334impl TchTensor {
335    /// Creates an empty tensor from a shape and a device.
336    ///
337    /// # Arguments
338    ///
339    /// * `shape` - The shape of the tensor.
340    /// * `device` - The device to create the tensor on.
341    ///
342    /// # Returns
343    ///
344    /// A new empty tensor.
345    pub fn empty<E: tch::kind::Element>(shape: Shape, device: LibTorchDevice) -> Self {
346        let shape_tch = TchShape::from(shape);
347        let tensor = tch::Tensor::empty(shape_tch.dims, (E::KIND, device.into()));
348
349        Self::new(tensor)
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use crate::LibTorch;
356
357    use super::*;
358    use burn_tensor::{Distribution, Tensor, TensorPrimitive};
359    use rand::SeedableRng;
360    use rand::prelude::StdRng;
361
362    #[test]
363    fn should_support_into_and_from_data_1d() {
364        let data_expected = TensorData::random::<f32, _, _>(
365            Shape::new([3]),
366            Distribution::Default,
367            &mut StdRng::from_os_rng(),
368        );
369        let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
370
371        let data_actual =
372            Tensor::<LibTorch<f32>, 1>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
373
374        assert_eq!(data_expected, data_actual);
375    }
376
377    #[test]
378    fn should_support_into_and_from_data_2d() {
379        let data_expected = TensorData::random::<f32, _, _>(
380            Shape::new([2, 3]),
381            Distribution::Default,
382            &mut StdRng::from_os_rng(),
383        );
384        let tensor = TchTensor::from_data::<f32>(data_expected.clone(), tch::Device::Cpu);
385
386        let data_actual =
387            Tensor::<LibTorch<f32>, 2>::from_primitive(TensorPrimitive::Float(tensor)).into_data();
388
389        assert_eq!(data_expected, data_actual);
390    }
391
392    #[test]
393    fn should_not_update_inplace_after_reshape() {
394        let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
395        let tensor_2 = tensor_1.clone();
396
397        let tensor_3 = tensor_2.reshape([1, 2]).add_scalar(2.0);
398
399        assert_ne!(
400            tensor_3.to_data().as_slice::<f32>().unwrap(),
401            tensor_1.to_data().as_slice::<f32>().unwrap()
402        );
403    }
404
405    #[test]
406    fn should_not_update_inplace_after_slice() {
407        let tensor_1 = Tensor::<LibTorch<f32>, 1>::from_floats([4.0, 4.0], &Default::default());
408        let tensor_2 = tensor_1.clone();
409
410        let tensor_3 = tensor_2.slice([0..2]).add_scalar(2.0);
411
412        assert_ne!(
413            tensor_3.to_data().as_slice::<f32>().unwrap(),
414            tensor_1.to_data().as_slice::<f32>().unwrap()
415        );
416    }
417}
418
419unsafe extern "C" {
420    /// Dummy function to get CUDA to link properly
421    pub fn dummy_cuda_dependency();
422}
423
424#[used]
425static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];