arr_rs/core/operations/
axis.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    extensions::prelude::*,
5    validators::prelude::*,
6};
7use crate::prelude::Numeric;
8
9/// `ArrayTrait` - Array Axis functions
10pub trait ArrayAxis<T: ArrayElement> where Array<T>: Sized + Clone {
11
12    /// Applies given function along an axis
13    ///
14    /// # Arguments
15    ///
16    /// * `axis` - axis along which the function will be applied
17    /// * `f` - function to apply
18    ///
19    /// # Errors
20    ///
21    /// may returns `ArrayError`
22    fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, f: F) -> Result<Array<S>, ArrayError>
23        where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError>;
24
25    /// Returns an array with axes transposed
26    ///
27    /// # Arguments
28    ///
29    /// * `axes` - if defined, it's a list of axes to be included in transposition
30    ///
31    /// # Examples
32    ///
33    /// ```
34    /// use arr_rs::prelude::*;
35    ///
36    /// let arr = Array::new(vec![1,2,3,4,5,6,7,8], vec![2, 4]).unwrap();
37    /// assert_eq!(array!(i32, [[1, 5], [2, 6], [3, 7], [4, 8]]), arr.transpose(None));
38    ///
39    /// let arr = Array::new(vec![1,2,3,4,5,6,7,8], vec![4, 2]).unwrap();
40    /// assert_eq!(array!(i32, [[1, 3, 5, 7], [2, 4, 6, 8]]), arr.transpose(None));
41    ///
42    /// let arr = Array::new(vec![1,2,3,4,5,6,7,8], vec![4, 2]).unwrap();
43    /// assert_eq!(array!(i32, [[1, 3, 5, 7], [2, 4, 6, 8]]), arr.transpose(Some(vec![1, 0])));
44    /// ```
45    ///
46    /// # Errors
47    ///
48    /// may returns `ArrayError`
49    fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError>;
50
51    /// Move axes of an array to new positions
52    ///
53    /// # Arguments
54    ///
55    /// * `source` - original positions of the axes to move. must be unique
56    /// * `destination` - destination positions for each of the original axes. must be unique
57    ///
58    /// # Examples
59    ///
60    /// ```
61    /// use arr_rs::prelude::*;
62    ///
63    /// let arr = Array::<i32>::zeros(vec![3, 4, 5]);
64    /// assert_eq!(vec![4, 5, 3], arr.moveaxis(vec![0], vec![2]).get_shape().unwrap());
65    /// assert_eq!(vec![5, 3, 4], arr.moveaxis(vec![2], vec![0]).get_shape().unwrap());
66    /// ```
67    ///
68    /// # Errors
69    ///
70    /// may returns `ArrayError`
71    fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Array<T>, ArrayError>;
72
73    /// Roll the specified axis backwards, until it lies in a given position
74    ///
75    /// # Arguments
76    ///
77    /// * `axis` - the axis to be rolled
78    /// * `start` - start position. optional, defaults to 0
79    ///
80    /// # Examples
81    ///
82    /// ```
83    /// use arr_rs::prelude::*;
84    ///
85    /// let arr = Array::<i32>::zeros(vec![3, 4, 5]);
86    /// assert_eq!(vec![4, 3, 5], arr.rollaxis(1, None).get_shape().unwrap());
87    /// assert_eq!(vec![3, 5, 4], arr.rollaxis(2, Some(1)).get_shape().unwrap());
88    /// ```
89    ///
90    /// # Errors
91    ///
92    /// may returns `ArrayError`
93    fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Array<T>, ArrayError>;
94
95    /// Interchange two axes of an array
96    ///
97    /// # Arguments
98    ///
99    /// * `axis_1` - first axis
100    /// * `axis_1` - second axis
101    ///
102    /// # Examples
103    ///
104    /// ```
105    /// use arr_rs::prelude::*;
106    ///
107    /// let arr = Array::<i32>::zeros(vec![3, 4, 5]);
108    /// assert_eq!(vec![5, 4, 3], arr.swapaxes(0, 2).get_shape().unwrap());
109    /// assert_eq!(vec![3, 5, 4], arr.swapaxes(2, 1).get_shape().unwrap());
110    /// ```
111    ///
112    /// # Errors
113    ///
114    /// may returns `ArrayError`
115    fn swapaxes(&self, axis: isize, start: isize) -> Result<Array<T>, ArrayError>;
116
117    /// Expand the shape of an array
118    ///
119    /// # Arguments
120    ///
121    /// * `axes` - position in the expanded axes where the new axis (or axes) is placed
122    ///
123    /// # Examples
124    ///
125    /// ```
126    /// use arr_rs::prelude::*;
127    ///
128    /// let arr = Array::<i32>::zeros(vec![3, 4, 5]);
129    /// assert_eq!(vec![1, 3, 4, 5], arr.expand_dims(vec![0]).get_shape().unwrap());
130    /// assert_eq!(vec![3, 1, 4, 1, 5], arr.expand_dims(vec![1, 3]).get_shape().unwrap());
131    /// ```
132    ///
133    /// # Errors
134    ///
135    /// may returns `ArrayError`
136    fn expand_dims(&self, axes: Vec<isize>) -> Result<Array<T>, ArrayError>;
137
138    /// Remove axes of length one from array
139    ///
140    /// # Arguments
141    ///
142    /// * `axes` - position of the 10-sized axes to remove. if None, all such axes will be removed
143    ///
144    /// # Examples
145    ///
146    /// ```
147    /// use arr_rs::prelude::*;
148    ///
149    /// let arr = Array::<i32>::zeros(vec![1, 3, 1, 4, 5]);
150    /// assert_eq!(vec![3, 4, 5], arr.squeeze(None).get_shape().unwrap());
151    /// assert_eq!(vec![3, 1, 4, 5], arr.squeeze(Some(vec![0])).get_shape().unwrap());
152    /// assert_eq!(vec![1, 3, 4, 5], arr.squeeze(Some(vec![2])).get_shape().unwrap());
153    /// ```
154    ///
155    /// # Errors
156    ///
157    /// may returns `ArrayError`
158    fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError>;
159}
160
161impl <T: ArrayElement> ArrayAxis<T> for Array<T> {
162
163    fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, mut f: F) -> Result<Array<S>, ArrayError>
164        where F: FnMut(&Self) -> Result<Array<S>, ArrayError> {
165        self.axis_in_bounds(axis)?;
166        let parts = self.get_shape()?.remove_at(axis).into_iter().product();
167        let array = self.moveaxis(vec![axis.to_isize()], vec![self.ndim()?.to_isize()])?;
168        let partial = array
169            .ravel()
170            .split(parts, None)?.into_iter()
171            .map(|arr| f(&arr))
172            .collect::<Vec<Result<Array<S>, _>>>()
173            .has_error()?.into_iter()
174            .map(Result::unwrap)
175            .collect::<Vec<Array<S>>>();
176        let partial_len = partial[0].len()?;
177        let partial = partial.into_iter().flatten().collect::<Array<S>>();
178
179        let new_shape = array.get_shape()?.update_at(self.ndim()? - 1, partial_len);
180        let partial = partial.reshape(&new_shape);
181        if axis == 0 { partial.rollaxis((self.ndim()? - 1).to_isize(), None) }
182        else { partial.moveaxis(vec![axis.to_isize()], vec![(self.ndim()? - 1).to_isize()]) }
183    }
184
185    fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Self, ArrayError> {
186
187        fn transpose_recursive<T: ArrayElement>(
188            input: &[T], input_shape: &[usize],
189            output: &mut [T], output_shape: &[usize],
190            current_indices: &mut [usize], current_dim: usize,
191            axes: &Option<Vec<usize>>) {
192            if current_dim < input_shape.len() - 1 {
193                (0..input_shape[current_dim]).for_each(|i| {
194                    current_indices[current_dim] = i;
195                    transpose_recursive(input, input_shape, output, output_shape, current_indices, current_dim + 1, axes);
196                });
197            } else {
198                (0..input_shape[current_dim]).for_each(|i| {
199                    current_indices[current_dim] = i;
200                    let input_index = input_shape.iter().enumerate().fold(0, |acc, (dim, size)| { acc * size + current_indices[dim] });
201                    let output_indices = axes.as_ref().map_or_else(
202                        || current_indices.iter().rev().copied().collect::<Vec<usize>>(),
203                        |axes| axes.iter().map(|&ax| current_indices[ax]).collect::<Vec<usize>>());
204                    let output_index = output_shape.iter().enumerate().fold(0, |acc, (dim, size)| { acc * size + output_indices[dim] });
205                    output[output_index] = input[input_index].clone();
206                });
207            }
208        }
209
210        let axes = axes.map(|axes| axes.iter()
211            .map(|i| self.normalize_axis(*i))
212            .collect::<Vec<usize>>());
213        let mut new_elements = vec![T::zero(); self.elements.len()];
214        let new_shape: Vec<usize> = axes.clone().map_or_else(
215            || self.shape.clone().into_iter().rev().collect(),
216            |axes| axes.into_iter().map(|ax| self.shape[ax]).collect());
217
218        transpose_recursive(
219            &self.elements, &self.shape,
220            &mut new_elements, &new_shape,
221            &mut vec![0; self.shape.len()], 0,
222            &axes
223        );
224
225        Self::new(new_elements, new_shape)
226    }
227
228    fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Self, ArrayError> {
229        source.is_unique()?;
230        source.len().is_equal(&destination.len())?;
231        let source = source.iter().map(|i| self.normalize_axis(*i)).collect::<Vec<usize>>();
232        let destination = destination.iter().map(|i| self.normalize_axis(*i)).collect::<Vec<usize>>();
233        source.is_unique()?;
234        destination.is_unique()?;
235
236        let mut order = (0..self.ndim()?)
237            .filter(|f| !source.contains(f))
238            .collect::<Vec<usize>>();
239
240        destination.into_iter()
241            .zip(source)
242            .sorted()
243            .for_each(|(d, s)| order.insert(d.min(order.len()), s));
244
245        self.transpose(Some(order.iter().map(Numeric::to_isize).collect()))
246    }
247
248    fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Self, ArrayError> {
249        let axis = self.normalize_axis(axis);
250        let start = start.map_or(0, |ax| self.normalize_axis(ax));
251
252        let mut new_axes = (0..self.ndim()?).collect::<Vec<usize>>();
253        let axis_to_move = new_axes.remove(axis);
254        new_axes.insert(start, axis_to_move);
255
256        self.transpose(Some(new_axes.iter().map(|&i| i.to_isize()).collect()))
257    }
258
259    fn swapaxes(&self, axis_1: isize, axis_2: isize) -> Result<Self, ArrayError> {
260        let axis_1 = self.normalize_axis(axis_1);
261        let axis_2 = self.normalize_axis(axis_2);
262
263        let new_axes = (0..self.ndim()?)
264            .collect::<Vec<usize>>()
265            .swap_ext(axis_1, axis_2);
266
267        self.transpose(Some(new_axes.iter().map(|&i| i.to_isize()).collect()))
268    }
269
270    fn expand_dims(&self, axes: Vec<isize>) -> Result<Self, ArrayError> {
271        let axes = axes.iter()
272            .map(|&i| self.normalize_axis_dim(i, axes.len()))
273            .sorted()
274            .collect::<Vec<usize>>();
275        let mut new_shape = self.get_shape()?;
276
277        for item in axes { new_shape.insert(item, 1) }
278        self.reshape(&new_shape)
279    }
280
281    fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Self, ArrayError> {
282        if let Some(axes) = axes {
283            let axes = axes.iter()
284                .map(|&i| self.normalize_axis(i))
285                .sorted()
286                .rev()
287                .collect::<Vec<usize>>();
288            let mut new_shape = self.get_shape()?;
289
290            if axes.iter().any(|a| new_shape[*a] != 1) {
291                Err(ArrayError::SqueezeShapeOfAxisMustBeOne)
292            } else {
293                for item in axes { new_shape.remove(item); }
294                self.reshape(&new_shape)
295            }
296        }
297        else {
298            self.reshape(&self.get_shape()?.into_iter().filter(|&i| i != 1).collect::<Vec<usize>>())
299        }
300    }
301}
302
303impl <T: ArrayElement> ArrayAxis<T> for Result<Array<T>, ArrayError> {
304
305    fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, f: F) -> Result<Array<S>, ArrayError>
306        where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError> {
307        self.clone()?.apply_along_axis(axis, f)
308    }
309
310    fn transpose(&self, axes: Option<Vec<isize>>) -> Self {
311        self.clone()?.transpose(axes)
312    }
313
314    fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Self {
315        self.clone()?.moveaxis(source, destination)
316    }
317
318    fn rollaxis(&self, axis: isize, start: Option<isize>) -> Self {
319        self.clone()?.rollaxis(axis, start)
320    }
321
322    fn swapaxes(&self, axis: isize, start: isize) -> Self {
323        self.clone()?.swapaxes(axis, start)
324    }
325
326    fn expand_dims(&self, axes: Vec<isize>) -> Self {
327        self.clone()?.expand_dims(axes)
328    }
329
330    fn squeeze(&self, axes: Option<Vec<isize>>) -> Self {
331        self.clone()?.squeeze(axes)
332    }
333}