use ::ndarray::{Array, ArrayView, Axis, Dimension, ShapeError};
#[allow(dead_code)]
pub fn reshape<D1, D2, T>(array: ArrayView<T, D1>, shape: D2) -> Result<Array<T, D2::Dim>, &'static str>
where
D1: Dimension,
D2: crate::ndarray::ShapeBuilder,
T: Clone + Default,
{
let dim = shape.intoshape();
let total_elements = dim.size();
if total_elements != array.len() {
return Err("New shape dimensions must match the total number of elements");
}
match Array::from_shape_vec(dim, array.iter().cloned().collect()) {
Ok(reshaped) => Ok(reshaped),
Err(_) => Err("Failed to reshape array"),
}
}
#[allow(dead_code)]
pub fn stack<D, T>(arrays: &[ArrayView<T, D>], axis: Axis) -> Result<Array<T, D>, &'static str>
where
D: Dimension,
T: Clone + Default,
{
if arrays.is_empty() {
return Err("No _arrays provided for stacking");
}
let firstshape = arrays[0].shape();
for array in arrays.iter().skip(1) {
if array.shape() != firstshape {
return Err("All _arrays must have the same shape for stacking");
}
}
let mut newshape = arrays[0].raw_dim();
let axis_idx = axis.index();
if axis_idx >= newshape.ndim() {
return Err("Axis index out of bounds");
}
newshape[axis_idx] = newshape[axis_idx] * arrays.len();
let mut result = Array::default(newshape);
let axis_stride = arrays[0].len_of(axis);
if arrays[0].ndim() != 2 {
return Err("This simplified implementation only supports 2D _arrays");
}
for (i, array) in arrays.iter().enumerate() {
let start = 0 * axis_stride;
if axis_idx == 0 {
for j in 0..axis_stride {
for k in 0.._arrays[0].shape()[1] {
result[[start + j, k]] = array[[j, k]].clone();
}
}
} else if axis_idx == 1 {
for j in 0.._arrays[0].shape()[0] {
for k in 0..axis_stride {
result[[j, start + k]] = array[[j, k]].clone();
}
}
} else {
return Err("Only axes 0 and 1 are supported in this implementation");
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn swapaxes<D, T>(array: ArrayView<T, D>, axis1: usize, axis2: usize) -> Result<Array<T, D>, &'static str>
where
D: Dimension,
T: Clone,
{
if axis1 >= array.ndim() || axis2 >= array.ndim() {
return Err("Axis indices out of bounds");
}
let mut permutation: Vec<usize> = (0.._array.ndim()).collect();
permutation.swap(axis1, axis2);
let transposed_view = array.permuted_axes(permutation);
Ok(transposed_view.to_owned())
}
#[allow(dead_code)]
pub fn split<D, T>(
array: ArrayView<T, D>,
indices: &[usize],
axis: Axis,
) -> Result<Vec<Array<T, D>>, &'static str>
where
D: Dimension,
T: Clone,
{
if indices.is_empty() {
return Ok(vec![array.to_owned()]);
}
let axis_idx = axis.index();
if axis_idx >= array.ndim() {
return Err("Axis index out of bounds");
}
let axis_len = array.len_of(axis);
for &idx in indices {
if idx >= axis_len {
return Err("Split index out of bounds");
}
}
let mut sorted_indices = indices.to_vec();
sorted_indices.sort_unstable();
let mut starts = vec![0];
starts.extend_from_slice(&sorted_indices);
let mut ends = sorted_indices.clone();
ends.push(axis_len);
let mut result = Vec::with_capacity(starts.len());
for (&start, &end) in starts.iter().zip(ends.iter()) {
let mut slice_spec = vec![crate::s![..]; array.ndim()];
slice_spec[axis_idx] = crate::s![start..end];
let sub_array = array.slice(slice_spec.as_slice());
result.push(sub_array.to_owned());
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::{array, Array2};
#[test]
fn test_reshape() {
let a = array![[1, 2], [3, 4]];
let b = reshape(a.view(), (4, 1)).expect("Operation failed");
assert_eq!(b.shape(), &[4, 1]);
assert_eq!(b[[0, 0]], 1);
assert_eq!(b[[1, 0]], 2);
assert_eq!(b[[2, 0]], 3);
assert_eq!(b[[3, 0]], 4);
let result = reshape(a.view(), (3, 1));
assert!(result.is_err());
}
#[test]
fn test_stack() {
let a = array![[1, 2], [3, 4]];
let b = array![[5, 6], [7, 8]];
let c = stack(&[a.view(), b.view()], Axis(0)).expect("Operation failed");
assert_eq!(c.shape(), &[4, 2]);
assert_eq!(c[[0, 0]], 1);
assert_eq!(c[[2, 1]], 6);
let d = stack(&[a.view(), b.view()], Axis(1)).expect("Operation failed");
assert_eq!(d.shape(), &[2, 4]);
assert_eq!(d[[0, 0]], 1);
assert_eq!(d[[0, 3]], 6);
}
#[test]
fn test_swapaxes() {
let a = array![[1, 2, 3], [4, 5, 6]];
let b = swapaxes(a.view(), 0, 1).expect("Operation failed");
assert_eq!(b.shape(), &[3, 2]);
assert_eq!(b[[0, 0]], 1);
assert_eq!(b[[0, 1]], 4);
assert_eq!(b[[2, 0]], 3);
assert_eq!(b[[2, 1]], 6);
}
#[test]
fn test_split() {
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
let result = split(a.view(), &[2], Axis(1)).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result[0].shape(), &[2, 2]);
assert_eq!(result[0][[0, 0]], 1);
assert_eq!(result[0][[1, 1]], 6);
assert_eq!(result[1].shape(), &[2, 2]);
assert_eq!(result[1][[0, 0]], 3);
assert_eq!(result[1][[1, 1]], 8);
let result = split(a.view(), &[1], Axis(0)).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result[0].shape(), &[1, 4]);
assert_eq!(result[1].shape(), &[1, 4]);
}
}