1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
use std::ptr::NonNull;

use crate::{
    ffi::{self, DataType, Device},
    utils::is_contiguous,
    ShapeAndStrides,
};

/// DLPack is a data structure that can be used to describe tensor data.
/// It's a pointer to a DLManagedTensor.
pub type DLPack = NonNull<ffi::DLManagedTensor>;

/// Infer DataType from generic parameter.
pub trait InferDtype {
    fn infer_dtype() -> DataType;
}

/// Access Tensor data.
pub trait TensorView {
    /// Get untyped data ptr
    fn data_ptr(&self) -> *mut std::ffi::c_void;
    /// Get shape as slice.
    fn shape(&self) -> &[i64];
    /// Get strides as slice. If strides is None, Tensor is assumed to be
    /// contiguous.
    fn strides(&self) -> Option<&[i64]>;
    fn ndim(&self) -> usize;
    fn device(&self) -> Device;
    fn dtype(&self) -> DataType;
    fn byte_offset(&self) -> u64;

    // Get num elements in Tensor.
    fn num_elements(&self) -> usize {
        self.shape().iter().product::<i64>() as usize
    }

    /// For given DLTensor, the size of memory required to store the contents of
    /// data is calculated as follows:
    ///
    /// ```c
    /// static inline size_t GetDataSize(const DLTensor* t) {
    ///   size_t size = 1;
    ///   for (tvm_index_t i = 0; i < t->ndim; ++i) {
    ///     size *= t->shape[i];
    ///   }
    ///   size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
    ///   return size;
    /// }
    /// ```
    fn data_size(&self) -> usize {
        self.num_elements() * self.dtype().size()
    }

    /// Return true if tensor is contiguous in memory in the order specified by
    /// memory format.
    fn is_contiguous(&self) -> bool {
        match self.strides() {
            Some(strides) => is_contiguous(self.shape(), strides),
            None => true,
        }
    }
}

/// User should implement this trait for their tensor.
pub trait ToTensor {
    fn data_ptr(&self) -> *mut std::ffi::c_void;
    /// If return None, tensor must be contiguous.
    fn shape_and_strides(&self) -> ShapeAndStrides;
    fn device(&self) -> Device;
    fn dtype(&self) -> DataType;
    fn byte_offset(&self) -> u64;
}

// TODO: we should add `try_to_dlpack` fn
// We may have to define error type for this.
/// Convert into [`DLPack`](crate::DLPack)
pub trait IntoDLPack {
    fn into_dlpack(self) -> DLPack;
}

// TODO: we should add `try_from_dlpack` fn
/// Make Tensor from [`DLPack`](crate::DLPack)
pub trait FromDLPack {
    // TODO: DLManagedTensor will be deprecated in th future.
    fn from_dlpack(dlpack: DLPack) -> Self;
}