use mlxrs::Array;
#[test]
fn mean_of_2x2_ones_yields_1() {
let a = Array::ones::<f32>(&(2, 2)).unwrap();
let mut r = a.mean(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 1.0);
}
#[test]
fn mean_axes_of_2x2_along_axis0() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.mean_axes(&[0], false).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.0, 3.0]);
}
#[test]
fn mean_axes_empty_is_identity_for_float() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.mean_axes(&[], false).unwrap();
assert_eq!(r.shape(), vec![2, 2]);
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn mean_axes_empty_promotes_int_to_float() {
let a = Array::from_slice(&[1_i32, 2, 3, 4], &(2, 2)).unwrap();
assert_eq!(a.dtype().unwrap(), mlxrs::Dtype::I32);
let r_empty = a.mean_axes(&[], false).unwrap();
let r_full = mlxrs::ops::reduction::mean(&a, false).unwrap();
assert_eq!(
r_empty.dtype().unwrap(),
r_full.dtype().unwrap(),
"empty-axes and full-reduction must agree on output dtype",
);
assert_eq!(
r_empty.dtype().unwrap(),
mlxrs::Dtype::F32,
"mean of int promotes to f32",
);
}
#[test]
fn max_of_arange_yields_last() {
let a = Array::arange::<f32>(0.0, 5.0, 1.0).unwrap();
let mut r = a.max(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 4.0);
}
#[test]
fn max_axes_of_2x2_along_axis1() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.max_axes(&[1], false).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.0, 4.0]);
}
#[test]
fn max_axes_keepdims_preserves_axis() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let r = a.max_axes(&[1], true).unwrap();
assert_eq!(r.shape(), vec![2, 1]);
}
#[test]
fn max_axes_empty_on_zero_size_errors() {
let a = Array::from_slice::<f32>(&[], &[0i32]).unwrap();
assert_eq!(a.size(), 0);
let r_max = mlxrs::ops::reduction::max_axes(&a, &[], false);
assert!(
matches!(
&r_max,
Err(mlxrs::Error::MlxOp(p)) if matches!(p.op(), mlxrs::error::MlxOpKind::Pool)
),
"expected Err(MlxOp(Pool)) for max_axes(zero_size, &[]), got {r_max:?}",
);
let r_min = mlxrs::ops::reduction::min_axes(&a, &[], false);
assert!(
matches!(
&r_min,
Err(mlxrs::Error::MlxOp(p)) if matches!(p.op(), mlxrs::error::MlxOpKind::Pool)
),
"expected Err(MlxOp(Pool)) for min_axes(zero_size, &[]), got {r_min:?}",
);
}
#[test]
fn max_axes_empty_on_non_zero_size_is_identity() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.max_axes(&[], false).unwrap();
assert_eq!(r.shape(), vec![2, 2]);
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn min_of_arange_yields_first() {
let a = Array::arange::<f32>(0.0, 5.0, 1.0).unwrap();
let mut r = a.min(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 0.0);
}
#[test]
fn min_axes_of_2x2_along_axis0() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.min_axes(&[0], false).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 2.0]);
}
#[test]
fn prod_of_2x2_twos_yields_16() {
let a = Array::full::<f32>(&(2, 2), 2.0).unwrap();
let mut r = a.prod(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 16.0);
}
#[test]
fn prod_axes_of_2x2_along_axis1() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.prod_axes(&[1], false).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.0, 12.0]);
}
#[test]
fn prod_axes_empty_returns_clone() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.prod_axes(&[], false).unwrap();
assert_eq!(r.shape(), vec![2, 2]);
assert_eq!(r.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn median_odd_count_is_middle_element() {
let a = Array::from_slice(&[3.0_f32, 1.0, 2.0], &(3,)).unwrap();
let mut r = a.median(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 2.0);
}
#[test]
fn median_even_count_averages_two_midpoints() {
let a = Array::from_slice(&[4.0_f32, 1.0, 3.0, 2.0], &(4,)).unwrap();
let mut r = a.median(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 2.5);
}
#[test]
fn median_promotes_int_input_to_float() {
let a = Array::from_slice(&[1_i32, 2], &(2,)).unwrap();
assert_eq!(a.dtype().unwrap(), mlxrs::Dtype::I32);
let mut r = a.median(false).unwrap();
assert_eq!(
r.dtype().unwrap(),
mlxrs::Dtype::F32,
"median promotes int to f32"
);
assert_eq!(r.item::<f32>().unwrap(), 1.5);
}
#[test]
fn median_axes_over_axis1_with_keepdims() {
let a = Array::from_slice(&[1.0_f32, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0], &(2, 4)).unwrap();
let r = a.median_axes(&[1], true).unwrap();
assert_eq!(r.shape(), vec![2, 1]);
let mut r = r;
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.5, 25.0]);
}
#[test]
fn median_axes_over_axis0_no_keepdims() {
let a = Array::from_slice(&[1.0_f32, 5.0, 3.0, 9.0, 2.0, 7.0], &(3, 2)).unwrap();
let r = a.median_axes(&[0], false).unwrap();
assert_eq!(r.shape(), vec![2]);
let mut r = mlxrs::ops::shape::contiguous(&r, false).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.0, 7.0]);
}
#[test]
fn median_axes_empty_is_rejected() {
let a = Array::from_slice(&[1_i32, 2, 3, 4], &(2, 2)).unwrap();
match a.median_axes(&[], false) {
Err(mlxrs::Error::EmptyInput(p)) => assert_eq!(p.context(), "median_axes: axes"),
other => panic!("expected EmptyInput for empty axes, got {other:?}"),
}
}
#[test]
fn median_scalar_rank0_is_identity() {
let a = Array::from_slice::<f32>(&[5.0], &[0i32; 0]).unwrap();
assert_eq!(a.ndim(), 0);
let mut r = a.median(false).unwrap();
assert_eq!(r.item::<f32>().unwrap(), 5.0);
}
#[test]
fn median_scalar_rank0_promotes_int() {
let a = Array::from_slice::<i32>(&[7], &[0i32; 0]).unwrap();
assert_eq!(a.ndim(), 0);
let mut r = a.median(false).unwrap();
assert_eq!(r.dtype().unwrap(), mlxrs::Dtype::F32);
assert_eq!(r.item::<f32>().unwrap(), 7.0);
}
#[test]
fn mean_freefn_parity_with_method() {
let a = Array::ones::<f32>(&(2, 2)).unwrap();
let mut method = a.mean(false).unwrap();
let mut freefn = mlxrs::ops::reduction::mean(&a, false).unwrap();
assert_eq!(method.item::<f32>().unwrap(), freefn.item::<f32>().unwrap());
}