1use crate::prelude_dev::*;
2
3impl<R, T, B, D> TensorAny<R, T, B, D>
7where
8 D: DimAPI,
9 B: DeviceAPI<T>,
10 R: DataAPI<Data = B::Raw>,
11{
12 pub fn view(&self) -> TensorView<'_, T, B, D> {
14 let layout = self.layout().clone();
15 let data = self.data().as_ref();
16 let storage = Storage::new(data, self.device().clone());
17 unsafe { TensorBase::new_unchecked(storage, layout) }
18 }
19
20 pub fn view_mut(&mut self) -> TensorMut<'_, T, B, D>
22 where
23 R: DataMutAPI,
24 {
25 let device = self.device().clone();
26 let layout = self.layout().clone();
27 let data = self.data_mut().as_mut();
28 let storage = Storage::new(data, device);
29 unsafe { TensorBase::new_unchecked(storage, layout) }
30 }
31
32 pub fn into_cow<'a>(self) -> TensorCow<'a, T, B, D>
34 where
35 R: DataIntoCowAPI<'a>,
36 {
37 let (storage, layout) = self.into_raw_parts();
38 let (data, device) = storage.into_raw_parts();
39 let storage = Storage::new(data.into_cow(), device);
40 unsafe { TensorBase::new_unchecked(storage, layout) }
41 }
42
43 pub fn into_owned_keep_layout(self) -> Tensor<T, B, D>
55 where
56 R::Data: Clone,
57 R: DataCloneAPI,
58 {
59 let (storage, layout) = self.into_raw_parts();
60 let (data, device) = storage.into_raw_parts();
61 let storage = Storage::new(data.into_owned(), device);
62 unsafe { TensorBase::new_unchecked(storage, layout) }
63 }
64
65 pub fn into_shared_keep_layout(self) -> TensorArc<T, B, D>
77 where
78 R::Data: Clone,
79 R: DataCloneAPI,
80 {
81 let (storage, layout) = self.into_raw_parts();
82 let (data, device) = storage.into_raw_parts();
83 let storage = Storage::new(data.into_shared(), device);
84 unsafe { TensorBase::new_unchecked(storage, layout) }
85 }
86}
87
88impl<R, T, B, D> TensorAny<R, T, B, D>
89where
90 R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
91 R::Data: Clone,
92 D: DimAPI,
93 T: Clone,
94 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
95{
96 pub fn into_owned(self) -> Tensor<T, B, D> {
97 let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
98 if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
99 return self.into_owned_keep_layout();
100 } else {
101 return asarray((&self, TensorIterOrder::K));
102 }
103 }
104
105 pub fn into_shared(self) -> TensorArc<T, B, D> {
106 let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
107 if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
108 return self.into_shared_keep_layout();
109 } else {
110 return asarray((&self, TensorIterOrder::K)).into_shared();
111 }
112 }
113
114 pub fn to_owned(&self) -> Tensor<T, B, D> {
115 self.view().into_owned()
116 }
117}
118
119impl<T, B, D> Clone for Tensor<T, B, D>
120where
121 T: Clone,
122 D: DimAPI,
123 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
124 <B as DeviceRawAPI<T>>::Raw: Clone,
125{
126 fn clone(&self) -> Self {
127 self.to_owned()
128 }
129}
130
131impl<T, B, D> Clone for TensorCow<'_, T, B, D>
132where
133 T: Clone,
134 D: DimAPI,
135 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
136 <B as DeviceRawAPI<T>>::Raw: Clone,
137{
138 fn clone(&self) -> Self {
139 let tsr_owned = self.to_owned();
140 let (storage, layout) = tsr_owned.into_raw_parts();
141 let (data, device) = storage.into_raw_parts();
142 let data = data.into_cow();
143 let storage = Storage::new(data, device);
144 unsafe { TensorBase::new_unchecked(storage, layout) }
145 }
146}
147
148impl<R, T, B, D> TensorAny<R, T, B, D>
149where
150 R: DataAPI<Data = B::Raw> + DataForceMutAPI<B::Raw>,
151 B: DeviceAPI<T>,
152 D: DimAPI,
153{
154 pub unsafe fn force_mut(&self) -> TensorMut<'_, T, B, D> {
159 let layout = self.layout().clone();
160 let data = self.data().force_mut();
161 let storage = Storage::new(data, self.device().clone());
162 TensorBase::new_unchecked(storage, layout)
163 }
164}
165
166impl<R, T, B, D> TensorAny<R, T, B, D>
171where
172 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
173 T: Clone,
174 D: DimAPI,
175 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
176{
177 pub fn to_raw_f(&self) -> Result<<B as DeviceRawAPI<T>>::Raw> {
178 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
179 let device = self.device();
180 let layout = self.layout().to_dim::<Ix1>()?;
181 let size = layout.size();
182 let mut new_storage = device.uninit_impl(size)?;
183 device.assign_uninit(new_storage.raw_mut(), &[size].c(), self.raw(), &layout)?;
184 let storage = unsafe { B::assume_init_impl(new_storage) }?;
185 let (data, _) = storage.into_raw_parts();
186 Ok(data.into_raw())
187 }
188
189 pub fn to_vec(&self) -> <B as DeviceRawAPI<T>>::Raw {
190 self.to_raw_f().rstsr_unwrap()
191 }
192}
193
194impl<T, B, D> Tensor<T, B, D>
195where
196 T: Clone,
197 D: DimAPI,
198 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
199{
200 pub fn into_raw_f(self) -> Result<<B as DeviceRawAPI<T>>::Raw> {
201 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
202 let layout = self.layout();
203 let (idx_min, idx_max) = layout.bounds_index()?;
204 if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
205 let (storage, _) = self.into_raw_parts();
206 let (data, _) = storage.into_raw_parts();
207 return Ok(data.into_raw());
208 } else {
209 return self.to_raw_f();
210 }
211 }
212
213 pub fn into_raw(self) -> <B as DeviceRawAPI<T>>::Raw {
214 self.into_raw_f().rstsr_unwrap()
215 }
216}
217
218impl<T, B, D> Tensor<T, B, D>
219where
220 T: Clone,
221 D: DimAPI,
222 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
223 <B as DeviceRawAPI<T>>::Raw: Clone,
224{
225 pub fn into_vec_f(self) -> Result<Vec<T>> {
226 rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
227 let layout = self.layout();
228 let (idx_min, idx_max) = layout.bounds_index()?;
229 if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
230 let (storage, _) = self.into_raw_parts();
231 storage.into_cpu_vec()
232 } else {
233 let data = self.to_raw_f()?;
234 let storage = Storage::new(DataOwned::from(data), self.device().clone());
235 storage.into_cpu_vec()
236 }
237 }
238
239 pub fn into_vec(self) -> Vec<T> {
240 self.into_vec_f().rstsr_unwrap()
241 }
242}
243
244impl<R, T, B, D> TensorAny<R, T, B, D>
249where
250 R: DataCloneAPI<Data = B::Raw>,
251 B::Raw: Clone,
252 T: Clone,
253 D: DimAPI,
254 B: DeviceAPI<T>,
255{
256 pub fn to_scalar_f(&self) -> Result<T> {
257 let layout = self.layout();
258 rstsr_assert_eq!(layout.size(), 1, InvalidLayout)?;
259 let storage = self.storage();
260 let vec = storage.to_cpu_vec()?;
261 Ok(vec[0].clone())
262 }
263
264 pub fn to_scalar(&self) -> T {
265 self.to_scalar_f().rstsr_unwrap()
266 }
267}
268
269impl<R, T, B, D> TensorAny<R, T, B, D>
274where
275 R: DataAPI<Data = B::Raw>,
276 D: DimAPI,
277 B: DeviceAPI<T, Raw = Vec<T>>,
278{
279 pub fn as_ptr(&self) -> *const T {
280 unsafe { self.raw().as_ptr().add(self.layout().offset()) }
281 }
282
283 pub fn as_mut_ptr(&mut self) -> *mut T
284 where
285 R: DataMutAPI,
286 {
287 unsafe { self.raw_mut().as_mut_ptr().add(self.layout().offset()) }
288 }
289}
290
291pub trait TensorViewAPI
296where
297 Self::Dim: DimAPI,
298 Self::Backend: DeviceAPI<Self::Type>,
299{
300 type Type;
301 type Backend;
302 type Dim;
303 fn view(&self) -> TensorView<'_, Self::Type, Self::Backend, Self::Dim>;
305}
306
307impl<R, T, B, D> TensorViewAPI for TensorAny<R, T, B, D>
308where
309 D: DimAPI,
310 R: DataAPI<Data = B::Raw>,
311 B: DeviceAPI<T>,
312{
313 type Type = T;
314 type Backend = B;
315 type Dim = D;
316
317 fn view(&self) -> TensorView<'_, T, B, D> {
318 let data = self.data().as_ref();
319 let storage = Storage::new(data, self.device().clone());
320 let layout = self.layout().clone();
321 unsafe { TensorBase::new_unchecked(storage, layout) }
322 }
323}
324
325impl<R, T, B, D> TensorViewAPI for &TensorAny<R, T, B, D>
326where
327 D: DimAPI,
328 R: DataAPI<Data = B::Raw>,
329 B: DeviceAPI<T>,
330{
331 type Type = T;
332 type Backend = B;
333 type Dim = D;
334
335 fn view(&self) -> TensorView<'_, T, B, D> {
336 TensorAny::view(*self)
337 }
338}
339
340impl<R, T, B, D> TensorViewAPI for &mut TensorAny<R, T, B, D>
341where
342 D: DimAPI,
343 R: DataAPI<Data = B::Raw>,
344 B: DeviceAPI<T>,
345{
346 type Type = T;
347 type Backend = B;
348 type Dim = D;
349
350 fn view(&self) -> TensorView<'_, T, B, D> {
351 TensorAny::view(*self)
352 }
353}
354
355pub trait TensorViewMutAPI
356where
357 Self::Dim: DimAPI,
358 Self::Backend: DeviceAPI<Self::Type>,
359{
360 type Type;
361 type Backend;
362 type Dim;
363
364 fn view_mut(&mut self) -> TensorMut<'_, Self::Type, Self::Backend, Self::Dim>;
366}
367
368impl<R, T, B, D> TensorViewMutAPI for TensorAny<R, T, B, D>
369where
370 D: DimAPI,
371 R: DataMutAPI<Data = B::Raw>,
372 B: DeviceAPI<T>,
373{
374 type Type = T;
375 type Backend = B;
376 type Dim = D;
377
378 fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
379 let device = self.device().clone();
380 let layout = self.layout().clone();
381 let data = self.data_mut().as_mut();
382 let storage = Storage::new(data, device);
383 unsafe { TensorBase::new_unchecked(storage, layout) }
384 }
385}
386
387impl<R, T, B, D> TensorViewMutAPI for &mut TensorAny<R, T, B, D>
388where
389 D: DimAPI,
390 R: DataMutAPI<Data = B::Raw>,
391 B: DeviceAPI<T>,
392{
393 type Type = T;
394 type Backend = B;
395 type Dim = D;
396
397 fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
398 (*self).view_mut()
399 }
400}
401
402pub trait TensorIntoOwnedAPI<T, B, D>
403where
404 D: DimAPI,
405 B: DeviceAPI<T>,
406{
407 fn into_owned(self) -> Tensor<T, B, D>;
413}
414
415impl<R, T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorAny<R, T, B, D>
416where
417 R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
418 <B as DeviceRawAPI<T>>::Raw: Clone,
419 T: Clone,
420 D: DimAPI,
421 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
422{
423 fn into_owned(self) -> Tensor<T, B, D> {
424 TensorAny::into_owned(self)
425 }
426}
427
428pub trait TensorRefAPI<'l>: TensorViewAPI {}
433impl<'l, R, T, B, D> TensorRefAPI<'l> for &'l TensorAny<R, T, B, D>
434where
435 D: DimAPI,
436 R: DataAPI<Data = B::Raw>,
437 B: DeviceAPI<T>,
438 Self: TensorViewAPI,
439{
440}
441impl<'l, T, B, D> TensorRefAPI<'l> for TensorView<'l, T, B, D>
442where
443 D: DimAPI,
444 B: DeviceAPI<T>,
445 Self: TensorViewAPI,
446{
447}
448
449pub trait TensorRefMutAPI<'l>: TensorViewAPI {}
450impl<'l, R, T, B, D> TensorRefMutAPI<'l> for &mut TensorAny<R, T, B, D>
451where
452 D: DimAPI,
453 R: DataMutAPI<Data = B::Raw>,
454 B: DeviceAPI<T>,
455 Self: TensorViewMutAPI,
456{
457}
458impl<'l, T, B, D> TensorRefMutAPI<'l> for TensorMut<'l, T, B, D>
459where
460 D: DimAPI,
461 B: DeviceAPI<T>,
462 Self: TensorViewMutAPI,
463{
464}
465
466#[cfg(test)]
469mod test {
470 use super::*;
471
472 #[test]
473 fn test_into_cow() {
474 let mut a = arange(3);
475 let ptr_a = a.raw().as_ptr();
476
477 let a_mut = a.view_mut();
478 let a_cow = a_mut.into_cow();
479 println!("{a_cow:?}");
480
481 let a_ref = a.view();
482 let a_cow = a_ref.into_cow();
483 println!("{a_cow:?}");
484
485 let a_cow = a.into_cow();
486 println!("{a_cow:?}");
487 let ptr_a_cow = a_cow.raw().as_ptr();
488 assert_eq!(ptr_a, ptr_a_cow);
489 }
490
491 #[test]
492 #[ignore]
493 fn test_force_mut() {
494 let n = 4096;
495 let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
496 for _ in 0..10 {
497 let time = std::time::Instant::now();
498 for i in 0..n {
499 let a_view = a.slice(i);
500 let mut a_mut = unsafe { a_view.force_mut() };
501 a_mut *= i as f64 / 2048.0;
502 }
503 println!("Elapsed time {:?}", time.elapsed());
504 }
505 println!("{a:16.10}");
506 }
507
508 #[test]
509 #[ignore]
510 #[cfg(feature = "rayon")]
511 fn test_force_mut_par() {
512 use rayon::prelude::*;
513 let n = 4096;
514 let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
515 for _ in 0..10 {
516 let time = std::time::Instant::now();
517 (0..n).into_par_iter().for_each(|i| {
518 let a_view = a.slice(i);
519 let mut a_mut = unsafe { a_view.force_mut() };
520 a_mut *= i as f64 / 2048.0;
521 });
522 println!("Elapsed time {:?}", time.elapsed());
523 }
524 println!("{a:16.10}");
525 }
526}