rstsr_core/tensor/
ownership_conversion.rs

1use crate::prelude_dev::*;
2
3/* #region basic conversion */
4
5/// Methods for tensor ownership conversion.
6impl<R, T, B, D> TensorAny<R, T, B, D>
7where
8    D: DimAPI,
9    B: DeviceAPI<T>,
10    R: DataAPI<Data = B::Raw>,
11{
12    /// Get a view of tensor.
13    pub fn view(&self) -> TensorView<'_, T, B, D> {
14        let layout = self.layout().clone();
15        let data = self.data().as_ref();
16        let storage = Storage::new(data, self.device().clone());
17        unsafe { TensorBase::new_unchecked(storage, layout) }
18    }
19
20    /// Get a mutable view of tensor.
21    pub fn view_mut(&mut self) -> TensorMut<'_, T, B, D>
22    where
23        R: DataMutAPI,
24    {
25        let device = self.device().clone();
26        let layout = self.layout().clone();
27        let data = self.data_mut().as_mut();
28        let storage = Storage::new(data, device);
29        unsafe { TensorBase::new_unchecked(storage, layout) }
30    }
31
32    /// Convert current tensor into copy-on-write.
33    pub fn into_cow<'a>(self) -> TensorCow<'a, T, B, D>
34    where
35        R: DataIntoCowAPI<'a>,
36    {
37        let (storage, layout) = self.into_raw_parts();
38        let (data, device) = storage.into_raw_parts();
39        let storage = Storage::new(data.into_cow(), device);
40        unsafe { TensorBase::new_unchecked(storage, layout) }
41    }
42
43    /// Convert tensor into owned tensor.
44    ///
45    /// Data is either moved or fully cloned.
46    /// Layout is not involved; i.e. all underlying data is moved or cloned
47    /// without changing layout.
48    ///
49    /// # See also
50    ///
51    /// [`Tensor::into_owned`] keep data in some conditions, otherwise clone.
52    /// This function can avoid cases where data memory bulk is large, but
53    /// tensor view is small.
54    pub fn into_owned_keep_layout(self) -> Tensor<T, B, D>
55    where
56        R::Data: Clone,
57        R: DataCloneAPI,
58    {
59        let (storage, layout) = self.into_raw_parts();
60        let (data, device) = storage.into_raw_parts();
61        let storage = Storage::new(data.into_owned(), device);
62        unsafe { TensorBase::new_unchecked(storage, layout) }
63    }
64
65    /// Convert tensor into shared tensor.
66    ///
67    /// Data is either moved or cloned.
68    /// Layout is not involved; i.e. all underlying data is moved or cloned
69    /// without changing layout.
70    ///
71    /// # See also
72    ///
73    /// [`Tensor::into_shared`] keep data in some conditions, otherwise clone.
74    /// This function can avoid cases where data memory bulk is large, but
75    /// tensor view is small.
76    pub fn into_shared_keep_layout(self) -> TensorArc<T, B, D>
77    where
78        R::Data: Clone,
79        R: DataCloneAPI,
80    {
81        let (storage, layout) = self.into_raw_parts();
82        let (data, device) = storage.into_raw_parts();
83        let storage = Storage::new(data.into_shared(), device);
84        unsafe { TensorBase::new_unchecked(storage, layout) }
85    }
86}
87
88impl<R, T, B, D> TensorAny<R, T, B, D>
89where
90    R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
91    R::Data: Clone,
92    D: DimAPI,
93    T: Clone,
94    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
95{
96    pub fn into_owned(self) -> Tensor<T, B, D> {
97        let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
98        if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
99            return self.into_owned_keep_layout();
100        } else {
101            return asarray((&self, TensorIterOrder::K));
102        }
103    }
104
105    pub fn into_shared(self) -> TensorArc<T, B, D> {
106        let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
107        if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
108            return self.into_shared_keep_layout();
109        } else {
110            return asarray((&self, TensorIterOrder::K)).into_shared();
111        }
112    }
113
114    pub fn to_owned(&self) -> Tensor<T, B, D> {
115        self.view().into_owned()
116    }
117}
118
119impl<T, B, D> Clone for Tensor<T, B, D>
120where
121    T: Clone,
122    D: DimAPI,
123    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
124    <B as DeviceRawAPI<T>>::Raw: Clone,
125{
126    fn clone(&self) -> Self {
127        self.to_owned()
128    }
129}
130
131impl<T, B, D> Clone for TensorCow<'_, T, B, D>
132where
133    T: Clone,
134    D: DimAPI,
135    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
136    <B as DeviceRawAPI<T>>::Raw: Clone,
137{
138    fn clone(&self) -> Self {
139        let tsr_owned = self.to_owned();
140        let (storage, layout) = tsr_owned.into_raw_parts();
141        let (data, device) = storage.into_raw_parts();
142        let data = data.into_cow();
143        let storage = Storage::new(data, device);
144        unsafe { TensorBase::new_unchecked(storage, layout) }
145    }
146}
147
148impl<R, T, B, D> TensorAny<R, T, B, D>
149where
150    R: DataAPI<Data = B::Raw> + DataForceMutAPI<B::Raw>,
151    B: DeviceAPI<T>,
152    D: DimAPI,
153{
154    /// # Safety
155    ///
156    /// This function is highly unsafe, as it entirely bypasses Rust's lifetime
157    /// and borrowing rules.
158    pub unsafe fn force_mut(&self) -> TensorMut<'_, T, B, D> {
159        let layout = self.layout().clone();
160        let data = self.data().force_mut();
161        let storage = Storage::new(data, self.device().clone());
162        TensorBase::new_unchecked(storage, layout)
163    }
164}
165
166/* #endregion */
167
168/* #region to_raw */
169
170impl<R, T, B, D> TensorAny<R, T, B, D>
171where
172    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
173    T: Clone,
174    D: DimAPI,
175    B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
176{
177    pub fn to_raw_f(&self) -> Result<Vec<T>> {
178        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
179        let device = self.device();
180        let layout = self.layout().to_dim::<Ix1>()?;
181        let size = layout.size();
182        let mut new_storage = device.uninit_impl(size)?;
183        device.assign_uninit(new_storage.raw_mut(), &[size].c(), self.raw(), &layout)?;
184        let storage = unsafe { B::assume_init_impl(new_storage) }?;
185        let (data, _) = storage.into_raw_parts();
186        Ok(data.into_raw())
187    }
188
189    pub fn to_vec(&self) -> Vec<T> {
190        self.to_raw_f().rstsr_unwrap()
191    }
192}
193
194impl<T, B, D> Tensor<T, B, D>
195where
196    T: Clone,
197    D: DimAPI,
198    B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
199{
200    pub fn into_vec_f(self) -> Result<Vec<T>> {
201        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
202        let layout = self.layout();
203        let (idx_min, idx_max) = layout.bounds_index()?;
204        if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
205            let (storage, _) = self.into_raw_parts();
206            let (data, _) = storage.into_raw_parts();
207            return Ok(data.into_raw());
208        } else {
209            return self.to_raw_f();
210        }
211    }
212
213    pub fn into_vec(self) -> Vec<T> {
214        self.into_vec_f().rstsr_unwrap()
215    }
216}
217
218/* #endregion */
219
220/* #region to_scalar */
221
222impl<R, T, B, D> TensorAny<R, T, B, D>
223where
224    R: DataCloneAPI<Data = B::Raw>,
225    B::Raw: Clone,
226    T: Clone,
227    D: DimAPI,
228    B: DeviceAPI<T>,
229{
230    pub fn to_scalar_f(&self) -> Result<T> {
231        let layout = self.layout();
232        rstsr_assert_eq!(layout.size(), 1, InvalidLayout)?;
233        let storage = self.storage();
234        let vec = storage.to_cpu_vec()?;
235        Ok(vec[0].clone())
236    }
237
238    pub fn to_scalar(&self) -> T {
239        self.to_scalar_f().rstsr_unwrap()
240    }
241}
242
243/* #endregion */
244
245/* #region as_ptr */
246
247impl<R, T, B, D> TensorAny<R, T, B, D>
248where
249    R: DataAPI<Data = B::Raw>,
250    D: DimAPI,
251    B: DeviceAPI<T, Raw = Vec<T>>,
252{
253    pub fn as_ptr(&self) -> *const T {
254        unsafe { self.raw().as_ptr().add(self.layout().offset()) }
255    }
256
257    pub fn as_mut_ptr(&mut self) -> *mut T
258    where
259        R: DataMutAPI,
260    {
261        unsafe { self.raw_mut().as_mut_ptr().add(self.layout().offset()) }
262    }
263}
264
265/* #endregion */
266
267/* #region view API */
268
269pub trait TensorViewAPI
270where
271    Self::Dim: DimAPI,
272    Self::Backend: DeviceAPI<Self::Type>,
273{
274    type Type;
275    type Backend;
276    type Dim;
277    /// Get a view of tensor.
278    fn view(&self) -> TensorView<'_, Self::Type, Self::Backend, Self::Dim>;
279}
280
281impl<R, T, B, D> TensorViewAPI for TensorAny<R, T, B, D>
282where
283    D: DimAPI,
284    R: DataAPI<Data = B::Raw>,
285    B: DeviceAPI<T>,
286{
287    type Type = T;
288    type Backend = B;
289    type Dim = D;
290
291    fn view(&self) -> TensorView<'_, T, B, D> {
292        let data = self.data().as_ref();
293        let storage = Storage::new(data, self.device().clone());
294        let layout = self.layout().clone();
295        unsafe { TensorBase::new_unchecked(storage, layout) }
296    }
297}
298
299impl<R, T, B, D> TensorViewAPI for &TensorAny<R, T, B, D>
300where
301    D: DimAPI,
302    R: DataAPI<Data = B::Raw>,
303    B: DeviceAPI<T>,
304{
305    type Type = T;
306    type Backend = B;
307    type Dim = D;
308
309    fn view(&self) -> TensorView<'_, T, B, D> {
310        TensorAny::view(*self)
311    }
312}
313
314impl<R, T, B, D> TensorViewAPI for &mut TensorAny<R, T, B, D>
315where
316    D: DimAPI,
317    R: DataAPI<Data = B::Raw>,
318    B: DeviceAPI<T>,
319{
320    type Type = T;
321    type Backend = B;
322    type Dim = D;
323
324    fn view(&self) -> TensorView<'_, T, B, D> {
325        TensorAny::view(*self)
326    }
327}
328
329pub trait TensorViewMutAPI
330where
331    Self::Dim: DimAPI,
332    Self::Backend: DeviceAPI<Self::Type>,
333{
334    type Type;
335    type Backend;
336    type Dim;
337
338    /// Get a mutable view of tensor.
339    fn view_mut(&mut self) -> TensorMut<'_, Self::Type, Self::Backend, Self::Dim>;
340}
341
342impl<R, T, B, D> TensorViewMutAPI for TensorAny<R, T, B, D>
343where
344    D: DimAPI,
345    R: DataMutAPI<Data = B::Raw>,
346    B: DeviceAPI<T>,
347{
348    type Type = T;
349    type Backend = B;
350    type Dim = D;
351
352    fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
353        let device = self.device().clone();
354        let layout = self.layout().clone();
355        let data = self.data_mut().as_mut();
356        let storage = Storage::new(data, device);
357        unsafe { TensorBase::new_unchecked(storage, layout) }
358    }
359}
360
361impl<R, T, B, D> TensorViewMutAPI for &mut TensorAny<R, T, B, D>
362where
363    D: DimAPI,
364    R: DataMutAPI<Data = B::Raw>,
365    B: DeviceAPI<T>,
366{
367    type Type = T;
368    type Backend = B;
369    type Dim = D;
370
371    fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
372        (*self).view_mut()
373    }
374}
375
376pub trait TensorIntoOwnedAPI<T, B, D>
377where
378    D: DimAPI,
379    B: DeviceAPI<T>,
380{
381    /// Convert tensor into owned tensor.
382    ///
383    /// Data is either moved or fully cloned.
384    /// Layout is not involved; i.e. all underlying data is moved or cloned
385    /// without changing layout.
386    fn into_owned(self) -> Tensor<T, B, D>;
387}
388
389impl<R, T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorAny<R, T, B, D>
390where
391    R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
392    <B as DeviceRawAPI<T>>::Raw: Clone,
393    T: Clone,
394    D: DimAPI,
395    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
396{
397    fn into_owned(self) -> Tensor<T, B, D> {
398        TensorAny::into_owned(self)
399    }
400}
401
402/* #endregion */
403
404/* #region tensor prop for computation */
405
406pub trait TensorRefAPI<'l>: TensorViewAPI {}
407impl<'l, R, T, B, D> TensorRefAPI<'l> for &'l TensorAny<R, T, B, D>
408where
409    D: DimAPI,
410    R: DataAPI<Data = B::Raw>,
411    B: DeviceAPI<T>,
412    Self: TensorViewAPI,
413{
414}
415impl<'l, T, B, D> TensorRefAPI<'l> for TensorView<'l, T, B, D>
416where
417    D: DimAPI,
418    B: DeviceAPI<T>,
419    Self: TensorViewAPI,
420{
421}
422
423pub trait TensorRefMutAPI<'l>: TensorViewAPI {}
424impl<'l, R, T, B, D> TensorRefMutAPI<'l> for &mut TensorAny<R, T, B, D>
425where
426    D: DimAPI,
427    R: DataMutAPI<Data = B::Raw>,
428    B: DeviceAPI<T>,
429    Self: TensorViewMutAPI,
430{
431}
432impl<'l, T, B, D> TensorRefMutAPI<'l> for TensorMut<'l, T, B, D>
433where
434    D: DimAPI,
435    B: DeviceAPI<T>,
436    Self: TensorViewMutAPI,
437{
438}
439
440/* #endregion */
441
442#[cfg(test)]
443mod test {
444    use super::*;
445
446    #[test]
447    fn test_into_cow() {
448        let mut a = arange(3);
449        let ptr_a = a.raw().as_ptr();
450
451        let a_mut = a.view_mut();
452        let a_cow = a_mut.into_cow();
453        println!("{a_cow:?}");
454
455        let a_ref = a.view();
456        let a_cow = a_ref.into_cow();
457        println!("{a_cow:?}");
458
459        let a_cow = a.into_cow();
460        println!("{a_cow:?}");
461        let ptr_a_cow = a_cow.raw().as_ptr();
462        assert_eq!(ptr_a, ptr_a_cow);
463    }
464
465    #[test]
466    #[ignore]
467    fn test_force_mut() {
468        let n = 4096;
469        let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
470        for _ in 0..10 {
471            let time = std::time::Instant::now();
472            for i in 0..n {
473                let a_view = a.slice(i);
474                let mut a_mut = unsafe { a_view.force_mut() };
475                a_mut *= i as f64 / 2048.0;
476            }
477            println!("Elapsed time {:?}", time.elapsed());
478        }
479        println!("{a:16.10}");
480    }
481
482    #[test]
483    #[ignore]
484    #[cfg(feature = "rayon")]
485    fn test_force_mut_par() {
486        use rayon::prelude::*;
487        let n = 4096;
488        let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
489        for _ in 0..10 {
490            let time = std::time::Instant::now();
491            (0..n).into_par_iter().for_each(|i| {
492                let a_view = a.slice(i);
493                let mut a_mut = unsafe { a_view.force_mut() };
494                a_mut *= i as f64 / 2048.0;
495            });
496            println!("Elapsed time {:?}", time.elapsed());
497        }
498        println!("{a:16.10}");
499    }
500}