Skip to main content

rstsr_core/
tensorbase.rs

1use crate::prelude_dev::*;
2
3pub trait TensorBaseAPI {}
4
5pub struct TensorBase<S, D>
6where
7    D: DimAPI,
8{
9    pub(crate) storage: S,
10    pub(crate) layout: Layout<D>,
11}
12
13pub type Tensor<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataOwned<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
14pub type TensorView<'a, T, B = DeviceCpu, D = IxD> =
15    TensorBase<Storage<DataRef<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
16pub type TensorViewMut<'a, T, B = DeviceCpu, D = IxD> =
17    TensorBase<Storage<DataMut<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
18pub type TensorCow<'a, T, B = DeviceCpu, D = IxD> =
19    TensorBase<Storage<DataCow<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
20pub type TensorArc<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataArc<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
21pub type TensorReference<'a, T, B = DeviceCpu, D = IxD> =
22    TensorBase<Storage<DataReference<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
23pub type TensorAny<R, T, B, D> = TensorBase<Storage<R, T, B>, D>;
24pub use TensorView as TensorRef;
25pub use TensorViewMut as TensorMut;
26
27impl<R, D> TensorBaseAPI for TensorBase<R, D> where D: DimAPI {}
28
29/// Basic definitions for tensor object.
30impl<S, D> TensorBase<S, D>
31where
32    D: DimAPI,
33{
34    /// Initialize tensor object.
35    ///
36    /// # Safety
37    ///
38    /// This function will not check whether data meets the standard of
39    /// [Storage<T, B>], or whether layout may exceed pointer bounds of data.
40    pub unsafe fn new_unchecked(storage: S, layout: Layout<D>) -> Self {
41        Self { storage, layout }
42    }
43
44    #[inline]
45    pub fn storage(&self) -> &S {
46        &self.storage
47    }
48
49    #[inline]
50    pub fn storage_mut(&mut self) -> &mut S {
51        &mut self.storage
52    }
53
54    pub fn layout(&self) -> &Layout<D> {
55        &self.layout
56    }
57
58    #[inline]
59    pub fn shape(&self) -> &D {
60        self.layout().shape()
61    }
62
63    #[inline]
64    pub fn stride(&self) -> &D::Stride {
65        self.layout().stride()
66    }
67
68    #[inline]
69    pub fn offset(&self) -> usize {
70        self.layout().offset()
71    }
72
73    #[inline]
74    pub fn ndim(&self) -> usize {
75        self.layout().ndim()
76    }
77
78    #[inline]
79    pub fn size(&self) -> usize {
80        self.layout().size()
81    }
82
83    #[inline]
84    pub fn into_data(self) -> S {
85        self.storage
86    }
87
88    #[inline]
89    pub fn into_raw_parts(self) -> (S, Layout<D>) {
90        (self.storage, self.layout)
91    }
92
93    #[inline]
94    pub fn c_contig(&self) -> bool {
95        self.layout().c_contig()
96    }
97
98    #[inline]
99    pub fn f_contig(&self) -> bool {
100        self.layout().f_contig()
101    }
102
103    #[inline]
104    pub fn c_prefer(&self) -> bool {
105        self.layout().c_prefer()
106    }
107
108    #[inline]
109    pub fn f_prefer(&self) -> bool {
110        self.layout().f_prefer()
111    }
112}
113
114impl<R, T, B, D> TensorAny<R, T, B, D>
115where
116    R: DataAPI<Data = B::Raw>,
117    D: DimAPI,
118    B: DeviceAPI<T>,
119{
120    pub fn new_f(storage: Storage<R, T, B>, layout: Layout<D>) -> Result<Self> {
121        // check stride sanity
122        layout.check_strides(true)?;
123
124        // check pointer exceed
125        let len_data = storage.len();
126        let (_, idx_max) = layout.bounds_index()?;
127        rstsr_pattern!(idx_max, ..=len_data, ValueOutOfRange)?;
128        return Ok(Self { storage, layout });
129    }
130
131    pub fn new(storage: Storage<R, T, B>, layout: Layout<D>) -> Self {
132        Self::new_f(storage, layout).rstsr_unwrap()
133    }
134
135    pub fn device(&self) -> &B {
136        self.storage().device()
137    }
138
139    pub fn device_mut(&mut self) -> &mut B {
140        self.storage_mut().device_mut()
141    }
142
143    pub fn data(&self) -> &R {
144        self.storage().data()
145    }
146
147    pub fn data_mut(&mut self) -> &mut R {
148        self.storage_mut().data_mut()
149    }
150
151    pub fn raw(&self) -> &B::Raw {
152        self.storage().data().raw()
153    }
154
155    pub fn raw_mut(&mut self) -> &mut B::Raw
156    where
157        R: DataMutAPI<Data = B::Raw>,
158    {
159        self.storage_mut().data_mut().raw_mut()
160    }
161}
162
163impl<T, B, D> TensorCow<'_, T, B, D>
164where
165    B: DeviceAPI<T>,
166    D: DimAPI,
167{
168    pub fn is_owned(&self) -> bool {
169        self.data().is_owned()
170    }
171
172    pub fn is_ref(&self) -> bool {
173        self.data().is_ref()
174    }
175}
176
177/* #region TensorReference */
178
179impl<T, B, D> TensorReference<'_, T, B, D>
180where
181    B: DeviceAPI<T>,
182    D: DimAPI,
183{
184    pub fn is_ref(&self) -> bool {
185        self.data().is_ref()
186    }
187
188    pub fn is_mut(&self) -> bool {
189        self.data().is_mut()
190    }
191}
192
193impl<'a, T, B, D> From<TensorView<'a, T, B, D>> for TensorReference<'a, T, B, D>
194where
195    B: DeviceAPI<T>,
196    D: DimAPI,
197{
198    fn from(tensor: TensorView<'a, T, B, D>) -> Self {
199        let (storage, layout) = tensor.into_raw_parts();
200        let (data, device) = storage.into_raw_parts();
201        let data = DataReference::Ref(data);
202        let storage = Storage::new(data, device);
203        TensorReference::new(storage, layout)
204    }
205}
206
207impl<'a, T, B, D> From<TensorViewMut<'a, T, B, D>> for TensorReference<'a, T, B, D>
208where
209    B: DeviceAPI<T>,
210    D: DimAPI,
211{
212    fn from(tensor: TensorViewMut<'a, T, B, D>) -> Self {
213        let (storage, layout) = tensor.into_raw_parts();
214        let (data, device) = storage.into_raw_parts();
215        let data = DataReference::Mut(data);
216        let storage = Storage::new(data, device);
217        TensorReference::new(storage, layout)
218    }
219}
220
221impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorView<'a, T, B, D>
222where
223    B: DeviceAPI<T>,
224    D: DimAPI,
225{
226    fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
227        let (storage, layout) = tensor.into_raw_parts();
228        let (data, device) = storage.into_raw_parts();
229        let data = match data {
230            DataReference::Ref(data) => data,
231            DataReference::Mut(_) => {
232                rstsr_raise!(RuntimeError, "cannot convert to TensorView if data is mutable").rstsr_unwrap()
233            },
234        };
235        let storage = Storage::new(data, device);
236        TensorView::new(storage, layout)
237    }
238}
239
240impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorMut<'a, T, B, D>
241where
242    B: DeviceAPI<T>,
243    D: DimAPI,
244{
245    fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
246        let (storage, layout) = tensor.into_raw_parts();
247        let (data, device) = storage.into_raw_parts();
248        let data = match data {
249            DataReference::Mut(data) => data,
250            DataReference::Ref(_) => {
251                rstsr_raise!(RuntimeError, "cannot convert to TensorMut if data is immutable").rstsr_unwrap()
252            },
253        };
254        let storage = Storage::new(data, device);
255        TensorViewMut::new(storage, layout)
256    }
257}
258
259/* #endregion */
260
261unsafe impl<R, D> Send for TensorBase<R, D>
262where
263    D: DimAPI,
264    R: Send,
265{
266}
267
268unsafe impl<R, D> Sync for TensorBase<R, D>
269where
270    D: DimAPI,
271    R: Sync,
272{
273}