burn_tch/
tensor.rs

1use crate::{LibTorchDevice, TchElement};
2use burn_backend::{DType, FloatDType, IntDType, Shape, TensorData, TensorMetadata};
3use libc::c_void;
4use std::sync::Arc;
5
6/// A reference to a tensor storage.
7///
8/// We manually implement `Sync` and `Send` unsafely, so even if we could use `Rc`, it isn't safe.
9#[allow(clippy::arc_with_non_send_sync)]
10pub type StorageRef = Arc<*mut c_void>;
11
12/// A reference to a tensor storage.
13#[derive(PartialEq, Debug, Clone)]
14pub enum Storage {
15    /// When a tensor is a partial view of another tensor.
16    View {
17        /// Storage reference for the whole buffer.
18        buffer_ref: StorageRef,
19        /// Storage reference for the partial buffer.
20        view_ref: StorageRef,
21    },
22    /// When a tensor use all of its buffer.
23    Owned {
24        /// Storage reference for the whole buffer.
25        buffer_ref: StorageRef,
26    },
27}
28
29impl Storage {
30    /// Check if the storage can be used inplace.
31    pub fn can_mut(&self) -> bool {
32        match self {
33            Storage::View {
34                buffer_ref: start_ref,
35                view_ref,
36            } => Arc::strong_count(start_ref) == 1 && Arc::strong_count(view_ref) == 1,
37            Storage::Owned {
38                buffer_ref: start_ref,
39            } => Arc::strong_count(start_ref) == 1,
40        }
41    }
42
43    /// Get the whole buffer reference.
44    pub fn buffer_ref(&self) -> &StorageRef {
45        match self {
46            Storage::View {
47                buffer_ref: start_ref,
48                view_ref: _,
49            } => start_ref,
50            Storage::Owned {
51                buffer_ref: start_ref,
52            } => start_ref,
53        }
54    }
55}
56
57/// A tensor using the tch backend.
58#[derive(Debug, PartialEq)]
59pub struct TchTensor {
60    /// Handle to the tensor. Call methods on this field.
61    pub tensor: tch::Tensor,
62
63    /// The tensor's storage
64    pub storage: Storage,
65}
66
67impl TensorMetadata for TchTensor {
68    fn dtype(&self) -> DType {
69        match self.tensor.kind() {
70            tch::Kind::Uint8 => DType::U8,
71            tch::Kind::Int8 => DType::I8,
72            tch::Kind::Int16 => DType::I16,
73            tch::Kind::Int => DType::I32,
74            tch::Kind::Int64 => DType::I64,
75            tch::Kind::Half => DType::F16,
76            tch::Kind::Float => DType::F32,
77            tch::Kind::Double => DType::F64,
78            tch::Kind::Bool => DType::Bool,
79            tch::Kind::BFloat16 => DType::BF16,
80            // Complex and quantization types are not valid/implemented.
81            _ => unimplemented!(),
82        }
83    }
84
85    fn shape(&self) -> Shape {
86        Shape::from(self.tensor.size())
87    }
88
89    fn rank(&self) -> usize {
90        self.tensor.dim()
91    }
92}
93
94impl burn_backend::QTensorPrimitive for TchTensor {
95    fn scheme(&self) -> &burn_backend::quantization::QuantScheme {
96        unimplemented!("Quantization is not supported")
97    }
98}
99
100impl core::fmt::Display for TchTensor {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(f, "{}", self.tensor)
103    }
104}
105
106pub(crate) trait IntoKind {
107    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError>;
108    fn into_kind(self) -> tch::Kind
109    where
110        Self: Sized,
111    {
112        self.try_into_kind().unwrap()
113    }
114}
115
116impl IntoKind for IntDType {
117    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
118        let dtype: DType = self.into();
119        dtype.try_into_kind()
120    }
121}
122
123impl IntoKind for FloatDType {
124    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
125        let dtype: DType = self.into();
126        dtype.try_into_kind()
127    }
128}
129
130impl IntoKind for DType {
131    fn try_into_kind(self) -> Result<tch::Kind, tch::TchError> {
132        match self {
133            DType::F64 => Ok(tch::Kind::Double),
134            DType::F32 => Ok(tch::Kind::Float),
135            DType::Flex32 => Ok(tch::Kind::Float),
136            DType::F16 => Ok(tch::Kind::Half),
137            DType::BF16 => Ok(tch::Kind::BFloat16),
138            DType::I64 => Ok(tch::Kind::Int64),
139            DType::I32 => Ok(tch::Kind::Int),
140            DType::I16 => Ok(tch::Kind::Int16),
141            DType::I8 => Ok(tch::Kind::Int8),
142            DType::U8 => Ok(tch::Kind::Uint8),
143            DType::Bool => Ok(tch::Kind::Bool),
144            other => Err(tch::TchError::Kind(format!("Unsupported dtype {other:?}"))),
145        }
146    }
147}
148
149impl TchTensor {
150    /// Create a new tensor.
151    ///
152    /// Note that if the tensor was created from an operation that may reuse the same tensor
153    /// storage as the parent, you should use [from_existing](TchTensor::from_existing)
154    /// instead.
155    pub fn new(tensor: tch::Tensor) -> Self {
156        #[allow(clippy::arc_with_non_send_sync)]
157        let storage = Storage::Owned {
158            buffer_ref: Arc::new(tensor.data_ptr()),
159        };
160
161        Self { tensor, storage }
162    }
163
164    /// Create a tensor that was created from an operation executed on a parent tensor.
165    ///
166    /// If the child tensor shared the same storage as its parent, it will be cloned, effectively
167    /// tracking how much tensors point to the same memory space.
168    pub fn from_existing(tensor: tch::Tensor, storage_parent: Storage) -> Self {
169        let storage_child = tensor.data_ptr();
170        let mut is_a_new_tensor = true;
171
172        match &storage_parent {
173            Storage::View {
174                buffer_ref: start_ref,
175                view_ref,
176            } => {
177                if storage_child == *start_ref.as_ref() || storage_child == *view_ref.as_ref() {
178                    is_a_new_tensor = false;
179                }
180            }
181            Storage::Owned {
182                buffer_ref: start_ref,
183            } => {
184                if storage_child == *start_ref.as_ref() {
185                    is_a_new_tensor = false;
186                }
187            }
188        };
189
190        let storage = match is_a_new_tensor {
191            true => Storage::Owned {
192                #[allow(clippy::arc_with_non_send_sync)]
193                buffer_ref: Arc::new(storage_child),
194            },
195            false => storage_parent.clone(),
196        };
197
198        Self { tensor, storage }
199    }
200
201    /// Create a tensor that uses a part of its parent tensor such as slice and narrow.
202    pub fn partial(tensor: tch::Tensor, storage_parent: Storage) -> Self {
203        let storage = Storage::View {
204            buffer_ref: storage_parent.buffer_ref().clone(),
205            #[allow(clippy::arc_with_non_send_sync)]
206            view_ref: Arc::new(tensor.data_ptr()),
207        };
208        Self { tensor, storage }
209    }
210}
211
212// This is safe since we don't use autodiff from LibTorch.
213// Also, atomic reference counting is used to know if the tensor's data can be reused.
214// If there are multiple reference on the same tensor, it becomes read only.
215unsafe impl Send for TchTensor {}
216unsafe impl Sync for TchTensor {}
217
218impl TchTensor {
219    /// Checks if the tensor can be mutated in-place.
220    ///
221    /// Returns `true` if the tensor's stride does not contain zero (no broadcasting)
222    /// and the storage can be mutated.
223    pub fn can_mut(&self) -> bool {
224        let stride_contains_zero = self.tensor.stride().contains(&0);
225
226        !stride_contains_zero && self.storage.can_mut()
227    }
228
229    /// Executes an operation on a tensor if the data can be reused.
230    pub fn mut_ops<F: Fn(&mut tch::Tensor) -> tch::Tensor>(
231        &mut self,
232        func: F,
233    ) -> Option<TchTensor> {
234        if !self.can_mut() {
235            return None;
236        }
237
238        let data = self.storage.clone();
239        Some(TchTensor::from_existing(func(&mut self.tensor), data))
240    }
241
242    /// Executes a unary operation, reusing the tensor data if possible.
243    pub fn unary_ops<FOwn, FRef>(self, fown: FOwn, fref: FRef) -> TchTensor
244    where
245        FOwn: Fn(tch::Tensor) -> tch::Tensor,
246        FRef: Fn(&tch::Tensor) -> tch::Tensor,
247    {
248        if !self.can_mut() {
249            return TchTensor::from_existing(fref(&self.tensor), self.storage);
250        }
251
252        TchTensor::from_existing(fown(self.tensor), self.storage)
253    }
254
255    /// Executes a binary operation, reusing the tensor data if possible.
256    pub fn binary_ops_tensor<FLMut, FRMut, FRef>(
257        mut lhs: Self,
258        mut rhs: Self,
259        flmut: FLMut,
260        frmut: FRMut,
261        fref: FRef,
262    ) -> TchTensor
263    where
264        FLMut: Fn(&mut tch::Tensor, &tch::Tensor) -> tch::Tensor,
265        FRMut: Fn(&tch::Tensor, &mut tch::Tensor) -> tch::Tensor,
266        FRef: Fn(&tch::Tensor, &tch::Tensor) -> tch::Tensor,
267    {
268        let lhs_shape = lhs.shape();
269        let rhs_shape = rhs.shape();
270
271        // Both lhs and rhs are expected to have the same rank
272        let d_out = lhs_shape.num_dims();
273        let mut out_shape = Shape::from(vec![1usize; d_out]);
274
275        for i in 0..d_out {
276            out_shape[i] = usize::max(lhs_shape[i], rhs_shape[i]);
277        }
278
279        let num_elements_out = out_shape.num_elements();
280
281        // Attempt to mutate lhs tensor
282        if lhs_shape.num_elements() == num_elements_out
283            && let Some(output) = lhs.mut_ops(|lhs| flmut(lhs, &rhs.tensor))
284        {
285            return output;
286        }
287
288        // Attempt to mutate rhs tensor
289        if rhs_shape.num_elements() == num_elements_out
290            && let Some(output) = rhs.mut_ops(|rhs| frmut(&lhs.tensor, rhs))
291        {
292            return output;
293        }
294
295        let storage = lhs.storage;
296        let tensor = fref(&lhs.tensor, &rhs.tensor);
297
298        TchTensor::from_existing(tensor, storage)
299    }
300}
301
302impl Clone for TchTensor {
303    fn clone(&self) -> Self {
304        Self {
305            tensor: self.tensor.shallow_clone(),
306            storage: self.storage.clone(),
307        }
308    }
309}
310
311/// A shape that can be used by LibTorch.
312#[derive(Debug)]
313pub struct TchShape {
314    /// The shape's dimensions.
315    pub dims: Vec<i64>,
316}
317
318impl From<Shape> for TchShape {
319    fn from(shape: Shape) -> Self {
320        TchShape {
321            dims: shape.dims.into_iter().map(|d| d as i64).collect(),
322        }
323    }
324}
325
326impl From<&[usize]> for TchShape {
327    fn from(shape: &[usize]) -> Self {
328        TchShape {
329            dims: shape.iter().map(|d| *d as i64).collect(),
330        }
331    }
332}
333
334impl TchTensor {
335    /// Creates a new tensor from a shape and a device.
336    ///
337    /// # Arguments
338    ///
339    /// * `data` - The tensor's data.
340    /// * `device` - The device on which the tensor will be allocated.
341    ///
342    /// # Returns
343    ///
344    /// A new tensor.
345    pub fn from_data<E: TchElement>(data: TensorData, device: tch::Device) -> Self {
346        let shape_tch = TchShape::from(data.shape.as_slice());
347        let tensor =
348            tch::Tensor::from_data_size(&data.bytes, &shape_tch.dims, E::kind()).to(device);
349
350        Self::new(tensor)
351    }
352}
353
354impl TchTensor {
355    /// Creates an empty tensor from a shape and a device.
356    ///
357    /// # Arguments
358    ///
359    /// * `shape` - The shape of the tensor.
360    /// * `device` - The device to create the tensor on.
361    ///
362    /// # Returns
363    ///
364    /// A new empty tensor.
365    pub fn empty<E: TchElement>(shape: Shape, device: LibTorchDevice) -> Self {
366        let shape_tch = TchShape::from(shape);
367        let tensor = tch::Tensor::empty(shape_tch.dims, (E::kind(), device.into()));
368
369        Self::new(tensor)
370    }
371}
372
373// Adapted from `tch` to use patched `T::kind()` instead of `T::KIND` which is incorrect for bf16.
374// TODO: remove when fixed in `tch` release (https://github.com/LaurentMazare/tch-rs/pull/996).
375impl<T: TchElement + Copy> TryFrom<&TchTensor> for Vec<T> {
376    type Error = tch::TchError;
377    fn try_from(tensor: &TchTensor) -> Result<Self, Self::Error> {
378        let tensor = &tensor.tensor;
379        let size = tensor.size();
380        if size.len() != 1 {
381            Err(tch::TchError::Convert(format!(
382                "Attempting to convert a Tensor with {} dimensions to flat vector",
383                size.len()
384            )))?;
385        }
386        let numel = size[0] as usize;
387        let mut vec = vec![T::ZERO; numel];
388        // Adapted to use patched `T::kind()` instead
389        // TODO: tensor.f_to_kind(T::KIND)?.f_copy_data(&mut vec, numel)?;
390        f_copy_data(&mut tensor.f_to_kind(T::kind())?, &mut vec, numel)?;
391        Ok(vec)
392    }
393}
394
395unsafe fn ptr_to_string(ptr: *mut libc::c_char) -> Option<String> {
396    if !ptr.is_null() {
397        unsafe {
398            let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
399            libc::free(ptr as *mut libc::c_void);
400            Some(str)
401        }
402    } else {
403        None
404    }
405}
406
407/// Copies `numel` elements from `self` to `dst`.
408fn f_copy_data<T: TchElement>(
409    tensor: &mut tch::Tensor,
410    dst: &mut [T],
411    numel: usize,
412) -> Result<(), tch::TchError> {
413    if T::kind() != tensor.f_kind()? {
414        return Err(tch::TchError::Kind(format!(
415            "incoherent elt kind, {:?} != {:?}",
416            tensor.f_kind(),
417            T::kind()
418        )));
419    }
420    if dst.len() < numel {
421        return Err(tch::TchError::Shape(format!("slice len < {numel}")));
422    }
423
424    unsafe {
425        torch_sys::at_copy_data(
426            tensor.as_mut_ptr(),
427            dst.as_mut_ptr() as *const c_void,
428            numel,
429            T::kind().elt_size_in_bytes(),
430        );
431        match ptr_to_string(torch_sys::get_and_reset_last_err()) {
432            None => Ok(()),
433            Some(c_error) => Err(tch::TchError::Torch(c_error)),
434        }
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use burn_backend::ops::FloatTensorOps;
442    use burn_backend::{Backend, quantization::QuantScheme, read_sync};
443
444    type B = crate::LibTorch<f32>;
445
446    #[test]
447    fn should_have_bf16_kind() {
448        let data = TensorData::from([4.0, 4.0]);
449        let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
450        let tensor_2 = B::float_cast(tensor_1, DType::BF16.into());
451
452        assert_eq!(tensor_2.tensor.kind(), tch::Kind::BFloat16);
453
454        let out = read_sync(B::float_into_data(tensor_2)).unwrap();
455
456        out.assert_eq(&TensorData::from([4.0, 4.0]), false);
457    }
458
459    #[test]
460    fn should_support_dtypes() {
461        let device = Default::default();
462
463        assert!(B::supports_dtype(&device, DType::F64));
464        assert!(B::supports_dtype(&device, DType::F32));
465        assert!(B::supports_dtype(&device, DType::Flex32));
466        assert!(B::supports_dtype(&device, DType::F16));
467        assert!(B::supports_dtype(&device, DType::BF16));
468        assert!(B::supports_dtype(&device, DType::I64));
469        assert!(B::supports_dtype(&device, DType::I32));
470        assert!(B::supports_dtype(&device, DType::I16));
471        assert!(B::supports_dtype(&device, DType::I8));
472        assert!(B::supports_dtype(&device, DType::U8));
473        assert!(B::supports_dtype(&device, DType::Bool));
474
475        assert!(!B::supports_dtype(&device, DType::U64));
476        assert!(!B::supports_dtype(&device, DType::U32));
477        assert!(!B::supports_dtype(&device, DType::U16));
478        assert!(!B::supports_dtype(
479            &device,
480            DType::QFloat(QuantScheme::default())
481        ));
482    }
483
484    #[test]
485    fn should_support_from_bf16() {
486        let data = TensorData::from([[1.0], [1.]]).convert_dtype(DType::BF16);
487        let tensor_1: TchTensor = B::float_from_data(data, &Default::default());
488        let data = TensorData::from([[2.0], [2.]]).convert_dtype(DType::BF16);
489        let tensor_2 = B::float_from_data(data, &Default::default());
490
491        let tensor_3 = B::float_add(tensor_1, tensor_2);
492
493        assert_eq!(tensor_3.tensor.kind(), tch::Kind::BFloat16);
494
495        let out = read_sync(B::float_into_data(tensor_3)).unwrap();
496
497        out.assert_eq(&TensorData::from([[3.0], [3.0]]), false);
498    }
499}
500
501unsafe extern "C" {
502    /// Dummy function to get CUDA to link properly
503    pub fn dummy_cuda_dependency();
504}
505
506#[used]
507static INIT_ARRAY: [unsafe extern "C" fn(); 1] = [dummy_cuda_dependency];