blas_array2/util/
util_ndarray.rs

1use crate::util::*;
2use ndarray::prelude::*;
3
4#[derive(Debug)]
5pub enum ArrayOut<'a, F, D>
6where
7    D: Dimension,
8{
9    ViewMut(ArrayViewMut<'a, F, D>),
10    Owned(Array<F, D>),
11    ToBeCloned(ArrayViewMut<'a, F, D>, Array<F, D>),
12}
13
14impl<F, D> ArrayOut<'_, F, D>
15where
16    F: Clone,
17    D: Dimension,
18{
19    pub fn view(&self) -> ArrayView<'_, F, D> {
20        match self {
21            Self::ViewMut(arr) => arr.view(),
22            Self::Owned(arr) => arr.view(),
23            Self::ToBeCloned(_, arr) => arr.view(),
24        }
25    }
26
27    pub fn view_mut(&mut self) -> ArrayViewMut<'_, F, D> {
28        match self {
29            Self::ViewMut(arr) => arr.view_mut(),
30            Self::Owned(arr) => arr.view_mut(),
31            Self::ToBeCloned(_, arr) => arr.view_mut(),
32        }
33    }
34
35    pub fn into_owned(self) -> Array<F, D> {
36        match self {
37            Self::ViewMut(arr) => arr.to_owned(),
38            Self::Owned(arr) => arr,
39            Self::ToBeCloned(mut arr_view, arr_owned) => {
40                arr_view.assign(&arr_owned);
41                arr_owned
42            },
43        }
44    }
45
46    pub fn is_view_mut(&mut self) -> bool {
47        match self {
48            Self::ViewMut(_) => true,
49            Self::Owned(_) => false,
50            Self::ToBeCloned(_, _) => true,
51        }
52    }
53
54    pub fn is_owned(&mut self) -> bool {
55        match self {
56            Self::ViewMut(_) => false,
57            Self::Owned(_) => true,
58            Self::ToBeCloned(_, _) => false,
59        }
60    }
61
62    pub fn clone_to_view_mut(self) -> Self {
63        match self {
64            ArrayOut::ToBeCloned(mut arr_view, arr_owned) => {
65                arr_view.assign(&arr_owned);
66                ArrayOut::ViewMut(arr_view)
67            },
68            _ => self,
69        }
70    }
71
72    pub fn reversed_axes(self) -> Self {
73        match self {
74            ArrayOut::ViewMut(arr) => ArrayOut::ViewMut(arr.reversed_axes()),
75            ArrayOut::Owned(arr) => ArrayOut::Owned(arr.reversed_axes()),
76            ArrayOut::ToBeCloned(mut arr_view, arr_owned) => {
77                arr_view.assign(&arr_owned);
78                ArrayOut::ViewMut(arr_view.reversed_axes())
79            },
80        }
81    }
82
83    pub fn get_data_mut_ptr(&mut self) -> *mut F {
84        match self {
85            Self::ViewMut(arr) => arr.as_mut_ptr(),
86            Self::Owned(arr) => arr.as_mut_ptr(),
87            Self::ToBeCloned(_, arr) => arr.as_mut_ptr(),
88        }
89    }
90}
91
92pub type ArrayOut1<'a, F> = ArrayOut<'a, F, Ix1>;
93pub type ArrayOut2<'a, F> = ArrayOut<'a, F, Ix2>;
94pub type ArrayOut3<'a, F> = ArrayOut<'a, F, Ix3>;
95
96/* #endregion */
97
98/* #region Strides */
99
100#[inline]
101pub fn get_layout_array2<F>(arr: &ArrayView2<F>) -> BLASLayout {
102    // Note that this only shows order of matrix (dimension information)
103    // not c/f-contiguous (memory layout)
104    // So some sequential (both c/f-contiguous) cases may be considered as only row or col major
105    // Examples:
106    // RowMajor     ==>   shape=[1, 4], strides=[0, 1], layout=CFcf (0xf)
107    // ColMajor     ==>   shape=[4, 1], strides=[1, 0], layout=CFcf (0xf)
108    // Sequential   ==>   shape=[1, 1], strides=[0, 0], layout=CFcf (0xf)
109    // NonContig    ==>   shape=[4, 1], strides=[10, 0], layout=Custom (0x0)
110    let (d0, d1) = arr.dim();
111    let [s0, s1] = arr.strides().try_into().unwrap();
112    if d0 == 0 || d1 == 0 {
113        // empty array
114        return BLASLayout::Sequential;
115    } else if d0 == 1 && d1 == 1 {
116        // one element
117        return BLASLayout::Sequential;
118    } else if s1 == 1 {
119        // row-major
120        return BLASRowMajor;
121    } else if s0 == 1 {
122        // col-major
123        return BLASColMajor;
124    } else {
125        // non-contiguous
126        return BLASLayout::NonContiguous;
127    }
128}
129
130/* #endregion */
131
132/* #region flip */
133
134pub(crate) fn flip_trans_fpref<'a, F>(
135    trans: BLASTranspose,
136    view: &'a ArrayView2<F>,
137    view_t: &'a ArrayView2<F>,
138    hermi: bool,
139) -> Result<(BLASTranspose, CowArray<'a, F, Ix2>), BLASError>
140where
141    F: BLASFloat,
142{
143    if view.is_fpref() {
144        return Ok((trans, view.to_col_layout()?));
145    } else {
146        match trans {
147            BLASNoTrans => Ok((
148                trans.flip(hermi)?,
149                match hermi {
150                    false => view_t.to_col_layout()?,
151                    true => {
152                        blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
153                        CowArray::from(view.mapv(F::conj).reversed_axes())
154                    },
155                },
156            )),
157            BLASTrans => Ok((trans.flip(hermi)?, view_t.to_col_layout()?)),
158            BLASConjTrans => Ok((trans.flip(hermi)?, {
159                blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
160                CowArray::from(view.mapv(F::conj).reversed_axes())
161            })),
162            _ => blas_invalid!(trans),
163        }
164    }
165}
166
167pub(crate) fn flip_trans_cpref<'a, F>(
168    trans: BLASTranspose,
169    view: &'a ArrayView2<F>,
170    view_t: &'a ArrayView2<F>,
171    hermi: bool,
172) -> Result<(BLASTranspose, CowArray<'a, F, Ix2>), BLASError>
173where
174    F: BLASFloat,
175{
176    if view.is_cpref() {
177        return Ok((trans, view.to_row_layout()?));
178    } else {
179        match trans {
180            BLASNoTrans => Ok((
181                trans.flip(hermi)?,
182                match hermi {
183                    false => view_t.to_row_layout()?,
184                    true => {
185                        blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
186                        CowArray::from(view_t.mapv(F::conj))
187                    },
188                },
189            )),
190            BLASTrans => Ok((trans.flip(hermi)?, view_t.to_row_layout()?)),
191            BLASConjTrans => Ok((trans.flip(hermi)?, {
192                blas_warn_layout_clone!(view_t, "Perform element-wise conjugate to matrix")?;
193                CowArray::from(view_t.mapv(F::conj))
194            })),
195            _ => blas_invalid!(trans),
196        }
197    }
198}
199
200/* #endregion */
201
202/* #region contiguous preference */
203
204pub(crate) trait LayoutPref {
205    fn is_fpref(&self) -> bool;
206    fn is_cpref(&self) -> bool;
207}
208
209impl<A> LayoutPref for ArrayView2<'_, A> {
210    fn is_fpref(&self) -> bool {
211        get_layout_array2(self).is_fpref()
212    }
213
214    fn is_cpref(&self) -> bool {
215        get_layout_array2(self).is_cpref()
216    }
217}
218
219/* #endregion */
220
221/* #region warn on clone */
222
223pub trait ToLayoutCowArray2<A> {
224    fn to_row_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError>;
225    fn to_col_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError>;
226}
227
228impl<A> ToLayoutCowArray2<A> for ArrayView2<'_, A>
229where
230    A: Clone,
231{
232    fn to_row_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError> {
233        if self.is_cpref() {
234            Ok(CowArray::from(self))
235        } else {
236            blas_warn_layout_clone!(self)?;
237            let owned = self.into_owned();
238            Ok(CowArray::from(owned))
239        }
240    }
241
242    fn to_col_layout(&self) -> Result<CowArray<'_, A, Ix2>, BLASError> {
243        if self.is_fpref() {
244            Ok(CowArray::from(self))
245        } else {
246            blas_warn_layout_clone!(self)?;
247            let owned = self.t().into_owned().reversed_axes();
248            Ok(CowArray::from(owned))
249        }
250    }
251}
252
253pub trait ToLayoutCowArray1<A> {
254    fn to_seq_layout(&self) -> Result<CowArray<'_, A, Ix1>, BLASError>;
255}
256
257impl<A> ToLayoutCowArray1<A> for ArrayView1<'_, A>
258where
259    A: Clone,
260{
261    fn to_seq_layout(&self) -> Result<CowArray<'_, A, Ix1>, BLASError> {
262        let cow = self.as_standard_layout();
263        if cow.is_owned() {
264            blas_warn_layout_clone!(self)?;
265        }
266        Ok(cow)
267    }
268}
269
270/* #endregion */