1use crate::prelude_dev::*;
2
3impl<R, T, B, D> TensorAny<R, T, B, D>
7where
8 D: DimAPI,
9 B: DeviceAPI<T>,
10 R: DataAPI<Data = B::Raw>,
11{
12 pub fn view(&self) -> TensorView<'_, T, B, D> {
14 let layout = self.layout().clone();
15 let data = self.data().as_ref();
16 let storage = Storage::new(data, self.device().clone());
17 unsafe { TensorBase::new_unchecked(storage, layout) }
18 }
19
20 pub fn view_mut(&mut self) -> TensorMut<'_, T, B, D>
22 where
23 R: DataMutAPI,
24 {
25 let device = self.device().clone();
26 let layout = self.layout().clone();
27 let data = self.data_mut().as_mut();
28 let storage = Storage::new(data, device);
29 unsafe { TensorBase::new_unchecked(storage, layout) }
30 }
31
32 pub fn into_cow<'a>(self) -> TensorCow<'a, T, B, D>
34 where
35 R: DataIntoCowAPI<'a>,
36 {
37 let (storage, layout) = self.into_raw_parts();
38 let (data, device) = storage.into_raw_parts();
39 let storage = Storage::new(data.into_cow(), device);
40 unsafe { TensorBase::new_unchecked(storage, layout) }
41 }
42
43 pub fn into_owned_keep_layout(self) -> Tensor<T, B, D>
55 where
56 R::Data: Clone,
57 R: DataCloneAPI,
58 {
59 let (storage, layout) = self.into_raw_parts();
60 let (data, device) = storage.into_raw_parts();
61 let storage = Storage::new(data.into_owned(), device);
62 unsafe { TensorBase::new_unchecked(storage, layout) }
63 }
64
65 pub fn into_shared_keep_layout(self) -> TensorArc<T, B, D>
77 where
78 R::Data: Clone,
79 R: DataCloneAPI,
80 {
81 let (storage, layout) = self.into_raw_parts();
82 let (data, device) = storage.into_raw_parts();
83 let storage = Storage::new(data.into_shared(), device);
84 unsafe { TensorBase::new_unchecked(storage, layout) }
85 }
86}
87
88impl<R, T, B, D> TensorAny<R, T, B, D>
89where
90 R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
91 R::Data: Clone,
92 D: DimAPI,
93 T: Clone,
94 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
95{
96 pub fn into_owned(self) -> Tensor<T, B, D> {
97 let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
98 if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
99 return self.into_owned_keep_layout();
100 } else {
101 return asarray((&self, TensorIterOrder::K));
102 }
103 }
104
105 pub fn into_shared(self) -> TensorArc<T, B, D> {
106 let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
107 if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
108 return self.into_shared_keep_layout();
109 } else {
110 return asarray((&self, TensorIterOrder::K)).into_shared();
111 }
112 }
113
114 pub fn to_owned(&self) -> Tensor<T, B, D> {
115 self.view().into_owned()
116 }
117}
118
119impl<T, B, D> Clone for Tensor<T, B, D>
120where
121 T: Clone,
122 D: DimAPI,
123 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
124 <B as DeviceRawAPI<T>>::Raw: Clone,
125{
126 fn clone(&self) -> Self {
127 self.to_owned()
128 }
129}
130
131impl<T, B, D> Clone for TensorCow<'_, T, B, D>
132where
133 T: Clone,
134 D: DimAPI,
135 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
136 <B as DeviceRawAPI<T>>::Raw: Clone,
137{
138 fn clone(&self) -> Self {
139 let tsr_owned = self.to_owned();
140 let (storage, layout) = tsr_owned.into_raw_parts();
141 let (data, device) = storage.into_raw_parts();
142 let data = data.into_cow();
143 let storage = Storage::new(data, device);
144 unsafe { TensorBase::new_unchecked(storage, layout) }
145 }
146}
147
148impl<R, T, B, D> TensorAny<R, T, B, D>
149where
150 R: DataAPI<Data = B::Raw> + DataForceMutAPI<B::Raw>,
151 B: DeviceAPI<T>,
152 D: DimAPI,
153{
154 pub unsafe fn force_mut(&self) -> TensorMut<'_, T, B, D> {
159 let layout = self.layout().clone();
160 let data = self.data().force_mut();
161 let storage = Storage::new(data, self.device().clone());
162 TensorBase::new_unchecked(storage, layout)
163 }
164}
165
166impl<R, T, B, D> TensorAny<R, T, B, D>
171where
172 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
173 T: Clone,
174 D: DimAPI,
175 B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
176{
177 pub fn to_raw_f(&self) -> Result<Vec<T>> {
178 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
179 let device = self.device();
180 let layout = self.layout().to_dim::<Ix1>()?;
181 let size = layout.size();
182 let mut new_storage = device.uninit_impl(size)?;
183 device.assign_uninit(new_storage.raw_mut(), &[size].c(), self.raw(), &layout)?;
184 let storage = unsafe { B::assume_init_impl(new_storage) }?;
185 let (data, _) = storage.into_raw_parts();
186 Ok(data.into_raw())
187 }
188
189 pub fn to_vec(&self) -> Vec<T> {
190 self.to_raw_f().rstsr_unwrap()
191 }
192}
193
194impl<T, B, D> Tensor<T, B, D>
195where
196 T: Clone,
197 D: DimAPI,
198 B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
199{
200 pub fn into_vec_f(self) -> Result<Vec<T>> {
201 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
202 let layout = self.layout();
203 let (idx_min, idx_max) = layout.bounds_index()?;
204 if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
205 let (storage, _) = self.into_raw_parts();
206 let (data, _) = storage.into_raw_parts();
207 return Ok(data.into_raw());
208 } else {
209 return self.to_raw_f();
210 }
211 }
212
213 pub fn into_vec(self) -> Vec<T> {
214 self.into_vec_f().rstsr_unwrap()
215 }
216}
217
218impl<R, T, B, D> TensorAny<R, T, B, D>
223where
224 R: DataCloneAPI<Data = B::Raw>,
225 B::Raw: Clone,
226 T: Clone,
227 D: DimAPI,
228 B: DeviceAPI<T>,
229{
230 pub fn to_scalar_f(&self) -> Result<T> {
231 let layout = self.layout();
232 rstsr_assert_eq!(layout.size(), 1, InvalidLayout)?;
233 let storage = self.storage();
234 let vec = storage.to_cpu_vec()?;
235 Ok(vec[0].clone())
236 }
237
238 pub fn to_scalar(&self) -> T {
239 self.to_scalar_f().rstsr_unwrap()
240 }
241}
242
243impl<R, T, B, D> TensorAny<R, T, B, D>
248where
249 R: DataAPI<Data = B::Raw>,
250 D: DimAPI,
251 B: DeviceAPI<T, Raw = Vec<T>>,
252{
253 pub fn as_ptr(&self) -> *const T {
254 unsafe { self.raw().as_ptr().add(self.layout().offset()) }
255 }
256
257 pub fn as_mut_ptr(&mut self) -> *mut T
258 where
259 R: DataMutAPI,
260 {
261 unsafe { self.raw_mut().as_mut_ptr().add(self.layout().offset()) }
262 }
263}
264
265pub trait TensorViewAPI
270where
271 Self::Dim: DimAPI,
272 Self::Backend: DeviceAPI<Self::Type>,
273{
274 type Type;
275 type Backend;
276 type Dim;
277 fn view(&self) -> TensorView<'_, Self::Type, Self::Backend, Self::Dim>;
279}
280
281impl<R, T, B, D> TensorViewAPI for TensorAny<R, T, B, D>
282where
283 D: DimAPI,
284 R: DataAPI<Data = B::Raw>,
285 B: DeviceAPI<T>,
286{
287 type Type = T;
288 type Backend = B;
289 type Dim = D;
290
291 fn view(&self) -> TensorView<'_, T, B, D> {
292 let data = self.data().as_ref();
293 let storage = Storage::new(data, self.device().clone());
294 let layout = self.layout().clone();
295 unsafe { TensorBase::new_unchecked(storage, layout) }
296 }
297}
298
299impl<R, T, B, D> TensorViewAPI for &TensorAny<R, T, B, D>
300where
301 D: DimAPI,
302 R: DataAPI<Data = B::Raw>,
303 B: DeviceAPI<T>,
304{
305 type Type = T;
306 type Backend = B;
307 type Dim = D;
308
309 fn view(&self) -> TensorView<'_, T, B, D> {
310 TensorAny::view(*self)
311 }
312}
313
314impl<R, T, B, D> TensorViewAPI for &mut TensorAny<R, T, B, D>
315where
316 D: DimAPI,
317 R: DataAPI<Data = B::Raw>,
318 B: DeviceAPI<T>,
319{
320 type Type = T;
321 type Backend = B;
322 type Dim = D;
323
324 fn view(&self) -> TensorView<'_, T, B, D> {
325 TensorAny::view(*self)
326 }
327}
328
329pub trait TensorViewMutAPI
330where
331 Self::Dim: DimAPI,
332 Self::Backend: DeviceAPI<Self::Type>,
333{
334 type Type;
335 type Backend;
336 type Dim;
337
338 fn view_mut(&mut self) -> TensorMut<'_, Self::Type, Self::Backend, Self::Dim>;
340}
341
342impl<R, T, B, D> TensorViewMutAPI for TensorAny<R, T, B, D>
343where
344 D: DimAPI,
345 R: DataMutAPI<Data = B::Raw>,
346 B: DeviceAPI<T>,
347{
348 type Type = T;
349 type Backend = B;
350 type Dim = D;
351
352 fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
353 let device = self.device().clone();
354 let layout = self.layout().clone();
355 let data = self.data_mut().as_mut();
356 let storage = Storage::new(data, device);
357 unsafe { TensorBase::new_unchecked(storage, layout) }
358 }
359}
360
361impl<R, T, B, D> TensorViewMutAPI for &mut TensorAny<R, T, B, D>
362where
363 D: DimAPI,
364 R: DataMutAPI<Data = B::Raw>,
365 B: DeviceAPI<T>,
366{
367 type Type = T;
368 type Backend = B;
369 type Dim = D;
370
371 fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
372 (*self).view_mut()
373 }
374}
375
376pub trait TensorIntoOwnedAPI<T, B, D>
377where
378 D: DimAPI,
379 B: DeviceAPI<T>,
380{
381 fn into_owned(self) -> Tensor<T, B, D>;
387}
388
389impl<R, T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorAny<R, T, B, D>
390where
391 R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
392 <B as DeviceRawAPI<T>>::Raw: Clone,
393 T: Clone,
394 D: DimAPI,
395 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
396{
397 fn into_owned(self) -> Tensor<T, B, D> {
398 TensorAny::into_owned(self)
399 }
400}
401
402pub trait TensorRefAPI<'l>: TensorViewAPI {}
407impl<'l, R, T, B, D> TensorRefAPI<'l> for &'l TensorAny<R, T, B, D>
408where
409 D: DimAPI,
410 R: DataAPI<Data = B::Raw>,
411 B: DeviceAPI<T>,
412 Self: TensorViewAPI,
413{
414}
415impl<'l, T, B, D> TensorRefAPI<'l> for TensorView<'l, T, B, D>
416where
417 D: DimAPI,
418 B: DeviceAPI<T>,
419 Self: TensorViewAPI,
420{
421}
422
423pub trait TensorRefMutAPI<'l>: TensorViewAPI {}
424impl<'l, R, T, B, D> TensorRefMutAPI<'l> for &mut TensorAny<R, T, B, D>
425where
426 D: DimAPI,
427 R: DataMutAPI<Data = B::Raw>,
428 B: DeviceAPI<T>,
429 Self: TensorViewMutAPI,
430{
431}
432impl<'l, T, B, D> TensorRefMutAPI<'l> for TensorMut<'l, T, B, D>
433where
434 D: DimAPI,
435 B: DeviceAPI<T>,
436 Self: TensorViewMutAPI,
437{
438}
439
440#[cfg(test)]
443mod test {
444 use super::*;
445
446 #[test]
447 fn test_into_cow() {
448 let mut a = arange(3);
449 let ptr_a = a.raw().as_ptr();
450
451 let a_mut = a.view_mut();
452 let a_cow = a_mut.into_cow();
453 println!("{a_cow:?}");
454
455 let a_ref = a.view();
456 let a_cow = a_ref.into_cow();
457 println!("{a_cow:?}");
458
459 let a_cow = a.into_cow();
460 println!("{a_cow:?}");
461 let ptr_a_cow = a_cow.raw().as_ptr();
462 assert_eq!(ptr_a, ptr_a_cow);
463 }
464
465 #[test]
466 #[ignore]
467 fn test_force_mut() {
468 let n = 4096;
469 let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
470 for _ in 0..10 {
471 let time = std::time::Instant::now();
472 for i in 0..n {
473 let a_view = a.slice(i);
474 let mut a_mut = unsafe { a_view.force_mut() };
475 a_mut *= i as f64 / 2048.0;
476 }
477 println!("Elapsed time {:?}", time.elapsed());
478 }
479 println!("{a:16.10}");
480 }
481
482 #[test]
483 #[ignore]
484 #[cfg(feature = "rayon")]
485 fn test_force_mut_par() {
486 use rayon::prelude::*;
487 let n = 4096;
488 let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
489 for _ in 0..10 {
490 let time = std::time::Instant::now();
491 (0..n).into_par_iter().for_each(|i| {
492 let a_view = a.slice(i);
493 let mut a_mut = unsafe { a_view.force_mut() };
494 a_mut *= i as f64 / 2048.0;
495 });
496 println!("Elapsed time {:?}", time.elapsed());
497 }
498 println!("{a:16.10}");
499 }
500}