1use crate::prelude_dev::*;
4
5pub enum TensorMutable<'a, T, B, D>
16where
17 B: DeviceRawAPI<T>,
18 D: DimAPI,
19{
20 Owned(Tensor<T, B, D>),
21 Mut(TensorMut<'a, T, B, D>),
22 ToBeCloned(TensorMut<'a, T, B, D>, Tensor<T, B, D>),
23}
24
25impl<T, B, D> TensorViewAPI<T, B, D> for TensorMutable<'_, T, B, D>
26where
27 D: DimAPI,
28 B: DeviceAPI<T>,
29{
30 fn view(&self) -> TensorView<'_, T, B, D> {
31 match self {
32 TensorMutable::Owned(t) => t.view(),
33 TensorMutable::Mut(t) => t.view(),
34 TensorMutable::ToBeCloned(_, t) => t.view(),
35 }
36 }
37}
38
39impl<T, B, D> TensorViewMutAPI<T, B, D> for TensorMutable<'_, T, B, D>
40where
41 D: DimAPI,
42 B: DeviceAPI<T>,
43{
44 fn view_mut(&mut self) -> TensorViewMut<'_, T, B, D> {
45 match self {
46 TensorMutable::Owned(t) => t.view_mut(),
47 TensorMutable::Mut(t) => t.view_mut(),
48 TensorMutable::ToBeCloned(_, t) => t.view_mut(),
49 }
50 }
51}
52
53impl<T, B, D> TensorIntoOwnedAPI<T, B, D> for TensorMutable<'_, T, B, D>
54where
55 T: Clone,
56 D: DimAPI,
57 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
58 B::Raw: Clone,
59{
60 fn into_owned(self) -> Tensor<T, B, D> {
61 match self {
62 TensorMutable::Owned(t) => t,
63 TensorMutable::Mut(t) => t.into_owned(),
64 TensorMutable::ToBeCloned(_, t) => t,
65 }
66 }
67}
68
69impl<T, B, D> TensorMutable<'_, T, B, D>
70where
71 T: Clone,
72 D: DimAPI,
73 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
74{
75 pub fn clone_to_mut(self) -> Self {
76 match self {
77 TensorMutable::ToBeCloned(mut arr_view, arr_owned) => {
78 arr_view.assign(&arr_owned);
79 TensorMutable::Mut(arr_view)
80 },
81 _ => self,
82 }
83 }
84
85 pub fn into_reverse_axes(self) -> Self {
86 match self {
87 TensorMutable::Owned(t) => TensorMutable::Owned(t.into_reverse_axes()),
88 TensorMutable::Mut(t) => TensorMutable::Mut(t.into_reverse_axes()),
89 TensorMutable::ToBeCloned(t, t_owned) => {
90 TensorMutable::ToBeCloned(t.into_reverse_axes(), t_owned.into_reverse_axes())
91 },
92 }
93 }
94
95 pub fn f_prefer(&self) -> bool {
96 self.view().f_prefer()
97 }
98
99 pub fn c_prefer(&self) -> bool {
100 self.view().c_prefer()
101 }
102
103 pub fn f_contig(&self) -> bool {
104 self.view().f_contig()
105 }
106
107 pub fn c_contig(&self) -> bool {
108 self.view().c_contig()
109 }
110}
111
112impl<T, B, D> TensorMutable<'_, T, B, D>
113where
114 T: Clone,
115 D: DimAPI,
116 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
117 B::Raw: Clone,
118{
119 pub fn to_owned(&self) -> Tensor<T, B, D> {
120 match self {
121 TensorMutable::Owned(t) => t.to_owned(),
122 TensorMutable::Mut(t) => t.to_owned(),
123 TensorMutable::ToBeCloned(_, t) => t.to_owned(),
124 }
125 }
126}
127
128impl<'a, T, B, D> TensorMutable<'a, T, B, D>
129where
130 T: Clone,
131 D: DimAPI,
132 B: DeviceAPI<T> + DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
133{
134 pub fn into_dim_f<D2>(self) -> Result<TensorMutable<'a, T, B, D2>>
135 where
136 D: DimIntoAPI<D2>,
137 D2: DimAPI,
138 {
139 match self {
140 TensorMutable::Owned(t) => Ok(TensorMutable::Owned(t.into_dim_f()?)),
141 TensorMutable::Mut(t) => Ok(TensorMutable::Mut(t.into_dim_f()?)),
142 TensorMutable::ToBeCloned(t, t_owned) => {
143 Ok(TensorMutable::ToBeCloned(t.into_dim_f()?, t_owned.into_dim_f()?))
144 },
145 }
146 }
147
148 pub fn into_dim<D2>(self) -> TensorMutable<'a, T, B, D2>
149 where
150 D: DimIntoAPI<D2>,
151 D2: DimAPI,
152 {
153 self.into_dim_f().unwrap()
154 }
155}