arr_rs/numeric/operations/
create_from.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    numeric::prelude::*,
5    validators::prelude::*,
6};
7
8/// `ArrayTrait` - Array Create functions
9pub trait ArrayCreateFrom<N: Numeric> where Array<N>: Sized + Clone {
10
11    // matrices
12
13    /// Extract a diagonal or construct a diagonal array
14    ///
15    /// # Arguments
16    ///
17    /// * `k` - chosen diagonal. optional, defaults to 0
18    ///
19    /// # Examples
20    ///
21    /// ```
22    /// use arr_rs::prelude::*;
23    ///
24    /// let expected = array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap();
25    /// let arr = Array::diag(&array!(i32, [1, 2, 3, 4]).unwrap(), None).unwrap();
26    /// assert_eq!(expected, arr);
27    ///
28    /// let expected = array!(i32, [1, 2, 3, 4]).unwrap();
29    /// let arr = Array::diag(&array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap(), None).unwrap();
30    /// assert_eq!(expected, arr);
31    ///
32    /// let expected = array!(i32, [0, 0, 0]).unwrap();
33    /// let arr = Array::diag(&array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap(), Some(1)).unwrap();
34    /// assert_eq!(expected, arr);
35    /// ```
36    ///
37    /// # Errors
38    ///
39    /// may returns `ArrayError`
40    fn diag(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
41
42    /// Construct a diagonal array for flattened input
43    ///
44    /// # Arguments
45    ///
46    /// * `data` - input array
47    /// * `k` - chosen diagonal. optional, defaults to 0
48    ///
49    /// # Examples
50    ///
51    /// ```
52    /// use arr_rs::prelude::*;
53    ///
54    /// let expected = array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap();
55    /// let arr = Array::diagflat(&array!(i32, [1, 2, 3, 4]).unwrap(), None).unwrap();
56    /// assert_eq!(expected, arr);
57    ///
58    /// let expected = array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap();
59    /// let arr = Array::diagflat(&array!(i32, [[1, 2], [3, 4]]).unwrap(), None).unwrap();
60    /// assert_eq!(expected, arr);
61    /// ```
62    ///
63    /// # Errors
64    ///
65    /// may returns `ArrayError`
66    fn diagflat(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
67
68    /// Return a copy of an array with elements above the k-th diagonal zeroed.
69    /// For arrays with ndim exceeding 2, tril will apply to the final two axes.
70    ///
71    /// # Arguments
72    ///
73    /// * `k` - chosen diagonal. optional, defaults to 0
74    ///
75    /// # Examples
76    ///
77    /// ```
78    /// use arr_rs::prelude::*;
79    ///
80    /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 4]).unwrap();
81    /// let expected = array!(i32, [[1, 0, 0, 0], [5, 6, 0, 0]]).unwrap();
82    /// assert_eq!(expected, arr.tril(None).unwrap());
83    ///
84    /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 2, 2]).unwrap();
85    /// let expected = array!(i32, [[[1, 0], [3, 4]], [[5, 0], [7, 8]]]).unwrap();
86    /// assert_eq!(expected, arr.tril(None).unwrap());
87    /// ```
88    ///
89    /// # Errors
90    ///
91    /// may returns `ArrayError`
92    fn tril(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
93
94    /// Return a copy of an array with elements below the k-th diagonal zeroed.
95    /// For arrays with ndim exceeding 2, triu will apply to the final two axes.
96    ///
97    /// # Arguments
98    ///
99    /// * `k` - chosen diagonal. optional, defaults to 0
100    ///
101    /// # Examples
102    ///
103    /// ```
104    /// use arr_rs::prelude::*;
105    ///
106    /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 4]).unwrap();
107    /// let expected = array!(i32, [[1, 2, 3, 4], [0, 6, 7, 8]]).unwrap();
108    /// assert_eq!(expected, arr.triu(None).unwrap());
109    ///
110    /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 2, 2]).unwrap();
111    /// let expected = array!(i32, [[[1, 2], [0, 4]], [[5, 6], [0, 8]]]).unwrap();
112    /// assert_eq!(expected, arr.triu(None).unwrap());
113    /// ```
114    ///
115    /// # Errors
116    ///
117    /// may returns `ArrayError`
118    fn triu(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
119
120    /// Generate a Vandermonde matrix
121    ///
122    /// # Arguments
123    ///
124    /// * `n` - number of columns in the output. optional, by default square array is returned
125    /// * `increasing` - order of the powers of the columns. optional, defaults to false
126    /// if true, the powers increase from left to right, if false, they are reversed.
127    ///
128    /// # Examples
129    ///
130    /// ```
131    /// use arr_rs::prelude::*;
132    ///
133    /// let arr = array!(i32, [1, 2, 3, 4]).unwrap();
134    /// let expected = array!(i32, [[1, 1, 1, 1], [8, 4, 2, 1], [27, 9, 3, 1], [64, 16, 4, 1]]).unwrap();
135    /// assert_eq!(expected, arr.vander(None, Some(false)).unwrap());
136    ///
137    /// let arr = array!(i32, [1, 2, 3, 4]).unwrap();
138    /// let expected = array!(i32, [[1, 1, 1, 1], [1, 2, 4, 8], [1, 3, 9, 27], [1, 4, 16, 64]]).unwrap();
139    /// assert_eq!(expected, arr.vander(None, Some(true)).unwrap());
140    /// ```
141    ///
142    /// # Errors
143    ///
144    /// may returns `ArrayError`
145    fn vander(&self, n: Option<usize>, increasing: Option<bool>) -> Result<Self, ArrayError> where Self: Sized;
146}
147
148impl <N: Numeric> ArrayCreateFrom<N> for Array<N> {
149
150    // ==== matrices
151
152    fn diag(&self, k: Option<isize>) -> Result<Self, ArrayError> {
153
154        fn diag_1d<N: Numeric>(data: &Array<N>, k: isize) -> Result<Array<N>, ArrayError> {
155            let size = data.get_shape()?[0];
156            let abs_k = k.unsigned_abs();
157            let new_shape = vec![size + abs_k, size + abs_k];
158            let data_elements = data.get_elements()?;
159            let elements = (0..new_shape[0] * new_shape[1])
160                .map(|idx| {
161                    let (i, j) = (idx / new_shape[1], idx % new_shape[1]);
162                    if k >= 0 && j == i + k.to_usize() {
163                        if i < size { data_elements[i] }
164                        else { N::zero() }
165                    } else if k < 0 && i == j + abs_k {
166                        if j < size { data_elements[j] }
167                        else { N::zero() }
168                    } else {
169                        N::zero()
170                    }
171                })
172                .collect();
173
174            Array::new(elements, new_shape)
175        }
176
177        fn diag_2d<N: Numeric>(data: &Array<N>, k: isize) -> Result<Array<N>, ArrayError> {
178            let rows = data.get_shape()?[0];
179            let cols = data.get_shape()?[1];
180            let (start_row, start_col) =
181                if k >= 0 { (0, k.to_usize()) }
182                else { ((-k).to_usize(), 0) };
183
184            let data_elements = data.get_elements()?;
185            let elements = (start_row..rows)
186                .zip(start_col..cols)
187                .map(|(i, j)| data_elements[i * cols + j])
188                .collect::<Vec<N>>();
189
190            Array::new(elements.clone(), vec![elements.len()])
191        }
192
193        self.is_dim_supported(&[1, 2])?;
194
195        let k = k.unwrap_or(0);
196        if self.ndim()? == 1 { diag_1d(self, k) }
197        else { diag_2d(self, k) }
198    }
199
200    fn diagflat(&self, k: Option<isize>) -> Result<Self, ArrayError> {
201        self.ravel()?.diag(k)
202    }
203
204    fn tril(&self, k: Option<isize>) -> Result<Self, ArrayError> {
205        let k = k.unwrap_or(0);
206        self.apply_triangular(k, |j, i, k| j > i + k)
207    }
208
209    fn triu(&self, k: Option<isize>) -> Result<Self, ArrayError> {
210        let k = k.unwrap_or(0);
211        self.apply_triangular(k, |j, i, k| j < i + k)
212    }
213
214    fn vander(&self, n: Option<usize>, increasing: Option<bool>) -> Result<Self, ArrayError> {
215        self.is_dim_supported(&[1])?;
216
217        let size = self.shape[0];
218        let increasing = increasing.unwrap_or(false);
219        let n_columns = n.unwrap_or(size);
220        let mut elements = Vec::with_capacity(size * n_columns);
221
222        for item in self {
223            for i in 0..n_columns {
224                let power = if increasing { i } else { n_columns - i - 1 }.to_f64();
225                elements.push(N::from(item.to_f64().powf(power)));
226            }
227        }
228
229        Self::new(elements, vec![size, n_columns])
230    }
231}
232
233impl <N: Numeric> Array<N> {
234
235    fn apply_triangular<F>(&self, k: isize, compare: F) -> Result<Self, ArrayError>
236        where F: Fn(isize, isize, isize) -> bool {
237        let last_dim = self.shape.len() - 1;
238        let second_last_dim = self.shape.len() - 2;
239        let chunk_size = self.shape[last_dim] * self.shape[second_last_dim];
240
241        let elements = self.elements
242            .chunks(chunk_size)
243            .flat_map(|chunk| {
244                chunk
245                    .iter()
246                    .enumerate()
247                    .map(|(idx, &value)| {
248                        let i = (idx / self.shape[last_dim]) % self.shape[second_last_dim];
249                        let j = idx % self.shape[last_dim];
250                        if compare(j.to_isize(), i.to_isize(), k) { N::zero() } else { value }
251                    })
252            })
253            .collect();
254
255        Self::new(elements, self.shape.clone())
256    }
257}