rstsr_core/tensor/
tensor_mutable.rs

1//! Mutable tensor (either owned or mutable reference).
2
3use crate::prelude_dev::*;
4
5/// Mutable tensor (either owned or mutable reference).
6///
7/// This is mostly used for inplace operations as output.
8///
9/// It is designed similar to `TensorCow`.
10/// However, if inplace operation is not convenient because of the layout
11/// contiguous not fulfilled, a `ToBeCloned` variant is provided, where an owned
12/// tensor with contiguous layout is generated. This is not the same to
13/// `TensorCow`, where it only involves ownership conversion, but not layout
14/// difference between two tensors. So this is defined as a special type.
15pub enum TensorMutable<'a, T, B, D>
16where
17    B: DeviceRawAPI<T>,
18    D: DimAPI,
19{
20    Owned(Tensor<T, B, D>),
21    Mut(TensorMut<'a, T, B, D>),
22    ToBeCloned(TensorMut<'a, T, B, D>, Tensor<T, B, D>),
23}
24
25impl<T, B, D> TensorViewAPI<T, B, D> for TensorMutable<'_, T, B, D>
26where
27    D: DimAPI,
28    B: DeviceAPI<T>,
29{
30    fn view(&self) -> TensorView<'_, T, B, D> {
31        match self {
32            TensorMutable::Owned(t) => t.view(),
33            TensorMutable::Mut(t) => t.view(),
34            TensorMutable::ToBeCloned(_, t) => t.view(),
35        }
36    }
37}
38
39impl<T, B, D> TensorViewMutAPI<T, B, D> for TensorMutable<'_, T, B, D>
40where
41    D: DimAPI,
42    B: DeviceAPI<T>,
43{
44    fn view_mut(&mut self) -> TensorViewMut<'_, T, B, D> {
45        match self {
46            TensorMutable::Owned(t) => t.view_mut(),
47            TensorMutable::Mut(t) => t.view_mut(),
48            TensorMutable::ToBeCloned(_, t) => t.view_mut(),
49        }
50    }
51}
52
53impl<T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorMutable<'_, T, B, D>
54where
55    T: Clone,
56    D: DimAPI,
57    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
58    B::Raw: Clone,
59{
60    fn into_owned(self) -> Tensor<T, B, D> {
61        match self {
62            TensorMutable::Owned(t) => t,
63            TensorMutable::Mut(t) => t.into_owned(),
64            TensorMutable::ToBeCloned(_, t) => t,
65        }
66    }
67}
68
69impl<T, B, D> TensorMutable<'_, T, B, D>
70where
71    T: Clone,
72    D: DimAPI,
73    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
74{
75    pub fn clone_to_mut(self) -> Self {
76        match self {
77            TensorMutable::ToBeCloned(mut arr_view, arr_owned) => {
78                arr_view.assign(&arr_owned);
79                TensorMutable::Mut(arr_view)
80            },
81            _ => self,
82        }
83    }
84
85    pub fn into_reverse_axes(self) -> Self {
86        match self {
87            TensorMutable::Owned(t) => TensorMutable::Owned(t.into_reverse_axes()),
88            TensorMutable::Mut(t) => TensorMutable::Mut(t.into_reverse_axes()),
89            TensorMutable::ToBeCloned(t, t_owned) => {
90                TensorMutable::ToBeCloned(t.into_reverse_axes(), t_owned.into_reverse_axes())
91            },
92        }
93    }
94
95    pub fn f_prefer(&self) -> bool {
96        self.view().f_prefer()
97    }
98
99    pub fn c_prefer(&self) -> bool {
100        self.view().c_prefer()
101    }
102
103    pub fn f_contig(&self) -> bool {
104        self.view().f_contig()
105    }
106
107    pub fn c_contig(&self) -> bool {
108        self.view().c_contig()
109    }
110}
111
112impl<T, B, D> TensorMutable<'_, T, B, D>
113where
114    T: Clone,
115    D: DimAPI,
116    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
117    B::Raw: Clone,
118{
119    pub fn to_owned(&self) -> Tensor<T, B, D> {
120        match self {
121            TensorMutable::Owned(t) => t.to_owned(),
122            TensorMutable::Mut(t) => t.to_owned(),
123            TensorMutable::ToBeCloned(_, t) => t.to_owned(),
124        }
125    }
126}
127
128impl<'a, T, B, D> TensorMutable<'a, T, B, D>
129where
130    T: Clone,
131    D: DimAPI,
132    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
133{
134    pub fn into_dim_f<D2>(self) -> Result<TensorMutable<'a, T, B, D2>>
135    where
136        D: DimIntoAPI<D2>,
137        D2: DimAPI,
138    {
139        match self {
140            TensorMutable::Owned(t) => Ok(TensorMutable::Owned(t.into_dim_f()?)),
141            TensorMutable::Mut(t) => Ok(TensorMutable::Mut(t.into_dim_f()?)),
142            TensorMutable::ToBeCloned(t, t_owned) => {
143                Ok(TensorMutable::ToBeCloned(t.into_dim_f()?, t_owned.into_dim_f()?))
144            },
145        }
146    }
147
148    pub fn into_dim<D2>(self) -> TensorMutable<'a, T, B, D2>
149    where
150        D: DimIntoAPI<D2>,
151        D2: DimAPI,
152    {
153        self.into_dim_f().unwrap()
154    }
155}