1use crate::prelude_dev::*;
2
3pub trait TensorBaseAPI {}
4
5pub struct TensorBase<S, D>
6where
7 D: DimAPI,
8{
9 pub(crate) storage: S,
10 pub(crate) layout: Layout<D>,
11}
12
13pub type Tensor<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataOwned<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
14pub type TensorView<'a, T, B = DeviceCpu, D = IxD> =
15 TensorBase<Storage<DataRef<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
16pub type TensorViewMut<'a, T, B = DeviceCpu, D = IxD> =
17 TensorBase<Storage<DataMut<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
18pub type TensorCow<'a, T, B = DeviceCpu, D = IxD> =
19 TensorBase<Storage<DataCow<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
20pub type TensorArc<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataArc<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
21pub type TensorReference<'a, T, B = DeviceCpu, D = IxD> =
22 TensorBase<Storage<DataReference<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
23pub type TensorAny<R, T, B, D> = TensorBase<Storage<R, T, B>, D>;
24pub use TensorView as TensorRef;
25pub use TensorViewMut as TensorMut;
26
27impl<R, D> TensorBaseAPI for TensorBase<R, D> where D: DimAPI {}
28
29impl<S, D> TensorBase<S, D>
31where
32 D: DimAPI,
33{
34 pub unsafe fn new_unchecked(storage: S, layout: Layout<D>) -> Self {
41 Self { storage, layout }
42 }
43
44 #[inline]
45 pub fn storage(&self) -> &S {
46 &self.storage
47 }
48
49 #[inline]
50 pub fn storage_mut(&mut self) -> &mut S {
51 &mut self.storage
52 }
53
54 pub fn layout(&self) -> &Layout<D> {
55 &self.layout
56 }
57
58 #[inline]
59 pub fn shape(&self) -> &D {
60 self.layout().shape()
61 }
62
63 #[inline]
64 pub fn stride(&self) -> &D::Stride {
65 self.layout().stride()
66 }
67
68 #[inline]
69 pub fn offset(&self) -> usize {
70 self.layout().offset()
71 }
72
73 #[inline]
74 pub fn ndim(&self) -> usize {
75 self.layout().ndim()
76 }
77
78 #[inline]
79 pub fn size(&self) -> usize {
80 self.layout().size()
81 }
82
83 #[inline]
84 pub fn into_data(self) -> S {
85 self.storage
86 }
87
88 #[inline]
89 pub fn into_raw_parts(self) -> (S, Layout<D>) {
90 (self.storage, self.layout)
91 }
92
93 #[inline]
94 pub fn c_contig(&self) -> bool {
95 self.layout().c_contig()
96 }
97
98 #[inline]
99 pub fn f_contig(&self) -> bool {
100 self.layout().f_contig()
101 }
102
103 #[inline]
104 pub fn c_prefer(&self) -> bool {
105 self.layout().c_prefer()
106 }
107
108 #[inline]
109 pub fn f_prefer(&self) -> bool {
110 self.layout().f_prefer()
111 }
112}
113
114impl<R, T, B, D> TensorAny<R, T, B, D>
115where
116 R: DataAPI<Data = B::Raw>,
117 D: DimAPI,
118 B: DeviceAPI<T>,
119{
120 pub fn new_f(storage: Storage<R, T, B>, layout: Layout<D>) -> Result<Self> {
121 layout.check_strides(true)?;
123
124 let len_data = storage.len();
126 let (_, idx_max) = layout.bounds_index()?;
127 rstsr_pattern!(idx_max, ..=len_data, ValueOutOfRange)?;
128 return Ok(Self { storage, layout });
129 }
130
131 pub fn new(storage: Storage<R, T, B>, layout: Layout<D>) -> Self {
132 Self::new_f(storage, layout).rstsr_unwrap()
133 }
134
135 pub fn device(&self) -> &B {
136 self.storage().device()
137 }
138
139 pub fn device_mut(&mut self) -> &mut B {
140 self.storage_mut().device_mut()
141 }
142
143 pub fn data(&self) -> &R {
144 self.storage().data()
145 }
146
147 pub fn data_mut(&mut self) -> &mut R {
148 self.storage_mut().data_mut()
149 }
150
151 pub fn raw(&self) -> &B::Raw {
152 self.storage().data().raw()
153 }
154
155 pub fn raw_mut(&mut self) -> &mut B::Raw
156 where
157 R: DataMutAPI<Data = B::Raw>,
158 {
159 self.storage_mut().data_mut().raw_mut()
160 }
161}
162
163impl<T, B, D> TensorCow<'_, T, B, D>
164where
165 B: DeviceAPI<T>,
166 D: DimAPI,
167{
168 pub fn is_owned(&self) -> bool {
169 self.data().is_owned()
170 }
171
172 pub fn is_ref(&self) -> bool {
173 self.data().is_ref()
174 }
175}
176
177impl<T, B, D> TensorReference<'_, T, B, D>
180where
181 B: DeviceAPI<T>,
182 D: DimAPI,
183{
184 pub fn is_ref(&self) -> bool {
185 self.data().is_ref()
186 }
187
188 pub fn is_mut(&self) -> bool {
189 self.data().is_mut()
190 }
191}
192
193impl<'a, T, B, D> From<TensorView<'a, T, B, D>> for TensorReference<'a, T, B, D>
194where
195 B: DeviceAPI<T>,
196 D: DimAPI,
197{
198 fn from(tensor: TensorView<'a, T, B, D>) -> Self {
199 let (storage, layout) = tensor.into_raw_parts();
200 let (data, device) = storage.into_raw_parts();
201 let data = DataReference::Ref(data);
202 let storage = Storage::new(data, device);
203 TensorReference::new(storage, layout)
204 }
205}
206
207impl<'a, T, B, D> From<TensorViewMut<'a, T, B, D>> for TensorReference<'a, T, B, D>
208where
209 B: DeviceAPI<T>,
210 D: DimAPI,
211{
212 fn from(tensor: TensorViewMut<'a, T, B, D>) -> Self {
213 let (storage, layout) = tensor.into_raw_parts();
214 let (data, device) = storage.into_raw_parts();
215 let data = DataReference::Mut(data);
216 let storage = Storage::new(data, device);
217 TensorReference::new(storage, layout)
218 }
219}
220
221impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorView<'a, T, B, D>
222where
223 B: DeviceAPI<T>,
224 D: DimAPI,
225{
226 fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
227 let (storage, layout) = tensor.into_raw_parts();
228 let (data, device) = storage.into_raw_parts();
229 let data = match data {
230 DataReference::Ref(data) => data,
231 DataReference::Mut(_) => {
232 rstsr_raise!(RuntimeError, "cannot convert to TensorView if data is mutable").rstsr_unwrap()
233 },
234 };
235 let storage = Storage::new(data, device);
236 TensorView::new(storage, layout)
237 }
238}
239
240impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorMut<'a, T, B, D>
241where
242 B: DeviceAPI<T>,
243 D: DimAPI,
244{
245 fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
246 let (storage, layout) = tensor.into_raw_parts();
247 let (data, device) = storage.into_raw_parts();
248 let data = match data {
249 DataReference::Mut(data) => data,
250 DataReference::Ref(_) => {
251 rstsr_raise!(RuntimeError, "cannot convert to TensorMut if data is immutable").rstsr_unwrap()
252 },
253 };
254 let storage = Storage::new(data, device);
255 TensorViewMut::new(storage, layout)
256 }
257}
258
259unsafe impl<R, D> Send for TensorBase<R, D>
262where
263 D: DimAPI,
264 R: Send,
265{
266}
267
268unsafe impl<R, D> Sync for TensorBase<R, D>
269where
270 D: DimAPI,
271 R: Sync,
272{
273}