use alloc::vec::Vec;
use crate::dimension;
use crate::error::{from_kind, ErrorKind, ShapeError};
use crate::imp_prelude::*;
pub fn stack<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Clone,
D: Dimension,
D::Larger: RemoveAxis,
{
#[allow(deprecated)]
stack_new_axis(axis, arrays)
}
pub fn concatenate<A, D>(axis: Axis, arrays: &[ArrayView<A, D>]) -> Result<Array<A, D>, ShapeError>
where
A: Clone,
D: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let mut res_dim = arrays[0].raw_dim();
if axis.index() >= res_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let common_dim = res_dim.remove_axis(axis);
if arrays
.iter()
.any(|a| a.raw_dim().remove_axis(axis) != common_dim)
{
return Err(from_kind(ErrorKind::IncompatibleShape));
}
let stacked_dim = arrays.iter().fold(0, |acc, a| acc + a.len_of(axis));
res_dim.set_axis(axis, stacked_dim);
let new_len = dimension::size_of_shape_checked(&res_dim)?;
res_dim.set_axis(axis, 0);
let mut res = unsafe {
Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
};
for array in arrays {
res.append(axis, array.clone())?;
}
debug_assert_eq!(res.len_of(axis), stacked_dim);
Ok(res)
}
#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
pub fn stack_new_axis<A, D>(
axis: Axis,
arrays: &[ArrayView<A, D>],
) -> Result<Array<A, D::Larger>, ShapeError>
where
A: Clone,
D: Dimension,
D::Larger: RemoveAxis,
{
if arrays.is_empty() {
return Err(from_kind(ErrorKind::Unsupported));
}
let common_dim = arrays[0].raw_dim();
if axis.index() > common_dim.ndim() {
return Err(from_kind(ErrorKind::OutOfBounds));
}
let mut res_dim = common_dim.insert_axis(axis);
if arrays.iter().any(|a| a.raw_dim() != common_dim) {
return Err(from_kind(ErrorKind::IncompatibleShape));
}
res_dim.set_axis(axis, arrays.len());
let new_len = dimension::size_of_shape_checked(&res_dim)?;
res_dim.set_axis(axis, 0);
let mut res = unsafe {
Array::from_shape_vec_unchecked(res_dim, Vec::with_capacity(new_len))
};
for array in arrays {
res.append(axis, array.clone().insert_axis(axis))?;
}
debug_assert_eq!(res.len_of(axis), arrays.len());
Ok(res)
}
#[macro_export]
macro_rules! stack {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}
#[macro_export]
macro_rules! concatenate {
($axis:expr, $( $array:expr ),+ ) => {
$crate::concatenate($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}
#[macro_export]
#[deprecated(note="Use under the name stack instead.", since="0.15.0")]
macro_rules! stack_new_axis {
($axis:expr, $( $array:expr ),+ ) => {
$crate::stack_new_axis($axis, &[ $($crate::ArrayView::from(&$array) ),* ]).unwrap()
}
}