pub use ::ndarray::{
s, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, Data, DataMut, Dim, Dimension, Ix1, Ix2,
Ix3, Ix4, Ix5, Ix6, IxDyn, OwnedRepr, RemoveAxis, ScalarOperand, ShapeBuilder, ShapeError,
SliceInfo, ViewRepr, Zip,
};
pub use ::ndarray::{arr1, arr2, array};
pub use ::ndarray::{ArcArray1, ArcArray2};
pub use ::ndarray::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD};
pub use ::ndarray::{
ArrayView0, ArrayView1, ArrayView2, ArrayView3, ArrayView4, ArrayView5, ArrayView6, ArrayViewD,
};
pub use ::ndarray::{
ArrayViewMut0, ArrayViewMut1, ArrayViewMut2, ArrayViewMut3, ArrayViewMut4, ArrayViewMut5,
ArrayViewMut6, ArrayViewMutD,
};
pub mod indexing;
pub mod stats;
pub mod matrix;
pub mod manipulation;
pub mod reduction;
pub mod preprocessing;
pub mod elementwise;
#[cfg(feature = "random")]
pub mod random;
#[allow(dead_code)]
pub fn reshape_2d<T>(
array: ArrayView<T, Ix2>,
shape: (usize, usize),
) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Default,
{
let (rows, cols) = shape;
let total_elements = rows * cols;
if total_elements != array.len() {
return Err("New shape dimensions must match the total number of elements");
}
let mut result = Array::<T, Ix2>::default(shape);
let flat_iter = array.iter();
for (i, val) in flat_iter.enumerate() {
let r = i / cols;
let c = i % cols;
result[[r, c]] = val.clone();
}
Ok(result)
}
#[allow(dead_code)]
pub fn stack_2d<T>(arrays: &[ArrayView<T, Ix2>], axis: usize) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Default,
{
if arrays.is_empty() {
return Err("No _arrays provided for stacking");
}
let firstshape = arrays[0].shape();
for array in arrays.iter().skip(1) {
if array.shape() != firstshape {
return Err("All _arrays must have the same shape for stacking");
}
}
let (rows, cols) = (firstshape[0], firstshape[1]);
let (new_rows, new_cols) = match axis {
0 => (rows * arrays.len(), cols), 1 => (rows, cols * arrays.len()), _ => return Err("Axis must be 0 or 1 for 2D _arrays"),
};
let mut result = Array::<T, Ix2>::default((new_rows, new_cols));
match axis {
0 => {
for (array_idx, array) in arrays.iter().enumerate() {
let start_row = array_idx * rows;
for r in 0..rows {
for c in 0..cols {
result[[start_row + r, c]] = array[[r, c]].clone();
}
}
}
}
1 => {
for (array_idx, array) in arrays.iter().enumerate() {
let start_col = array_idx * cols;
for r in 0..rows {
for c in 0..cols {
result[[r, start_col + c]] = array[[r, c]].clone();
}
}
}
}
_ => unreachable!(),
}
Ok(result)
}
#[allow(dead_code)]
pub fn transpose_2d<T>(array: ArrayView<T, Ix2>) -> Array<T, Ix2>
where
T: Clone,
{
array.t().to_owned()
}
#[allow(dead_code)]
pub fn split_2d<T>(
array: ArrayView<T, Ix2>,
indices: &[usize],
axis: usize,
) -> Result<Vec<Array<T, Ix2>>, &'static str>
where
T: Clone + Default,
{
if indices.is_empty() {
return Ok(vec![array.to_owned()]);
}
let (rows, cols) = (array.shape()[0], array.shape()[1]);
let axis_len = if axis == 0 { rows } else { cols };
for &idx in indices {
if idx >= axis_len {
return Err("Split index out of bounds");
}
}
let mut sorted_indices = indices.to_vec();
sorted_indices.sort_unstable();
let mut starts = vec![0];
starts.extend_from_slice(&sorted_indices);
let mut ends = sorted_indices.clone();
ends.push(axis_len);
let mut result = Vec::with_capacity(starts.len());
match axis {
0 => {
for (&start, &end) in starts.iter().zip(ends.iter()) {
let sub_rows = end - start;
let mut sub_array = Array::<T, Ix2>::default((sub_rows, cols));
for r in 0..sub_rows {
for c in 0..cols {
sub_array[[r, c]] = array[[start + r, c]].clone();
}
}
result.push(sub_array);
}
}
1 => {
for (&start, &end) in starts.iter().zip(ends.iter()) {
let sub_cols = end - start;
let mut sub_array = Array::<T, Ix2>::default((rows, sub_cols));
for r in 0..rows {
for c in 0..sub_cols {
sub_array[[r, c]] = array[[r, start + c]].clone();
}
}
result.push(sub_array);
}
}
_ => return Err("Axis must be 0 or 1 for 2D arrays"),
}
Ok(result)
}
#[allow(dead_code)]
pub fn take_2d<T>(
array: ArrayView<T, Ix2>,
indices: ArrayView<usize, Ix1>,
axis: usize,
) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Default,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
let axis_len = if axis == 0 { rows } else { cols };
for &idx in indices.iter() {
if idx >= axis_len {
return Err("Index out of bounds");
}
}
let (result_rows, result_cols) = match axis {
0 => (indices.len(), cols),
1 => (rows, indices.len()),
_ => return Err("Axis must be 0 or 1 for 2D arrays"),
};
let mut result = Array::<T, Ix2>::default((result_rows, result_cols));
match axis {
0 => {
for (i, &idx) in indices.iter().enumerate() {
for j in 0..cols {
result[[i, j]] = array[[idx, j]].clone();
}
}
}
1 => {
for i in 0..rows {
for (j, &idx) in indices.iter().enumerate() {
result[[i, j]] = array[[i, idx]].clone();
}
}
}
_ => unreachable!(),
}
Ok(result)
}
#[allow(dead_code)]
pub fn mask_select<T>(
array: ArrayView<T, Ix2>,
mask: ArrayView<bool, Ix2>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Default,
{
if array.shape() != mask.shape() {
return Err("Mask shape must match array shape");
}
let true_count = mask.iter().filter(|&&x| x).count();
let mut result = Array::<T, Ix1>::default(true_count);
let mut idx = 0;
for (val, &m) in array.iter().zip(mask.iter()) {
if m {
result[idx] = val.clone();
idx += 1;
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn fancy_index_2d<T>(
array: ArrayView<T, Ix2>,
row_indices: ArrayView<usize, Ix1>,
col_indices: ArrayView<usize, Ix1>,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Default,
{
let result_size = row_indices.len();
if col_indices.len() != result_size {
return Err("Row and column index arrays must have the same length");
}
let (rows, cols) = (array.shape()[0], array.shape()[1]);
for &idx in row_indices.iter() {
if idx >= rows {
return Err("Row index out of bounds");
}
}
for &idx in col_indices.iter() {
if idx >= cols {
return Err("Column index out of bounds");
}
}
let mut result = Array::<T, Ix1>::default(result_size);
for i in 0..result_size {
let row = row_indices[i];
let col = col_indices[i];
result[i] = array[[row, col]].clone();
}
Ok(result)
}
#[allow(dead_code)]
pub fn where_condition<T, F>(
array: ArrayView<T, Ix2>,
condition: F,
) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Default,
F: Fn(&T) -> bool,
{
let mask = array.map(condition);
mask_select(array, mask.view())
}
#[allow(dead_code)]
pub fn is_broadcast_compatible(shape1: &[usize], shape2: &[usize]) -> bool {
let max_dim = shape1.len().max(shape2.len());
let get_dim = |shape: &[usize], i: usize| -> usize {
let offset = max_dim - shape.len();
if i < offset {
1 } else {
shape[i - offset]
}
};
for i in 0..max_dim {
let dim1 = get_dim(shape1, i);
let dim2 = get_dim(shape2, i);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return false;
}
}
true
}
#[allow(dead_code)]
pub fn broadcastshape(shape1: &[usize], shape2: &[usize]) -> Option<Vec<usize>> {
if !is_broadcast_compatible(shape1, shape2) {
return None;
}
let max_dim = shape1.len().max(shape2.len());
let mut result = Vec::with_capacity(max_dim);
let get_dim = |shape: &[usize], i: usize| -> usize {
let offset = max_dim - shape.len();
if i < offset {
1 } else {
shape[i - offset]
}
};
for i in 0..max_dim {
let dim1 = get_dim(shape1, i);
let dim2 = get_dim(shape2, i);
result.push(dim1.max(dim2));
}
Some(result)
}
#[allow(dead_code)]
pub fn broadcast_1d_to_2d<T>(
array: ArrayView<T, Ix1>,
repeats: usize,
axis: usize,
) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Default,
{
let len = array.len();
let (rows, cols) = match axis {
0 => (repeats, len), 1 => (len, repeats), _ => return Err("Axis must be 0 or 1"),
};
let mut result = Array::<T, Ix2>::default((rows, cols));
match axis {
0 => {
for i in 0..repeats {
for j in 0..len {
result[[i, j]] = array[j].clone();
}
}
}
1 => {
for i in 0..len {
for j in 0..repeats {
result[[i, j]] = array[i].clone();
}
}
}
_ => unreachable!(),
}
Ok(result)
}
#[allow(dead_code)]
pub fn broadcast_apply<T, R, F>(
a: ArrayView<T, Ix2>,
b: ArrayView<T, Ix1>,
op: F,
) -> Result<Array<R, Ix2>, &'static str>
where
T: Clone + Default,
R: Clone + Default,
F: Fn(&T, &T) -> R,
{
let (a_rows, a_cols) = (a.shape()[0], a.shape()[1]);
let b_len = b.len();
if a_cols != b_len {
return Err("Arrays are not broadcast compatible");
}
let mut result = Array::<R, Ix2>::default((a_rows, a_cols));
for i in 0..a_rows {
for j in 0..a_cols {
result[[i, j]] = op(&a[[i, j]], &b[j]);
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ndarray::array;
#[test]
fn test_reshape_2d() {
let a = array![[1, 2], [3, 4]];
let b = reshape_2d(a.view(), (4, 1)).expect("Operation failed");
assert_eq!(b.shape(), &[4, 1]);
assert_eq!(b[[0, 0]], 1);
assert_eq!(b[[1, 0]], 2);
assert_eq!(b[[2, 0]], 3);
assert_eq!(b[[3, 0]], 4);
let result = reshape_2d(a.view(), (3, 1));
assert!(result.is_err());
}
#[test]
fn test_stack_2d() {
let a = array![[1, 2], [3, 4]];
let b = array![[5, 6], [7, 8]];
let c = stack_2d(&[a.view(), b.view()], 0).expect("Operation failed");
assert_eq!(c.shape(), &[4, 2]);
assert_eq!(c[[0, 0]], 1);
assert_eq!(c[[1, 0]], 3);
assert_eq!(c[[2, 0]], 5);
assert_eq!(c[[3, 0]], 7);
let d = stack_2d(&[a.view(), b.view()], 1).expect("Operation failed");
assert_eq!(d.shape(), &[2, 4]);
assert_eq!(d[[0, 0]], 1);
assert_eq!(d[[0, 1]], 2);
assert_eq!(d[[0, 2]], 5);
assert_eq!(d[[0, 3]], 6);
}
#[test]
fn test_transpose_2d() {
let a = array![[1, 2, 3], [4, 5, 6]];
let b = transpose_2d(a.view());
assert_eq!(b.shape(), &[3, 2]);
assert_eq!(b[[0, 0]], 1);
assert_eq!(b[[0, 1]], 4);
assert_eq!(b[[1, 0]], 2);
assert_eq!(b[[2, 1]], 6);
}
#[test]
fn test_split_2d() {
let a = array![[1, 2, 3, 4], [5, 6, 7, 8]];
let result = split_2d(a.view(), &[2], 1).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result[0].shape(), &[2, 2]);
assert_eq!(result[0][[0, 0]], 1);
assert_eq!(result[0][[0, 1]], 2);
assert_eq!(result[0][[1, 0]], 5);
assert_eq!(result[0][[1, 1]], 6);
assert_eq!(result[1].shape(), &[2, 2]);
assert_eq!(result[1][[0, 0]], 3);
assert_eq!(result[1][[0, 1]], 4);
assert_eq!(result[1][[1, 0]], 7);
assert_eq!(result[1][[1, 1]], 8);
let result = split_2d(a.view(), &[1], 0).expect("Operation failed");
assert_eq!(result.len(), 2);
assert_eq!(result[0].shape(), &[1, 4]);
assert_eq!(result[1].shape(), &[1, 4]);
}
#[test]
fn test_take_2d() {
let a = array![[1, 2, 3], [4, 5, 6]];
let indices = array![0, 2];
let result = take_2d(a.view(), indices.view(), 1).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result[[0, 0]], 1);
assert_eq!(result[[0, 1]], 3);
assert_eq!(result[[1, 0]], 4);
assert_eq!(result[[1, 1]], 6);
}
#[test]
fn test_mask_select() {
let a = array![[1, 2, 3], [4, 5, 6]];
let mask = array![[true, false, true], [false, true, false]];
let result = mask_select(a.view(), mask.view()).expect("Operation failed");
assert_eq!(result.shape(), &[3]);
assert_eq!(result[0], 1);
assert_eq!(result[1], 3);
assert_eq!(result[2], 5);
}
#[test]
fn test_fancy_index_2d() {
let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
let row_indices = array![0, 2];
let col_indices = array![0, 1];
let result = fancy_index_2d(a.view(), row_indices.view(), col_indices.view())
.expect("Operation failed");
assert_eq!(result.shape(), &[2]);
assert_eq!(result[0], 1);
assert_eq!(result[1], 8);
}
#[test]
fn test_where_condition() {
let a = array![[1, 2, 3], [4, 5, 6]];
let result = where_condition(a.view(), |&x| x > 3).expect("Operation failed");
assert_eq!(result.shape(), &[3]);
assert_eq!(result[0], 4);
assert_eq!(result[1], 5);
assert_eq!(result[2], 6);
}
#[test]
fn test_broadcast_1d_to_2d() {
let a = array![1, 2, 3];
let b = broadcast_1d_to_2d(a.view(), 2, 0).expect("Operation failed");
assert_eq!(b.shape(), &[2, 3]);
assert_eq!(b[[0, 0]], 1);
assert_eq!(b[[0, 1]], 2);
assert_eq!(b[[1, 0]], 1);
assert_eq!(b[[1, 2]], 3);
let c = broadcast_1d_to_2d(a.view(), 2, 1).expect("Operation failed");
assert_eq!(c.shape(), &[3, 2]);
assert_eq!(c[[0, 0]], 1);
assert_eq!(c[[0, 1]], 1);
assert_eq!(c[[1, 0]], 2);
assert_eq!(c[[2, 1]], 3);
}
#[test]
fn test_broadcast_apply() {
let a = array![[1, 2, 3], [4, 5, 6]];
let b = array![10, 20, 30];
let result = broadcast_apply(a.view(), b.view(), |x, y| x + y).expect("Operation failed");
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(result[[0, 0]], 11);
assert_eq!(result[[0, 1]], 22);
assert_eq!(result[[0, 2]], 33);
assert_eq!(result[[1, 0]], 14);
assert_eq!(result[[1, 1]], 25);
assert_eq!(result[[1, 2]], 36);
let result = broadcast_apply(a.view(), b.view(), |x, y| x * y).expect("Operation failed");
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(result[[0, 0]], 10);
assert_eq!(result[[0, 1]], 40);
assert_eq!(result[[0, 2]], 90);
assert_eq!(result[[1, 0]], 40);
assert_eq!(result[[1, 1]], 100);
assert_eq!(result[[1, 2]], 180);
}
}