use crate::{
core::prelude::*,
errors::prelude::*,
extensions::prelude::*,
validators::prelude::*,
};
pub trait ArrayAxis<T: ArrayElement> where Array<T>: Sized + Clone {
fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, f: F) -> Result<Array<S>, ArrayError>
where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError>;
fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError>;
fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Array<T>, ArrayError>;
fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Array<T>, ArrayError>;
fn swapaxes(&self, axis: isize, start: isize) -> Result<Array<T>, ArrayError>;
fn expand_dims(&self, axes: Vec<isize>) -> Result<Array<T>, ArrayError>;
fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError>;
}
impl <T: ArrayElement> ArrayAxis<T> for Array<T> {
fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, mut f: F) -> Result<Array<S>, ArrayError>
where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError> {
self.axis_in_bounds(axis)?;
let parts = self.get_shape()?.remove_at(axis).into_iter().product();
let array = self.moveaxis(vec![axis as isize], vec![self.ndim()? as isize])?;
let partial = array
.ravel()
.split(parts, None)?.into_iter()
.map(|arr| f(&arr))
.collect::<Vec<Result<Array<S>, _>>>()
.has_error()?.into_iter()
.map(|arr| arr.unwrap())
.collect::<Vec<Array<S>>>();
let partial_len = partial[0].len()?;
let partial = partial.into_iter().flatten().collect::<Array<S>>();
let new_shape = array.get_shape()?.update_at(self.ndim()? - 1, partial_len);
let partial = partial.reshape(&new_shape);
if axis == 0 { partial.rollaxis((self.ndim()? - 1) as isize, None) }
else { partial.moveaxis(vec![axis as isize], vec![(self.ndim()? - 1) as isize]) }
}
fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Self, ArrayError> {
fn transpose_recursive<T: ArrayElement>(
input: &[T], input_shape: &[usize],
output: &mut [T], output_shape: &[usize],
current_indices: &mut [usize], current_dim: usize,
axes: &Option<Vec<usize>>) {
if current_dim < input_shape.len() - 1 {
(0 .. input_shape[current_dim]).for_each(|i| {
current_indices[current_dim] = i;
transpose_recursive(input, input_shape, output, output_shape, current_indices, current_dim + 1, axes);
});
} else {
(0 .. input_shape[current_dim]).for_each(|i| {
current_indices[current_dim] = i;
let input_index = input_shape.iter().enumerate().fold(0, |acc, (dim, size)| { acc * size + current_indices[dim] });
let output_indices = match axes {
Some(ref axes) => axes.iter().map(|&ax| current_indices[ax]).collect::<Vec<usize>>(),
None => current_indices.iter().rev().cloned().collect::<Vec<usize>>(),
};
let output_index = output_shape.iter().enumerate().fold(0, |acc, (dim, size)| { acc * size + output_indices[dim] });
output[output_index] = input[input_index].clone();
});
}
}
let axes = axes.map(|axes| axes.iter()
.map(|i| self.normalize_axis(*i))
.collect::<Vec<usize>>());
let mut new_elements = vec![T::zero(); self.elements.len()];
let new_shape: Vec<usize> = match axes.clone() {
Some(axes) => axes.into_iter().map(|ax| self.shape[ax]).collect(),
None => self.shape.clone().into_iter().rev().collect(),
};
transpose_recursive(
&self.elements, &self.shape,
&mut new_elements, &new_shape,
&mut vec![0; self.shape.len()], 0,
&axes
);
Self::new(new_elements, new_shape)
}
fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Self, ArrayError> {
source.is_unique()?;
source.len().is_equal(&destination.len())?;
let source = source.iter().map(|i| self.normalize_axis(*i)).collect::<Vec<usize>>();
let destination = destination.iter().map(|i| self.normalize_axis(*i)).collect::<Vec<usize>>();
source.is_unique()?;
destination.is_unique()?;
let mut order = (0 .. self.ndim()?)
.filter(|f| !source.contains(f))
.collect::<Vec<usize>>();
destination.into_iter()
.zip(source)
.sorted()
.for_each(|(d, s)| order.insert(d.min(order.len()), s));
let order = order.iter().map(|i| *i as isize).collect();
self.transpose(Some(order))
}
fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Self, ArrayError> {
let axis = self.normalize_axis(axis);
let start = if let Some(ax) = start { self.normalize_axis(ax) } else { 0 };
let mut new_axes = (0 .. self.ndim()?).collect::<Vec<usize>>();
let axis_to_move = new_axes.remove(axis);
new_axes.insert(start, axis_to_move);
self.transpose(Some(new_axes.iter().map(|&i| i as isize).collect()))
}
fn swapaxes(&self, axis_1: isize, axis_2: isize) -> Result<Self, ArrayError> {
let axis_1 = self.normalize_axis(axis_1);
let axis_2 = self.normalize_axis(axis_2);
let new_axes = (0 .. self.ndim()?)
.collect::<Vec<usize>>()
.swap_ext(axis_1, axis_2);
self.transpose(Some(new_axes.iter().map(|&i| i as isize).collect()))
}
fn expand_dims(&self, axes: Vec<isize>) -> Result<Self, ArrayError> {
let axes = axes.iter()
.map(|&i| self.normalize_axis_dim(i, axes.len()))
.sorted()
.collect::<Vec<usize>>();
let mut new_shape = self.get_shape()?;
for item in axes { new_shape.insert(item, 1) }
self.reshape(&new_shape)
}
fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Self, ArrayError> {
if let Some(axes) = axes {
let axes = axes.iter()
.map(|&i| self.normalize_axis(i))
.sorted()
.rev()
.collect::<Vec<usize>>();
let mut new_shape = self.get_shape()?;
if axes.iter().any(|a| new_shape[*a] != 1) {
Err(ArrayError::SqueezeShapeOfAxisMustBeOne)
} else {
for item in axes { new_shape.remove(item); }
self.reshape(&new_shape)
}
}
else {
self.reshape(&self.get_shape()?.into_iter().filter(|&i| i != 1).collect::<Vec<usize>>())
}
}
}
impl <T: ArrayElement> ArrayAxis<T> for Result<Array<T>, ArrayError> {
fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, f: F) -> Result<Array<S>, ArrayError>
where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError> {
self.clone()?.apply_along_axis(axis, f)
}
fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError> {
self.clone()?.transpose(axes)
}
fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Array<T>, ArrayError> {
self.clone()?.moveaxis(source, destination)
}
fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Array<T>, ArrayError> {
self.clone()?.rollaxis(axis, start)
}
fn swapaxes(&self, axis: isize, start: isize) -> Result<Array<T>, ArrayError> {
self.clone()?.swapaxes(axis, start)
}
fn expand_dims(&self, axes: Vec<isize>) -> Result<Array<T>, ArrayError> {
self.clone()?.expand_dims(axes)
}
fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError> {
self.clone()?.squeeze(axes)
}
}