#![cfg(feature = "vlm")]
use mlxrs::{
Array, Error,
vlm::inputs::{PaddingSide, PrepareInputsOpts, prepare_inputs},
};
#[test]
fn prepare_inputs_text_only_no_payloads() {
let batch_a = [10_u32, 20, 30];
let batches: &[&[u32]] = &[&batch_a];
let out = prepare_inputs(batches, None, None, None, &PrepareInputsOpts::default()).unwrap();
assert_eq!(out.input_ids_ref().shape(), vec![1, 3]);
assert_eq!(out.attention_mask_ref().shape(), vec![1, 3]);
assert!(out.pixel_values_ref().is_none());
assert!(out.input_features_ref().is_none());
assert!(out.pixel_values_videos_ref().is_none());
}
#[test]
fn prepare_inputs_image_only_dispatch() {
let batch_a = [10_u32, 99, 20]; let batches: &[&[u32]] = &[&batch_a];
let pixels = Array::full::<f32>(&(1usize, 3usize, 4usize, 4usize), 0.5).unwrap();
let out = prepare_inputs(
batches,
Some(pixels),
None,
None,
&PrepareInputsOpts::default(),
)
.unwrap();
assert_eq!(out.input_ids_ref().shape(), vec![1, 3]);
assert_eq!(out.attention_mask_ref().shape(), vec![1, 3]);
let pv = out.pixel_values_ref().expect("pixel_values present");
assert_eq!(pv.shape(), vec![1, 3, 4, 4]);
assert!(out.input_features_ref().is_none());
assert!(out.pixel_values_videos_ref().is_none());
}
#[test]
fn prepare_inputs_audio_only_dispatch() {
let batch_a = [10_u32, 88, 20]; let batches: &[&[u32]] = &[&batch_a];
let features = Array::full::<f32>(&(1usize, 80usize, 100usize), 0.25).unwrap();
let out = prepare_inputs(
batches,
None,
Some(features),
None,
&PrepareInputsOpts::default(),
)
.unwrap();
let f = out.input_features_ref().expect("input_features present");
assert_eq!(f.shape(), vec![1, 80, 100]);
assert!(out.pixel_values_ref().is_none());
assert!(out.pixel_values_videos_ref().is_none());
}
#[test]
fn prepare_inputs_combined_image_text_audio() {
let batch_a = [10_u32, 99, 88, 20];
let batches: &[&[u32]] = &[&batch_a];
let pixels = Array::full::<f32>(&(1usize, 3usize, 4usize, 4usize), 0.5).unwrap();
let features = Array::full::<f32>(&(1usize, 80usize, 50usize), 0.25).unwrap();
let out = prepare_inputs(
batches,
Some(pixels),
Some(features),
None,
&PrepareInputsOpts::default(),
)
.unwrap();
assert_eq!(out.input_ids_ref().shape(), vec![1, 4]);
assert!(out.pixel_values_ref().is_some());
assert!(out.input_features_ref().is_some());
assert!(out.pixel_values_videos_ref().is_none());
}
#[test]
fn prepare_inputs_video_dispatch() {
let batch_a = [10_u32, 77, 20]; let batches: &[&[u32]] = &[&batch_a];
let frames = Array::full::<f32>(&(8usize, 224usize, 224usize, 3usize), 0.5).unwrap();
let out = prepare_inputs(
batches,
None,
None,
Some(frames),
&PrepareInputsOpts::default(),
)
.unwrap();
let v = out
.pixel_values_videos_ref()
.expect("pixel_values_videos present");
assert_eq!(v.shape(), vec![8, 224, 224, 3]);
assert!(out.pixel_values_ref().is_none());
assert!(out.input_features_ref().is_none());
}
#[test]
fn prepare_inputs_padding_side_left_default() {
let a = [10_u32, 20]; let b = [30_u32, 40, 50, 60]; let batches: &[&[u32]] = &[&a, &b];
let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(true)
.with_padding_side(PaddingSide::Left);
let mut out = prepare_inputs(batches, None, None, None, &opts).unwrap();
assert_eq!(out.input_ids_ref().shape(), vec![2, 4]);
let ids = out.input_ids_mut().to_vec::<i32>().unwrap();
assert_eq!(&ids[0..4], &[0, 0, 10, 20]);
assert_eq!(&ids[4..8], &[30, 40, 50, 60]);
let mask = out.attention_mask_mut().to_vec::<bool>().unwrap();
assert_eq!(&mask[0..4], &[false, false, true, true]);
assert_eq!(&mask[4..8], &[true, true, true, true]);
}
#[test]
fn prepare_inputs_padding_side_right_vs_left() {
let a = [10_u32, 20]; let b = [30_u32, 40, 50, 60]; let batches: &[&[u32]] = &[&a, &b];
let opts = PrepareInputsOpts::new()
.with_pad_token_id(7)
.with_padding(true)
.with_padding_side(PaddingSide::Right);
let mut out = prepare_inputs(batches, None, None, None, &opts).unwrap();
let ids = out.input_ids_mut().to_vec::<i32>().unwrap();
assert_eq!(&ids[0..4], &[10, 20, 7, 7]);
assert_eq!(&ids[4..8], &[30, 40, 50, 60]);
let mask = out.attention_mask_mut().to_vec::<bool>().unwrap();
assert_eq!(&mask[0..4], &[true, true, false, false]);
}
#[test]
fn prepare_inputs_padding_disabled_requires_uniform_length() {
let a = [10_u32, 20];
let b = [30_u32, 40, 50];
let batches: &[&[u32]] = &[&a, &b];
let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(false)
.with_padding_side(PaddingSide::Left);
let err = prepare_inputs(batches, None, None, None, &opts).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("padding=false"), "unexpected error: {msg}");
}
#[test]
fn prepare_inputs_empty_batches_errors() {
let batches: &[&[u32]] = &[];
let err = prepare_inputs(batches, None, None, None, &PrepareInputsOpts::default()).unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("empty"), "unexpected error: {msg}");
}
#[test]
fn prepare_inputs_uniform_no_padding_needed() {
let a = [10_u32, 20, 30];
let b = [40_u32, 50, 60];
let batches: &[&[u32]] = &[&a, &b];
let mut out = prepare_inputs(batches, None, None, None, &PrepareInputsOpts::default()).unwrap();
assert_eq!(out.input_ids_ref().shape(), vec![2, 3]);
let ids = out.input_ids_mut().to_vec::<i32>().unwrap();
assert_eq!(ids, vec![10, 20, 30, 40, 50, 60]);
let mask = out.attention_mask_mut().to_vec::<bool>().unwrap();
assert!(mask.iter().all(|&b| b));
}
#[test]
fn prepare_inputs_caller_supplied_attention_mask_overrides_default() {
let a = [10_u32, 20, 0, 0]; let b = [30_u32, 40, 50, 60]; let batches: &[&[u32]] = &[&a, &b];
let caller_mask = vec![vec![true, true, false, false], vec![true, true, true, true]];
let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(true)
.with_padding_side(PaddingSide::Left)
.with_attention_mask(caller_mask);
let mut out = prepare_inputs(batches, None, None, None, &opts).unwrap();
assert_eq!(out.attention_mask_ref().shape(), vec![2, 4]);
let mask = out.attention_mask_mut().to_vec::<bool>().unwrap();
assert_eq!(&mask[0..4], &[true, true, false, false]);
assert_eq!(&mask[4..8], &[true, true, true, true]);
}
#[test]
fn prepare_inputs_caller_mask_left_pads_with_false() {
let a = [10_u32, 20]; let b = [30_u32, 40, 50, 60]; let batches: &[&[u32]] = &[&a, &b];
let caller_mask = vec![vec![true, false], vec![true, true, true, true]];
let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(true)
.with_padding_side(PaddingSide::Left)
.with_attention_mask(caller_mask);
let mut out = prepare_inputs(batches, None, None, None, &opts).unwrap();
let mask = out.attention_mask_mut().to_vec::<bool>().unwrap();
assert_eq!(&mask[0..4], &[false, false, true, false]);
assert_eq!(&mask[4..8], &[true, true, true, true]);
}
#[test]
fn prepare_inputs_caller_mask_right_pads_with_false() {
let a = [10_u32, 20]; let b = [30_u32, 40, 50, 60]; let batches: &[&[u32]] = &[&a, &b];
let caller_mask = vec![vec![true, false], vec![true, true, true, true]];
let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(true)
.with_padding_side(PaddingSide::Right)
.with_attention_mask(caller_mask);
let mut out = prepare_inputs(batches, None, None, None, &opts).unwrap();
let mask = out.attention_mask_mut().to_vec::<bool>().unwrap();
assert_eq!(&mask[0..4], &[true, false, false, false]);
assert_eq!(&mask[4..8], &[true, true, true, true]);
}
#[test]
fn prepare_inputs_caller_mask_dimension_mismatch_errors() {
let a = [10_u32, 20];
let b = [30_u32, 40];
let batches: &[&[u32]] = &[&a, &b];
let bad_mask = vec![vec![true, true]]; let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(true)
.with_padding_side(PaddingSide::Left)
.with_attention_mask(bad_mask);
let err = prepare_inputs(batches, None, None, None, &opts).unwrap_err();
match &err {
Error::LengthMismatch(p) => {
assert!(
p.context().contains("attention_mask outer"),
"expected outer-length mismatch context, got: {}",
p.context()
);
assert_eq!(p.expected(), 2);
assert_eq!(p.actual(), 1);
}
_ => panic!("expected LengthMismatch, got: {err:?}"),
}
}
#[test]
fn prepare_inputs_caller_mask_inner_dimension_mismatch_errors() {
let a = [10_u32, 20];
let b = [30_u32, 40, 50];
let batches: &[&[u32]] = &[&a, &b];
let bad_mask = vec![vec![true, true], vec![true, true]]; let opts = PrepareInputsOpts::new()
.with_pad_token_id(0)
.with_padding(true)
.with_padding_side(PaddingSide::Left)
.with_attention_mask(bad_mask);
let err = prepare_inputs(batches, None, None, None, &opts).unwrap_err();
match &err {
Error::LengthMismatch(p) => {
assert!(
p.context().contains("attention_mask[i]"),
"expected per-row inner-length mismatch context, got: {}",
p.context()
);
assert_eq!(p.expected(), 3);
assert_eq!(p.actual(), 2);
}
_ => panic!("expected LengthMismatch, got: {err:?}"),
}
}
#[test]
fn load_video_wraps_vlm_video() {
use mlxrs::vlm::{
image::{ColorOrder, ImageProcessorConfig, ResizeFilter},
inputs::load_video,
};
let mk_frame = || {
let buf = ::image::ImageBuffer::<::image::Rgb<u8>, Vec<u8>>::from_fn(4, 4, |x, y| {
::image::Rgb([x as u8 * 50, y as u8 * 50, 128])
});
::image::DynamicImage::ImageRgb8(buf)
};
let frames = vec![mk_frame(), mk_frame()];
let cfg = ImageProcessorConfig::new()
.with_size((4, 4))
.with_mean([0.0, 0.0, 0.0])
.with_std([1.0, 1.0, 1.0])
.with_rescale_factor(1.0 / 255.0)
.with_do_resize(false)
.with_do_rescale(true)
.with_do_normalize(false)
.with_resample(ResizeFilter::Bilinear)
.with_color_order(ColorOrder::Rgb);
let out = load_video(&frames, &cfg).unwrap();
assert_eq!(out.shape(), vec![2, 4, 4, 3]);
}
#[cfg(feature = "audio")]
mod audio_glue {
use mlxrs::{
Array,
vlm::inputs::{load_audio_vlm, normalize_audio_features, read_audio},
};
use std::io::Write;
fn write_wav(path: &std::path::Path, samples: &[i16], sr: u32) {
let mut f = std::fs::File::create(path).unwrap();
let data_bytes = (samples.len() * 2) as u32;
let chunk_size = 36 + data_bytes;
f.write_all(b"RIFF").unwrap();
f.write_all(&chunk_size.to_le_bytes()).unwrap();
f.write_all(b"WAVE").unwrap();
f.write_all(b"fmt ").unwrap();
f.write_all(&16u32.to_le_bytes()).unwrap(); f.write_all(&1u16.to_le_bytes()).unwrap(); f.write_all(&1u16.to_le_bytes()).unwrap(); f.write_all(&sr.to_le_bytes()).unwrap();
f.write_all(&(sr * 2).to_le_bytes()).unwrap(); f.write_all(&2u16.to_le_bytes()).unwrap(); f.write_all(&16u16.to_le_bytes()).unwrap(); f.write_all(b"data").unwrap();
f.write_all(&data_bytes.to_le_bytes()).unwrap();
for &s in samples {
f.write_all(&s.to_le_bytes()).unwrap();
}
}
#[test]
fn read_audio_wraps_load_audio() {
let dir = std::env::temp_dir();
let path = dir.join(format!("mlxrs_v4_read_audio_{}.wav", std::process::id()));
let samples_i16: Vec<i16> = (0..1600)
.map(|i| ((i as f32 * 0.1).sin() * 8000.0) as i16)
.collect();
write_wav(&path, &samples_i16, 16000);
let (samples, sr) = read_audio(&path).unwrap();
assert_eq!(sr, 16000);
assert_eq!(samples.len(), 1600);
assert!(samples[0].is_finite());
let (samples2, sr2) = mlxrs::audio::io::load_audio(&path).unwrap();
assert_eq!(sr, sr2);
assert_eq!(samples, samples2);
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_audio_vlm_no_resample_when_sr_matches() {
let dir = std::env::temp_dir();
let path = dir.join(format!(
"mlxrs_v4_load_audio_same_{}.wav",
std::process::id()
));
let samples_i16: Vec<i16> = (0..1600)
.map(|i| ((i as f32 * 0.1).sin() * 8000.0) as i16)
.collect();
write_wav(&path, &samples_i16, 16000);
let out = load_audio_vlm(&path, 16000).unwrap();
assert_eq!(out.len(), 1600);
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_audio_vlm_resamples_when_sr_differs() {
let dir = std::env::temp_dir();
let path = dir.join(format!(
"mlxrs_v4_load_audio_resamp_{}.wav",
std::process::id()
));
let samples_i16: Vec<i16> = (0..1600)
.map(|i| ((i as f32 * 0.1).sin() * 8000.0) as i16)
.collect();
write_wav(&path, &samples_i16, 16000);
let out = load_audio_vlm(&path, 8000).unwrap();
assert!(
(out.len() as i64 - 800).abs() <= 2,
"expected ~800 samples after downsampling, got {}",
out.len()
);
let _ = std::fs::remove_file(&path);
}
#[test]
fn normalize_audio_features_matches_python_reference() {
let features =
Array::from_slice::<f32>(&[1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0], &(2usize, 3usize)).unwrap();
let mut normalized = normalize_audio_features(&features).unwrap();
assert_eq!(normalized.shape(), vec![2, 3]);
let vals = normalized.to_vec::<f32>().unwrap();
let mean = 3.5_f32;
let sq_diff: f32 = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0]
.iter()
.map(|x| (x - mean).powi(2))
.sum();
let std = (sq_diff / 6.0).sqrt();
let denom = std + 1e-6_f32;
for (i, x) in [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0].iter().enumerate() {
let expected = (x - mean) / denom;
assert!(
(vals[i] - expected).abs() < 1e-5,
"vals[{i}]={} expected≈{expected}",
vals[i]
);
}
}
}