rstsr_core/tensor/
tensor_mutable.rs

1//! Mutable tensor (either owned or mutable reference).
2
3use crate::prelude_dev::*;
4
5/// Mutable tensor (either owned or mutable reference).
6///
7/// This is mostly used for inplace operations as output.
8///
9/// It is designed similar to `TensorCow`.
10/// However, if inplace operation is not convenient because of the layout
11/// contiguous not fulfilled, a `ToBeCloned` variant is provided, where an owned
12/// tensor with contiguous layout is generated. This is not the same to
13/// `TensorCow`, where it only involves ownership conversion, but not layout
14/// difference between two tensors. So this is defined as a special type.
15pub enum TensorMutable<'a, T, B, D>
16where
17    B: DeviceRawAPI<T>,
18    D: DimAPI,
19{
20    Owned(Tensor<T, B, D>),
21    Mut(TensorMut<'a, T, B, D>),
22    ToBeCloned(TensorMut<'a, T, B, D>, Tensor<T, B, D>),
23}
24
25impl<T, B, D> TensorViewAPI for TensorMutable<'_, T, B, D>
26where
27    D: DimAPI,
28    B: DeviceAPI<T>,
29{
30    type Type = T;
31    type Backend = B;
32    type Dim = D;
33
34    fn view(&self) -> TensorView<'_, T, B, D> {
35        match self {
36            TensorMutable::Owned(t) => t.view(),
37            TensorMutable::Mut(t) => t.view(),
38            TensorMutable::ToBeCloned(_, t) => t.view(),
39        }
40    }
41}
42
43impl<T, B, D> TensorViewMutAPI for TensorMutable<'_, T, B, D>
44where
45    D: DimAPI,
46    B: DeviceAPI<T>,
47{
48    type Type = T;
49    type Backend = B;
50    type Dim = D;
51
52    fn view_mut(&mut self) -> TensorViewMut<'_, T, B, D> {
53        match self {
54            TensorMutable::Owned(t) => t.view_mut(),
55            TensorMutable::Mut(t) => t.view_mut(),
56            TensorMutable::ToBeCloned(_, t) => t.view_mut(),
57        }
58    }
59}
60
61impl<T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorMutable<'_, T, B, D>
62where
63    T: Clone,
64    D: DimAPI,
65    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
66    <B as DeviceRawAPI<T>>::Raw: Clone,
67{
68    fn into_owned(self) -> Tensor<T, B, D> {
69        match self {
70            TensorMutable::Owned(t) => t,
71            TensorMutable::Mut(t) => t.into_owned(),
72            TensorMutable::ToBeCloned(_, t) => t,
73        }
74    }
75}
76
77impl<T, B, D> TensorMutable<'_, T, B, D>
78where
79    T: Clone,
80    D: DimAPI,
81    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
82{
83    pub fn clone_to_mut(self) -> Self {
84        match self {
85            TensorMutable::ToBeCloned(mut arr_view, arr_owned) => {
86                arr_view.assign(&arr_owned);
87                TensorMutable::Mut(arr_view)
88            },
89            _ => self,
90        }
91    }
92
93    pub fn into_reverse_axes(self) -> Self {
94        match self {
95            TensorMutable::Owned(t) => TensorMutable::Owned(t.into_reverse_axes()),
96            TensorMutable::Mut(t) => TensorMutable::Mut(t.into_reverse_axes()),
97            TensorMutable::ToBeCloned(t, t_owned) => {
98                TensorMutable::ToBeCloned(t.into_reverse_axes(), t_owned.into_reverse_axes())
99            },
100        }
101    }
102
103    pub fn f_prefer(&self) -> bool {
104        self.view().f_prefer()
105    }
106
107    pub fn c_prefer(&self) -> bool {
108        self.view().c_prefer()
109    }
110
111    pub fn f_contig(&self) -> bool {
112        self.view().f_contig()
113    }
114
115    pub fn c_contig(&self) -> bool {
116        self.view().c_contig()
117    }
118}
119
120impl<T, B, D> TensorMutable<'_, T, B, D>
121where
122    T: Clone,
123    D: DimAPI,
124    B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
125    <B as DeviceRawAPI<T>>::Raw: Clone,
126{
127    pub fn to_owned(&self) -> Tensor<T, B, D> {
128        match self {
129            TensorMutable::Owned(t) => t.to_owned(),
130            TensorMutable::Mut(t) => t.to_owned(),
131            TensorMutable::ToBeCloned(_, t) => t.to_owned(),
132        }
133    }
134}
135
136impl<'a, T, B, D> TensorMutable<'a, T, B, D>
137where
138    T: Clone,
139    D: DimAPI,
140    B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
141{
142    pub fn into_dim_f<D2>(self) -> Result<TensorMutable<'a, T, B, D2>>
143    where
144        D: DimIntoAPI<D2>,
145        D2: DimAPI,
146    {
147        match self {
148            TensorMutable::Owned(t) => Ok(TensorMutable::Owned(t.into_dim_f()?)),
149            TensorMutable::Mut(t) => Ok(TensorMutable::Mut(t.into_dim_f()?)),
150            TensorMutable::ToBeCloned(t, t_owned) => {
151                Ok(TensorMutable::ToBeCloned(t.into_dim_f()?, t_owned.into_dim_f()?))
152            },
153        }
154    }
155
156    pub fn into_dim<D2>(self) -> TensorMutable<'a, T, B, D2>
157    where
158        D: DimIntoAPI<D2>,
159        D2: DimAPI,
160    {
161        self.into_dim_f().rstsr_unwrap()
162    }
163}