1use crate::util::*;
2use ndarray::prelude::*;
3
4#[derive(Debug)]
5pub enum ArrayOut<'a, F, D>
6where
7 D: Dimension,
8{
9 ViewMut(ArrayViewMut<'a, F, D>),
10 Owned(Array<F, D>),
11 ToBeCloned(ArrayViewMut<'a, F, D>, Array<F, D>),
12}
13
14impl<F, D> ArrayOut<'_, F, D>
15where
16 F: Clone,
17 D: Dimension,
18{
19 pub fn view(&self) -> ArrayView<'_, F, D> {
20 match self {
21 Self::ViewMut(arr) => arr.view(),
22 Self::Owned(arr) => arr.view(),
23 Self::ToBeCloned(_, arr) => arr.view(),
24 }
25 }
26
27 pub fn view_mut(&mut self) -> ArrayViewMut<'_, F, D> {
28 match self {
29 Self::ViewMut(arr) => arr.view_mut(),
30 Self::Owned(arr) => arr.view_mut(),
31 Self::ToBeCloned(_, arr) => arr.view_mut(),
32 }
33 }
34
35 pub fn into_owned(self) -> Array<F, D> {
36 match self {
37 Self::ViewMut(arr) => arr.to_owned(),
38 Self::Owned(arr) => arr,
39 Self::ToBeCloned(mut arr_view, arr_owned) => {
40 arr_view.assign(&arr_owned);
41 arr_owned
42 },
43 }
44 }
45
46 pub fn is_view_mut(&mut self) -> bool {
47 match self {
48 Self::ViewMut(_) => true,
49 Self::Owned(_) => false,
50 Self::ToBeCloned(_, _) => true,
51 }
52 }
53
54 pub fn is_owned(&mut self) -> bool {
55 match self {
56 Self::ViewMut(_) => false,
57 Self::Owned(_) => true,
58 Self::ToBeCloned(_, _) => false,
59 }
60 }
61
62 pub fn clone_to_view_mut(self) -> Self {
63 match self {
64 ArrayOut::ToBeCloned(mut arr_view, arr_owned) => {
65 arr_view.assign(&arr_owned);
66 ArrayOut::ViewMut(arr_view)
67 },
68 _ => self,
69 }
70 }
71
72 pub fn reversed_axes(self) -> Self {
73 match self {
74 ArrayOut::ViewMut(arr) => ArrayOut::ViewMut(arr.reversed_axes()),
75 ArrayOut::Owned(arr) => ArrayOut::Owned(arr.reversed_axes()),
76 ArrayOut::ToBeCloned(mut arr_view, arr_owned) => {
77 arr_view.assign(&arr_owned);
78 ArrayOut::ViewMut(arr_view.reversed_axes())
79 },
80 }
81 }
82
83 pub fn get_data_mut_ptr(&mut self) -> *mut F {
84 match self {
85 Self::ViewMut(arr) => arr.as_mut_ptr(),
86 Self::Owned(arr) => arr.as_mut_ptr(),
87 Self::ToBeCloned(_, arr) => arr.as_mut_ptr(),
88 }
89 }
90}
91
92pub type ArrayOut1<'a, F> = ArrayOut<'a, F, Ix1>;
93pub type ArrayOut2<'a, F> = ArrayOut<'a, F, Ix2>;
94pub type ArrayOut3<'a, F> = ArrayOut<'a, F, Ix3>;
95
96#[inline]
101pub fn get_layout_array2<F>(arr: &ArrayView2<F>) -> BLASLayout {
102 let (d0, d1) = arr.dim();
111 let [s0, s1] = arr.strides().try_into().unwrap();
112 if d0 == 0 || d1 == 0 {
113 return BLASLayout::Sequential;
115 } else if d0 == 1 && d1 == 1 {
116 return BLASLayout::Sequential;
118 } else if s1 == 1 {
119 return BLASRowMajor;
121 } else if s0 == 1 {
122 return BLASColMajor;
124 } else {
125 return BLASLayout::NonContiguous;
127 }
128}
129
130pub(crate) fn flip_trans_fpref<'a, F>(
135 trans: BLASTranspose,
136 view: &'a ArrayView2<F>,
137 view_t: &'a ArrayView2<F>,
138 hermi: bool,
139) -> Result<(BLASTranspose, CowArray<'a, F, Ix2>), BLASError>
140where
141 F: BLASFloat,
142{
143 if view.is_fpref() {
144 return Ok((trans, view.to_col_layout()?));
145 } else {
146 match trans {
147 BLASNoTrans => Ok((
148 trans.flip(hermi)?,
149 match hermi {
150 false => view_t.to_col_layout()?,
151 true => {
152 blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
153 CowArray::from(view.mapv(F::conj).reversed_axes())
154 },
155 },
156 )),
157 BLASTrans => Ok((trans.flip(hermi)?, view_t.to_col_layout()?)),
158 BLASConjTrans => Ok((trans.flip(hermi)?, {
159 blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
160 CowArray::from(view.mapv(F::conj).reversed_axes())
161 })),
162 _ => blas_invalid!(trans),
163 }
164 }
165}
166
167pub(crate) fn flip_trans_cpref<'a, F>(
168 trans: BLASTranspose,
169 view: &'a ArrayView2<F>,
170 view_t: &'a ArrayView2<F>,
171 hermi: bool,
172) -> Result<(BLASTranspose, CowArray<'a, F, Ix2>), BLASError>
173where
174 F: BLASFloat,
175{
176 if view.is_cpref() {
177 return Ok((trans, view.to_row_layout()?));
178 } else {
179 match trans {
180 BLASNoTrans => Ok((
181 trans.flip(hermi)?,
182 match hermi {
183 false => view_t.to_row_layout()?,
184 true => {
185 blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
186 CowArray::from(view_t.mapv(F::conj))
187 },
188 },
189 )),
190 BLASTrans => Ok((trans.flip(hermi)?, view_t.to_row_layout()?)),
191 BLASConjTrans => Ok((trans.flip(hermi)?, {
192 blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
193 CowArray::from(view_t.mapv(F::conj))
194 })),
195 _ => blas_invalid!(trans),
196 }
197 }
198}
199
200pub(crate) trait LayoutPref {
205 fn is_fpref(&self) -> bool;
206 fn is_cpref(&self) -> bool;
207}
208
209impl<A> LayoutPref for ArrayView2<'_, A> {
210 fn is_fpref(&self) -> bool {
211 get_layout_array2(self).is_fpref()
212 }
213
214 fn is_cpref(&self) -> bool {
215 get_layout_array2(self).is_cpref()
216 }
217}
218
219pub trait ToLayoutCowArray2<A> {
224 fn to_row_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError>;
225 fn to_col_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError>;
226}
227
228impl<A> ToLayoutCowArray2<A> for ArrayView2<'_, A>
229where
230 A: Clone,
231{
232 fn to_row_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError> {
233 if self.is_cpref() {
234 Ok(CowArray::from(self))
235 } else {
236 blas_warn_layout_clone!(self)?;
237 let owned = self.into_owned();
238 Ok(CowArray::from(owned))
239 }
240 }
241
242 fn to_col_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError> {
243 if self.is_fpref() {
244 Ok(CowArray::from(self))
245 } else {
246 blas_warn_layout_clone!(self)?;
247 let owned = self.t().into_owned().reversed_axes();
248 Ok(CowArray::from(owned))
249 }
250 }
251}
252
253pub trait ToLayoutCowArray1<A> {
254 fn to_seq_layout(&self) -> Result<CowArray<'_, A, Ix1>, BLASError>;
255}
256
257impl<A> ToLayoutCowArray1<A> for ArrayView1<'_, A>
258where
259 A: Clone,
260{
261 fn to_seq_layout(&self) -> Result<CowArray<'_, A, Ix1>, BLASError> {
262 let cow = self.as_standard_layout();
263 if cow.is_owned() {
264 blas_warn_layout_clone!(self)?;
265 }
266 Ok(cow)
267 }
268}
269
270