use std::ffi::CString;
use mlxrs::{Array, ops};
#[test]
fn transpose_2x3_swaps_to_3x2() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let t = a.transpose().unwrap();
assert_eq!(t.shape(), vec![3, 2]);
}
#[test]
fn transpose_axes_3d_permutes() {
let a = Array::ones::<f32>(&(2usize, 3, 4)).unwrap();
let t = a.transpose_axes(&[2, 0, 1]).unwrap();
assert_eq!(t.shape(), vec![4, 2, 3]);
}
#[test]
fn transpose_axes_empty_for_scalar() {
let empty: [i32; 0] = [];
let a = Array::from_slice::<f32>(&[7.0], &empty).unwrap();
let mut t = a.transpose_axes(&[]).unwrap();
assert_eq!(t.shape(), Vec::<usize>::new());
assert_eq!(t.item::<f32>().unwrap(), 7.0);
}
#[test]
fn expand_dims_axes_inserts_dims() {
let a = Array::ones::<f32>(&(3usize, 4)).unwrap();
let e = a.expand_dims_axes(&[0, 2]).unwrap();
assert_eq!(e.shape(), vec![1, 3, 1, 4]);
}
#[test]
fn expand_dims_axes_empty_is_clone() {
let mut a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.expand_dims_axes(&[]).unwrap();
assert_eq!(r.shape(), a.shape());
assert_eq!(r.to_vec::<f32>().unwrap(), a.to_vec::<f32>().unwrap());
}
#[test]
fn squeeze_axes_drops_size1() {
let a = Array::ones::<f32>(&(1usize, 3, 1, 4)).unwrap();
let s = a.squeeze_axes(&[0, 2]).unwrap();
assert_eq!(s.shape(), vec![3, 4]);
}
#[test]
fn squeeze_axes_empty_is_clone() {
let mut a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut r = a.squeeze_axes(&[]).unwrap();
assert_eq!(r.shape(), a.shape());
assert_eq!(r.to_vec::<f32>().unwrap(), a.to_vec::<f32>().unwrap());
}
#[test]
fn broadcast_to_expands_shape() {
let a = Array::ones::<f32>(&(1usize, 3)).unwrap();
let b = a.broadcast_to(&(4usize, 3)).unwrap();
assert_eq!(b.shape(), vec![4, 3]);
assert_eq!(b.size(), 12);
}
#[test]
fn stack_two_2x2_along_axis0() {
let a = Array::ones::<f32>(&(2usize, 2)).unwrap();
let b = Array::ones::<f32>(&(2usize, 2)).unwrap();
let s = ops::shape::stack(&[&a, &b]).unwrap();
assert_eq!(s.shape(), vec![2, 2, 2]);
}
#[test]
fn stack_axis_two_2x2_along_axis2() {
let a = Array::ones::<f32>(&(2usize, 2)).unwrap();
let b = Array::ones::<f32>(&(2usize, 2)).unwrap();
let s = ops::shape::stack_axis(&[&a, &b], 2).unwrap();
assert_eq!(s.shape(), vec![2, 2, 2]);
}
#[test]
fn stack_with_method_form() {
let a = Array::ones::<f32>(&(2usize, 2)).unwrap();
let b = Array::ones::<f32>(&(2usize, 2)).unwrap();
let c = Array::ones::<f32>(&(2usize, 2)).unwrap();
let s = a.stack_with(&[&b, &c], 0).unwrap();
assert_eq!(s.shape(), vec![3, 2, 2]);
}
#[test]
fn stack_rejects_empty_input() {
let r = ops::shape::stack(&[]);
assert!(matches!(r, Err(mlxrs::Error::EmptyInput(_))));
let r2 = ops::shape::stack_axis(&[], 0);
assert!(matches!(r2, Err(mlxrs::Error::EmptyInput(_))));
}
#[test]
fn split_sections_at_indices_yields_three_parts() {
let a = Array::arange::<f32>(0.0, 10.0, 1.0).unwrap();
let parts = a.split_sections(&[3, 5], 0).unwrap();
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape(), vec![3]);
assert_eq!(parts[1].shape(), vec![2]);
assert_eq!(parts[2].shape(), vec![5]);
}
#[test]
fn split_sections_empty_indices_yields_single_part() {
let a = Array::arange::<f32>(0.0, 4.0, 1.0).unwrap();
let parts = a.split_sections(&[], 0).unwrap();
assert_eq!(parts.len(), 1);
assert_eq!(parts[0].shape(), vec![4]);
}
#[test]
fn flatten_2x3_to_6() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &(2, 3)).unwrap();
let mut f = a.flatten(0, -1).unwrap();
assert_eq!(f.shape(), vec![6]);
assert_eq!(
f.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
);
}
#[test]
fn flatten_partial_range() {
let a = Array::ones::<f32>(&(2usize, 3, 4)).unwrap();
let f = a.flatten(1, 2).unwrap();
assert_eq!(f.shape(), vec![2, 12]);
}
#[test]
fn swapaxes_swaps_axes() {
let a = Array::ones::<f32>(&(2usize, 3, 4)).unwrap();
let s = a.swapaxes(0, 2).unwrap();
assert_eq!(s.shape(), vec![4, 3, 2]);
}
#[test]
fn pad_constant_grows_axis() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let zero = Array::from_slice::<f32>(&[0.0], &[0i32; 0]).unwrap();
let mode = CString::new("constant").unwrap();
let p = ops::shape::pad(&a, &[0], &[2], &[1], &zero, &mode).unwrap();
assert_eq!(p.shape(), vec![6]);
}
#[test]
fn pad_rejects_length_mismatch() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let zero = Array::from_slice::<f32>(&[0.0], &[0i32; 0]).unwrap();
let mode = CString::new("constant").unwrap();
let r = ops::shape::pad(&a, &[0], &[2], &[1, 2], &zero, &mode);
assert!(
matches!(
r,
Err(mlxrs::Error::MultiLengthMismatch(ref p))
if p.context() == "pad: axes/low/high"
),
"expected Err(MultiLengthMismatch), got {r:?}"
);
}
#[test]
fn pad_rejects_negative_low() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let zero = Array::from_slice::<f32>(&[0.0], &[0i32; 0]).unwrap();
let mode = CString::new("constant").unwrap();
let r = ops::shape::pad(&a, &[0], &[-1], &[1], &zero, &mode);
match r {
Err(mlxrs::Error::OutOfRange(p)) => {
assert!(
p.context().contains("validate_dims"),
"context names the validator: {}",
p.context()
);
assert_eq!(p.requirement(), "must be non-negative");
assert!(
p.value().contains("-1"),
"value names the offending dim: {}",
p.value()
);
}
other => panic!("expected Err(OutOfRange) for negative low, got {other:?}"),
}
}
#[test]
fn pad_rejects_negative_high() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let zero = Array::from_slice::<f32>(&[0.0], &[0i32; 0]).unwrap();
let mode = CString::new("constant").unwrap();
let r = ops::shape::pad(&a, &[0], &[1], &[-2], &zero, &mode);
match r {
Err(mlxrs::Error::OutOfRange(p)) => {
assert_eq!(p.requirement(), "must be non-negative");
assert!(
p.value().contains("-2"),
"value names the offending dim: {}",
p.value()
);
}
other => panic!("expected Err(OutOfRange) for negative high, got {other:?}"),
}
}
#[test]
fn as_strided_basic_view() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let mut v = unsafe { ops::shape::as_strided(&a, &(2usize, 2), &[2, 1], 0) }.unwrap();
assert_eq!(v.shape(), vec![2, 2]);
assert_eq!(v.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
}
#[test]
fn as_strided_with_offset() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let mut v = unsafe { ops::shape::as_strided(&a, &(2usize,), &[1], 1) }.unwrap();
assert_eq!(v.shape(), vec![2]);
assert_eq!(v.to_vec::<f32>().unwrap(), vec![1.0, 2.0]);
}
#[test]
fn as_strided_accepts_slice_shape() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let shape: &[i32] = &[2, 2];
let mut v = unsafe { ops::shape::as_strided(&a, &shape, &[2, 1], 0) }.unwrap();
assert_eq!(v.shape(), vec![2, 2]);
assert_eq!(v.to_vec::<f32>().unwrap(), vec![0.0, 1.0, 2.0, 3.0]);
}
#[test]
fn as_strided_shape_strides_length_mismatch_errors() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let r = unsafe { ops::shape::as_strided(&a, &(2usize, 2), &[1i64], 0) };
match r {
Err(mlxrs::Error::LengthMismatch(p)) => {
assert!(
p.context().contains("as_strided") && p.context().contains("shape length"),
"context names the as_strided shape-vs-strides check: {}",
p.context()
);
assert_eq!(p.expected(), 2, "shape length");
assert_eq!(p.actual(), 1, "strides length");
}
other => panic!("expected Err(LengthMismatch), got {other:?}"),
}
}
#[test]
fn moveaxis_moves_first_to_last() {
let a = Array::ones::<f32>(&(2usize, 3, 4)).unwrap();
let m = a.moveaxis(0, 2).unwrap();
assert_eq!(m.shape(), vec![3, 4, 2]);
}
#[test]
fn moveaxis_negative_axes() {
let a = Array::ones::<f32>(&(2usize, 3, 4)).unwrap();
let m = a.moveaxis(-1, 0).unwrap();
assert_eq!(m.shape(), vec![4, 2, 3]);
}
#[test]
fn moveaxis_free_fn_form() {
let a = Array::ones::<f32>(&(2usize, 3, 4)).unwrap();
let m = ops::shape::moveaxis(&a, 1, 0).unwrap();
assert_eq!(m.shape(), vec![3, 2, 4]);
}
#[test]
fn roll_flattened_right() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.roll(&[2]).unwrap();
assert_eq!(r.shape(), vec![2, 3]);
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![4.0, 5.0, 0.0, 1.0, 2.0, 3.0]
);
}
#[test]
fn roll_flattened_negative_shift() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &[6i32]).unwrap();
let mut r = ops::shape::roll(&a, &[-1]).unwrap();
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 5.0, 0.0]
);
}
#[test]
fn roll_shift_larger_than_size_wraps() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let mut r = a.roll(&[6]).unwrap();
assert_eq!(r.to_vec::<f32>().unwrap(), vec![2.0, 3.0, 0.0, 1.0]);
}
#[test]
fn roll_axis_rolls_columns_then_rows() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r1 = a.roll_axis(&[1], 1).unwrap();
assert_eq!(r1.shape(), vec![2, 3]);
assert_eq!(
r1.to_vec::<f32>().unwrap(),
vec![2.0, 0.0, 1.0, 5.0, 3.0, 4.0]
);
let mut r0 = ops::shape::roll_axis(&a, &[1], 0).unwrap();
assert_eq!(
r0.to_vec::<f32>().unwrap(),
vec![3.0, 4.0, 5.0, 0.0, 1.0, 2.0]
);
}
#[test]
fn roll_axis_negative_axis() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.roll_axis(&[1], -1).unwrap();
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![2.0, 0.0, 1.0, 5.0, 3.0, 4.0]
);
}
#[test]
fn roll_axes_multi_shift() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &(2, 3)).unwrap();
let mut r = a.roll_axes(&[1, 1], &[0, 1]).unwrap();
assert_eq!(r.shape(), vec![2, 3]);
assert_eq!(
r.to_vec::<f32>().unwrap(),
vec![5.0, 3.0, 4.0, 2.0, 0.0, 1.0]
);
}
#[test]
fn roll_axes_rejects_shift_axes_count_mismatch() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], &(2, 3)).unwrap();
let too_few = ops::shape::roll_axes(&a, &[1], &[0, 1]);
assert!(
matches!(too_few, Err(mlxrs::Error::LengthMismatch(_))),
"expected Err(LengthMismatch) for too-few shifts, got {too_few:?}"
);
let too_many = ops::shape::roll_axes(&a, &[1, 2, 3], &[0, 1]);
assert!(
matches!(too_many, Err(mlxrs::Error::LengthMismatch(_))),
"expected Err(LengthMismatch) for too-many shifts, got {too_many:?}"
);
}
#[test]
fn tile_1d_doubles() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let mut t = a.tile(&[2]).unwrap();
assert_eq!(t.shape(), vec![6]);
assert_eq!(
t.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]
);
}
#[test]
fn tile_reps_of_one_is_identity() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let mut t = ops::shape::tile(&a, &[1]).unwrap();
assert_eq!(t.shape(), vec![3]);
assert_eq!(t.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0]);
}
#[test]
fn tile_multi_reps_2d() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2, 2)).unwrap();
let mut t = a.tile(&[2, 1]).unwrap();
assert_eq!(t.shape(), vec![4, 2]);
assert_eq!(
t.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn tile_reps_longer_than_ndim_prepends_axis() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &[2i32]).unwrap();
let mut t = a.tile(&[2, 2]).unwrap();
assert_eq!(t.shape(), vec![2, 4]);
assert_eq!(
t.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]
);
}
#[test]
fn roll_multi_shift_sum_overflow_is_typed_error() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let r = ops::shape::roll(&a, &[i32::MAX, 1]);
assert!(
matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))),
"expected Err(ArithmeticOverflow) for overflowing shift sum, got {r:?}"
);
let r2 = a.roll(&[i32::MAX, 1]);
assert!(matches!(r2, Err(mlxrs::Error::ArithmeticOverflow(_))));
}
#[test]
fn roll_axis_multi_shift_sum_overflow_is_typed_error() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &(2, 2)).unwrap();
let r = ops::shape::roll_axis(&a, &[i32::MAX, i32::MAX], 0);
assert!(
matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))),
"expected Err(ArithmeticOverflow) for overflowing shift sum, got {r:?}"
);
}
#[test]
fn roll_int_min_shift_is_typed_range_error() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let r = ops::shape::roll(&a, &[i32::MIN]);
assert!(
matches!(r, Err(mlxrs::Error::OutOfRange(_))),
"roll(i32::MIN) should be OutOfRange, got {r:?}"
);
let r_axis = ops::shape::roll_axis(&a, &[i32::MIN], 0);
assert!(
matches!(r_axis, Err(mlxrs::Error::OutOfRange(_))),
"roll_axis(i32::MIN) should be OutOfRange, got {r_axis:?}"
);
let r_axes = ops::shape::roll_axes(&a, &[i32::MIN], &[0]);
assert!(
matches!(r_axes, Err(mlxrs::Error::OutOfRange(_))),
"roll_axes(i32::MIN) should be OutOfRange, got {r_axes:?}"
);
}
#[test]
fn roll_sum_to_int_min_is_typed_range_error() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let half = i32::MIN / 2; let r = ops::shape::roll(&a, &[half, half]);
assert!(
matches!(r, Err(mlxrs::Error::OutOfRange(_))),
"expected Err(OutOfRange) for shift sum == i32::MIN, got {r:?}"
);
}
#[test]
fn tile_huge_rep_output_dim_overflow_is_typed_error() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &[2i32]).unwrap();
let r = ops::shape::tile(&a, &[i32::MAX]);
assert!(
matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))),
"expected Err(ArithmeticOverflow) for overflowing tile out dim, got {r:?}"
);
let r2 = a.tile(&[i32::MAX]);
assert!(matches!(r2, Err(mlxrs::Error::ArithmeticOverflow(_))));
}
#[test]
fn tile_negative_reps_is_typed_range_error() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let r = ops::shape::tile(&a, &[-1]);
assert!(
matches!(r, Err(mlxrs::Error::OutOfRange(_))),
"expected Err(OutOfRange) for negative reps, got {r:?}"
);
let r2 = ops::shape::tile(&a, &[-2, 2]);
assert!(
matches!(r2, Err(mlxrs::Error::OutOfRange(_))),
"expected Err(OutOfRange) for leading negative reps, got {r2:?}"
);
}
#[test]
fn tile_zero_reps_yields_empty_dim() {
let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3i32]).unwrap();
let t = ops::shape::tile(&a, &[0]).unwrap();
assert_eq!(t.shape(), vec![0]);
}
#[test]
fn tile_multi_non_unit_reps_pass_intermediate_rank_guard() {
let a = Array::from_slice::<f32>(&[1.0, 2.0], &[2i32]).unwrap();
let mut t = ops::shape::tile(&a, &[2, 3]).unwrap();
assert_eq!(t.shape(), vec![2, 6]);
assert_eq!(
t.to_vec::<f32>().unwrap(),
vec![
1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0, ]
);
}
#[test]
fn as_strided_rejects_negative_dim() {
let a = Array::from_slice::<f32>(&[0.0, 1.0, 2.0, 3.0], &[4i32]).unwrap();
let shape: &[i32] = &[-1, 2];
let r = unsafe { ops::shape::as_strided(&a, &shape, &[2i64, 1], 0) };
match r {
Err(mlxrs::Error::OutOfRange(p)) => {
assert!(
p.context().contains("validate_dims"),
"context names the validator: {}",
p.context()
);
assert_eq!(p.requirement(), "must be non-negative");
assert!(
p.value().contains("-1"),
"value names the offending dim: {}",
p.value()
);
}
other => panic!("negative dim must Err(OutOfRange), got {other:?}"),
}
}