1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 extensions::prelude::*,
5 validators::prelude::*,
6};
7use crate::prelude::Numeric;
8
9pub trait ArrayAxis<T: ArrayElement> where Array<T>: Sized + Clone {
11
12 fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, f: F) -> Result<Array<S>, ArrayError>
23 where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError>;
24
25 fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError>;
50
51 fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Array<T>, ArrayError>;
72
73 fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Array<T>, ArrayError>;
94
95 fn swapaxes(&self, axis: isize, start: isize) -> Result<Array<T>, ArrayError>;
116
117 fn expand_dims(&self, axes: Vec<isize>) -> Result<Array<T>, ArrayError>;
137
138 fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Array<T>, ArrayError>;
159}
160
161impl <T: ArrayElement> ArrayAxis<T> for Array<T> {
162
163 fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, mut f: F) -> Result<Array<S>, ArrayError>
164 where F: FnMut(&Self) -> Result<Array<S>, ArrayError> {
165 self.axis_in_bounds(axis)?;
166 let parts = self.get_shape()?.remove_at(axis).into_iter().product();
167 let array = self.moveaxis(vec![axis.to_isize()], vec![self.ndim()?.to_isize()])?;
168 let partial = array
169 .ravel()
170 .split(parts, None)?.into_iter()
171 .map(|arr| f(&arr))
172 .collect::<Vec<Result<Array<S>, _>>>()
173 .has_error()?.into_iter()
174 .map(Result::unwrap)
175 .collect::<Vec<Array<S>>>();
176 let partial_len = partial[0].len()?;
177 let partial = partial.into_iter().flatten().collect::<Array<S>>();
178
179 let new_shape = array.get_shape()?.update_at(self.ndim()? - 1, partial_len);
180 let partial = partial.reshape(&new_shape);
181 if axis == 0 { partial.rollaxis((self.ndim()? - 1).to_isize(), None) }
182 else { partial.moveaxis(vec![axis.to_isize()], vec![(self.ndim()? - 1).to_isize()]) }
183 }
184
185 fn transpose(&self, axes: Option<Vec<isize>>) -> Result<Self, ArrayError> {
186
187 fn transpose_recursive<T: ArrayElement>(
188 input: &[T], input_shape: &[usize],
189 output: &mut [T], output_shape: &[usize],
190 current_indices: &mut [usize], current_dim: usize,
191 axes: &Option<Vec<usize>>) {
192 if current_dim < input_shape.len() - 1 {
193 (0..input_shape[current_dim]).for_each(|i| {
194 current_indices[current_dim] = i;
195 transpose_recursive(input, input_shape, output, output_shape, current_indices, current_dim + 1, axes);
196 });
197 } else {
198 (0..input_shape[current_dim]).for_each(|i| {
199 current_indices[current_dim] = i;
200 let input_index = input_shape.iter().enumerate().fold(0, |acc, (dim, size)| { acc * size + current_indices[dim] });
201 let output_indices = axes.as_ref().map_or_else(
202 || current_indices.iter().rev().copied().collect::<Vec<usize>>(),
203 |axes| axes.iter().map(|&ax| current_indices[ax]).collect::<Vec<usize>>());
204 let output_index = output_shape.iter().enumerate().fold(0, |acc, (dim, size)| { acc * size + output_indices[dim] });
205 output[output_index] = input[input_index].clone();
206 });
207 }
208 }
209
210 let axes = axes.map(|axes| axes.iter()
211 .map(|i| self.normalize_axis(*i))
212 .collect::<Vec<usize>>());
213 let mut new_elements = vec![T::zero(); self.elements.len()];
214 let new_shape: Vec<usize> = axes.clone().map_or_else(
215 || self.shape.clone().into_iter().rev().collect(),
216 |axes| axes.into_iter().map(|ax| self.shape[ax]).collect());
217
218 transpose_recursive(
219 &self.elements, &self.shape,
220 &mut new_elements, &new_shape,
221 &mut vec![0; self.shape.len()], 0,
222 &axes
223 );
224
225 Self::new(new_elements, new_shape)
226 }
227
228 fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Result<Self, ArrayError> {
229 source.is_unique()?;
230 source.len().is_equal(&destination.len())?;
231 let source = source.iter().map(|i| self.normalize_axis(*i)).collect::<Vec<usize>>();
232 let destination = destination.iter().map(|i| self.normalize_axis(*i)).collect::<Vec<usize>>();
233 source.is_unique()?;
234 destination.is_unique()?;
235
236 let mut order = (0..self.ndim()?)
237 .filter(|f| !source.contains(f))
238 .collect::<Vec<usize>>();
239
240 destination.into_iter()
241 .zip(source)
242 .sorted()
243 .for_each(|(d, s)| order.insert(d.min(order.len()), s));
244
245 self.transpose(Some(order.iter().map(Numeric::to_isize).collect()))
246 }
247
248 fn rollaxis(&self, axis: isize, start: Option<isize>) -> Result<Self, ArrayError> {
249 let axis = self.normalize_axis(axis);
250 let start = start.map_or(0, |ax| self.normalize_axis(ax));
251
252 let mut new_axes = (0..self.ndim()?).collect::<Vec<usize>>();
253 let axis_to_move = new_axes.remove(axis);
254 new_axes.insert(start, axis_to_move);
255
256 self.transpose(Some(new_axes.iter().map(|&i| i.to_isize()).collect()))
257 }
258
259 fn swapaxes(&self, axis_1: isize, axis_2: isize) -> Result<Self, ArrayError> {
260 let axis_1 = self.normalize_axis(axis_1);
261 let axis_2 = self.normalize_axis(axis_2);
262
263 let new_axes = (0..self.ndim()?)
264 .collect::<Vec<usize>>()
265 .swap_ext(axis_1, axis_2);
266
267 self.transpose(Some(new_axes.iter().map(|&i| i.to_isize()).collect()))
268 }
269
270 fn expand_dims(&self, axes: Vec<isize>) -> Result<Self, ArrayError> {
271 let axes = axes.iter()
272 .map(|&i| self.normalize_axis_dim(i, axes.len()))
273 .sorted()
274 .collect::<Vec<usize>>();
275 let mut new_shape = self.get_shape()?;
276
277 for item in axes { new_shape.insert(item, 1) }
278 self.reshape(&new_shape)
279 }
280
281 fn squeeze(&self, axes: Option<Vec<isize>>) -> Result<Self, ArrayError> {
282 if let Some(axes) = axes {
283 let axes = axes.iter()
284 .map(|&i| self.normalize_axis(i))
285 .sorted()
286 .rev()
287 .collect::<Vec<usize>>();
288 let mut new_shape = self.get_shape()?;
289
290 if axes.iter().any(|a| new_shape[*a] != 1) {
291 Err(ArrayError::SqueezeShapeOfAxisMustBeOne)
292 } else {
293 for item in axes { new_shape.remove(item); }
294 self.reshape(&new_shape)
295 }
296 }
297 else {
298 self.reshape(&self.get_shape()?.into_iter().filter(|&i| i != 1).collect::<Vec<usize>>())
299 }
300 }
301}
302
303impl <T: ArrayElement> ArrayAxis<T> for Result<Array<T>, ArrayError> {
304
305 fn apply_along_axis<S: ArrayElement, F>(&self, axis: usize, f: F) -> Result<Array<S>, ArrayError>
306 where F: FnMut(&Array<T>) -> Result<Array<S>, ArrayError> {
307 self.clone()?.apply_along_axis(axis, f)
308 }
309
310 fn transpose(&self, axes: Option<Vec<isize>>) -> Self {
311 self.clone()?.transpose(axes)
312 }
313
314 fn moveaxis(&self, source: Vec<isize>, destination: Vec<isize>) -> Self {
315 self.clone()?.moveaxis(source, destination)
316 }
317
318 fn rollaxis(&self, axis: isize, start: Option<isize>) -> Self {
319 self.clone()?.rollaxis(axis, start)
320 }
321
322 fn swapaxes(&self, axis: isize, start: isize) -> Self {
323 self.clone()?.swapaxes(axis, start)
324 }
325
326 fn expand_dims(&self, axes: Vec<isize>) -> Self {
327 self.clone()?.expand_dims(axes)
328 }
329
330 fn squeeze(&self, axes: Option<Vec<isize>>) -> Self {
331 self.clone()?.squeeze(axes)
332 }
333}