use super::*;
#[test]
fn resize_filter_as_str_all_variants() {
assert_eq!(ResizeFilter::Nearest.as_str(), "nearest");
assert_eq!(ResizeFilter::Bilinear.as_str(), "bilinear");
assert_eq!(ResizeFilter::Bicubic.as_str(), "bicubic");
assert_eq!(ResizeFilter::Lanczos3.as_str(), "lanczos3");
for f in [
ResizeFilter::Nearest,
ResizeFilter::Bilinear,
ResizeFilter::Bicubic,
ResizeFilter::Lanczos3,
] {
assert_eq!(format!("{f}"), f.as_str());
}
}
#[test]
fn color_order_as_str_all_variants() {
assert_eq!(ColorOrder::Rgb.as_str(), "rgb");
assert_eq!(ColorOrder::Bgr.as_str(), "bgr");
for c in [ColorOrder::Rgb, ColorOrder::Bgr] {
assert_eq!(format!("{c}"), c.as_str());
}
}
#[test]
fn layout_as_str_all_variants() {
assert_eq!(Layout::Hwc.as_str(), "hwc");
assert_eq!(Layout::Chw.as_str(), "chw");
assert_eq!(Layout::Bchw.as_str(), "bchw");
for l in [Layout::Hwc, Layout::Chw, Layout::Bchw] {
assert_eq!(format!("{l}"), l.as_str());
}
}
#[test]
fn make_channel_broadcast_rank1_returns_unreshaped_channel_vector() {
let mut a = make_channel_broadcast(&[0.485, 0.456, 0.406], 1, Dtype::F32)
.expect("rank-1 channel broadcast must succeed");
assert_eq!(
a.shape(),
vec![3],
"ndim<=1 returns the unreshaped (3,) array"
);
let v: Vec<f32> = a.to_vec().expect("materialize (3,) channel vector");
assert_eq!(v.len(), 3);
assert!((v[0] - 0.485).abs() < 1e-6);
assert!((v[1] - 0.456).abs() < 1e-6);
assert!((v[2] - 0.406).abs() < 1e-6);
}
#[test]
fn make_channel_broadcast_rank2_reshapes_to_leading_singleton() {
let a = make_channel_broadcast(&[1.0, 2.0, 3.0], 2, Dtype::F32)
.expect("rank-2 channel broadcast must succeed");
assert_eq!(a.shape(), vec![1, 3], "ndim==2 reshapes to [1, 3]");
}
#[test]
fn make_channel_broadcast_rejects_ndim_over_max() {
let err = make_channel_broadcast(&[0.0, 0.0, 0.0], 17, Dtype::F32)
.expect_err("ndim > MAX_NDIM (16) must be rejected");
match err {
Error::CapExceeded(p) => {
assert_eq!(p.cap_name(), "MAX_NDIM");
assert_eq!(p.cap(), 16, "cap is MAX_NDIM = 16");
assert_eq!(p.observed(), 17, "offending ndim is 17");
}
other => panic!("expected CapExceeded(MAX_NDIM), got {other:?}"),
}
}
#[test]
fn normalize_rejects_rank0_scalar_input() {
let scalar = Array::from_slice(&[42.0_f32], &[0i32; 0]).expect("0-d scalar array");
assert_eq!(scalar.ndim(), 0, "from_slice with an empty shape is rank-0");
let err = normalize(&scalar, &[0.0; 3], &[1.0; 3])
.expect_err("rank-0 scalar input must be rejected before the trailing-dim read");
match err {
Error::RankMismatch(p) => {
assert_eq!(p.actual(), 0, "observed rank is 0");
assert!(
p.context().contains("normalize"),
"RankMismatch must name normalize; got: {}",
p.context()
);
}
other => panic!("expected RankMismatch on rank-0 input, got {other:?}"),
}
}
#[test]
fn patchify_w_not_divisible_errors_on_width_axis() {
let arr = Array::from_slice(&[0.0_f32; 4 * 5 * 3], &(4usize, 5, 3)).expect("[4,5,3] array");
let err = patchify(&arr, 2).expect_err("W=5 not divisible by patch_size=2 must error");
match err {
Error::DivisibilityConstraint(p) => {
assert_eq!(p.name_dividend(), "W", "must be the W-axis arm, not H");
assert_eq!(p.dividend(), 5);
assert_eq!(p.divisor(), 2);
assert!(
p.context().contains("W by patch_size"),
"context must name the W divisibility constraint; got: {}",
p.context()
);
}
other => panic!("expected DivisibilityConstraint on the W axis, got {other:?}"),
}
}
#[test]
fn load_image_corrupt_png_returns_parse_error() {
let dir = std::env::temp_dir().join(format!("mlxrs-vlm-image-parse-{}", std::process::id()));
std::fs::create_dir_all(&dir).expect("create temp dir");
let path = dir.join("corrupt.png");
std::fs::write(&path, b"this is definitely not a png file at all").expect("write garbage");
let err = load_image(&path).expect_err("a corrupt PNG must fail to decode");
let _ = std::fs::remove_file(&path);
let _ = std::fs::remove_dir(&dir);
match err {
Error::Parse(p) => assert!(
p.context().contains("load_image"),
"Parse error must name load_image; got context: {}",
p.context()
),
other => panic!("expected Error::Parse from the decode path, got {other:?}"),
}
}
#[test]
fn rotate_buf_element_count_overflow_is_typed_arithmetic_error() {
let err = rotate_buf::<u8>(&[], u32::MAX, u32::MAX, 2, RotateKind::Rotate90)
.expect_err("w * h * channels must overflow usize for u32::MAX dims and channels=2");
match err {
Error::ArithmeticOverflow(p) => {
assert_eq!(
p.context(),
"rotate_buf: elements (w * h * channels)",
"context must name the rotate_buf element-count expression"
);
assert_eq!(p.op_type(), "usize", "the overflowing result type is usize");
let ops = p.operands();
assert_eq!(ops.len(), 3, "w, h, and channels operands are all carried");
assert_eq!(ops[0], ("w", u64::from(u32::MAX)));
assert_eq!(ops[1], ("h", u64::from(u32::MAX)));
assert_eq!(ops[2], ("channels", 2));
}
other => panic!("expected ArithmeticOverflow from rotate_buf, got {other:?}"),
}
}
#[test]
fn rotate_buf_element_count_overflow_independent_of_rotation_kind() {
for kind in [
RotateKind::Rotate90,
RotateKind::Rotate270,
RotateKind::Rotate90FlipH,
RotateKind::Rotate270FlipH,
] {
let err = rotate_buf::<f32>(&[], u32::MAX, u32::MAX, 4, kind)
.expect_err("w * h * 4 must overflow usize for u32::MAX dims regardless of rotation");
match err {
Error::ArithmeticOverflow(p) => {
assert_eq!(p.context(), "rotate_buf: elements (w * h * channels)");
assert_eq!(p.operands().last(), Some(&("channels", 4)));
}
other => panic!("expected ArithmeticOverflow for {kind:?}, got {other:?}"),
}
}
}
#[test]
fn rotate_buf_u8_element_count_overflow_is_typed_arithmetic_error() {
let err = rotate_buf_u8(&[], u32::MAX, u32::MAX, 4, RotateKind::Rotate90)
.expect_err("w * h * 4 must overflow usize for u32::MAX dims");
match err {
Error::ArithmeticOverflow(p) => {
assert_eq!(
p.context(),
"rotate_buf_u8: elements (w * h * channels)",
"context must name the rotate_buf_u8 element-count expression"
);
assert_eq!(p.op_type(), "usize");
let ops = p.operands();
assert_eq!(ops.len(), 3);
assert_eq!(ops[0], ("w", u64::from(u32::MAX)));
assert_eq!(ops[1], ("h", u64::from(u32::MAX)));
assert_eq!(ops[2], ("channels", 4));
}
other => panic!("expected ArithmeticOverflow from rotate_buf_u8, got {other:?}"),
}
}
#[test]
fn rotate_buf_rotate90_tiny_luma_permutes_into_swapped_extent() {
let src: [u8; 2] = [7, 9];
let dst =
rotate_buf::<u8>(&src, 2, 1, 1, RotateKind::Rotate90).expect("2x1 rotate90 well under cap");
assert_eq!(dst.len(), 2, "1*2*channels(1) destination element count");
assert_eq!(
dst,
vec![7, 9],
"Rotate90 of a single row stacks the pixels"
);
}