hpt_traits/
tensor.rs

1use hpt_common::{
2    layout::layout::Layout, shape::shape::Shape, strides::strides::Strides, utils::pointer::Pointer,
3};
4use hpt_types::{
5    dtype::TypeCommon,
6    into_scalar::Cast,
7    type_promote::{
8        FloatOutBinary, FloatOutBinaryPromote, FloatOutUnary, FloatOutUnaryPromote, NormalOut,
9        NormalOutPromote, NormalOutUnary,
10    },
11};
12use std::fmt::Debug;
13use std::fmt::Display;
14
15/// A trait for getting information of a Tensor
16pub trait TensorInfo<T> {
17    /// Returns a pointer to the tensor's first data.
18    #[track_caller]
19    fn ptr(&self) -> Pointer<T>;
20
21    /// Returns the size of the tensor based on the shape
22    #[track_caller]
23    fn size(&self) -> usize;
24
25    /// Returns the shape of the tensor.
26    #[track_caller]
27    fn shape(&self) -> &Shape;
28
29    /// Returns the strides of the tensor.
30    #[track_caller]
31    fn strides(&self) -> &Strides;
32
33    /// Returns the layout of the tensor. Layout contains shape and strides.
34    #[track_caller]
35    fn layout(&self) -> &Layout;
36    /// Returns the root tensor, if any.
37    ///
38    /// if the tensor is a view, it will return the root tensor. Otherwise, it will return None.
39    #[track_caller]
40    fn parent(&self) -> Option<Pointer<T>>;
41
42    /// Returns the number of dimensions of the tensor.
43    #[track_caller]
44    fn ndim(&self) -> usize;
45
46    /// Returns whether the tensor is contiguous in memory. View or transpose tensors are not contiguous.
47    #[track_caller]
48    fn is_contiguous(&self) -> bool;
49
50    /// Returns the data type memory size in bytes.
51    #[track_caller]
52    fn elsize() -> usize {
53        size_of::<T>()
54    }
55}
56
57/// A trait for let the object like a tensor
58pub trait TensorLike<T>: Sized {
59    /// directly convert the tensor to raw slice
60    ///
61    /// # Note
62    ///
63    /// This function will return a raw slice of the tensor regardless of the shape and strides.
64    ///
65    /// if you do iteration on the view tensor, you may see unexpected results.
66    fn as_raw(&self) -> &[T];
67
68    /// directly convert the tensor to mutable raw slice
69    ///
70    /// # Note
71    ///
72    /// This function will return a mutable raw slice of the tensor regardless of the shape and strides.
73    ///
74    /// if you do iteration on the view tensor, you may see unexpected results.
75    fn as_raw_mut(&mut self) -> &mut [T];
76
77    /// Returns the data type memory size in bytes.
78    fn elsize() -> usize {
79        size_of::<T>()
80    }
81}
82
83/// Common bounds for primitive types
84pub trait CommonBounds
85where
86    <Self as TypeCommon>::Vec: Send + Sync + Copy,
87    Self: Sync
88        + Send
89        + Clone
90        + Copy
91        + TypeCommon
92        + 'static
93        + Display
94        + Debug
95        + Cast<Self>
96        + NormalOut<Self, Output = Self>
97        + FloatOutUnary
98        + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
99        + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
100        + FloatOutBinary<Self>
101        + NormalOut<
102            <Self as FloatOutBinary<Self>>::Output,
103            Output = <Self as FloatOutBinary<Self>>::Output,
104        >
105        + NormalOutUnary
106        + FloatOutUnaryPromote
107        + FloatOutBinaryPromote
108        + NormalOutPromote
109        + Cast<f64>,
110{
111}
112impl<T> CommonBounds for T
113where
114    <Self as TypeCommon>::Vec: Send + Sync + Copy,
115    Self: Sync
116        + Send
117        + Clone
118        + Copy
119        + TypeCommon
120        + 'static
121        + Display
122        + Debug
123        + Cast<Self>
124        + NormalOut<Self, Output = Self>
125        + FloatOutUnary
126        + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
127        + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
128        + FloatOutBinary<Self>
129        + FloatOutBinary<
130            <Self as FloatOutBinary<Self>>::Output,
131            Output = <Self as FloatOutBinary<Self>>::Output,
132        >
133        + NormalOut<
134            <Self as FloatOutBinary<Self>>::Output,
135            Output = <Self as FloatOutBinary<Self>>::Output,
136        >
137        + NormalOutUnary
138        + FloatOutUnaryPromote
139        + FloatOutBinaryPromote
140        + NormalOutPromote
141        + Cast<f64>,
142{
143}