use crate::ndarray_ext::reduction;
use crate::numeric::Float;
use crate::simd_ops::SimdUnifiedOps;
use ::ndarray::{Array, ArrayView, Ix1, Ix2};
use num_traits::Zero;
pub type GradientResult<T> = Result<(Array<T, Ix2>, Array<T, Ix2>), &'static str>;
#[allow(dead_code)]
pub fn flip_2d<T>(array: ArrayView<T, Ix2>, flip_axis_0: bool, flipaxis_1: bool) -> Array<T, Ix2>
where
T: Clone + Zero,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
let mut result = Array::<T, Ix2>::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
let src_i = if flip_axis_0 { rows - 1 - i } else { i };
let src_j = if flipaxis_1 { cols - 1 - j } else { j };
result[[i, j]] = array[[src_i, src_j]].clone();
}
}
result
}
#[allow(dead_code)]
pub fn roll_2d<T>(
array: ArrayView<T, Ix2>,
shift_axis_0: isize,
shift_axis_1: isize,
) -> Array<T, Ix2>
where
T: Clone + Zero,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if shift_axis_0 == 0 && shift_axis_1 == 0 {
return array.to_owned();
}
let effective_shift_0 = if rows == 0 {
0
} else {
((shift_axis_0 % rows as isize) + rows as isize) % rows as isize
};
let effective_shift_1 = if cols == 0 {
0
} else {
((shift_axis_1 % cols as isize) + cols as isize) % cols as isize
};
let mut result = Array::<T, Ix2>::zeros((rows, cols));
for i in 0..rows {
for j in 0..cols {
let src_i = (i as isize + rows as isize - effective_shift_0) % rows as isize;
let src_j = (j as isize + cols as isize - effective_shift_1) % cols as isize;
result[[i, j]] = array[[src_i as usize, src_j as usize]].clone();
}
}
result
}
#[allow(dead_code)]
pub fn tile_2d<T>(array: ArrayView<T, Ix2>, reps_axis_0: usize, repsaxis_1: usize) -> Array<T, Ix2>
where
T: Clone + Default + Zero,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
let new_rows = rows * reps_axis_0;
let new_cols = cols * repsaxis_1;
if reps_axis_0 == 0 || repsaxis_1 == 0 {
return Array::<T, Ix2>::default((0, 0));
}
if reps_axis_0 == 1 && repsaxis_1 == 1 {
return array.to_owned();
}
let mut result = Array::<T, Ix2>::zeros((new_rows, new_cols));
for i in 0..new_rows {
for j in 0..new_cols {
let src_i = i % rows;
let src_j = j % cols;
result[[i, j]] = array[[src_i, src_j]].clone();
}
}
result
}
#[allow(dead_code)]
pub fn repeat_2d<T>(
array: ArrayView<T, Ix2>,
repeats_axis_0: usize,
repeats_axis_1: usize,
) -> Array<T, Ix2>
where
T: Clone + Default + Zero,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
let new_rows = rows * repeats_axis_0;
let new_cols = cols * repeats_axis_1;
if repeats_axis_0 == 0 || repeats_axis_1 == 0 {
return Array::<T, Ix2>::default((0, 0));
}
if repeats_axis_0 == 1 && repeats_axis_1 == 1 {
return array.to_owned();
}
let mut result = Array::<T, Ix2>::zeros((new_rows, new_cols));
for i in 0..rows {
for j in 0..cols {
for i_rep in 0..repeats_axis_0 {
for j_rep in 0..repeats_axis_1 {
let dest_i = i * repeats_axis_0 + i_rep;
let dest_j = j * repeats_axis_1 + j_rep;
result[[dest_i, dest_j]] = array[[i, j]].clone();
}
}
}
}
result
}
#[allow(dead_code)]
pub fn swap_axes_2d<T>(
array: ArrayView<T, Ix2>,
index1: usize,
index2: usize,
axis: usize,
) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if axis > 1 {
return Err("Axis must be 0 or 1 for 2D arrays");
}
let axis_len = if axis == 0 { rows } else { cols };
if index1 >= axis_len || index2 >= axis_len {
return Err("Indices out of bounds");
}
if index1 == index2 {
return Ok(array.to_owned());
}
let mut result = array.to_owned();
match axis {
0 => {
for j in 0..cols {
let temp = result[[index1, j]].clone();
result[[index1, j]] = result[[index2, j]].clone();
result[[index2, j]] = temp;
}
}
1 => {
for i in 0..rows {
let temp = result[[i, index1]].clone();
result[[i, index1]] = result[[i, index2]].clone();
result[[i, index2]] = temp;
}
}
_ => unreachable!(),
}
Ok(result)
}
#[allow(dead_code)]
pub fn pad_2d<T>(
array: ArrayView<T, Ix2>,
pad_width: ((usize, usize), (usize, usize)),
pad_value: T,
) -> Array<T, Ix2>
where
T: Clone,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
let ((before_0, after_0), (before_1, after_1)) = pad_width;
let new_rows = rows + before_0 + after_0;
let new_cols = cols + before_1 + after_1;
let mut result = Array::<T, Ix2>::from_elem((new_rows, new_cols), pad_value);
for i in 0..rows {
for j in 0..cols {
result[[i + before_0, j + before_1]] = array[[i, j]].clone();
}
}
result
}
#[allow(dead_code)]
pub fn concatenate_2d<T>(
arrays: &[ArrayView<T, Ix2>],
axis: usize,
) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Zero,
{
if arrays.is_empty() {
return Err("No arrays provided for concatenation");
}
if axis > 1 {
return Err("Axis must be 0 or 1 for 2D arrays");
}
let firstshape = arrays[0].shape();
let mut totalshape = [firstshape[0], firstshape[1]];
for array in arrays.iter().skip(1) {
let currentshape = array.shape();
if axis == 0 && currentshape[1] != firstshape[1] {
return Err("All arrays must have the same number of columns for axis=0 concatenation");
} else if axis == 1 && currentshape[0] != firstshape[0] {
return Err("All arrays must have the same number of rows for axis=1 concatenation");
}
totalshape[axis] += currentshape[axis];
}
let mut result = Array::<T, Ix2>::zeros((totalshape[0], totalshape[1]));
match axis {
0 => {
let mut row_offset = 0;
for array in arrays {
let rows = array.shape()[0];
let cols = array.shape()[1];
for i in 0..rows {
for j in 0..cols {
result[[row_offset + i, j]] = array[[i, j]].clone();
}
}
row_offset += rows;
}
}
1 => {
let mut col_offset = 0;
for array in arrays {
let rows = array.shape()[0];
let cols = array.shape()[1];
for i in 0..rows {
for j in 0..cols {
result[[i, col_offset + j]] = array[[i, j]].clone();
}
}
col_offset += cols;
}
}
_ => unreachable!(),
}
Ok(result)
}
#[allow(dead_code)]
pub fn vstack_1d<T>(arrays: &[ArrayView<T, Ix1>]) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Zero,
{
if arrays.is_empty() {
return Err("No arrays provided for stacking");
}
let expected_len = arrays[0].len();
for (_i, array) in arrays.iter().enumerate().skip(1) {
if array.len() != expected_len {
return Err("Arrays must have consistent lengths for stacking");
}
}
let rows = arrays.len();
let cols = expected_len;
let mut result = Array::<T, Ix2>::zeros((rows, cols));
for (i, array) in arrays.iter().enumerate() {
for (j, val) in array.iter().enumerate() {
result[[i, j]] = val.clone();
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn hstack_1d<T>(arrays: &[ArrayView<T, Ix1>]) -> Result<Array<T, Ix2>, &'static str>
where
T: Clone + Zero,
{
if arrays.is_empty() {
return Err("No arrays provided for stacking");
}
let expected_len = arrays[0].len();
for (_i, array) in arrays.iter().enumerate().skip(1) {
if array.len() != expected_len {
return Err("Arrays must have consistent lengths for stacking");
}
}
let rows = expected_len;
let cols = arrays.len();
let mut result = Array::<T, Ix2>::zeros((rows, cols));
for (j, array) in arrays.iter().enumerate() {
for (i, val) in array.iter().enumerate() {
result[[i, j]] = val.clone();
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn squeeze_2d<T>(array: ArrayView<T, Ix2>, axis: usize) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Zero,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
match axis {
0 => {
if rows != 1 {
return Err("Cannot squeeze array with more than 1 row along axis 0");
}
let mut result = Array::<T, Ix1>::zeros(cols);
for j in 0..cols {
result[j] = array[[0, j]].clone();
}
Ok(result)
}
1 => {
if cols != 1 {
return Err("Cannot squeeze array with more than 1 column along axis 1");
}
let mut result = Array::<T, Ix1>::zeros(rows);
for i in 0..rows {
result[i] = array[[i, 0]].clone();
}
Ok(result)
}
_ => Err("Axis must be 0 or 1 for 2D arrays"),
}
}
#[allow(dead_code)]
pub fn meshgrid<T>(x: ArrayView<T, Ix1>, y: ArrayView<T, Ix1>) -> GradientResult<T>
where
T: Clone + Zero,
{
let nx = x.len();
let ny = y.len();
if nx == 0 || ny == 0 {
return Err("Input arrays must not be empty");
}
let mut x_grid = Array::<T, Ix2>::zeros((ny, nx));
let mut y_grid = Array::<T, Ix2>::zeros((ny, nx));
for i in 0..ny {
for j in 0..nx {
x_grid[[i, j]] = x[j].clone();
y_grid[[i, j]] = y[i].clone();
}
}
Ok((x_grid, y_grid))
}
#[allow(dead_code)]
pub fn unique<T>(array: ArrayView<T, Ix1>) -> Result<Array<T, Ix1>, &'static str>
where
T: Clone + Ord,
{
if array.is_empty() {
return Err("Input array must not be empty");
}
let mut values: Vec<T> = array.iter().cloned().collect();
values.sort();
values.dedup();
Ok(Array::from_vec(values))
}
#[allow(dead_code)]
pub fn argmin<T>(
array: ArrayView<T, Ix2>,
axis: Option<usize>,
) -> Result<Array<usize, Ix1>, &'static str>
where
T: Clone + PartialOrd,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if rows == 0 || cols == 0 {
return Err("Input array must not be empty");
}
match axis {
Some(0) => {
let mut indices = Array::<usize, Ix1>::zeros(cols);
for j in 0..cols {
let mut min_idx = 0;
let mut min_val = &array[[0, j]];
for i in 1..rows {
if &array[[i, j]] < min_val {
min_idx = i;
min_val = &array[[i, j]];
}
}
indices[j] = min_idx;
}
Ok(indices)
}
Some(1) => {
let mut indices = Array::<usize, Ix1>::zeros(rows);
for i in 0..rows {
let mut min_idx = 0;
let mut min_val = &array[[i, 0]];
for j in 1..cols {
if &array[[i, j]] < min_val {
min_idx = j;
min_val = &array[[i, j]];
}
}
indices[i] = min_idx;
}
Ok(indices)
}
Some(_) => Err("Axis must be 0 or 1 for 2D arrays"),
None => {
let mut min_idx = 0;
let mut min_val = &array[[0, 0]];
for i in 0..rows {
for j in 0..cols {
if &array[[i, j]] < min_val {
min_idx = i * cols + j;
min_val = &array[[i, j]];
}
}
}
Ok(Array::from_vec(vec![min_idx]))
}
}
}
#[allow(dead_code)]
pub fn argmin_simd<T>(
array: ArrayView<T, Ix2>,
axis: Option<usize>,
) -> Result<Array<usize, Ix1>, &'static str>
where
T: Clone + PartialOrd + Float + SimdUnifiedOps,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if rows == 0 || cols == 0 {
return Err("Input array must not be empty");
}
match axis {
Some(0) => {
let mut indices = Array::<usize, Ix1>::zeros(cols);
for j in 0..cols {
let col = array.column(j);
if col.is_standard_layout() {
if let Some(idx) = reduction::argmin_simd(&col) {
indices[j] = idx;
}
} else {
let mut min_idx = 0;
let mut min_val = &array[[0, j]];
for i in 1..rows {
if &array[[i, j]] < min_val {
min_idx = i;
min_val = &array[[i, j]];
}
}
indices[j] = min_idx;
}
}
Ok(indices)
}
Some(1) => {
let mut indices = Array::<usize, Ix1>::zeros(rows);
for i in 0..rows {
let row = array.row(i);
if row.is_standard_layout() {
if let Some(idx) = reduction::argmin_simd(&row) {
indices[i] = idx;
}
} else {
let mut min_idx = 0;
let mut min_val = &array[[i, 0]];
for j in 1..cols {
if &array[[i, j]] < min_val {
min_idx = j;
min_val = &array[[i, j]];
}
}
indices[i] = min_idx;
}
}
Ok(indices)
}
Some(_) => Err("Axis must be 0 or 1 for 2D arrays"),
None => {
let flattened = array.as_slice();
if let Some(slice) = flattened {
let view = crate::ndarray::ArrayView1::from(slice);
if let Some(idx) = reduction::argmin_simd(&view) {
return Ok(Array::from_vec(vec![idx]));
}
}
let mut min_idx = 0;
let mut min_val = &array[[0, 0]];
for i in 0..rows {
for j in 0..cols {
if &array[[i, j]] < min_val {
min_idx = i * cols + j;
min_val = &array[[i, j]];
}
}
}
Ok(Array::from_vec(vec![min_idx]))
}
}
}
#[allow(dead_code)]
pub fn argmax<T>(
array: ArrayView<T, Ix2>,
axis: Option<usize>,
) -> Result<Array<usize, Ix1>, &'static str>
where
T: Clone + PartialOrd,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if rows == 0 || cols == 0 {
return Err("Input array must not be empty");
}
match axis {
Some(0) => {
let mut indices = Array::<usize, Ix1>::zeros(cols);
for j in 0..cols {
let mut max_idx = 0;
let mut max_val = &array[[0, j]];
for i in 1..rows {
if &array[[i, j]] > max_val {
max_idx = i;
max_val = &array[[i, j]];
}
}
indices[j] = max_idx;
}
Ok(indices)
}
Some(1) => {
let mut indices = Array::<usize, Ix1>::zeros(rows);
for i in 0..rows {
let mut max_idx = 0;
let mut max_val = &array[[i, 0]];
for j in 1..cols {
if &array[[i, j]] > max_val {
max_idx = j;
max_val = &array[[i, j]];
}
}
indices[i] = max_idx;
}
Ok(indices)
}
Some(_) => Err("Axis must be 0 or 1 for 2D arrays"),
None => {
let mut max_idx = 0;
let mut max_val = &array[[0, 0]];
for i in 0..rows {
for j in 0..cols {
if &array[[i, j]] > max_val {
max_idx = i * cols + j;
max_val = &array[[i, j]];
}
}
}
Ok(Array::from_vec(vec![max_idx]))
}
}
}
#[allow(dead_code)]
pub fn argmax_simd<T>(
array: ArrayView<T, Ix2>,
axis: Option<usize>,
) -> Result<Array<usize, Ix1>, &'static str>
where
T: Clone + PartialOrd + Float + SimdUnifiedOps,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if rows == 0 || cols == 0 {
return Err("Input array must not be empty");
}
match axis {
Some(0) => {
let mut indices = Array::<usize, Ix1>::zeros(cols);
for j in 0..cols {
let col = array.column(j);
if col.is_standard_layout() {
if let Some(idx) = reduction::argmax_simd(&col) {
indices[j] = idx;
}
} else {
let mut max_idx = 0;
let mut max_val = &array[[0, j]];
for i in 1..rows {
if &array[[i, j]] > max_val {
max_idx = i;
max_val = &array[[i, j]];
}
}
indices[j] = max_idx;
}
}
Ok(indices)
}
Some(1) => {
let mut indices = Array::<usize, Ix1>::zeros(rows);
for i in 0..rows {
let row = array.row(i);
if row.is_standard_layout() {
if let Some(idx) = reduction::argmax_simd(&row) {
indices[i] = idx;
}
} else {
let mut max_idx = 0;
let mut max_val = &array[[i, 0]];
for j in 1..cols {
if &array[[i, j]] > max_val {
max_idx = j;
max_val = &array[[i, j]];
}
}
indices[i] = max_idx;
}
}
Ok(indices)
}
Some(_) => Err("Axis must be 0 or 1 for 2D arrays"),
None => {
let flattened = array.as_slice();
if let Some(slice) = flattened {
let view = crate::ndarray::ArrayView1::from(slice);
if let Some(idx) = reduction::argmax_simd(&view) {
return Ok(Array::from_vec(vec![idx]));
}
}
let mut max_idx = 0;
let mut max_val = &array[[0, 0]];
for i in 0..rows {
for j in 0..cols {
if &array[[i, j]] > max_val {
max_idx = i * cols + j;
max_val = &array[[i, j]];
}
}
}
Ok(Array::from_vec(vec![max_idx]))
}
}
}
#[allow(dead_code)]
pub fn gradient<T>(array: ArrayView<T, Ix2>, spacing: Option<(T, T)>) -> GradientResult<T>
where
T: Clone + num_traits::Float,
{
let (rows, cols) = (array.shape()[0], array.shape()[1]);
if rows == 0 || cols == 0 {
return Err("Input array must not be empty");
}
let (dy, dx) = spacing.unwrap_or((T::one(), T::one()));
let mut grad_y = Array::<T, Ix2>::zeros((rows, cols));
let mut grad_x = Array::<T, Ix2>::zeros((rows, cols));
if rows == 1 {
} else {
for j in 0..cols {
grad_y[[0, j]] = (array[[1, j]] - array[[0, j]]) / dy;
}
for i in 1..rows - 1 {
for j in 0..cols {
grad_y[[i, j]] = (array[[i + 1, j]] - array[[i.saturating_sub(1), j]]) / (dy + dy);
}
}
for j in 0..cols {
grad_y[[rows - 1, j]] = (array[[rows - 1, j]] - array[[rows - 2, j]]) / dy;
}
}
if cols == 1 {
} else {
for i in 0..rows {
grad_x[[i, 0]] = (array[[i, 1]] - array[[i, 0]]) / dx;
for j in 1..cols - 1 {
grad_x[[i, j]] = (array[[i, j + 1]] - array[[i, j.saturating_sub(1)]]) / (dx + dx);
}
grad_x[[i, cols - 1]] = (array[[i, cols - 1]] - array[[i, cols - 2]]) / dx;
}
}
Ok((grad_y, grad_x))
}
#[cfg(test)]
mod tests {
use super::*;
use ::ndarray::array;
use approx::assert_abs_diff_eq;
#[test]
fn test_flip_2d() {
let a = array![[1, 2, 3], [4, 5, 6]];
let flipped_rows = flip_2d(a.view(), true, false);
assert_eq!(flipped_rows, array![[4, 5, 6], [1, 2, 3]]);
let flipped_cols = flip_2d(a.view(), false, true);
assert_eq!(flipped_cols, array![[3, 2, 1], [6, 5, 4]]);
let flipped_both = flip_2d(a.view(), true, true);
assert_eq!(flipped_both, array![[6, 5, 4], [3, 2, 1]]);
}
#[test]
fn test_roll_2d() {
let a = array![[1, 2, 3], [4, 5, 6]];
let rolled_rows = roll_2d(a.view(), 1, 0);
assert_eq!(rolled_rows, array![[4, 5, 6], [1, 2, 3]]);
let rolled_cols = roll_2d(a.view(), 0, 1);
assert_eq!(rolled_cols, array![[3, 1, 2], [6, 4, 5]]);
let rolled_neg = roll_2d(a.view(), 0, -1);
assert_eq!(rolled_neg, array![[2, 3, 1], [5, 6, 4]]);
let rolled_zero = roll_2d(a.view(), 0, 0);
assert_eq!(rolled_zero, a);
}
#[test]
fn test_tile_2d() {
let a = array![[1, 2], [3, 4]];
let tiled = tile_2d(a.view(), 2, 3);
assert_eq!(tiled.shape(), &[4, 6]);
assert_eq!(
tiled,
array![
[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4]
]
);
let tiled_axis_0 = tile_2d(a.view(), 2, 1);
assert_eq!(tiled_axis_0.shape(), &[4, 2]);
assert_eq!(tiled_axis_0, array![[1, 2], [3, 4], [1, 2], [3, 4]]);
let single = array![[5]];
let tiled_single = tile_2d(single.view(), 2, 2);
assert_eq!(tiled_single.shape(), &[2, 2]);
assert_eq!(tiled_single, array![[5, 5], [5, 5]]);
}
#[test]
fn test_repeat_2d() {
let a = array![[1, 2], [3, 4]];
let repeated = repeat_2d(a.view(), 2, 3);
assert_eq!(repeated.shape(), &[4, 6]);
assert_eq!(
repeated,
array![
[1, 1, 1, 2, 2, 2],
[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4],
[3, 3, 3, 4, 4, 4]
]
);
let repeated_axis_1 = repeat_2d(a.view(), 1, 2);
assert_eq!(repeated_axis_1.shape(), &[2, 4]);
assert_eq!(repeated_axis_1, array![[1, 1, 2, 2], [3, 3, 4, 4]]);
}
#[test]
fn test_swap_axes_2d() {
let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
let swapped_rows = swap_axes_2d(a.view(), 0, 2, 0).expect("Operation failed");
assert_eq!(swapped_rows, array![[7, 8, 9], [4, 5, 6], [1, 2, 3]]);
let swapped_cols = swap_axes_2d(a.view(), 0, 2, 1).expect("Operation failed");
assert_eq!(swapped_cols, array![[3, 2, 1], [6, 5, 4], [9, 8, 7]]);
let swapped_same = swap_axes_2d(a.view(), 1, 1, 0).expect("Operation failed");
assert_eq!(swapped_same, a);
assert!(swap_axes_2d(a.view(), 0, 1, 2).is_err());
assert!(swap_axes_2d(a.view(), 0, 3, 0).is_err());
}
#[test]
fn test_pad_2d() {
let a = array![[1, 2], [3, 4]];
let padded_all = pad_2d(a.view(), ((1, 1), (1, 1)), 0);
assert_eq!(padded_all.shape(), &[4, 4]);
assert_eq!(
padded_all,
array![[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]
);
let padded_uneven = pad_2d(a.view(), ((2, 0), (0, 1)), 9);
assert_eq!(padded_uneven.shape(), &[4, 3]);
assert_eq!(
padded_uneven,
array![[9, 9, 9], [9, 9, 9], [1, 2, 9], [3, 4, 9]]
);
}
#[test]
fn test_concatenate_2d() {
let a = array![[1, 2], [3, 4]];
let b = array![[5, 6], [7, 8]];
let vertical = concatenate_2d(&[a.view(), b.view()], 0).expect("Operation failed");
assert_eq!(vertical.shape(), &[4, 2]);
assert_eq!(vertical, array![[1, 2], [3, 4], [5, 6], [7, 8]]);
let horizontal = concatenate_2d(&[a.view(), b.view()], 1).expect("Operation failed");
assert_eq!(horizontal.shape(), &[2, 4]);
assert_eq!(horizontal, array![[1, 2, 5, 6], [3, 4, 7, 8]]);
let c = array![[9, 10, 11]];
assert!(concatenate_2d(&[a.view(), c.view()], 0).is_err());
let empty: [ArrayView<i32, Ix2>; 0] = [];
assert!(concatenate_2d(&empty, 0).is_err());
assert!(concatenate_2d(&[a.view(), b.view()], 2).is_err());
}
#[test]
fn test_vstack_1d() {
let a = array![1, 2, 3];
let b = array![4, 5, 6];
let c = array![7, 8, 9];
let stacked = vstack_1d(&[a.view(), b.view(), c.view()]).expect("Operation failed");
assert_eq!(stacked.shape(), &[3, 3]);
assert_eq!(stacked, array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
let empty: [ArrayView<i32, Ix1>; 0] = [];
assert!(vstack_1d(&empty).is_err());
let d = array![10, 11];
assert!(vstack_1d(&[a.view(), d.view()]).is_err());
}
#[test]
fn test_hstack_1d() {
let a = array![1, 2, 3];
let b = array![4, 5, 6];
let stacked = hstack_1d(&[a.view(), b.view()]).expect("Operation failed");
assert_eq!(stacked.shape(), &[3, 2]);
assert_eq!(stacked, array![[1, 4], [2, 5], [3, 6]]);
let empty: [ArrayView<i32, Ix1>; 0] = [];
assert!(hstack_1d(&empty).is_err());
let c = array![7, 8];
assert!(hstack_1d(&[a.view(), c.view()]).is_err());
}
#[test]
fn test_squeeze_2d() {
let a = array![[1, 2, 3]]; let b = array![[1], [2], [3]];
let squeezed_a = squeeze_2d(a.view(), 0).expect("Operation failed");
assert_eq!(squeezed_a.shape(), &[3]);
assert_eq!(squeezed_a, array![1, 2, 3]);
let squeezed_b = squeeze_2d(b.view(), 1).expect("Operation failed");
assert_eq!(squeezed_b.shape(), &[3]);
assert_eq!(squeezed_b, array![1, 2, 3]);
let c = array![[1, 2], [3, 4]]; assert!(squeeze_2d(c.view(), 0).is_err());
assert!(squeeze_2d(c.view(), 1).is_err());
assert!(squeeze_2d(a.view(), 2).is_err());
}
#[test]
fn test_meshgrid() {
let x = array![1, 2, 3];
let y = array![4, 5];
let (x_grid, y_grid) = meshgrid(x.view(), y.view()).expect("Operation failed");
assert_eq!(x_grid.shape(), &[2, 3]);
assert_eq!(y_grid.shape(), &[2, 3]);
assert_eq!(x_grid, array![[1, 2, 3], [1, 2, 3]]);
assert_eq!(y_grid, array![[4, 4, 4], [5, 5, 5]]);
let empty = array![];
assert!(meshgrid(x.view(), empty.view()).is_err());
assert!(meshgrid(empty.view(), y.view()).is_err());
}
#[test]
fn test_unique() {
let a = array![3, 1, 2, 2, 3, 4, 1];
let result = unique(a.view()).expect("Operation failed");
assert_eq!(result, array![1, 2, 3, 4]);
let empty: Array<i32, Ix1> = array![];
assert!(unique(empty.view()).is_err());
}
#[test]
fn test_argmin() {
let a = array![[5, 2, 3], [4, 1, 6]];
let result = argmin(a.view(), Some(0)).expect("Operation failed");
assert_eq!(result, array![1, 1, 0]);
let result = argmin(a.view(), Some(1)).expect("Operation failed");
assert_eq!(result, array![1, 1]);
let result = argmin(a.view(), None).expect("Operation failed");
assert_eq!(result[0], 4);
assert!(argmin(a.view(), Some(2)).is_err());
let empty: Array<i32, Ix2> = Array::zeros((0, 0));
assert!(argmin(empty.view(), Some(0)).is_err());
}
#[test]
fn test_argmax() {
let a = array![[5, 2, 3], [4, 1, 6]];
let result = argmax(a.view(), Some(0)).expect("Operation failed");
assert_eq!(result, array![0, 0, 1]);
let result = argmax(a.view(), Some(1)).expect("Operation failed");
assert_eq!(result, array![0, 2]);
let result = argmax(a.view(), None).expect("Operation failed");
assert_eq!(result[0], 5);
assert!(argmax(a.view(), Some(2)).is_err());
let empty: Array<i32, Ix2> = Array::zeros((0, 0));
assert!(argmax(empty.view(), Some(0)).is_err());
}
#[test]
fn test_gradient() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let (grad_y, grad_x) = gradient(a.view(), None).expect("Operation failed");
assert_eq!(grad_y.shape(), &[2, 3]);
assert_eq!(grad_x.shape(), &[2, 3]);
assert_abs_diff_eq!(grad_y[[0, 0]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_y[[0, 1]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_y[[0, 2]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_y[[1, 0]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_y[[1, 1]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_y[[1, 2]], 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[0, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[0, 1]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[0, 2]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[1, 0]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[1, 1]], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[1, 2]], 1.0, epsilon = 1e-10);
let (grad_y, grad_x) = gradient(a.view(), Some((2.0, 0.5))).expect("Operation failed");
assert_abs_diff_eq!(grad_y[[0, 0]], 1.5, epsilon = 1e-10); assert_abs_diff_eq!(grad_y[[0, 1]], 1.5, epsilon = 1e-10);
assert_abs_diff_eq!(grad_y[[0, 2]], 1.5, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[0, 0]], 2.0, epsilon = 1e-10); assert_abs_diff_eq!(grad_x[[0, 1]], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(grad_x[[0, 2]], 2.0, epsilon = 1e-10);
let empty: Array<f32, Ix2> = Array::zeros((0, 0));
assert!(gradient(empty.view(), None).is_err());
}
}