1use crate::prelude_dev::*;
4
5pub enum TensorMutable<'a, T, B, D>
16where
17 B: DeviceRawAPI<T>,
18 D: DimAPI,
19{
20 Owned(Tensor<T, B, D>),
21 Mut(TensorMut<'a, T, B, D>),
22 ToBeCloned(TensorMut<'a, T, B, D>, Tensor<T, B, D>),
23}
24
25impl<T, B, D> TensorViewAPI for TensorMutable<'_, T, B, D>
26where
27 D: DimAPI,
28 B: DeviceAPI<T>,
29{
30 type Type = T;
31 type Backend = B;
32 type Dim = D;
33
34 fn view(&self) -> TensorView<'_, T, B, D> {
35 match self {
36 TensorMutable::Owned(t) => t.view(),
37 TensorMutable::Mut(t) => t.view(),
38 TensorMutable::ToBeCloned(_, t) => t.view(),
39 }
40 }
41}
42
43impl<T, B, D> TensorViewMutAPI for TensorMutable<'_, T, B, D>
44where
45 D: DimAPI,
46 B: DeviceAPI<T>,
47{
48 type Type = T;
49 type Backend = B;
50 type Dim = D;
51
52 fn view_mut(&mut self) -> TensorViewMut<'_, T, B, D> {
53 match self {
54 TensorMutable::Owned(t) => t.view_mut(),
55 TensorMutable::Mut(t) => t.view_mut(),
56 TensorMutable::ToBeCloned(_, t) => t.view_mut(),
57 }
58 }
59}
60
61impl<T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorMutable<'_, T, B, D>
62where
63 T: Clone,
64 D: DimAPI,
65 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
66 <B as DeviceRawAPI<T>>::Raw: Clone,
67{
68 fn into_owned(self) -> Tensor<T, B, D> {
69 match self {
70 TensorMutable::Owned(t) => t,
71 TensorMutable::Mut(t) => t.into_owned(),
72 TensorMutable::ToBeCloned(_, t) => t,
73 }
74 }
75}
76
77impl<T, B, D> TensorMutable<'_, T, B, D>
78where
79 T: Clone,
80 D: DimAPI,
81 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
82{
83 pub fn clone_to_mut(self) -> Self {
84 match self {
85 TensorMutable::ToBeCloned(mut arr_view, arr_owned) => {
86 arr_view.assign(&arr_owned);
87 TensorMutable::Mut(arr_view)
88 },
89 _ => self,
90 }
91 }
92
93 pub fn into_reverse_axes(self) -> Self {
94 match self {
95 TensorMutable::Owned(t) => TensorMutable::Owned(t.into_reverse_axes()),
96 TensorMutable::Mut(t) => TensorMutable::Mut(t.into_reverse_axes()),
97 TensorMutable::ToBeCloned(t, t_owned) => {
98 TensorMutable::ToBeCloned(t.into_reverse_axes(), t_owned.into_reverse_axes())
99 },
100 }
101 }
102
103 pub fn f_prefer(&self) -> bool {
104 self.view().f_prefer()
105 }
106
107 pub fn c_prefer(&self) -> bool {
108 self.view().c_prefer()
109 }
110
111 pub fn f_contig(&self) -> bool {
112 self.view().f_contig()
113 }
114
115 pub fn c_contig(&self) -> bool {
116 self.view().c_contig()
117 }
118}
119
120impl<T, B, D> TensorMutable<'_, T, B, D>
121where
122 T: Clone,
123 D: DimAPI,
124 B: DeviceAPI<T> + DeviceRawAPI<MaybeUninit<T>> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
125 <B as DeviceRawAPI<T>>::Raw: Clone,
126{
127 pub fn to_owned(&self) -> Tensor<T, B, D> {
128 match self {
129 TensorMutable::Owned(t) => t.to_owned(),
130 TensorMutable::Mut(t) => t.to_owned(),
131 TensorMutable::ToBeCloned(_, t) => t.to_owned(),
132 }
133 }
134}
135
136impl<'a, T, B, D> TensorMutable<'a, T, B, D>
137where
138 T: Clone,
139 D: DimAPI,
140 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
141{
142 pub fn into_dim_f<D2>(self) -> Result<TensorMutable<'a, T, B, D2>>
143 where
144 D: DimIntoAPI<D2>,
145 D2: DimAPI,
146 {
147 match self {
148 TensorMutable::Owned(t) => Ok(TensorMutable::Owned(t.into_dim_f()?)),
149 TensorMutable::Mut(t) => Ok(TensorMutable::Mut(t.into_dim_f()?)),
150 TensorMutable::ToBeCloned(t, t_owned) => {
151 Ok(TensorMutable::ToBeCloned(t.into_dim_f()?, t_owned.into_dim_f()?))
152 },
153 }
154 }
155
156 pub fn into_dim<D2>(self) -> TensorMutable<'a, T, B, D2>
157 where
158 D: DimIntoAPI<D2>,
159 D2: DimAPI,
160 {
161 self.into_dim_f().rstsr_unwrap()
162 }
163}