use mlxrs::{Array, Dtype, ops};
#[test]
fn argmin_arange_5_yields_0() {
let a = Array::arange::<f32>(0.0, 5.0, 1.0).unwrap();
let mut r = ops::misc::argmin(&a, None, false).unwrap();
assert_eq!(r.item::<u32>().unwrap(), 0);
}
#[test]
fn argmin_axis_2x3_yields_per_row_index() {
let data = [5.0_f32, 1.0, 3.0, 2.0, 4.0, 0.0];
let a = Array::from_slice(&data, &(2, 3)).unwrap();
let mut r = a.argmin(Some(1), false).unwrap();
assert_eq!(r.shape(), vec![2]);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![1, 2]);
}
#[test]
fn cumsum_arange_5_yields_running_total() {
let a = Array::arange::<f32>(0.0, 5.0, 1.0).unwrap();
let mut r = ops::misc::cumsum(&a, 0, false, true).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 3.0, 6.0, 10.0]);
}
#[test]
fn cumprod_method_arange_1_to_4_yields_factorials() {
let a = Array::arange::<f32>(1.0, 5.0, 1.0).unwrap();
let mut r = a.cumprod(0, false, true).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 6.0, 24.0]);
}
#[test]
fn cummax_running_maximum() {
let a = Array::from_slice(&[1.0_f32, 3.0, 2.0, 5.0, 4.0], &[5]).unwrap();
let mut r = ops::misc::cummax(&a, 0, false, true).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 3.0, 3.0, 5.0, 5.0]);
}
#[test]
fn cummin_running_minimum() {
let a = Array::from_slice(&[4.0_f32, 2.0, 5.0, 1.0, 3.0], &[5]).unwrap();
let mut r = a.cummin(0, false, true).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![4.0, 2.0, 2.0, 1.0, 1.0]);
}
#[test]
fn sort_unsorted_1d_yields_ascending() {
let a = Array::from_slice(&[3.0_f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0], &[7]).unwrap();
let mut r = ops::misc::sort(&a).unwrap();
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 9.0]
);
}
#[test]
fn sort_axis_2x3_per_row() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.sort_axis(1).unwrap();
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
}
#[test]
fn argsort_yields_index_permutation_u32() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0], &[3]).unwrap();
let mut r = ops::misc::argsort(&a).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![1, 2, 0]);
}
#[test]
fn sort_no_axis_flattens_2d() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = ops::misc::sort(&a).unwrap();
assert_eq!(r.shape(), vec![6]);
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
}
#[test]
fn argsort_no_axis_flattens_2d() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = ops::misc::argsort(&a).unwrap();
assert_eq!(r.shape(), vec![6]);
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![1, 2, 0, 4, 5, 3]);
}
#[test]
fn topk_no_axis_flattens_2d() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = ops::misc::topk(&a, 3).unwrap();
assert_eq!(r.shape(), vec![3]);
let mut v = r.to_vec::<f32>().unwrap();
v.sort_by(|x, y| x.partial_cmp(y).unwrap());
assert_eq!(v, vec![4.0, 5.0, 6.0]);
}
#[test]
fn partition_no_axis_flattens_2d() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = ops::misc::partition(&a, 3).unwrap();
assert_eq!(r.shape(), vec![6]);
let v = r.to_vec::<f32>().unwrap();
assert_eq!(v[3], 4.0); for x in &v[..3] {
assert!(*x <= 4.0, "lower side must be ≤ pivot, got {x}");
}
for x in &v[4..] {
assert!(*x >= 4.0, "upper side must be ≥ pivot, got {x}");
}
}
#[test]
fn argsort_axis_per_row_u32() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.argsort_axis(1).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![1, 2, 0, 1, 2, 0]);
}
#[test]
fn topk_returns_k_largest_unsorted() {
let a = Array::from_slice(&[1.0_f32, 5.0, 2.0, 4.0, 3.0], &[5]).unwrap();
let mut r = ops::misc::topk(&a, 3).unwrap();
let mut v = r.to_vec::<f32>().unwrap();
v.sort_by(|x, y| x.partial_cmp(y).unwrap());
assert_eq!(v, vec![3.0, 4.0, 5.0]);
}
#[test]
fn topk_axis_per_row_largest_two() {
let a = Array::from_slice(&[1.0_f32, 5.0, 2.0, 4.0, 8.0, 6.0, 7.0, 3.0], &(2, 4)).unwrap();
let r = a.topk_axis(2, 1).unwrap();
assert_eq!(r.shape(), vec![2, 2]);
let mut row_sums = r.sum_axes(&[1], false).unwrap();
assert_eq!(row_sums.to_vec::<f32>().unwrap(), vec![9.0, 15.0]);
}
#[test]
fn partition_kth_element_is_in_position() {
let a = Array::from_slice(&[3.0_f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0], &[8]).unwrap();
let mut r = ops::misc::partition(&a, 2).unwrap();
let v = r.to_vec::<f32>().unwrap();
assert_eq!(v[2], 2.0);
for x in &v[..2] {
assert!(*x <= 2.0, "lower side must be ≤ pivot, got {x}");
}
for x in &v[3..] {
assert!(*x >= 2.0, "upper side must be ≥ pivot, got {x}");
}
}
#[test]
fn partition_axis_method_form() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.partition_axis(1, 1).unwrap();
let v = r.to_vec::<f32>().unwrap();
assert_eq!(v[1], 2.0);
assert_eq!(v[4], 5.0);
}
#[test]
fn clip_with_scalar_clamps_into_range() {
let a = Array::from_slice(&[-2.0_f32, -1.0, 0.0, 1.0, 2.0], &[5]).unwrap();
let mut r = ops::misc::clip_with_scalar(&a, -1.0, 1.0).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![-1.0, -1.0, 0.0, 1.0, 1.0]);
}
#[test]
fn clip_with_array_bounds() {
let a = Array::arange::<f32>(0.0, 5.0, 1.0).unwrap();
let lo = Array::full::<f32>(&[1], 1.0).unwrap();
let hi = Array::full::<f32>(&[1], 3.0).unwrap();
let mut r = ops::misc::clip(&a, &lo, &hi).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 1.0, 2.0, 3.0, 3.0]);
}
#[test]
fn clip_method_form_matches_freefn() {
let a = Array::arange::<f32>(0.0, 5.0, 1.0).unwrap();
let mut r = a.clip_with_scalar(1.0, 3.0).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 1.0, 2.0, 3.0, 3.0]);
}
#[test]
fn ones_like_inherits_shape_and_dtype() {
let a = Array::zeros::<f32>(&(2, 3)).unwrap();
let mut r = ops::misc::ones_like(&a).unwrap();
assert_eq!(r.shape(), vec![2, 3]);
assert_eq!(r.dtype().unwrap(), Dtype::F32);
assert!(r.to_vec::<f32>().unwrap().iter().all(|&x| x == 1.0));
}
#[test]
fn zeros_like_method_form() {
let a = Array::ones::<f32>(&(2, 2)).unwrap();
let mut r = a.zeros_like().unwrap();
assert_eq!(r.shape(), vec![2, 2]);
assert!(r.to_vec::<f32>().unwrap().iter().all(|&x| x == 0.0));
}
#[test]
fn full_like_fills_with_value() {
let a = Array::zeros::<f32>(&(2, 2)).unwrap();
let mut r = ops::misc::full_like(&a, 7.5).unwrap();
assert_eq!(r.shape(), vec![2, 2]);
assert!(r.to_vec::<f32>().unwrap().iter().all(|&x| x == 7.5));
}
#[test]
fn full_like_method_form() {
let a = Array::ones::<f32>(&(3,)).unwrap();
let mut r = a.full_like(2.5).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.5, 2.5, 2.5]);
}
#[test]
fn astype_f32_to_i32_truncates() {
let a = Array::from_slice(&[0.0_f32, 1.5, 2.9], &[3]).unwrap();
let mut r = ops::misc::astype(&a, Dtype::I32).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::I32);
assert_eq!(r.to_vec::<i32>().unwrap(), vec![0, 1, 2]);
}
#[test]
fn astype_method_form_changes_dtype() {
let a = Array::ones::<f32>(&[3]).unwrap();
let mut r = a.astype(Dtype::U32).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![1, 1, 1]);
}
#[test]
fn view_i32_to_u32_preserves_bit_pattern() {
let raw: u32 = 0xF0FF_FFFF;
let signed: i32 = raw as i32;
assert!(
signed < 0,
"fixture must be a negative i32 to exercise the sign bit"
);
let a = Array::from_slice(&[signed], &[1]).unwrap();
let mut r = ops::misc::view(&a, Dtype::U32).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![raw]);
}
#[test]
fn view_same_width_preserves_shape() {
let a = Array::from_slice(&[1_i32, 2, 3, 4], &(2, 2)).unwrap();
let mut r = ops::misc::view(&a, Dtype::U32).unwrap();
assert_eq!(r.shape(), vec![2, 2]);
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap(), vec![1, 2, 3, 4]);
}
#[test]
fn argpartition_method_form() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0], &[3]).unwrap();
let mut r = a.argpartition(0).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.to_vec::<u32>().unwrap()[0], 1);
}
#[test]
fn argpartition_axis_method_form() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0, 6.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.argpartition_axis(0, 1).unwrap();
assert_eq!(r.dtype().unwrap(), Dtype::U32);
assert_eq!(r.shape(), vec![2, 3]);
let v = r.to_vec::<u32>().unwrap();
assert_eq!((v[0], v[3]), (1, 1));
}
#[test]
fn softmax_axis_method_form() {
let a = Array::from_slice(&[0.0_f32, 0.0], &(1, 2)).unwrap();
let mut r = a.softmax_axis(1, false).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![0.5, 0.5]);
}
#[test]
fn stop_gradient_is_forward_identity() {
let a = Array::from_slice(&[1.5f32, -2.0, 3.25, 0.0], &[2, 2]).unwrap();
let mut sg = a.stop_gradient().unwrap();
assert_eq!(sg.shape(), &[2, 2]);
assert_eq!(sg.dtype().unwrap(), Dtype::F32);
assert_eq!(sg.to_vec::<f32>().unwrap(), vec![1.5, -2.0, 3.25, 0.0]);
}