pub mod extended;
use crate::array::owned::Array;
use crate::dimension::{Dimension, Ix1, IxDyn};
use crate::dtype::Element;
use crate::error::{FerrayError, FerrayResult};
pub fn reshape<T: Element, D: Dimension>(
a: &Array<T, D>,
new_shape: &[usize],
) -> FerrayResult<Array<T, IxDyn>> {
let old_size = a.size();
let new_size: usize = new_shape.iter().product();
if old_size != new_size {
return Err(FerrayError::shape_mismatch(format!(
"cannot reshape array of size {old_size} into shape {new_shape:?} (size {new_size})",
)));
}
let view = a.inner.view().into_dyn();
let reshaped = view
.to_shape(ndarray::IxDyn(new_shape))
.map_err(|e| FerrayError::shape_mismatch(e.to_string()))?;
Ok(Array::from_ndarray(
reshaped.as_standard_layout().into_owned(),
))
}
pub fn ravel<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
let n = a.size();
let view = a.inner.view().into_dyn();
let reshaped = view
.to_shape(ndarray::IxDyn(&[n]))
.expect("1-D reshape always succeeds for a size-preserving target");
let standard = reshaped.as_standard_layout().into_owned();
let one_d = standard
.into_dimensionality::<ndarray::Ix1>()
.expect("reshape result has ndim == 1 by construction");
Ok(Array::from_ndarray(one_d))
}
pub fn flatten<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, Ix1>> {
ravel(a)
}
pub fn squeeze<T: Element, D: Dimension>(
a: &Array<T, D>,
axis: Option<usize>,
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
if let Some(ax) = axis {
if ax >= shape.len() {
return Err(FerrayError::axis_out_of_bounds(ax, shape.len()));
}
if shape[ax] != 1 {
return Err(FerrayError::invalid_value(format!(
"cannot select axis {} with size {} for squeeze (must be 1)",
ax, shape[ax],
)));
}
let new_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter(|&(i, _)| i != ax)
.map(|(_, &s)| s)
.collect();
let data: Vec<T> = a.iter().cloned().collect();
Array::from_vec(IxDyn::new(&new_shape), data)
} else {
let new_shape: Vec<usize> = shape.iter().copied().filter(|&s| s != 1).collect();
let new_shape = if new_shape.is_empty() && !shape.is_empty() {
vec![1]
} else if new_shape.is_empty() {
vec![]
} else {
new_shape
};
let data: Vec<T> = a.iter().cloned().collect();
Array::from_vec(IxDyn::new(&new_shape), data)
}
}
pub fn expand_dims<T: Element, D: Dimension>(
a: &Array<T, D>,
axis: usize,
) -> FerrayResult<Array<T, IxDyn>> {
let ndim = a.ndim();
if axis > ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
}
let mut new_shape: Vec<usize> = a.shape().to_vec();
new_shape.insert(axis, 1);
let data: Vec<T> = a.iter().cloned().collect();
Array::from_vec(IxDyn::new(&new_shape), data)
}
pub fn broadcast_to<T: Element, D: Dimension>(
a: &Array<T, D>,
new_shape: &[usize],
) -> FerrayResult<Array<T, IxDyn>> {
let src_shape = a.shape();
let dyn_view = a.inner.view().into_dyn();
let broadcast_view = dyn_view
.broadcast(ndarray::IxDyn(new_shape))
.ok_or_else(|| FerrayError::BroadcastFailure {
shape_a: src_shape.to_vec(),
shape_b: new_shape.to_vec(),
})?;
Ok(Array::from_ndarray(
broadcast_view.as_standard_layout().into_owned(),
))
}
pub fn concatenate<T: Element>(
arrays: &[Array<T, IxDyn>],
axis: usize,
) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value(
"concatenate: need at least one array",
));
}
let ndim = arrays[0].ndim();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let base_shape = arrays[0].shape();
let mut total_along_axis = 0usize;
for arr in arrays {
if arr.ndim() != ndim {
return Err(FerrayError::shape_mismatch(format!(
"all arrays must have same ndim; got {} and {}",
ndim,
arr.ndim(),
)));
}
for (i, (&s, &base)) in arr.shape().iter().zip(base_shape.iter()).enumerate() {
if i != axis && s != base {
return Err(FerrayError::shape_mismatch(format!(
"shape mismatch on axis {i}: {s} vs {base}",
)));
}
}
total_along_axis += arr.shape()[axis];
}
let mut new_shape = base_shape.to_vec();
new_shape[axis] = total_along_axis;
let total: usize = new_shape.iter().product();
let mut data = Vec::with_capacity(total);
let src_vecs: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
let mut out_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
out_strides[i] = out_strides[i + 1] * new_shape[i + 1];
}
for flat_idx in 0..total {
let mut rem = flat_idx;
let mut nd_idx = vec![0usize; ndim];
for i in 0..ndim {
nd_idx[i] = rem / out_strides[i];
rem %= out_strides[i];
}
let concat_idx = nd_idx[axis];
let mut offset = 0;
let mut src_arr_idx = 0;
for (k, arr) in arrays.iter().enumerate() {
let len_along = arr.shape()[axis];
if concat_idx < offset + len_along {
src_arr_idx = k;
break;
}
offset += len_along;
}
let local_concat_idx = concat_idx - offset;
let src_shape = arrays[src_arr_idx].shape();
let mut src_flat = 0usize;
let mut src_mul = 1usize;
for i in (0..ndim).rev() {
let idx = if i == axis {
local_concat_idx
} else {
nd_idx[i]
};
src_flat += idx * src_mul;
src_mul *= src_shape[i];
}
let elem = src_vecs[src_arr_idx].get(src_flat).ok_or_else(|| {
FerrayError::invalid_value(format!(
"concatenate: internal index {} out of range for source array of length {}",
src_flat,
src_vecs[src_arr_idx].len(),
))
})?;
data.push(elem.clone());
}
Array::from_vec(IxDyn::new(&new_shape), data)
}
pub fn stack<T: Element>(arrays: &[Array<T, IxDyn>], axis: usize) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value("stack: need at least one array"));
}
let base_shape = arrays[0].shape();
let ndim = base_shape.len();
if axis > ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim + 1));
}
for arr in &arrays[1..] {
if arr.shape() != base_shape {
return Err(FerrayError::shape_mismatch(format!(
"all input arrays must have the same shape; got {:?} and {:?}",
base_shape,
arr.shape(),
)));
}
}
let mut expanded = Vec::with_capacity(arrays.len());
for arr in arrays {
expanded.push(expand_dims(arr, axis)?);
}
concatenate(&expanded, axis)
}
pub fn vstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value(
"vstack: need at least one array",
));
}
let ndim = arrays[0].ndim();
if ndim == 1 {
let mut reshaped = Vec::with_capacity(arrays.len());
for arr in arrays {
let n = arr.shape()[0];
reshaped.push(reshape(arr, &[1, n])?);
}
concatenate(&reshaped, 0)
} else {
concatenate(arrays, 0)
}
}
pub fn hstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value(
"hstack: need at least one array",
));
}
let ndim = arrays[0].ndim();
if ndim == 1 {
concatenate(arrays, 0)
} else {
concatenate(arrays, 1)
}
}
pub fn dstack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value(
"dstack: need at least one array",
));
}
let mut expanded = Vec::with_capacity(arrays.len());
for arr in arrays {
let shape = arr.shape();
match shape.len() {
1 => {
let n = shape[0];
expanded.push(reshape(arr, &[1, n, 1])?);
}
2 => {
let (m, n) = (shape[0], shape[1]);
expanded.push(reshape(arr, &[m, n, 1])?);
}
_ => {
let data: Vec<T> = arr.iter().cloned().collect();
expanded.push(Array::from_vec(IxDyn::new(shape), data)?);
}
}
}
concatenate(&expanded, 2)
}
pub fn column_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value(
"column_stack: need at least one array",
));
}
let first_ndim = arrays[0].ndim();
if first_ndim == 1 {
let n = arrays[0].shape()[0];
let mut reshaped = Vec::with_capacity(arrays.len());
for arr in arrays {
if arr.ndim() != 1 {
return Err(FerrayError::shape_mismatch(
"column_stack: all inputs must have the same ndim",
));
}
if arr.shape()[0] != n {
return Err(FerrayError::shape_mismatch(format!(
"column_stack: 1-D inputs must have the same length; got {} and {}",
n,
arr.shape()[0],
)));
}
reshaped.push(reshape(arr, &[n, 1])?);
}
concatenate(&reshaped, 1)
} else {
hstack(arrays)
}
}
pub fn row_stack<T: Element>(arrays: &[Array<T, IxDyn>]) -> FerrayResult<Array<T, IxDyn>> {
vstack(arrays)
}
pub fn block<T: Element>(blocks: &[Vec<Array<T, IxDyn>>]) -> FerrayResult<Array<T, IxDyn>> {
if blocks.is_empty() {
return Err(FerrayError::invalid_value("block: empty input"));
}
let mut rows = Vec::with_capacity(blocks.len());
for row in blocks {
if row.is_empty() {
return Err(FerrayError::invalid_value("block: empty row"));
}
let row_arr = if row.len() == 1 {
let data: Vec<T> = row[0].iter().cloned().collect();
Array::from_vec(IxDyn::new(row[0].shape()), data)?
} else {
hstack(row)?
};
rows.push(row_arr);
}
if rows.len() == 1 {
Ok(rows.pop().unwrap_or_else(|| unreachable!()))
} else {
vstack(&rows)
}
}
pub fn split<T: Element>(
a: &Array<T, IxDyn>,
n_sections: usize,
axis: usize,
) -> FerrayResult<Vec<Array<T, IxDyn>>> {
let shape = a.shape();
if axis >= shape.len() {
return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
}
let axis_len = shape[axis];
if n_sections == 0 {
return Err(FerrayError::invalid_value("split: n_sections must be > 0"));
}
if axis_len % n_sections != 0 {
return Err(FerrayError::invalid_value(format!(
"array of size {axis_len} along axis {axis} cannot be evenly split into {n_sections} sections",
)));
}
let chunk_size = axis_len / n_sections;
let indices: Vec<usize> = (1..n_sections).map(|i| i * chunk_size).collect();
array_split(a, &indices, axis)
}
pub fn array_split<T: Element>(
a: &Array<T, IxDyn>,
indices: &[usize],
axis: usize,
) -> FerrayResult<Vec<Array<T, IxDyn>>> {
let shape = a.shape();
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let axis_len = shape[axis];
let src_data: Vec<T> = a.iter().cloned().collect();
let mut splits = Vec::with_capacity(indices.len() + 2);
splits.push(0);
for &idx in indices {
splits.push(idx.min(axis_len));
}
splits.push(axis_len);
let mut src_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
src_strides[i] = src_strides[i + 1] * shape[i + 1];
}
let mut result = Vec::with_capacity(splits.len() - 1);
for w in splits.windows(2) {
let start = w[0];
let end = w[1];
let chunk_len = end - start;
let mut sub_shape = shape.to_vec();
sub_shape[axis] = chunk_len;
let sub_total: usize = sub_shape.iter().product();
let mut sub_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
sub_strides[i] = sub_strides[i + 1] * sub_shape[i + 1];
}
let mut sub_data = Vec::with_capacity(sub_total);
for flat in 0..sub_total {
let mut rem = flat;
let mut src_flat = 0usize;
for i in 0..ndim {
let idx = rem / sub_strides[i];
rem %= sub_strides[i];
let src_idx = if i == axis { idx + start } else { idx };
src_flat += src_idx * src_strides[i];
}
sub_data.push(src_data[src_flat].clone());
}
result.push(Array::from_vec(IxDyn::new(&sub_shape), sub_data)?);
}
Ok(result)
}
pub fn array_split_n<T: Element>(
a: &Array<T, IxDyn>,
n: usize,
axis: usize,
) -> FerrayResult<Vec<Array<T, IxDyn>>> {
if n == 0 {
return Err(FerrayError::invalid_value("array_split_n: n must be > 0"));
}
let shape = a.shape();
if axis >= shape.len() {
return Err(FerrayError::axis_out_of_bounds(axis, shape.len()));
}
let axis_len = shape[axis];
let base = axis_len / n;
let extra = axis_len % n;
let mut indices = Vec::with_capacity(n.saturating_sub(1));
let mut cum = 0usize;
for i in 0..n - 1 {
cum += if i < extra { base + 1 } else { base };
indices.push(cum);
}
array_split(a, &indices, axis)
}
pub fn vsplit<T: Element>(
a: &Array<T, IxDyn>,
n_sections: usize,
) -> FerrayResult<Vec<Array<T, IxDyn>>> {
split(a, n_sections, 0)
}
pub fn hsplit<T: Element>(
a: &Array<T, IxDyn>,
n_sections: usize,
) -> FerrayResult<Vec<Array<T, IxDyn>>> {
split(a, n_sections, 1)
}
pub fn dsplit<T: Element>(
a: &Array<T, IxDyn>,
n_sections: usize,
) -> FerrayResult<Vec<Array<T, IxDyn>>> {
split(a, n_sections, 2)
}
pub fn transpose<T: Element, D: Dimension>(
a: &Array<T, D>,
axes: Option<&[usize]>,
) -> FerrayResult<Array<T, IxDyn>> {
let ndim = a.ndim();
let perm: Vec<usize> = match axes {
Some(ax) => {
if ax.len() != ndim {
return Err(FerrayError::invalid_value(format!(
"axes must have length {} but got {}",
ndim,
ax.len(),
)));
}
let mut seen = vec![false; ndim];
for &a in ax {
if a >= ndim {
return Err(FerrayError::axis_out_of_bounds(a, ndim));
}
if seen[a] {
return Err(FerrayError::invalid_value(format!(
"duplicate axis {a} in transpose",
)));
}
seen[a] = true;
}
ax.to_vec()
}
None => (0..ndim).rev().collect(),
};
let permuted = a
.inner
.view()
.into_dyn()
.permuted_axes(ndarray::IxDyn(&perm));
Ok(Array::from_ndarray(
permuted.as_standard_layout().into_owned(),
))
}
pub fn swapaxes<T: Element, D: Dimension>(
a: &Array<T, D>,
axis1: usize,
axis2: usize,
) -> FerrayResult<Array<T, IxDyn>> {
let ndim = a.ndim();
if axis1 >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis1, ndim));
}
if axis2 >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis2, ndim));
}
let mut perm: Vec<usize> = (0..ndim).collect();
perm.swap(axis1, axis2);
transpose(a, Some(&perm))
}
pub fn moveaxis<T: Element, D: Dimension>(
a: &Array<T, D>,
source: usize,
destination: usize,
) -> FerrayResult<Array<T, IxDyn>> {
let ndim = a.ndim();
if source >= ndim {
return Err(FerrayError::axis_out_of_bounds(source, ndim));
}
if destination >= ndim {
return Err(FerrayError::axis_out_of_bounds(destination, ndim));
}
let mut order: Vec<usize> = (0..ndim).filter(|&x| x != source).collect();
order.insert(destination, source);
transpose(a, Some(&order))
}
pub fn rollaxis<T: Element, D: Dimension>(
a: &Array<T, D>,
axis: usize,
start: usize,
) -> FerrayResult<Array<T, IxDyn>> {
let ndim = a.ndim();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
if start > ndim {
return Err(FerrayError::axis_out_of_bounds(start, ndim + 1));
}
let dst = if start > axis { start - 1 } else { start };
if axis == dst {
let data: Vec<T> = a.iter().cloned().collect();
return Array::from_vec(IxDyn::new(a.shape()), data);
}
moveaxis(a, axis, dst)
}
pub fn flip<T: Element, D: Dimension>(
a: &Array<T, D>,
axis: usize,
) -> FerrayResult<Array<T, IxDyn>> {
let shape = a.shape();
let ndim = shape.len();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let src_data: Vec<T> = a.iter().cloned().collect();
let total = src_data.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let mut data = Vec::with_capacity(total);
for flat in 0..total {
let mut rem = flat;
let mut src_flat = 0usize;
for i in 0..ndim {
let idx = rem / strides[i];
rem %= strides[i];
let src_idx = if i == axis { shape[i] - 1 - idx } else { idx };
src_flat += src_idx * strides[i];
}
data.push(src_data[src_flat].clone());
}
Array::from_vec(IxDyn::new(shape), data)
}
pub fn fliplr<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
if a.ndim() < 2 {
return Err(FerrayError::invalid_value(
"fliplr: array must be at least 2-D",
));
}
flip(a, 1)
}
pub fn flipud<T: Element, D: Dimension>(a: &Array<T, D>) -> FerrayResult<Array<T, IxDyn>> {
if a.ndim() < 1 {
return Err(FerrayError::invalid_value(
"flipud: array must be at least 1-D",
));
}
flip(a, 0)
}
pub fn rot90<T: Element, D: Dimension>(a: &Array<T, D>, k: i32) -> FerrayResult<Array<T, IxDyn>> {
if a.ndim() < 2 {
return Err(FerrayError::invalid_value(
"rot90: array must be at least 2-D",
));
}
let k = k.rem_euclid(4);
let shape = a.shape();
let data: Vec<T> = a.iter().cloned().collect();
let as_dyn = Array::from_vec(IxDyn::new(shape), data)?;
match k {
0 => Ok(as_dyn),
1 => {
let flipped = flip(&as_dyn, 1)?;
swapaxes(&flipped, 0, 1)
}
2 => {
let f1 = flip(&as_dyn, 0)?;
flip(&f1, 1)
}
3 => {
let transposed = swapaxes(&as_dyn, 0, 1)?;
flip(&transposed, 1)
}
_ => unreachable!(),
}
}
pub fn roll<T: Element, D: Dimension>(
a: &Array<T, D>,
shift: isize,
axis: Option<usize>,
) -> FerrayResult<Array<T, IxDyn>> {
match axis {
None => {
let data: Vec<T> = a.iter().cloned().collect();
let n = data.len();
if n == 0 {
return Array::from_vec(IxDyn::new(a.shape()), data);
}
let shift = ((shift % n as isize) + n as isize) as usize % n;
let mut rolled = Vec::with_capacity(n);
for i in 0..n {
rolled.push(data[(n + i - shift) % n].clone());
}
Array::from_vec(IxDyn::new(a.shape()), rolled)
}
Some(ax) => {
let shape = a.shape();
let ndim = shape.len();
if ax >= ndim {
return Err(FerrayError::axis_out_of_bounds(ax, ndim));
}
let axis_len = shape[ax];
if axis_len == 0 {
let data: Vec<T> = a.iter().cloned().collect();
return Array::from_vec(IxDyn::new(shape), data);
}
let shift = ((shift % axis_len as isize) + axis_len as isize) as usize % axis_len;
let src_data: Vec<T> = a.iter().cloned().collect();
let total = src_data.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let mut data = Vec::with_capacity(total);
for flat in 0..total {
let mut rem = flat;
let mut src_flat = 0usize;
#[allow(clippy::needless_range_loop)]
for i in 0..ndim {
let idx = rem / strides[i];
rem %= strides[i];
let src_idx = if i == ax {
(axis_len + idx - shift) % axis_len
} else {
idx
};
src_flat += src_idx * strides[i];
}
data.push(src_data[src_flat].clone());
}
Array::from_vec(IxDyn::new(shape), data)
}
}
}
pub fn r_<T: Element>(arrays: &[Array<T, Ix1>]) -> FerrayResult<Array<T, Ix1>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value("r_: need at least one array"));
}
let total: usize = arrays.iter().map(|a| a.shape()[0]).sum();
let mut data = Vec::with_capacity(total);
for a in arrays {
data.extend(a.iter().cloned());
}
Array::from_vec(Ix1::new([total]), data)
}
pub fn c_<T: Element>(arrays: &[Array<T, Ix1>]) -> FerrayResult<Array<T, IxDyn>> {
if arrays.is_empty() {
return Err(FerrayError::invalid_value("c_: need at least one array"));
}
let n = arrays[0].shape()[0];
for a in &arrays[1..] {
if a.shape()[0] != n {
return Err(FerrayError::shape_mismatch(format!(
"c_: all 1-D inputs must have the same length; got {} and {}",
n,
a.shape()[0],
)));
}
}
let k = arrays.len();
let cols: Vec<Vec<T>> = arrays.iter().map(|a| a.iter().cloned().collect()).collect();
let mut data = Vec::with_capacity(n * k);
for i in 0..n {
for col in &cols {
data.push(col[i].clone());
}
}
Array::from_vec(IxDyn::new(&[n, k]), data)
}
#[cfg(test)]
mod tests {
use super::*;
fn dyn_arr(shape: &[usize], data: Vec<f64>) -> Array<f64, IxDyn> {
Array::from_vec(IxDyn::new(shape), data).unwrap()
}
#[test]
fn test_reshape() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = reshape(&a, &[3, 2]).unwrap();
assert_eq!(b.shape(), &[3, 2]);
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_reshape_size_mismatch() {
let a = dyn_arr(&[2, 3], vec![1.0; 6]);
assert!(reshape(&a, &[2, 4]).is_err());
}
#[test]
fn test_ravel() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = ravel(&a).unwrap();
assert_eq!(b.shape(), &[6]);
assert_eq!(b.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_flatten() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = flatten(&a).unwrap();
assert_eq!(b.shape(), &[6]);
}
#[test]
fn test_squeeze() {
let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
let b = squeeze(&a, None).unwrap();
assert_eq!(b.shape(), &[3]);
}
#[test]
fn test_squeeze_specific_axis() {
let a = dyn_arr(&[1, 3, 1], vec![1.0, 2.0, 3.0]);
let b = squeeze(&a, Some(0)).unwrap();
assert_eq!(b.shape(), &[3, 1]);
}
#[test]
fn test_squeeze_not_size_1() {
let a = dyn_arr(&[2, 3], vec![1.0; 6]);
assert!(squeeze(&a, Some(0)).is_err());
}
#[test]
fn test_expand_dims() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = expand_dims(&a, 0).unwrap();
assert_eq!(b.shape(), &[1, 3]);
let c = expand_dims(&a, 1).unwrap();
assert_eq!(c.shape(), &[3, 1]);
}
#[test]
fn test_expand_dims_oob() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
assert!(expand_dims(&a, 3).is_err());
}
#[test]
fn test_broadcast_to() {
let a = dyn_arr(&[1, 3], vec![1.0, 2.0, 3.0]);
let b = broadcast_to(&a, &[3, 3]).unwrap();
assert_eq!(b.shape(), &[3, 3]);
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_broadcast_to_1d_to_2d() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = broadcast_to(&a, &[2, 3]).unwrap();
assert_eq!(b.shape(), &[2, 3]);
}
#[test]
fn test_broadcast_to_incompatible() {
let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
assert!(broadcast_to(&a, &[3]).is_err());
}
#[test]
fn test_broadcast_to_from_non_contiguous_source() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let t = transpose(&a, None).unwrap();
let b = broadcast_to(&t, &[2, 3, 2]).unwrap();
assert_eq!(b.shape(), &[2, 3, 2]);
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(&data[..6], &data[6..12]);
}
#[test]
fn test_concatenate() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = dyn_arr(&[2, 3], vec![7.0, 8.0, 9.0, 10.0, 11.0, 12.0]);
let c = concatenate(&[a, b], 0).unwrap();
assert_eq!(c.shape(), &[4, 3]);
}
#[test]
fn test_concatenate_axis1() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
let c = concatenate(&[a, b], 1).unwrap();
assert_eq!(c.shape(), &[2, 5]);
}
#[test]
fn test_concatenate_shape_mismatch() {
let a = dyn_arr(&[2, 3], vec![1.0; 6]);
let b = dyn_arr(&[3, 3], vec![1.0; 9]);
let c = concatenate(&[a, b], 0).unwrap();
assert_eq!(c.shape(), &[5, 3]);
}
#[test]
fn test_concatenate_empty() {
let v: Vec<Array<f64, IxDyn>> = vec![];
assert!(concatenate(&v, 0).is_err());
}
#[test]
fn test_stack() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
let c = stack(&[a, b], 0).unwrap();
assert_eq!(c.shape(), &[2, 3]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_stack_axis1() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
let c = stack(&[a, b], 1).unwrap();
assert_eq!(c.shape(), &[3, 2]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_vstack() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
let c = vstack(&[a, b]).unwrap();
assert_eq!(c.shape(), &[2, 3]);
}
#[test]
fn test_hstack() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
let c = hstack(&[a, b]).unwrap();
assert_eq!(c.shape(), &[6]);
}
#[test]
fn test_hstack_2d() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = dyn_arr(&[2, 3], vec![5.0, 6.0, 7.0, 8.0, 9.0, 10.0]);
let c = hstack(&[a, b]).unwrap();
assert_eq!(c.shape(), &[2, 5]);
}
#[test]
fn test_dstack() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
let c = dstack(&[a, b]).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
}
#[test]
fn test_block() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
let c = dyn_arr(&[2, 2], vec![9.0, 10.0, 11.0, 12.0]);
let d = dyn_arr(&[2, 2], vec![13.0, 14.0, 15.0, 16.0]);
let result = block(&[vec![a, b], vec![c, d]]).unwrap();
assert_eq!(result.shape(), &[4, 4]);
}
#[test]
fn test_split() {
let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let parts = split(&a, 3, 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape(), &[2]);
assert_eq!(parts[1].shape(), &[2]);
assert_eq!(parts[2].shape(), &[2]);
}
#[test]
fn test_split_uneven() {
let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert!(split(&a, 3, 0).is_err()); }
#[test]
fn test_array_split() {
let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let parts = array_split(&a, &[2, 4], 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape(), &[2]); assert_eq!(parts[1].shape(), &[2]); assert_eq!(parts[2].shape(), &[1]); }
#[test]
fn test_vsplit() {
let a = dyn_arr(&[4, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let parts = vsplit(&a, 2).unwrap();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape(), &[2, 2]);
}
#[test]
fn test_hsplit() {
let a = dyn_arr(&[2, 4], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let parts = hsplit(&a, 2).unwrap();
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape(), &[2, 2]);
}
#[test]
fn test_transpose_2d() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = transpose(&a, None).unwrap();
assert_eq!(b.shape(), &[3, 2]);
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_transpose_explicit() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = transpose(&a, Some(&[1, 0])).unwrap();
assert_eq!(b.shape(), &[3, 2]);
}
#[test]
fn test_transpose_bad_axes() {
let a = dyn_arr(&[2, 3], vec![1.0; 6]);
assert!(transpose(&a, Some(&[0])).is_err()); }
#[test]
fn test_swapaxes() {
let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
let b = swapaxes(&a, 0, 2).unwrap();
assert_eq!(b.shape(), &[4, 3, 2]);
}
#[test]
fn test_moveaxis() {
let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
let b = moveaxis(&a, 0, 2).unwrap();
assert_eq!(b.shape(), &[3, 4, 2]);
}
#[test]
fn test_rollaxis() {
let a = dyn_arr(&[2, 3, 4], vec![0.0; 24]);
let b = rollaxis(&a, 2, 0).unwrap();
assert_eq!(b.shape(), &[4, 2, 3]);
}
#[test]
fn test_flip() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = flip(&a, 0).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![3.0, 2.0, 1.0]);
}
#[test]
fn test_flip_2d() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = flip(&a, 0).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
let c = flip(&a, 1).unwrap();
let data2: Vec<f64> = c.iter().copied().collect();
assert_eq!(data2, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
}
#[test]
fn test_fliplr() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = fliplr(&a).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![3.0, 2.0, 1.0, 6.0, 5.0, 4.0]);
}
#[test]
fn test_flipud() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = flipud(&a).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![4.0, 5.0, 6.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_fliplr_1d_err() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
assert!(fliplr(&a).is_err());
}
#[test]
fn test_rot90_once() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = rot90(&a, 1).unwrap();
assert_eq!(b.shape(), &[2, 2]);
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![2.0, 4.0, 1.0, 3.0]);
}
#[test]
fn test_rot90_twice() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = rot90(&a, 2).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn test_rot90_four_is_identity() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = rot90(&a, 4).unwrap();
let data_a: Vec<f64> = a.iter().copied().collect();
let data_b: Vec<f64> = b.iter().copied().collect();
assert_eq!(data_a, data_b);
assert_eq!(a.shape(), b.shape());
}
#[test]
fn test_roll_flat() {
let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let b = roll(&a, 2, None).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![4.0, 5.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_roll_negative() {
let a = dyn_arr(&[5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let b = roll(&a, -2, None).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![3.0, 4.0, 5.0, 1.0, 2.0]);
}
#[test]
fn test_roll_axis() {
let a = dyn_arr(&[2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let b = roll(&a, 1, Some(1)).unwrap();
let data: Vec<f64> = b.iter().copied().collect();
assert_eq!(data, vec![3.0, 1.0, 2.0, 6.0, 4.0, 5.0]);
}
#[test]
fn test_column_stack_1d() {
let a = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
let b = dyn_arr(&[4], vec![10.0, 20.0, 30.0, 40.0]);
let c = dyn_arr(&[4], vec![100.0, 200.0, 300.0, 400.0]);
let result = column_stack(&[a, b, c]).unwrap();
assert_eq!(result.shape(), &[4, 3]);
assert_eq!(
result.iter().copied().collect::<Vec<_>>(),
vec![
1.0, 10.0, 100.0, 2.0, 20.0, 200.0, 3.0, 30.0, 300.0, 4.0, 40.0, 400.0, ]
);
}
#[test]
fn test_column_stack_2d_same_as_hstack() {
let a = dyn_arr(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let b = dyn_arr(&[2, 2], vec![5.0, 6.0, 7.0, 8.0]);
let result = column_stack(&[a, b]).unwrap();
assert_eq!(result.shape(), &[2, 4]);
assert_eq!(
result.iter().copied().collect::<Vec<_>>(),
vec![1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 7.0, 8.0]
);
}
#[test]
fn test_column_stack_length_mismatch() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = dyn_arr(&[4], vec![1.0, 2.0, 3.0, 4.0]);
assert!(column_stack(&[a, b]).is_err());
}
#[test]
fn test_row_stack_is_vstack() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let b = dyn_arr(&[3], vec![4.0, 5.0, 6.0]);
let row = row_stack(&[a.clone(), b.clone()]).unwrap();
let v = vstack(&[a, b]).unwrap();
assert_eq!(row.shape(), v.shape());
assert_eq!(
row.iter().copied().collect::<Vec<_>>(),
v.iter().copied().collect::<Vec<_>>()
);
}
#[test]
fn test_array_split_n_uneven() {
let a = dyn_arr(&[7], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
let parts = array_split_n(&a, 3, 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(
parts[0].iter().copied().collect::<Vec<_>>(),
vec![1.0, 2.0, 3.0]
);
assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0]);
assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![6.0, 7.0]);
}
#[test]
fn test_array_split_n_even() {
let a = dyn_arr(&[6], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let parts = array_split_n(&a, 3, 0).unwrap();
assert_eq!(parts.len(), 3);
for (i, expected) in [vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]
.iter()
.enumerate()
{
assert_eq!(&parts[i].iter().copied().collect::<Vec<_>>(), expected);
}
}
#[test]
fn test_array_split_n_more_sections_than_elements() {
let a = dyn_arr(&[3], vec![1.0, 2.0, 3.0]);
let parts = array_split_n(&a, 5, 0).unwrap();
assert_eq!(parts.len(), 5);
assert_eq!(parts[0].iter().copied().collect::<Vec<_>>(), vec![1.0]);
assert_eq!(parts[1].iter().copied().collect::<Vec<_>>(), vec![2.0]);
assert_eq!(parts[2].iter().copied().collect::<Vec<_>>(), vec![3.0]);
assert_eq!(
parts[3].iter().copied().collect::<Vec<_>>(),
Vec::<f64>::new()
);
assert_eq!(
parts[4].iter().copied().collect::<Vec<_>>(),
Vec::<f64>::new()
);
}
#[test]
fn test_to_dyn_from_typed() {
use crate::Array;
use crate::dimension::Ix2;
let typed =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let dy = typed.to_dyn();
assert_eq!(dy.shape(), &[2, 3]);
assert_eq!(
dy.iter().copied().collect::<Vec<_>>(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
}
#[test]
fn test_concatenate_typed_via_to_dyn() {
use crate::Array;
use crate::dimension::Ix2;
let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![5.0, 6.0, 7.0, 8.0]).unwrap();
let result = concatenate(&[a.to_dyn(), b.to_dyn()], 0).unwrap();
assert_eq!(result.shape(), &[4, 2]);
assert_eq!(
result.iter().copied().collect::<Vec<_>>(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
);
}
#[test]
fn test_r_concatenates_1d() {
use crate::Array;
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let b = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![4, 5]).unwrap();
let c = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![6]).unwrap();
let r = r_(&[a, b, c]).unwrap();
assert_eq!(
r.iter().copied().collect::<Vec<_>>(),
vec![1, 2, 3, 4, 5, 6]
);
}
#[test]
fn test_r_empty_input_errs() {
let r = r_::<f64>(&[]);
assert!(r.is_err());
}
#[test]
fn test_c_columns_to_2d() {
use crate::Array;
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
let r = c_(&[a, b]).unwrap();
assert_eq!(r.shape(), &[3, 2]);
assert_eq!(
r.iter().copied().collect::<Vec<_>>(),
vec![1, 10, 2, 20, 3, 30],
);
}
#[test]
fn test_c_length_mismatch_errs() {
use crate::Array;
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let b = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![10, 20]).unwrap();
assert!(c_(&[a, b]).is_err());
}
}