use crate::{
core::prelude::*,
errors::prelude::*,
extensions::prelude::*,
};
pub trait ArrayTiling<T: ArrayElement> where Self: Sized + Clone {
fn repeat(&self, repeats: &[usize], axis: Option<usize>) -> Result<Array<T>, ArrayError>;
}
impl <T: ArrayElement> ArrayTiling<T> for Array<T> {
fn repeat(&self, repeats: &[usize], axis: Option<usize>) -> Result<Array<T>, ArrayError> {
if let Some(axis) = axis {
let repeats = repeats.to_vec().to_array()?.broadcast_to(vec![self.get_shape()?[axis]]).get_elements()?;
let new_axis_len = repeats.clone().into_iter().sum();
let new_shape = self.get_shape()?.update_at(axis, new_axis_len);
let tmp_shape = new_shape.clone().swap_ext(0, axis);
let partial = self.split(self.get_shape()?[axis], Some(axis))?.into_iter()
.zip(&repeats)
.flat_map(|(el, &rep)| vec![el; rep])
.collect::<Vec<Array<T>>>()
.into_iter().flatten()
.collect::<Array<T>>();
partial.reshape(&tmp_shape)
.moveaxis(vec![0], vec![axis as isize])
.reshape(&new_shape)
} else {
let result = self.get_elements()?.into_iter()
.zip(&repeats.to_vec().to_array()?.broadcast_to(self.get_shape()?).get_elements()?)
.flat_map(|(el, &rep)| vec![el; rep])
.collect();
Array::flat(result)
}
}
}
impl <T: ArrayElement> ArrayTiling<T> for Result<Array<T>, ArrayError> {
fn repeat(&self, repeats: &[usize], axis: Option<usize>) -> Result<Array<T>, ArrayError> {
self.clone()?.repeat(repeats, axis)
}
}