Skip to main content

rstsr_core/tensor/
ownership_conversion.rs

1use crate::prelude_dev::*;
2
3/* #region basic conversion */
4
5/// Methods for tensor ownership conversion.
6impl<R, T, B, D> TensorAny<R, T, B, D>
7where
8    D: DimAPI,
9    B: DeviceAPI<T>,
10    R: DataAPI<Data = B::Raw>,
11{
12    /// Get a view of tensor.
13    pub fn view(&self) -> TensorView<'_, T, B, D> {
14        let layout = self.layout().clone();
15        let data = self.data().as_ref();
16        let storage = Storage::new(data, self.device().clone());
17        unsafe { TensorBase::new_unchecked(storage, layout) }
18    }
19
20    /// Get a mutable view of tensor.
21    pub fn view_mut(&mut self) -> TensorMut<'_, T, B, D>
22    where
23        R: DataMutAPI,
24    {
25        let device = self.device().clone();
26        let layout = self.layout().clone();
27        let data = self.data_mut().as_mut();
28        let storage = Storage::new(data, device);
29        unsafe { TensorBase::new_unchecked(storage, layout) }
30    }
31
32    /// Convert current tensor into copy-on-write.
33    pub fn into_cow<'a>(self) -> TensorCow<'a, T, B, D>
34    where
35        R: DataIntoCowAPI<'a>,
36    {
37        let (storage, layout) = self.into_raw_parts();
38        let (data, device) = storage.into_raw_parts();
39        let storage = Storage::new(data.into_cow(), device);
40        unsafe { TensorBase::new_unchecked(storage, layout) }
41    }
42
43    /// Convert tensor into owned tensor.
44    ///
45    /// Data is either moved or fully cloned.
46    /// Layout is not involved; i.e. all underlying data is moved or cloned
47    /// without changing layout.
48    ///
49    /// # See also
50    ///
51    /// [`Tensor::into_owned`] keep data in some conditions, otherwise clone.
52    /// This function can avoid cases where data memory bulk is large, but
53    /// tensor view is small.
54    pub fn into_owned_keep_layout(self) -> Tensor<T, B, D>
55    where
56        R::Data: Clone,
57        R: DataCloneAPI,
58    {
59        let (storage, layout) = self.into_raw_parts();
60        let (data, device) = storage.into_raw_parts();
61        let storage = Storage::new(data.into_owned(), device);
62        unsafe { TensorBase::new_unchecked(storage, layout) }
63    }
64
65    /// Convert tensor into shared tensor.
66    ///
67    /// Data is either moved or cloned.
68    /// Layout is not involved; i.e. all underlying data is moved or cloned
69    /// without changing layout.
70    ///
71    /// # See also
72    ///
73    /// [`Tensor::into_shared`] keep data in some conditions, otherwise clone.
74    /// This function can avoid cases where data memory bulk is large, but
75    /// tensor view is small.
76    pub fn into_shared_keep_layout(self) -> TensorArc<T, B, D>
77    where
78        R::Data: Clone,
79        R: DataCloneAPI,
80    {
81        let (storage, layout) = self.into_raw_parts();
82        let (data, device) = storage.into_raw_parts();
83        let storage = Storage::new(data.into_shared(), device);
84        unsafe { TensorBase::new_unchecked(storage, layout) }
85    }
86}
87
88impl<R, T, B, D> TensorAny<R, T, B, D>
89where
90    R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
91    R::Data: Clone,
92    D: DimAPI,
93    T: Clone,
94    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
95{
96    pub fn into_owned(self) -> Tensor<T, B, D> {
97        let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
98        if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
99            return self.into_owned_keep_layout();
100        } else {
101            return asarray((&self, TensorIterOrder::K));
102        }
103    }
104
105    pub fn into_shared(self) -> TensorArc<T, B, D> {
106        let (idx_min, idx_max) = self.layout().bounds_index().rstsr_unwrap();
107        if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
108            return self.into_shared_keep_layout();
109        } else {
110            return asarray((&self, TensorIterOrder::K)).into_shared();
111        }
112    }
113
114    pub fn to_owned(&self) -> Tensor<T, B, D> {
115        self.view().into_owned()
116    }
117}
118
119impl<T, B, D> Clone for Tensor<T, B, D>
120where
121    T: Clone,
122    D: DimAPI,
123    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
124    <B as DeviceRawAPI<T>>::Raw: Clone,
125{
126    fn clone(&self) -> Self {
127        self.to_owned()
128    }
129}
130
131impl<T, B, D> Clone for TensorCow<'_, T, B, D>
132where
133    T: Clone,
134    D: DimAPI,
135    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
136    <B as DeviceRawAPI<T>>::Raw: Clone,
137{
138    fn clone(&self) -> Self {
139        let tsr_owned = self.to_owned();
140        let (storage, layout) = tsr_owned.into_raw_parts();
141        let (data, device) = storage.into_raw_parts();
142        let data = data.into_cow();
143        let storage = Storage::new(data, device);
144        unsafe { TensorBase::new_unchecked(storage, layout) }
145    }
146}
147
148impl<R, T, B, D> TensorAny<R, T, B, D>
149where
150    R: DataAPI<Data = B::Raw> + DataForceMutAPI<B::Raw>,
151    B: DeviceAPI<T>,
152    D: DimAPI,
153{
154    /// # Safety
155    ///
156    /// This function is highly unsafe, as it entirely bypasses Rust's lifetime
157    /// and borrowing rules.
158    pub unsafe fn force_mut(&self) -> TensorMut<'_, T, B, D> {
159        let layout = self.layout().clone();
160        let data = self.data().force_mut();
161        let storage = Storage::new(data, self.device().clone());
162        TensorBase::new_unchecked(storage, layout)
163    }
164}
165
166/* #endregion */
167
168/* #region to_raw */
169
170impl<R, T, B, D> TensorAny<R, T, B, D>
171where
172    R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
173    T: Clone,
174    D: DimAPI,
175    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
176{
177    pub fn to_raw_f(&self) -> Result<<B as DeviceRawAPI<T>>::Raw> {
178        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
179        let device = self.device();
180        let layout = self.layout().to_dim::<Ix1>()?;
181        let size = layout.size();
182        let mut new_storage = device.uninit_impl(size)?;
183        device.assign_uninit(new_storage.raw_mut(), &[size].c(), self.raw(), &layout)?;
184        let storage = unsafe { B::assume_init_impl(new_storage) }?;
185        let (data, _) = storage.into_raw_parts();
186        Ok(data.into_raw())
187    }
188
189    pub fn to_vec(&self) -> <B as DeviceRawAPI<T>>::Raw {
190        self.to_raw_f().rstsr_unwrap()
191    }
192}
193
194impl<T, B, D> Tensor<T, B, D>
195where
196    T: Clone,
197    D: DimAPI,
198    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
199{
200    pub fn into_raw_f(self) -> Result<<B as DeviceRawAPI<T>>::Raw> {
201        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
202        let layout = self.layout();
203        let (idx_min, idx_max) = layout.bounds_index()?;
204        if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
205            let (storage, _) = self.into_raw_parts();
206            let (data, _) = storage.into_raw_parts();
207            return Ok(data.into_raw());
208        } else {
209            return self.to_raw_f();
210        }
211    }
212
213    pub fn into_raw(self) -> <B as DeviceRawAPI<T>>::Raw {
214        self.into_raw_f().rstsr_unwrap()
215    }
216}
217
218impl<T, B, D> Tensor<T, B, D>
219where
220    T: Clone,
221    D: DimAPI,
222    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
223    <B as DeviceRawAPI<T>>::Raw: Clone,
224{
225    pub fn into_vec_f(self) -> Result<Vec<T>> {
226        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
227        let layout = self.layout();
228        let (idx_min, idx_max) = layout.bounds_index()?;
229        if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
230            let (storage, _) = self.into_raw_parts();
231            storage.into_cpu_vec()
232        } else {
233            let data = self.to_raw_f()?;
234            let storage = Storage::new(DataOwned::from(data), self.device().clone());
235            storage.into_cpu_vec()
236        }
237    }
238
239    pub fn into_vec(self) -> Vec<T> {
240        self.into_vec_f().rstsr_unwrap()
241    }
242}
243
244/* #endregion */
245
246/* #region to_scalar */
247
248impl<R, T, B, D> TensorAny<R, T, B, D>
249where
250    R: DataCloneAPI<Data = B::Raw>,
251    B::Raw: Clone,
252    T: Clone,
253    D: DimAPI,
254    B: DeviceAPI<T>,
255{
256    pub fn to_scalar_f(&self) -> Result<T> {
257        let layout = self.layout();
258        rstsr_assert_eq!(layout.size(), 1, InvalidLayout)?;
259        let storage = self.storage();
260        let vec = storage.to_cpu_vec()?;
261        Ok(vec[0].clone())
262    }
263
264    pub fn to_scalar(&self) -> T {
265        self.to_scalar_f().rstsr_unwrap()
266    }
267}
268
269/* #endregion */
270
271/* #region as_ptr */
272
273impl<R, T, B, D> TensorAny<R, T, B, D>
274where
275    R: DataAPI<Data = B::Raw>,
276    D: DimAPI,
277    B: DeviceAPI<T, Raw = Vec<T>>,
278{
279    pub fn as_ptr(&self) -> *const T {
280        unsafe { self.raw().as_ptr().add(self.layout().offset()) }
281    }
282
283    pub fn as_mut_ptr(&mut self) -> *mut T
284    where
285        R: DataMutAPI,
286    {
287        unsafe { self.raw_mut().as_mut_ptr().add(self.layout().offset()) }
288    }
289}
290
291/* #endregion */
292
293/* #region view API */
294
295pub trait TensorViewAPI
296where
297    Self::Dim: DimAPI,
298    Self::Backend: DeviceAPI<Self::Type>,
299{
300    type Type;
301    type Backend;
302    type Dim;
303    /// Get a view of tensor.
304    fn view(&self) -> TensorView<'_, Self::Type, Self::Backend, Self::Dim>;
305}
306
307impl<R, T, B, D> TensorViewAPI for TensorAny<R, T, B, D>
308where
309    D: DimAPI,
310    R: DataAPI<Data = B::Raw>,
311    B: DeviceAPI<T>,
312{
313    type Type = T;
314    type Backend = B;
315    type Dim = D;
316
317    fn view(&self) -> TensorView<'_, T, B, D> {
318        let data = self.data().as_ref();
319        let storage = Storage::new(data, self.device().clone());
320        let layout = self.layout().clone();
321        unsafe { TensorBase::new_unchecked(storage, layout) }
322    }
323}
324
325impl<R, T, B, D> TensorViewAPI for &TensorAny<R, T, B, D>
326where
327    D: DimAPI,
328    R: DataAPI<Data = B::Raw>,
329    B: DeviceAPI<T>,
330{
331    type Type = T;
332    type Backend = B;
333    type Dim = D;
334
335    fn view(&self) -> TensorView<'_, T, B, D> {
336        TensorAny::view(*self)
337    }
338}
339
340impl<R, T, B, D> TensorViewAPI for &mut TensorAny<R, T, B, D>
341where
342    D: DimAPI,
343    R: DataAPI<Data = B::Raw>,
344    B: DeviceAPI<T>,
345{
346    type Type = T;
347    type Backend = B;
348    type Dim = D;
349
350    fn view(&self) -> TensorView<'_, T, B, D> {
351        TensorAny::view(*self)
352    }
353}
354
355pub trait TensorViewMutAPI
356where
357    Self::Dim: DimAPI,
358    Self::Backend: DeviceAPI<Self::Type>,
359{
360    type Type;
361    type Backend;
362    type Dim;
363
364    /// Get a mutable view of tensor.
365    fn view_mut(&mut self) -> TensorMut<'_, Self::Type, Self::Backend, Self::Dim>;
366}
367
368impl<R, T, B, D> TensorViewMutAPI for TensorAny<R, T, B, D>
369where
370    D: DimAPI,
371    R: DataMutAPI<Data = B::Raw>,
372    B: DeviceAPI<T>,
373{
374    type Type = T;
375    type Backend = B;
376    type Dim = D;
377
378    fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
379        let device = self.device().clone();
380        let layout = self.layout().clone();
381        let data = self.data_mut().as_mut();
382        let storage = Storage::new(data, device);
383        unsafe { TensorBase::new_unchecked(storage, layout) }
384    }
385}
386
387impl<R, T, B, D> TensorViewMutAPI for &mut TensorAny<R, T, B, D>
388where
389    D: DimAPI,
390    R: DataMutAPI<Data = B::Raw>,
391    B: DeviceAPI<T>,
392{
393    type Type = T;
394    type Backend = B;
395    type Dim = D;
396
397    fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
398        (*self).view_mut()
399    }
400}
401
402pub trait TensorIntoOwnedAPI<T, B, D>
403where
404    D: DimAPI,
405    B: DeviceAPI<T>,
406{
407    /// Convert tensor into owned tensor.
408    ///
409    /// Data is either moved or fully cloned.
410    /// Layout is not involved; i.e. all underlying data is moved or cloned
411    /// without changing layout.
412    fn into_owned(self) -> Tensor<T, B, D>;
413}
414
415impl<R, T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorAny<R, T, B, D>
416where
417    R: DataCloneAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
418    <B as DeviceRawAPI<T>>::Raw: Clone,
419    T: Clone,
420    D: DimAPI,
421    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
422{
423    fn into_owned(self) -> Tensor<T, B, D> {
424        TensorAny::into_owned(self)
425    }
426}
427
428/* #endregion */
429
430/* #region tensor prop for computation */
431
432pub trait TensorRefAPI<'l>: TensorViewAPI {}
433impl<'l, R, T, B, D> TensorRefAPI<'l> for &'l TensorAny<R, T, B, D>
434where
435    D: DimAPI,
436    R: DataAPI<Data = B::Raw>,
437    B: DeviceAPI<T>,
438    Self: TensorViewAPI,
439{
440}
441impl<'l, T, B, D> TensorRefAPI<'l> for TensorView<'l, T, B, D>
442where
443    D: DimAPI,
444    B: DeviceAPI<T>,
445    Self: TensorViewAPI,
446{
447}
448
449pub trait TensorRefMutAPI<'l>: TensorViewAPI {}
450impl<'l, R, T, B, D> TensorRefMutAPI<'l> for &mut TensorAny<R, T, B, D>
451where
452    D: DimAPI,
453    R: DataMutAPI<Data = B::Raw>,
454    B: DeviceAPI<T>,
455    Self: TensorViewMutAPI,
456{
457}
458impl<'l, T, B, D> TensorRefMutAPI<'l> for TensorMut<'l, T, B, D>
459where
460    D: DimAPI,
461    B: DeviceAPI<T>,
462    Self: TensorViewMutAPI,
463{
464}
465
466/* #endregion */
467
468#[cfg(test)]
469mod test {
470    use super::*;
471
472    #[test]
473    fn test_into_cow() {
474        let mut a = arange(3);
475        let ptr_a = a.raw().as_ptr();
476
477        let a_mut = a.view_mut();
478        let a_cow = a_mut.into_cow();
479        println!("{a_cow:?}");
480
481        let a_ref = a.view();
482        let a_cow = a_ref.into_cow();
483        println!("{a_cow:?}");
484
485        let a_cow = a.into_cow();
486        println!("{a_cow:?}");
487        let ptr_a_cow = a_cow.raw().as_ptr();
488        assert_eq!(ptr_a, ptr_a_cow);
489    }
490
491    #[test]
492    #[ignore]
493    fn test_force_mut() {
494        let n = 4096;
495        let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
496        for _ in 0..10 {
497            let time = std::time::Instant::now();
498            for i in 0..n {
499                let a_view = a.slice(i);
500                let mut a_mut = unsafe { a_view.force_mut() };
501                a_mut *= i as f64 / 2048.0;
502            }
503            println!("Elapsed time {:?}", time.elapsed());
504        }
505        println!("{a:16.10}");
506    }
507
508    #[test]
509    #[ignore]
510    #[cfg(feature = "rayon")]
511    fn test_force_mut_par() {
512        use rayon::prelude::*;
513        let n = 4096;
514        let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
515        for _ in 0..10 {
516            let time = std::time::Instant::now();
517            (0..n).into_par_iter().for_each(|i| {
518                let a_view = a.slice(i);
519                let mut a_mut = unsafe { a_view.force_mut() };
520                a_mut *= i as f64 / 2048.0;
521            });
522            println!("Elapsed time {:?}", time.elapsed());
523        }
524        println!("{a:16.10}");
525    }
526}