#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use crate::dimension;
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Clone,
D: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let mut res_dim = arrays[0].raw_dim();
if axis.index() >= res_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let common_dim = res_dim.remove_axis(axis);
if arrays
.iter()
.any(|a| a.raw_dim().remove_axis(axis) != common_dim)
{
return Err(from_kind(ErrorKind::IncompatibleShape));
}
let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
res_dim.set_axis(axis, stacked_dim);
let new_len = dimension::size_of_shape_checked(&res_dim)?;
res_dim.set_axis(axis, 0);
let mut res = unsafe {
Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
};
for array in arrays {
res.append(axis, array.clone())?;
}
debug_assert_eq!(res.len_of(axis), stacked_dim);
Ok(res)
}
pub fn stack<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Clone,
D: Dimension,
D::Larger: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let common_dim = arrays[0].raw_dim();
if axis.index() > common_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let mut res_dim = common_dim.insert_axis(axis);
if arrays.iter().any(|a| a.raw_dim() != common_dim) {
return Err(from_kind(ErrorKind::IncompatibleShape));
}
res_dim.set_axis(axis, arrays.len());
let new_len = dimension::size_of_shape_checked(&res_dim)?;
res_dim.set_axis(axis, 0);
let mut res = unsafe {
Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
};
for array in arrays {
res.append(axis, array.clone().insert_axis(axis))?;
}
debug_assert_eq!(res.len_of(axis), arrays.len());
Ok(res)
}
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ,) => {
$crate::stack!($axis, $($array),+)
};
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
};
}
#[macro_export]
macro_rules! concatenate {
($axis:expr, $( $array:expr ),+ ,) => {
$crate::concatenate!($axis, $($array),+)
};
($axis:expr, $( $array:expr ),+ ) => {
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
};
}