use numrs2::array::Array;
use numrs2::array_ops::*;
#[test]
fn test_concatenate_single_axis() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![4, 5, 6]);
let c = concatenate(&[&a, &b], 0).unwrap();
assert_eq!(c.shape(), vec![6]);
assert_eq!(c.to_vec(), vec![1, 2, 3, 4, 5, 6]);
let d = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
let e = Array::from_vec(vec![5, 6, 7, 8]).reshape(&[2, 2]);
let f = concatenate(&[&d, &e], 0).unwrap();
assert_eq!(f.shape(), vec![4, 2]);
assert_eq!(f.to_vec(), vec![1, 2, 3, 4, 5, 6, 7, 8]);
let g = concatenate(&[&d, &e], 1).unwrap();
assert_eq!(g.shape(), vec![2, 4]);
let g_vec = g.to_vec();
assert!(
g_vec.contains(&1)
&& g_vec.contains(&2)
&& g_vec.contains(&3)
&& g_vec.contains(&4)
&& g_vec.contains(&5)
&& g_vec.contains(&6)
&& g_vec.contains(&7)
&& g_vec.contains(&8)
);
}
#[test]
fn test_concatenate_multiple_axes() {
let a = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
let b = Array::from_vec(vec![5, 6, 7, 8]).reshape(&[2, 2]);
let _temp = concatenate(&[&a, &b], 0).unwrap();
let shortcut = concatenate(&[&a, &b], vec![0]).unwrap();
assert_eq!(shortcut.shape(), vec![4, 2]);
assert_eq!(shortcut.get(&[0, 0]).unwrap(), 1);
assert_eq!(shortcut.get(&[0, 1]).unwrap(), 2);
assert_eq!(shortcut.get(&[2, 0]).unwrap(), 5);
assert_eq!(shortcut.get(&[2, 1]).unwrap(), 6);
}
#[test]
fn test_stack() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![4, 5, 6]);
let c = stack(&[&a, &b], 0).unwrap();
assert_eq!(c.shape(), vec![2, 3]);
assert_eq!(c.to_vec(), vec![1, 2, 3, 4, 5, 6]);
let d = stack(&[&a, &b], 1).unwrap();
assert_eq!(d.shape(), vec![3, 2]);
assert_eq!(d.get(&[0, 0]).unwrap(), 1);
assert_eq!(d.get(&[0, 1]).unwrap(), 4);
assert_eq!(d.get(&[1, 0]).unwrap(), 2);
assert_eq!(d.get(&[1, 1]).unwrap(), 5);
assert_eq!(d.get(&[2, 0]).unwrap(), 3);
assert_eq!(d.get(&[2, 1]).unwrap(), 6);
}
#[test]
fn test_array_creation_operations() {
let a = Array::from_vec(vec![1, 2, 3]);
let b = Array::from_vec(vec![4, 5, 6]);
let c = r_(&[&a, &b]).unwrap();
assert_eq!(c.shape(), vec![6]);
assert_eq!(c.to_vec(), vec![1, 2, 3, 4, 5, 6]);
let d = c_(&[&a, &b]).unwrap();
assert_eq!(d.shape(), vec![3, 2]);
assert_eq!(d.get(&[0, 0]).unwrap(), 1);
assert_eq!(d.get(&[0, 1]).unwrap(), 4);
assert_eq!(d.get(&[1, 0]).unwrap(), 2);
assert_eq!(d.get(&[1, 1]).unwrap(), 5);
assert_eq!(d.get(&[2, 0]).unwrap(), 3);
assert_eq!(d.get(&[2, 1]).unwrap(), 6);
let arr2d = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let splits = hsplit(&arr2d, 3).unwrap();
assert_eq!(splits.len(), 3);
assert_eq!(splits[0].shape(), vec![2, 1]);
assert_eq!(splits[1].shape(), vec![2, 1]);
assert_eq!(splits[2].shape(), vec![2, 1]);
let arr2d_t = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[3, 2]);
let splits_v = vsplit(&arr2d_t, 3).unwrap();
assert_eq!(splits_v.len(), 3);
assert_eq!(splits_v[0].shape(), vec![1, 2]);
assert_eq!(splits_v[1].shape(), vec![1, 2]);
assert_eq!(splits_v[2].shape(), vec![1, 2]);
}