1use crate::prelude_dev::*;
2
3pub trait TensorBaseAPI {}
4
5pub struct TensorBase<S, D>
6where
7 D: DimAPI,
8{
9 pub(crate) storage: S,
10 pub(crate) layout: Layout<D>,
11}
12
13pub type Tensor<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataOwned<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
14pub type TensorView<'a, T, B = DeviceCpu, D = IxD> =
15 TensorBase<Storage<DataRef<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
16pub type TensorViewMut<'a, T, B = DeviceCpu, D = IxD> =
17 TensorBase<Storage<DataMut<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
18pub type TensorCow<'a, T, B = DeviceCpu, D = IxD> =
19 TensorBase<Storage<DataCow<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
20pub type TensorArc<T, B = DeviceCpu, D = IxD> = TensorBase<Storage<DataArc<<B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
21pub type TensorReference<'a, T, B = DeviceCpu, D = IxD> =
22 TensorBase<Storage<DataReference<'a, <B as DeviceRawAPI<T>>::Raw>, T, B>, D>;
23pub type TensorAny<R, T, B, D> = TensorBase<Storage<R, T, B>, D>;
24pub use TensorView as TensorRef;
25pub use TensorViewMut as TensorMut;
26
27impl<R, D> TensorBaseAPI for TensorBase<R, D> where D: DimAPI {}
28
29impl<S, D> TensorBase<S, D>
31where
32 D: DimAPI,
33{
34 pub unsafe fn new_unchecked(storage: S, layout: Layout<D>) -> Self {
41 Self { storage, layout }
42 }
43
44 #[inline]
45 pub fn storage(&self) -> &S {
46 &self.storage
47 }
48
49 #[inline]
50 pub fn storage_mut(&mut self) -> &mut S {
51 &mut self.storage
52 }
53
54 pub fn layout(&self) -> &Layout<D> {
55 &self.layout
56 }
57
58 #[inline]
59 pub fn shape(&self) -> &D {
60 self.layout().shape()
61 }
62
63 #[inline]
64 pub fn stride(&self) -> &D::Stride {
65 self.layout().stride()
66 }
67
68 #[inline]
69 pub fn offset(&self) -> usize {
70 self.layout().offset()
71 }
72
73 #[inline]
74 pub fn ndim(&self) -> usize {
75 self.layout().ndim()
76 }
77
78 #[inline]
79 pub fn size(&self) -> usize {
80 self.layout().size()
81 }
82
83 #[inline]
84 pub fn into_data(self) -> S {
85 self.storage
86 }
87
88 #[inline]
89 pub fn into_raw_parts(self) -> (S, Layout<D>) {
90 (self.storage, self.layout)
91 }
92
93 #[inline]
94 pub fn c_contig(&self) -> bool {
95 self.layout().c_contig()
96 }
97
98 #[inline]
99 pub fn f_contig(&self) -> bool {
100 self.layout().f_contig()
101 }
102
103 #[inline]
104 pub fn c_prefer(&self) -> bool {
105 self.layout().c_prefer()
106 }
107
108 #[inline]
109 pub fn f_prefer(&self) -> bool {
110 self.layout().f_prefer()
111 }
112}
113
114impl<R, T, B, D> TensorAny<R, T, B, D>
115where
116 R: DataAPI<Data = B::Raw>,
117 D: DimAPI,
118 B: DeviceAPI<T>,
119{
120 pub fn new_f(storage: Storage<R, T, B>, layout: Layout<D>) -> Result<Self> {
121 layout.check_strides()?;
123
124 let len_data = storage.len();
126 let (_, idx_max) = layout.bounds_index()?;
127 rstsr_pattern!(idx_max, ..=len_data, ValueOutOfRange)?;
128 return Ok(Self { storage, layout });
129 }
130
131 pub fn new(storage: Storage<R, T, B>, layout: Layout<D>) -> Self {
132 Self::new_f(storage, layout).rstsr_unwrap()
133 }
134
135 pub fn device(&self) -> &B {
136 self.storage().device()
137 }
138
139 pub fn data(&self) -> &R {
140 self.storage().data()
141 }
142
143 pub fn data_mut(&mut self) -> &mut R {
144 self.storage_mut().data_mut()
145 }
146
147 pub fn raw(&self) -> &B::Raw {
148 self.storage().data().raw()
149 }
150
151 pub fn raw_mut(&mut self) -> &mut B::Raw
152 where
153 R: DataMutAPI<Data = B::Raw>,
154 {
155 self.storage_mut().data_mut().raw_mut()
156 }
157}
158
159impl<T, B, D> TensorCow<'_, T, B, D>
160where
161 B: DeviceAPI<T>,
162 D: DimAPI,
163{
164 pub fn is_owned(&self) -> bool {
165 self.data().is_owned()
166 }
167
168 pub fn is_ref(&self) -> bool {
169 self.data().is_ref()
170 }
171}
172
173impl<T, B, D> TensorReference<'_, T, B, D>
176where
177 B: DeviceAPI<T>,
178 D: DimAPI,
179{
180 pub fn is_ref(&self) -> bool {
181 self.data().is_ref()
182 }
183
184 pub fn is_mut(&self) -> bool {
185 self.data().is_mut()
186 }
187}
188
189impl<'a, T, B, D> From<TensorView<'a, T, B, D>> for TensorReference<'a, T, B, D>
190where
191 B: DeviceAPI<T>,
192 D: DimAPI,
193{
194 fn from(tensor: TensorView<'a, T, B, D>) -> Self {
195 let (storage, layout) = tensor.into_raw_parts();
196 let (data, device) = storage.into_raw_parts();
197 let data = DataReference::Ref(data);
198 let storage = Storage::new(data, device);
199 TensorReference::new(storage, layout)
200 }
201}
202
203impl<'a, T, B, D> From<TensorViewMut<'a, T, B, D>> for TensorReference<'a, T, B, D>
204where
205 B: DeviceAPI<T>,
206 D: DimAPI,
207{
208 fn from(tensor: TensorViewMut<'a, T, B, D>) -> Self {
209 let (storage, layout) = tensor.into_raw_parts();
210 let (data, device) = storage.into_raw_parts();
211 let data = DataReference::Mut(data);
212 let storage = Storage::new(data, device);
213 TensorReference::new(storage, layout)
214 }
215}
216
217impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorView<'a, T, B, D>
218where
219 B: DeviceAPI<T>,
220 D: DimAPI,
221{
222 fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
223 let (storage, layout) = tensor.into_raw_parts();
224 let (data, device) = storage.into_raw_parts();
225 let data = match data {
226 DataReference::Ref(data) => data,
227 DataReference::Mut(_) => {
228 rstsr_raise!(RuntimeError, "cannot convert to TensorView if data is mutable").rstsr_unwrap()
229 },
230 };
231 let storage = Storage::new(data, device);
232 TensorView::new(storage, layout)
233 }
234}
235
236impl<'a, T, B, D> From<TensorReference<'a, T, B, D>> for TensorMut<'a, T, B, D>
237where
238 B: DeviceAPI<T>,
239 D: DimAPI,
240{
241 fn from(tensor: TensorReference<'a, T, B, D>) -> Self {
242 let (storage, layout) = tensor.into_raw_parts();
243 let (data, device) = storage.into_raw_parts();
244 let data = match data {
245 DataReference::Mut(data) => data,
246 DataReference::Ref(_) => {
247 rstsr_raise!(RuntimeError, "cannot convert to TensorMut if data is immutable").rstsr_unwrap()
248 },
249 };
250 let storage = Storage::new(data, device);
251 TensorViewMut::new(storage, layout)
252 }
253}
254
255unsafe impl<R, D> Send for TensorBase<R, D>
258where
259 D: DimAPI,
260 R: Send,
261{
262}
263
264unsafe impl<R, D> Sync for TensorBase<R, D>
265where
266 D: DimAPI,
267 R: Sync,
268{
269}