rstsr_core/
tensorbase.rs

1use crate::prelude_dev::*;
2
3pub trait TensorBaseAPI {}
4
5pub struct TensorBase<S, D>
6where
7    D: DimAPI,
8{
9    pub(crate) storage: S,
10    pub(crate) layout: Layout<D>,
11}
12
13pub type Tensor<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataOwned<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
14pub type TensorView<'a, T, B = DeviceCpu, D = IxD> =
15    TensorBase<Storage<DataRef<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
16pub type TensorViewMut<'a, T, B = DeviceCpu, D = IxD> =
17    TensorBase<Storage<DataMut<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
18pub type TensorCow<'a, T, B = DeviceCpu, D = IxD> =
19    TensorBase<Storage<DataCow<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
20pub type TensorArc<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataArc<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
21pub type TensorReference<'a, T, B = DeviceCpu, D = IxD> =
22    TensorBase<Storage<DataReference<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
23pub type TensorAny<R, T, B, D> = TensorBase<Storage<R, T, B>, D>;
24pub use TensorView as TensorRef;
25pub use TensorViewMut as TensorMut;
26
27impl<R, D> TensorBaseAPI for TensorBase<R, D> where D: DimAPI {}
28
29/// Basic definitions for tensor object.
30impl<S, D> TensorBase<S, D>
31where
32    D: DimAPI,
33{
34    /// Initialize tensor object.
35    ///
36    /// # Safety
37    ///
38    /// This function will not check whether data meets the standard of
39    /// [Storage<T, B>], or whether layout may exceed pointer bounds of data.
40    pub unsafe fn new_unchecked(storage: S, layout: Layout<D>) -> Self {
41        Self { storage, layout }
42    }
43
44    #[inline]
45    pub fn storage(&self) -> &S {
46        &self.storage
47    }
48
49    #[inline]
50    pub fn storage_mut(&mut self) -> &mut S {
51        &mut self.storage
52    }
53
54    pub fn layout(&self) -> &Layout<D> {
55        &self.layout
56    }
57
58    #[inline]
59    pub fn shape(&self) -> &D {
60        self.layout().shape()
61    }
62
63    #[inline]
64    pub fn stride(&self) -> &D::Stride {
65        self.layout().stride()
66    }
67
68    #[inline]
69    pub fn offset(&self) -> usize {
70        self.layout().offset()
71    }
72
73    #[inline]
74    pub fn ndim(&self) -> usize {
75        self.layout().ndim()
76    }
77
78    #[inline]
79    pub fn size(&self) -> usize {
80        self.layout().size()
81    }
82
83    #[inline]
84    pub fn into_data(self) -> S {
85        self.storage
86    }
87
88    #[inline]
89    pub fn into_raw_parts(self) -> (S, Layout<D>) {
90        (self.storage, self.layout)
91    }
92
93    #[inline]
94    pub fn c_contig(&self) -> bool {
95        self.layout().c_contig()
96    }
97
98    #[inline]
99    pub fn f_contig(&self) -> bool {
100        self.layout().f_contig()
101    }
102
103    #[inline]
104    pub fn c_prefer(&self) -> bool {
105        self.layout().c_prefer()
106    }
107
108    #[inline]
109    pub fn f_prefer(&self) -> bool {
110        self.layout().f_prefer()
111    }
112}
113
114impl<R, T, B, D> TensorAny<R, T, B, D>
115where
116    R: DataAPI<Data = B::Raw>,
117    D: DimAPI,
118    B: DeviceAPI<T>,
119{
120    pub fn new_f(storage: Storage<R, T, B>, layout: Layout<D>) -> Result<Self> {
121        // check stride sanity
122        layout.check_strides()?;
123
124        // check pointer exceed
125        let len_data = storage.len();
126        let (_, idx_max) = layout.bounds_index()?;
127        rstsr_pattern!(idx_max, ..=len_data, ValueOutOfRange)?;
128        return Ok(Self { storage, layout });
129    }
130
131    pub fn new(storage: Storage<R, T, B>, layout: Layout<D>) -> Self {
132        Self::new_f(storage, layout).rstsr_unwrap()
133    }
134
135    pub fn device(&self) -> &B {
136        self.storage().device()
137    }
138
139    pub fn data(&self) -> &R {
140        self.storage().data()
141    }
142
143    pub fn data_mut(&mut self) -> &mut R {
144        self.storage_mut().data_mut()
145    }
146
147    pub fn raw(&self) -> &B::Raw {
148        self.storage().data().raw()
149    }
150
151    pub fn raw_mut(&mut self) -> &mut B::Raw
152    where
153        R: DataMutAPI<Data = B::Raw>,
154    {
155        self.storage_mut().data_mut().raw_mut()
156    }
157}
158
159impl<T, B, D> TensorCow<'_, T, B, D>
160where
161    B: DeviceAPI<T>,
162    D: DimAPI,
163{
164    pub fn is_owned(&self) -> bool {
165        self.data().is_owned()
166    }
167
168    pub fn is_ref(&self) -> bool {
169        self.data().is_ref()
170    }
171}
172
173/* #region TensorReference */
174
175impl<T, B, D> TensorReference<'_, T, B, D>
176where
177    B: DeviceAPI<T>,
178    D: DimAPI,
179{
180    pub fn is_ref(&self) -> bool {
181        self.data().is_ref()
182    }
183
184    pub fn is_mut(&self) -> bool {
185        self.data().is_mut()
186    }
187}
188
189impl<'a, T, B, D> From<TensorView<'a, T, B, D>> for TensorReference<'a, T, B, D>
190where
191    B: DeviceAPI<T>,
192    D: DimAPI,
193{
194    fn from(tensor: TensorView<'a, T, B, D>) -> Self {
195        let (storage, layout) = tensor.into_raw_parts();
196        let (data, device) = storage.into_raw_parts();
197        let data = DataReference::Ref(data);
198        let storage = Storage::new(data, device);
199        TensorReference::new(storage, layout)
200    }
201}
202
203impl<'a, T, B, D> From<TensorViewMut<'a, T, B, D>> for TensorReference<'a, T, B, D>
204where
205    B: DeviceAPI<T>,
206    D: DimAPI,
207{
208    fn from(tensor: TensorViewMut<'a, T, B, D>) -> Self {
209        let (storage, layout) = tensor.into_raw_parts();
210        let (data, device) = storage.into_raw_parts();
211        let data = DataReference::Mut(data);
212        let storage = Storage::new(data, device);
213        TensorReference::new(storage, layout)
214    }
215}
216
217impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorView<'a, T, B, D>
218where
219    B: DeviceAPI<T>,
220    D: DimAPI,
221{
222    fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
223        let (storage, layout) = tensor.into_raw_parts();
224        let (data, device) = storage.into_raw_parts();
225        let data = match data {
226            DataReference::Ref(data) => data,
227            DataReference::Mut(_) => {
228                rstsr_raise!(RuntimeError, "cannot convert to TensorView if data is mutable").rstsr_unwrap()
229            },
230        };
231        let storage = Storage::new(data, device);
232        TensorView::new(storage, layout)
233    }
234}
235
236impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorMut<'a, T, B, D>
237where
238    B: DeviceAPI<T>,
239    D: DimAPI,
240{
241    fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
242        let (storage, layout) = tensor.into_raw_parts();
243        let (data, device) = storage.into_raw_parts();
244        let data = match data {
245            DataReference::Mut(data) => data,
246            DataReference::Ref(_) => {
247                rstsr_raise!(RuntimeError, "cannot convert to TensorMut if data is immutable").rstsr_unwrap()
248            },
249        };
250        let storage = Storage::new(data, device);
251        TensorViewMut::new(storage, layout)
252    }
253}
254
255/* #endregion */
256
257unsafe impl<R, D> Send for TensorBase<R, D>
258where
259    D: DimAPI,
260    R: Send,
261{
262}
263
264unsafe impl<R, D> Sync for TensorBase<R, D>
265where
266    D: DimAPI,
267    R: Sync,
268{
269}