rstsr_core/tensor/
ownership_conversion.rs

1use crate::prelude_dev::*;
2
3/* #region basic conversion */
4
5/// Methods for tensor ownership conversion.
6impl<R, T, B, D> TensorAny<R, T, B, D>
7where
8    D: DimAPI,
9    B: DeviceAPI<T>,
10    R: DataAPI<Data = B::Raw>,
11{
12    /// Get a view of tensor.
13    pub fn view(&self) -> TensorView<'_, T, B, D> {
14        let layout = self.layout().clone();
15        let data = self.data().as_ref();
16        let storage = Storage::new(data, self.device().clone());
17        unsafe { TensorBase::new_unchecked(storage, layout) }
18    }
19
20    /// Get a mutable view of tensor.
21    pub fn view_mut(&mut self) -> TensorMut<'_, T, B, D>
22    where
23        R: DataMutAPI,
24    {
25        let device = self.device().clone();
26        let layout = self.layout().clone();
27        let data = self.data_mut().as_mut();
28        let storage = Storage::new(data, device);
29        unsafe { TensorBase::new_unchecked(storage, layout) }
30    }
31
32    /// Convert current tensor into copy-on-write.
33    pub fn into_cow<'a>(self) -> TensorCow<'a, T, B, D>
34    where
35        R: DataIntoCowAPI<'a>,
36    {
37        let (storage, layout) = self.into_raw_parts();
38        let (data, device) = storage.into_raw_parts();
39        let storage = Storage::new(data.into_cow(), device);
40        unsafe { TensorBase::new_unchecked(storage, layout) }
41    }
42
43    /// Convert tensor into owned tensor.
44    ///
45    /// Data is either moved or fully cloned.
46    /// Layout is not involved; i.e. all underlying data is moved or cloned
47    /// without changing layout.
48    ///
49    /// # See also
50    ///
51    /// [`Tensor::into_owned`] keep data in some conditions, otherwise clone.
52    /// This function can avoid cases where data memory bulk is large, but
53    /// tensor view is small.
54    pub fn into_owned_keep_layout(self) -> Tensor<T, B, D>
55    where
56        R::Data: Clone,
57        R: DataCloneAPI,
58    {
59        let (storage, layout) = self.into_raw_parts();
60        let (data, device) = storage.into_raw_parts();
61        let storage = Storage::new(data.into_owned(), device);
62        unsafe { TensorBase::new_unchecked(storage, layout) }
63    }
64
65    /// Convert tensor into shared tensor.
66    ///
67    /// Data is either moved or cloned.
68    /// Layout is not involved; i.e. all underlying data is moved or cloned
69    /// without changing layout.
70    ///
71    /// # See also
72    ///
73    /// [`Tensor::into_shared`] keep data in some conditions, otherwise clone.
74    /// This function can avoid cases where data memory bulk is large, but
75    /// tensor view is small.
76    pub fn into_shared_keep_layout(self) -> TensorArc<T, B, D>
77    where
78        R::Data: Clone,
79        R: DataCloneAPI,
80    {
81        let (storage, layout) = self.into_raw_parts();
82        let (data, device) = storage.into_raw_parts();
83        let storage = Storage::new(data.into_shared(), device);
84        unsafe { TensorBase::new_unchecked(storage, layout) }
85    }
86}
87
88impl<R, T, B, D> TensorAny<R, T, B, D>
89where
90    R: DataCloneAPI<Data = B::Raw>,
91    R::Data: Clone,
92    D: DimAPI,
93    T: Clone,
94    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
95{
96    pub fn into_owned(self) -> Tensor<T, B, D> {
97        let (idx_min, idx_max) = self.layout().bounds_index().unwrap();
98        if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
99            return self.into_owned_keep_layout();
100        } else {
101            return asarray((&self, TensorIterOrder::K));
102        }
103    }
104
105    pub fn into_shared(self) -> TensorArc<T, B, D> {
106        let (idx_min, idx_max) = self.layout().bounds_index().unwrap();
107        if idx_min == 0 && idx_max == self.storage().len() && idx_max == self.layout().size() {
108            return self.into_shared_keep_layout();
109        } else {
110            return asarray((&self, TensorIterOrder::K)).into_shared();
111        }
112    }
113
114    pub fn to_owned(&self) -> Tensor<T, B, D> {
115        self.view().into_owned()
116    }
117}
118
119impl<T, B, D> Clone for Tensor<T, B, D>
120where
121    T: Clone,
122    D: DimAPI,
123    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
124    B::Raw: Clone,
125{
126    fn clone(&self) -> Self {
127        self.to_owned()
128    }
129}
130
131impl<T, B, D> Clone for TensorCow<'_, T, B, D>
132where
133    T: Clone,
134    D: DimAPI,
135    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
136    B::Raw: Clone,
137{
138    fn clone(&self) -> Self {
139        let tsr_owned = self.to_owned();
140        let (storage, layout) = tsr_owned.into_raw_parts();
141        let (data, device) = storage.into_raw_parts();
142        let data = data.into_cow();
143        let storage = Storage::new(data, device);
144        unsafe { TensorBase::new_unchecked(storage, layout) }
145    }
146}
147
148impl<R, T, B, D> TensorAny<R, T, B, D>
149where
150    R: DataAPI<Data = B::Raw> + DataForceMutAPI<B::Raw>,
151    B: DeviceAPI<T>,
152    D: DimAPI,
153{
154    /// # Safety
155    ///
156    /// This function is highly unsafe, as it entirely bypasses Rust's lifetime
157    /// and borrowing rules.
158    pub unsafe fn force_mut(&self) -> TensorMut<'_, T, B, D> {
159        let layout = self.layout().clone();
160        let data = self.data().force_mut();
161        let storage = Storage::new(data, self.device().clone());
162        TensorBase::new_unchecked(storage, layout)
163    }
164}
165
166/* #endregion */
167
168/* #region to_raw */
169
170impl<R, T, B, D> TensorAny<R, T, B, D>
171where
172    R: DataAPI<Data = B::Raw>,
173    T: Clone,
174    D: DimAPI,
175    B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
176{
177    pub fn to_raw_f(&self) -> Result<Vec<T>> {
178        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
179        let device = self.device();
180        let layout = self.layout().to_dim::<Ix1>()?;
181        let size = layout.size();
182        let mut new_storage = unsafe { device.empty_impl(size)? };
183        device.assign(new_storage.raw_mut(), &[size].c(), self.raw(), &layout)?;
184        let (data, _) = new_storage.into_raw_parts();
185        Ok(data.into_raw())
186    }
187
188    pub fn to_vec(&self) -> Vec<T> {
189        self.to_raw_f().unwrap()
190    }
191}
192
193impl<T, B, D> Tensor<T, B, D>
194where
195    T: Clone,
196    D: DimAPI,
197    B: DeviceAPI<T, Raw = Vec<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, Ix1>,
198{
199    pub fn into_vec_f(self) -> Result<Vec<T>> {
200        rstsr_assert_eq!(self.ndim(), 1, InvalidLayout, "to_vec currently only support 1-D tensor")?;
201        let layout = self.layout();
202        let (idx_min, idx_max) = layout.bounds_index()?;
203        if idx_min == 0 && idx_max == self.storage().len() && idx_max == layout.size() && layout.stride()[0] > 0 {
204            let (storage, _) = self.into_raw_parts();
205            let (data, _) = storage.into_raw_parts();
206            return Ok(data.into_raw());
207        } else {
208            return self.to_raw_f();
209        }
210    }
211
212    pub fn into_vec(self) -> Vec<T> {
213        self.into_vec_f().unwrap()
214    }
215}
216
217/* #endregion */
218
219/* #region to_scalar */
220
221impl<R, T, B, D> TensorAny<R, T, B, D>
222where
223    R: DataCloneAPI<Data = B::Raw>,
224    B::Raw: Clone,
225    T: Clone,
226    D: DimAPI,
227    B: DeviceAPI<T>,
228{
229    pub fn to_scalar_f(&self) -> Result<T> {
230        let layout = self.layout();
231        rstsr_assert_eq!(layout.size(), 1, InvalidLayout)?;
232        let storage = self.storage();
233        let vec = storage.to_cpu_vec()?;
234        Ok(vec[0].clone())
235    }
236
237    pub fn to_scalar(&self) -> T {
238        self.to_scalar_f().unwrap()
239    }
240}
241
242/* #endregion */
243
244/* #region as_ptr */
245
246impl<R, T, B, D> TensorAny<R, T, B, D>
247where
248    R: DataAPI<Data = B::Raw>,
249    D: DimAPI,
250    B: DeviceAPI<T, Raw = Vec<T>>,
251{
252    pub fn as_ptr(&self) -> *const T {
253        unsafe { self.raw().as_ptr().add(self.layout().offset()) }
254    }
255
256    pub fn as_mut_ptr(&mut self) -> *mut T
257    where
258        R: DataMutAPI,
259    {
260        unsafe { self.raw_mut().as_mut_ptr().add(self.layout().offset()) }
261    }
262}
263
264/* #endregion */
265
266/* #region view API */
267
268pub trait TensorViewAPI<T, B, D>
269where
270    D: DimAPI,
271    B: DeviceAPI<T>,
272{
273    /// Get a view of tensor.
274    fn view(&self) -> TensorView<'_, T, B, D>;
275}
276
277impl<R, T, B, D> TensorViewAPI<T, B, D> for TensorAny<R, T, B, D>
278where
279    D: DimAPI,
280    R: DataAPI<Data = B::Raw>,
281    B: DeviceAPI<T>,
282{
283    fn view(&self) -> TensorView<'_, T, B, D> {
284        let data = self.data().as_ref();
285        let storage = Storage::new(data, self.device().clone());
286        let layout = self.layout().clone();
287        unsafe { TensorBase::new_unchecked(storage, layout) }
288    }
289}
290
291impl<R, T, B, D> TensorViewAPI<T, B, D> for &TensorAny<R, T, B, D>
292where
293    D: DimAPI,
294    R: DataAPI<Data = B::Raw>,
295    B: DeviceAPI<T>,
296{
297    fn view(&self) -> TensorView<'_, T, B, D> {
298        (*self).view()
299    }
300}
301
302pub trait TensorViewMutAPI<T, B, D>
303where
304    D: DimAPI,
305    B: DeviceAPI<T>,
306{
307    /// Get a mutable view of tensor.
308    fn view_mut(&mut self) -> TensorMut<'_, T, B, D>;
309}
310
311impl<R, T, B, D> TensorViewMutAPI<T, B, D> for TensorAny<R, T, B, D>
312where
313    D: DimAPI,
314    R: DataMutAPI<Data = B::Raw>,
315    B: DeviceAPI<T>,
316{
317    fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
318        let device = self.device().clone();
319        let layout = self.layout().clone();
320        let data = self.data_mut().as_mut();
321        let storage = Storage::new(data, device);
322        unsafe { TensorBase::new_unchecked(storage, layout) }
323    }
324}
325
326impl<R, T, B, D> TensorViewMutAPI<T, B, D> for &mut TensorAny<R, T, B, D>
327where
328    D: DimAPI,
329    R: DataMutAPI<Data = B::Raw>,
330    B: DeviceAPI<T>,
331{
332    fn view_mut(&mut self) -> TensorMut<'_, T, B, D> {
333        (*self).view_mut()
334    }
335}
336
337pub trait TensorIntoOwnedAPI<T, B, D>
338where
339    D: DimAPI,
340    B: DeviceAPI<T>,
341{
342    /// Convert tensor into owned tensor.
343    ///
344    /// Data is either moved or fully cloned.
345    /// Layout is not involved; i.e. all underlying data is moved or cloned
346    /// without changing layout.
347    fn into_owned(self) -> Tensor<T, B, D>;
348}
349
350impl<R, T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorAny<R, T, B, D>
351where
352    R: DataCloneAPI<Data = B::Raw>,
353    B::Raw: Clone,
354    T: Clone,
355    D: DimAPI,
356    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
357{
358    fn into_owned(self) -> Tensor<T, B, D> {
359        TensorAny::into_owned(self)
360    }
361}
362
363/* #endregion */
364
365/* #region tensor prop for computation */
366
367pub trait TensorRefAPI {}
368impl<R, T, B, D> TensorRefAPI for &TensorAny<R, T, B, D>
369where
370    D: DimAPI,
371    R: DataAPI<Data = B::Raw>,
372    B: DeviceAPI<T>,
373    Self: TensorViewAPI<T, B, D>,
374{
375}
376impl<T, B, D> TensorRefAPI for TensorView<'_, T, B, D>
377where
378    D: DimAPI,
379    B: DeviceAPI<T>,
380    Self: TensorViewAPI<T, B, D>,
381{
382}
383
384pub trait TensorRefMutAPI {}
385impl<R, T, B, D> TensorRefMutAPI for &mut TensorAny<R, T, B, D>
386where
387    D: DimAPI,
388    R: DataMutAPI<Data = B::Raw>,
389    B: DeviceAPI<T>,
390    Self: TensorViewMutAPI<T, B, D>,
391{
392}
393impl<T, B, D> TensorRefMutAPI for TensorMut<'_, T, B, D>
394where
395    D: DimAPI,
396    B: DeviceAPI<T>,
397    Self: TensorViewMutAPI<T, B, D>,
398{
399}
400
401/* #endregion */
402
403#[cfg(test)]
404mod test {
405    use super::*;
406
407    #[test]
408    fn test_into_cow() {
409        let mut a = arange(3);
410        let ptr_a = a.raw().as_ptr();
411
412        let a_mut = a.view_mut();
413        let a_cow = a_mut.into_cow();
414        println!("{a_cow:?}");
415
416        let a_ref = a.view();
417        let a_cow = a_ref.into_cow();
418        println!("{a_cow:?}");
419
420        let a_cow = a.into_cow();
421        println!("{a_cow:?}");
422        let ptr_a_cow = a_cow.raw().as_ptr();
423        assert_eq!(ptr_a, ptr_a_cow);
424    }
425
426    #[test]
427    #[ignore]
428    fn test_force_mut() {
429        let n = 4096;
430        let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
431        for _ in 0..10 {
432            let time = std::time::Instant::now();
433            for i in 0..n {
434                let a_view = a.slice(i);
435                let mut a_mut = unsafe { a_view.force_mut() };
436                a_mut *= i as f64 / 2048.0;
437            }
438            println!("Elapsed time {:?}", time.elapsed());
439        }
440        println!("{a:16.10}");
441    }
442
443    #[test]
444    #[ignore]
445    #[cfg(feature = "rayon")]
446    fn test_force_mut_par() {
447        use rayon::prelude::*;
448        let n = 4096;
449        let a = linspace((0.0, 1.0, n * n)).into_shape((n, n));
450        for _ in 0..10 {
451            let time = std::time::Instant::now();
452            (0..n).into_par_iter().for_each(|i| {
453                let a_view = a.slice(i);
454                let mut a_mut = unsafe { a_view.force_mut() };
455                a_mut *= i as f64 / 2048.0;
456            });
457            println!("Elapsed time {:?}", time.elapsed());
458        }
459        println!("{a:16.10}");
460    }
461}