arr_rs/numeric/operations/create_from.rs
1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 numeric::prelude::*,
5 validators::prelude::*,
6};
7
8/// `ArrayTrait` - Array Create functions
9pub trait ArrayCreateFrom<N: Numeric> where Array<N>: Sized + Clone {
10
11 // matrices
12
13 /// Extract a diagonal or construct a diagonal array
14 ///
15 /// # Arguments
16 ///
17 /// * `k` - chosen diagonal. optional, defaults to 0
18 ///
19 /// # Examples
20 ///
21 /// ```
22 /// use arr_rs::prelude::*;
23 ///
24 /// let expected = array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap();
25 /// let arr = Array::diag(&array!(i32, [1, 2, 3, 4]).unwrap(), None).unwrap();
26 /// assert_eq!(expected, arr);
27 ///
28 /// let expected = array!(i32, [1, 2, 3, 4]).unwrap();
29 /// let arr = Array::diag(&array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap(), None).unwrap();
30 /// assert_eq!(expected, arr);
31 ///
32 /// let expected = array!(i32, [0, 0, 0]).unwrap();
33 /// let arr = Array::diag(&array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap(), Some(1)).unwrap();
34 /// assert_eq!(expected, arr);
35 /// ```
36 ///
37 /// # Errors
38 ///
39 /// may returns `ArrayError`
40 fn diag(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
41
42 /// Construct a diagonal array for flattened input
43 ///
44 /// # Arguments
45 ///
46 /// * `data` - input array
47 /// * `k` - chosen diagonal. optional, defaults to 0
48 ///
49 /// # Examples
50 ///
51 /// ```
52 /// use arr_rs::prelude::*;
53 ///
54 /// let expected = array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap();
55 /// let arr = Array::diagflat(&array!(i32, [1, 2, 3, 4]).unwrap(), None).unwrap();
56 /// assert_eq!(expected, arr);
57 ///
58 /// let expected = array!(i32, [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]]).unwrap();
59 /// let arr = Array::diagflat(&array!(i32, [[1, 2], [3, 4]]).unwrap(), None).unwrap();
60 /// assert_eq!(expected, arr);
61 /// ```
62 ///
63 /// # Errors
64 ///
65 /// may returns `ArrayError`
66 fn diagflat(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
67
68 /// Return a copy of an array with elements above the k-th diagonal zeroed.
69 /// For arrays with ndim exceeding 2, tril will apply to the final two axes.
70 ///
71 /// # Arguments
72 ///
73 /// * `k` - chosen diagonal. optional, defaults to 0
74 ///
75 /// # Examples
76 ///
77 /// ```
78 /// use arr_rs::prelude::*;
79 ///
80 /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 4]).unwrap();
81 /// let expected = array!(i32, [[1, 0, 0, 0], [5, 6, 0, 0]]).unwrap();
82 /// assert_eq!(expected, arr.tril(None).unwrap());
83 ///
84 /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 2, 2]).unwrap();
85 /// let expected = array!(i32, [[[1, 0], [3, 4]], [[5, 0], [7, 8]]]).unwrap();
86 /// assert_eq!(expected, arr.tril(None).unwrap());
87 /// ```
88 ///
89 /// # Errors
90 ///
91 /// may returns `ArrayError`
92 fn tril(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
93
94 /// Return a copy of an array with elements below the k-th diagonal zeroed.
95 /// For arrays with ndim exceeding 2, triu will apply to the final two axes.
96 ///
97 /// # Arguments
98 ///
99 /// * `k` - chosen diagonal. optional, defaults to 0
100 ///
101 /// # Examples
102 ///
103 /// ```
104 /// use arr_rs::prelude::*;
105 ///
106 /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 4]).unwrap();
107 /// let expected = array!(i32, [[1, 2, 3, 4], [0, 6, 7, 8]]).unwrap();
108 /// assert_eq!(expected, arr.triu(None).unwrap());
109 ///
110 /// let arr = array_arange!(i32, 1, 8).reshape(&[2, 2, 2]).unwrap();
111 /// let expected = array!(i32, [[[1, 2], [0, 4]], [[5, 6], [0, 8]]]).unwrap();
112 /// assert_eq!(expected, arr.triu(None).unwrap());
113 /// ```
114 ///
115 /// # Errors
116 ///
117 /// may returns `ArrayError`
118 fn triu(&self, k: Option<isize>) -> Result<Array<N>, ArrayError>;
119
120 /// Generate a Vandermonde matrix
121 ///
122 /// # Arguments
123 ///
124 /// * `n` - number of columns in the output. optional, by default square array is returned
125 /// * `increasing` - order of the powers of the columns. optional, defaults to false
126 /// if true, the powers increase from left to right, if false, they are reversed.
127 ///
128 /// # Examples
129 ///
130 /// ```
131 /// use arr_rs::prelude::*;
132 ///
133 /// let arr = array!(i32, [1, 2, 3, 4]).unwrap();
134 /// let expected = array!(i32, [[1, 1, 1, 1], [8, 4, 2, 1], [27, 9, 3, 1], [64, 16, 4, 1]]).unwrap();
135 /// assert_eq!(expected, arr.vander(None, Some(false)).unwrap());
136 ///
137 /// let arr = array!(i32, [1, 2, 3, 4]).unwrap();
138 /// let expected = array!(i32, [[1, 1, 1, 1], [1, 2, 4, 8], [1, 3, 9, 27], [1, 4, 16, 64]]).unwrap();
139 /// assert_eq!(expected, arr.vander(None, Some(true)).unwrap());
140 /// ```
141 ///
142 /// # Errors
143 ///
144 /// may returns `ArrayError`
145 fn vander(&self, n: Option<usize>, increasing: Option<bool>) -> Result<Self, ArrayError> where Self: Sized;
146}
147
148impl <N: Numeric> ArrayCreateFrom<N> for Array<N> {
149
150 // ==== matrices
151
152 fn diag(&self, k: Option<isize>) -> Result<Self, ArrayError> {
153
154 fn diag_1d<N: Numeric>(data: &Array<N>, k: isize) -> Result<Array<N>, ArrayError> {
155 let size = data.get_shape()?[0];
156 let abs_k = k.unsigned_abs();
157 let new_shape = vec![size + abs_k, size + abs_k];
158 let data_elements = data.get_elements()?;
159 let elements = (0..new_shape[0] * new_shape[1])
160 .map(|idx| {
161 let (i, j) = (idx / new_shape[1], idx % new_shape[1]);
162 if k >= 0 && j == i + k.to_usize() {
163 if i < size { data_elements[i] }
164 else { N::zero() }
165 } else if k < 0 && i == j + abs_k {
166 if j < size { data_elements[j] }
167 else { N::zero() }
168 } else {
169 N::zero()
170 }
171 })
172 .collect();
173
174 Array::new(elements, new_shape)
175 }
176
177 fn diag_2d<N: Numeric>(data: &Array<N>, k: isize) -> Result<Array<N>, ArrayError> {
178 let rows = data.get_shape()?[0];
179 let cols = data.get_shape()?[1];
180 let (start_row, start_col) =
181 if k >= 0 { (0, k.to_usize()) }
182 else { ((-k).to_usize(), 0) };
183
184 let data_elements = data.get_elements()?;
185 let elements = (start_row..rows)
186 .zip(start_col..cols)
187 .map(|(i, j)| data_elements[i * cols + j])
188 .collect::<Vec<N>>();
189
190 Array::new(elements.clone(), vec![elements.len()])
191 }
192
193 self.is_dim_supported(&[1, 2])?;
194
195 let k = k.unwrap_or(0);
196 if self.ndim()? == 1 { diag_1d(self, k) }
197 else { diag_2d(self, k) }
198 }
199
200 fn diagflat(&self, k: Option<isize>) -> Result<Self, ArrayError> {
201 self.ravel()?.diag(k)
202 }
203
204 fn tril(&self, k: Option<isize>) -> Result<Self, ArrayError> {
205 let k = k.unwrap_or(0);
206 self.apply_triangular(k, |j, i, k| j > i + k)
207 }
208
209 fn triu(&self, k: Option<isize>) -> Result<Self, ArrayError> {
210 let k = k.unwrap_or(0);
211 self.apply_triangular(k, |j, i, k| j < i + k)
212 }
213
214 fn vander(&self, n: Option<usize>, increasing: Option<bool>) -> Result<Self, ArrayError> {
215 self.is_dim_supported(&[1])?;
216
217 let size = self.shape[0];
218 let increasing = increasing.unwrap_or(false);
219 let n_columns = n.unwrap_or(size);
220 let mut elements = Vec::with_capacity(size * n_columns);
221
222 for item in self {
223 for i in 0..n_columns {
224 let power = if increasing { i } else { n_columns - i - 1 }.to_f64();
225 elements.push(N::from(item.to_f64().powf(power)));
226 }
227 }
228
229 Self::new(elements, vec![size, n_columns])
230 }
231}
232
233impl <N: Numeric> Array<N> {
234
235 fn apply_triangular<F>(&self, k: isize, compare: F) -> Result<Self, ArrayError>
236 where F: Fn(isize, isize, isize) -> bool {
237 let last_dim = self.shape.len() - 1;
238 let second_last_dim = self.shape.len() - 2;
239 let chunk_size = self.shape[last_dim] * self.shape[second_last_dim];
240
241 let elements = self.elements
242 .chunks(chunk_size)
243 .flat_map(|chunk| {
244 chunk
245 .iter()
246 .enumerate()
247 .map(|(idx, &value)| {
248 let i = (idx / self.shape[last_dim]) % self.shape[second_last_dim];
249 let j = idx % self.shape[last_dim];
250 if compare(j.to_isize(), i.to_isize(), k) { N::zero() } else { value }
251 })
252 })
253 .collect();
254
255 Self::new(elements, self.shape.clone())
256 }
257}