#![cfg(feature = "audio")]
use std::{
fs::{self, File},
io::Write,
path::PathBuf,
process,
};
use mlxrs::audio::io::{load_audio, resample_linear, save_wav};
const FIXTURE_MP3: &[u8] = include_bytes!("fixtures/audio_tone.mp3");
const FIXTURE_FLAC: &[u8] = include_bytes!("fixtures/audio_tone.flac");
const FIXTURE_OGG_VORBIS: &[u8] = include_bytes!("fixtures/audio_tone.ogg");
const FIXTURE_SR: u32 = 8000;
const FIXTURE_NOMINAL_SAMPLES: usize = 2000;
fn temp_wav(name: &str) -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!("mlxrs_audio_io_{}_{}.wav", process::id(), name));
p
}
fn temp_path(name: &str, ext: &str) -> PathBuf {
let mut p = std::env::temp_dir();
p.push(format!("mlxrs_audio_io_{}_{}.{}", process::id(), name, ext));
p
}
fn write_fixture(path: &std::path::Path, bytes: &[u8]) {
let mut f = File::create(path).unwrap();
f.write_all(bytes).unwrap();
f.flush().unwrap();
}
fn assert_decoded_tone(samples: &[f32], sr: u32, fmt: &str) {
assert_eq!(sr, FIXTURE_SR, "{fmt}: sample rate mismatch");
assert!(
samples.len() >= FIXTURE_NOMINAL_SAMPLES / 2
&& samples.len() <= FIXTURE_NOMINAL_SAMPLES * 5 / 2,
"{fmt}: decoded {} samples, expected ~{FIXTURE_NOMINAL_SAMPLES}",
samples.len()
);
assert!(
samples.iter().all(|s| s.is_finite()),
"{fmt}: decoded a non-finite sample"
);
assert!(
samples.iter().any(|&s| s.abs() > 1e-4),
"{fmt}: decoded an all-(near-)zero buffer (codec likely not wired)"
);
assert!(
samples.iter().all(|&s| (-1.0..=1.0).contains(&s)),
"{fmt}: decoded a sample outside [-1, 1]"
);
}
fn write_pcm16_wav(path: &std::path::Path, samples: &[i16], sample_rate: u32, channels: u16) {
let bits_per_sample: u16 = 16;
let block_align: u16 = channels * (bits_per_sample / 8);
let data_size: u32 = (samples.len() as u32) * u32::from(bits_per_sample / 8);
let file_size_minus_8: u32 = 36u32 + data_size;
let byte_rate: u32 = sample_rate * u32::from(channels) * u32::from(bits_per_sample / 8);
let mut header = [0u8; 44];
header[0..4].copy_from_slice(b"RIFF");
header[4..8].copy_from_slice(&file_size_minus_8.to_le_bytes());
header[8..12].copy_from_slice(b"WAVE");
header[12..16].copy_from_slice(b"fmt ");
header[16..20].copy_from_slice(&16u32.to_le_bytes());
header[20..22].copy_from_slice(&1u16.to_le_bytes());
header[22..24].copy_from_slice(&channels.to_le_bytes());
header[24..28].copy_from_slice(&sample_rate.to_le_bytes());
header[28..32].copy_from_slice(&byte_rate.to_le_bytes());
header[32..34].copy_from_slice(&block_align.to_le_bytes());
header[34..36].copy_from_slice(&bits_per_sample.to_le_bytes());
header[36..40].copy_from_slice(b"data");
header[40..44].copy_from_slice(&data_size.to_le_bytes());
let mut f = File::create(path).unwrap();
f.write_all(&header).unwrap();
for s in samples {
f.write_all(&s.to_le_bytes()).unwrap();
}
f.flush().unwrap();
}
#[test]
fn wav_round_trip_preserves_samples_within_quantization() {
let path = temp_wav("round_trip");
let samples: Vec<f32> = (0..32)
.map(|i| ((i as f32) / 16.0 - 1.0).clamp(-1.0, 1.0))
.collect();
save_wav(&path, &samples, 16_000).unwrap();
let (got, sr) = load_audio(&path).unwrap();
assert_eq!(sr, 16_000);
assert_eq!(got.len(), samples.len());
for (g, w) in got.iter().zip(samples.iter()) {
assert!(
(g - w).abs() <= 1.0 / 16_384.0,
"round-trip diff too large: got={g} want={w}"
);
}
let _ = fs::remove_file(&path);
}
#[test]
fn wav_round_trip_preserves_sample_rate_44100() {
let path = temp_wav("sr44100");
let samples = vec![0.0_f32; 8];
save_wav(&path, &samples, 44_100).unwrap();
let (_, sr) = load_audio(&path).unwrap();
assert_eq!(sr, 44_100);
let _ = fs::remove_file(&path);
}
#[test]
fn save_clips_out_of_range_samples() {
let path = temp_wav("clip");
save_wav(&path, &[2.0_f32, -3.0, 0.0], 8_000).unwrap();
let (got, _) = load_audio(&path).unwrap();
assert_eq!(got.len(), 3);
assert!(got[0] > 0.999, "+2 should clip to ~+1, got {}", got[0]);
assert!(got[1] < -0.999, "-3 should clip to ~-1, got {}", got[1]);
assert!(
got[2].abs() < 1.0 / 16_384.0,
"0.0 should stay ~0, got {}",
got[2]
);
let _ = fs::remove_file(&path);
}
#[test]
fn save_rejects_nan_before_touching_destination() {
let path = temp_wav("nan");
fs::write(&path, b"PRESERVED").unwrap();
let r = save_wav(&path, &[0.0_f32, f32::NAN, 0.0], 8_000);
assert!(matches!(r, Err(mlxrs::Error::LayerKeyed(_))));
let stored = fs::read(&path).unwrap();
assert_eq!(stored, b"PRESERVED");
let _ = fs::remove_file(&path);
}
#[test]
fn save_rejects_zero_sample_rate() {
let path = temp_wav("zero_sr");
fs::write(&path, b"PRESERVED").unwrap();
let r = save_wav(&path, &[0.0_f32, 0.5, -0.5], 0);
assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
let stored = fs::read(&path).unwrap();
assert_eq!(stored, b"PRESERVED");
let _ = fs::remove_file(&path);
}
#[test]
fn load_wav_rejects_multichannel() {
let path = temp_wav("stereo");
let interleaved: &[i16] = &[100, -100, 200, -200, 300, -300, 400, -400];
write_pcm16_wav(&path, interleaved, 16_000, 2);
let r = load_audio(&path);
assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
let _ = fs::remove_file(&path);
}
#[test]
fn resample_passthrough_at_equal_rates() {
let samples = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5];
let got = resample_linear(&samples, 16_000, 16_000).unwrap();
assert_eq!(got, samples);
}
#[test]
fn resample_upsample_doubles_length() {
let samples = vec![0.0_f32, 1.0, 0.0, 1.0];
let got = resample_linear(&samples, 8_000, 16_000).unwrap();
assert_eq!(got.len(), 8);
assert!((got[0] - 0.0).abs() < 1e-6, "got[0]={}", got[0]);
assert!((got[1] - 0.5).abs() < 1e-6, "got[1]={}", got[1]);
}
#[test]
fn resample_downsample_halves_length() {
let samples = vec![0.0_f32, 0.25, 0.5, 0.75, 1.0, 0.75, 0.5, 0.25];
let got = resample_linear(&samples, 16_000, 8_000).unwrap();
assert_eq!(got.len(), 4);
}
#[test]
fn resample_rejects_zero_from_rate() {
let r = resample_linear(&[0.0_f32, 1.0], 0, 16_000);
assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}
#[test]
fn resample_rejects_zero_to_rate() {
let r = resample_linear(&[0.0_f32, 1.0], 16_000, 0);
assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}
#[test]
fn resample_empty_input_returns_empty() {
let got = resample_linear(&[], 8_000, 16_000).unwrap();
assert!(got.is_empty());
}
#[test]
fn resample_rejects_oversized_output_cap() {
let r = resample_linear(&[0.5_f32, -0.5], 1, u32::MAX);
assert!(matches!(r, Err(mlxrs::Error::CapExceeded(_))));
}
#[test]
fn load_wav_missing_file_returns_backend_error() {
let path = temp_wav("missing");
let _ = fs::remove_file(&path);
let r = load_audio(&path);
assert!(matches!(r, Err(mlxrs::Error::FileIo(_))));
}
#[test]
fn save_wav_atomic_no_tempfile_remains_after_successful_save() {
let path = temp_wav("no_temp_remains");
let samples: Vec<f32> = vec![0.1_f32, -0.1, 0.2, -0.2, 0.3];
save_wav(&path, &samples, 8_000).unwrap();
let final_name = path.file_name().unwrap().to_string_lossy().into_owned();
let parent = path.parent().unwrap();
let mut stray_tempfiles: Vec<PathBuf> = Vec::new();
for entry in fs::read_dir(parent).unwrap().flatten() {
let name = entry.file_name().to_string_lossy().into_owned();
if name.starts_with(&final_name) && name.ends_with(".tmp") {
stray_tempfiles.push(entry.path());
}
}
assert!(
stray_tempfiles.is_empty(),
"found stray tempfile(s) after successful save: {stray_tempfiles:?}"
);
let (got, _) = load_audio(&path).unwrap();
assert_eq!(got.len(), samples.len());
let _ = fs::remove_file(&path);
}
#[test]
fn save_wav_atomically_replaces_existing_file() {
let path = temp_wav("replaces_existing");
fs::write(&path, b"OLD_MARKER_DATA_THAT_IS_NOT_A_WAV").unwrap();
let samples: Vec<f32> = vec![0.0_f32, 0.5, -0.5, 0.25];
save_wav(&path, &samples, 16_000).unwrap();
let (got, sr) = load_audio(&path).unwrap();
assert_eq!(sr, 16_000);
assert_eq!(got.len(), samples.len());
let raw = fs::read(&path).unwrap();
assert_eq!(&raw[0..4], b"RIFF", "destination is not a fresh WAV file");
let _ = fs::remove_file(&path);
}
#[test]
fn load_wav_rejects_truncated_wav() {
let path = temp_wav("truncated");
let bits_per_sample: u16 = 16;
let channels: u16 = 1;
let sample_rate: u32 = 16_000;
let block_align: u16 = channels * (bits_per_sample / 8);
let declared_samples: u32 = 32;
let actual_samples: u32 = 8;
let data_size_declared: u32 = declared_samples * u32::from(bits_per_sample / 8);
let file_size_minus_8: u32 = 36u32 + data_size_declared;
let byte_rate: u32 = sample_rate * u32::from(channels) * u32::from(bits_per_sample / 8);
let mut header = [0u8; 44];
header[0..4].copy_from_slice(b"RIFF");
header[4..8].copy_from_slice(&file_size_minus_8.to_le_bytes());
header[8..12].copy_from_slice(b"WAVE");
header[12..16].copy_from_slice(b"fmt ");
header[16..20].copy_from_slice(&16u32.to_le_bytes());
header[20..22].copy_from_slice(&1u16.to_le_bytes());
header[22..24].copy_from_slice(&channels.to_le_bytes());
header[24..28].copy_from_slice(&sample_rate.to_le_bytes());
header[28..32].copy_from_slice(&byte_rate.to_le_bytes());
header[32..34].copy_from_slice(&block_align.to_le_bytes());
header[34..36].copy_from_slice(&bits_per_sample.to_le_bytes());
header[36..40].copy_from_slice(b"data");
header[40..44].copy_from_slice(&data_size_declared.to_le_bytes());
let mut f = File::create(&path).unwrap();
f.write_all(&header).unwrap();
for i in 0..actual_samples as i16 {
f.write_all(&i.to_le_bytes()).unwrap();
}
f.flush().unwrap();
drop(f);
let r = load_audio(&path);
assert!(
matches!(
r,
Err(mlxrs::Error::Parse(_) | mlxrs::Error::LengthMismatch(_))
),
"load_wav must reject a truncated WAV; got {r:?}"
);
let _ = fs::remove_file(&path);
}
#[test]
fn save_wav_rejects_sample_rate_exceeding_byte_rate_u32_ceiling() {
let path = temp_wav("sr_overflow");
fs::write(&path, b"PRESERVED").unwrap();
let r = save_wav(&path, &[0.0_f32], u32::MAX);
assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
let stored = fs::read(&path).unwrap();
assert_eq!(stored, b"PRESERVED");
let _ = fs::remove_file(&path);
}
#[cfg(unix)]
#[test]
fn save_wav_preserves_existing_destination_mode_bits() {
use std::os::unix::fs::PermissionsExt;
let path = temp_wav("preserve_perms");
fs::write(&path, b"prior content (does not need to be a valid WAV)").unwrap();
fs::set_permissions(&path, fs::Permissions::from_mode(0o600)).unwrap();
let pre_mode = fs::metadata(&path).unwrap().permissions().mode() & 0o777;
assert_eq!(pre_mode, 0o600, "test precondition: pre-set mode is 0600");
save_wav(&path, &[0.0_f32, 0.5, -0.5], 16_000).unwrap();
let post_mode = fs::metadata(&path).unwrap().permissions().mode() & 0o777;
assert_eq!(
post_mode, 0o600,
"post-save mode bits drifted: pre={pre_mode:o} post={post_mode:o}"
);
let _ = fs::remove_file(&path);
}
#[test]
fn load_wav_decodes_24bit_pcm_mono_wav() {
let path = temp_wav("pcm24");
let bits_per_sample: u16 = 24;
let channels: u16 = 1;
let sample_rate: u32 = 16_000;
let block_align: u16 = channels * (bits_per_sample / 8);
let n_samples: u32 = 3;
let data_size: u32 = n_samples * u32::from(bits_per_sample / 8);
let file_size_minus_8: u32 = 36u32 + data_size;
let byte_rate: u32 = sample_rate * u32::from(channels) * u32::from(bits_per_sample / 8);
let mut header = [0u8; 44];
header[0..4].copy_from_slice(b"RIFF");
header[4..8].copy_from_slice(&file_size_minus_8.to_le_bytes());
header[8..12].copy_from_slice(b"WAVE");
header[12..16].copy_from_slice(b"fmt ");
header[16..20].copy_from_slice(&16u32.to_le_bytes());
header[20..22].copy_from_slice(&1u16.to_le_bytes());
header[22..24].copy_from_slice(&channels.to_le_bytes());
header[24..28].copy_from_slice(&sample_rate.to_le_bytes());
header[28..32].copy_from_slice(&byte_rate.to_le_bytes());
header[32..34].copy_from_slice(&block_align.to_le_bytes());
header[34..36].copy_from_slice(&bits_per_sample.to_le_bytes());
header[36..40].copy_from_slice(b"data");
header[40..44].copy_from_slice(&data_size.to_le_bytes());
let mut f = File::create(&path).unwrap();
f.write_all(&header).unwrap();
f.write_all(&[0xff, 0xff, 0x7f]).unwrap();
f.write_all(&[0x00, 0x00, 0x00]).unwrap();
f.write_all(&[0x01, 0x00, 0x80]).unwrap();
f.flush().unwrap();
drop(f);
let (got, sr) = load_audio(&path).unwrap();
assert_eq!(sr, 16_000);
assert_eq!(got.len(), 3);
let expected = [8_388_607.0 / 8_388_608.0, 0.0, -8_388_607.0 / 8_388_608.0];
for (g, w) in got.iter().zip(expected.iter()) {
assert!(
(g - w).abs() < 1.0 / (1u64 << 22) as f32,
"24-bit decode mismatch: got={g} want={w}"
);
}
let _ = fs::remove_file(&path);
}
#[test]
fn load_wav_via_symphonia_roundtrip_matches_save_wav_output() {
let path = temp_wav("symphonia_roundtrip");
let samples: Vec<f32> = (0..64)
.map(|i| ((i as f32) / 32.0 - 1.0).clamp(-1.0, 1.0))
.collect();
save_wav(&path, &samples, 22_050).unwrap();
let (got, sr) = load_audio(&path).unwrap();
assert_eq!(sr, 22_050);
assert_eq!(got.len(), samples.len());
for (g, w) in got.iter().zip(samples.iter()) {
assert!(
(g - w).abs() <= 1.0 / 16_384.0,
"symphonia round-trip diff too large: got={g} want={w}"
);
}
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_decodes_mp3() {
let path = temp_path("mp3_decode", "mp3");
write_fixture(&path, FIXTURE_MP3);
let (samples, sr) = load_audio(&path).unwrap();
assert_decoded_tone(&samples, sr, "mp3");
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_decodes_flac() {
let path = temp_path("flac_decode", "flac");
write_fixture(&path, FIXTURE_FLAC);
let (samples, sr) = load_audio(&path).unwrap();
assert_decoded_tone(&samples, sr, "flac");
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_flac_decodes_exact_streaminfo_sample_count() {
let path = temp_path("flac_exact", "flac");
write_fixture(&path, FIXTURE_FLAC);
let (samples, sr) = load_audio(&path).unwrap();
assert_eq!(sr, FIXTURE_SR, "flac: sample rate mismatch");
assert_eq!(
samples.len(),
FIXTURE_NOMINAL_SAMPLES,
"flac: exact STREAMINFO total must decode to exactly {FIXTURE_NOMINAL_SAMPLES} samples"
);
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_rejects_truncated_flac() {
let path = temp_path("flac_truncated", "flac");
let cut = FIXTURE_FLAC.len() / 2;
write_fixture(&path, &FIXTURE_FLAC[..cut]);
let r = load_audio(&path);
assert!(
matches!(
r,
Err(mlxrs::Error::Parse(_) | mlxrs::Error::LengthMismatch(_))
),
"truncated FLAC must be rejected (count mismatch / corruption), got {r:?}"
);
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_decodes_ogg_vorbis() {
let path = temp_path("ogg_decode", "ogg");
write_fixture(&path, FIXTURE_OGG_VORBIS);
let (samples, sr) = load_audio(&path).unwrap();
assert_decoded_tone(&samples, sr, "ogg/vorbis");
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_autodetects_format_ignoring_extension() {
let mp3_as_flac = temp_path("mislabeled_mp3", "flac");
write_fixture(&mp3_as_flac, FIXTURE_MP3);
let (samples, sr) = load_audio(&mp3_as_flac).unwrap();
assert_decoded_tone(&samples, sr, "mp3-labeled-flac");
let _ = fs::remove_file(&mp3_as_flac);
let flac_as_wav = temp_path("mislabeled_flac", "wav");
write_fixture(&flac_as_wav, FIXTURE_FLAC);
let (samples, sr) = load_audio(&flac_as_wav).unwrap();
assert_decoded_tone(&samples, sr, "flac-labeled-wav");
let _ = fs::remove_file(&flac_as_wav);
}
#[test]
fn load_audio_autodetects_format_with_no_extension() {
for (name, bytes, fmt) in [
("noext_mp3", FIXTURE_MP3, "mp3"),
("noext_flac", FIXTURE_FLAC, "flac"),
("noext_ogg", FIXTURE_OGG_VORBIS, "ogg/vorbis"),
] {
let mut p = std::env::temp_dir();
p.push(format!("mlxrs_audio_io_{}_{}", process::id(), name));
write_fixture(&p, bytes);
let (samples, sr) = load_audio(&p).unwrap();
assert_decoded_tone(&samples, sr, fmt);
let _ = fs::remove_file(&p);
}
}
#[test]
fn load_audio_rejects_unsupported_opus_like_garbage() {
let path = temp_path("unsupported", "opus");
let garbage: Vec<u8> = (0..512u32).map(|i| (i % 251) as u8).collect();
write_fixture(&path, &garbage);
let r = load_audio(&path);
assert!(
matches!(r, Err(mlxrs::Error::Parse(_))),
"unsupported/garbage input must return Parse error, got {r:?}"
);
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_truncated_compressed_is_bounded_and_recoverable() {
for (name, bytes, ext) in [
("trunc_mp3", FIXTURE_MP3, "mp3"),
("trunc_flac", FIXTURE_FLAC, "flac"),
("trunc_ogg", FIXTURE_OGG_VORBIS, "ogg"),
] {
let path = temp_path(name, ext);
let cut = (bytes.len() * 2) / 5; write_fixture(&path, &bytes[..cut]);
match load_audio(&path) {
Err(mlxrs::Error::Parse(_) | mlxrs::Error::LengthMismatch(_)) => { }
Ok((samples, _)) => {
assert!(
samples.len() <= mlxrs::audio::io::MAX_DECODED_SAMPLES,
"{ext}: truncated decode returned {} samples (> cap)",
samples.len()
);
assert!(
samples.iter().all(|s| s.is_finite()),
"{ext}: truncated decode returned a non-finite sample"
);
}
Err(other) => panic!("{ext}: unexpected error variant: {other:?}"),
}
let _ = fs::remove_file(&path);
}
}
#[test]
fn load_audio_into_reuses_buffer_across_calls() {
use mlxrs::audio::io::load_audio_into;
let path1 = temp_wav("load_into_reuse_1");
let path2 = temp_wav("load_into_reuse_2");
let s1: Vec<f32> = (0..1000).map(|i| (i as f32 / 500.0).sin() * 0.5).collect();
let s2: Vec<f32> = (0..500).map(|i| (i as f32 / 100.0).cos() * 0.25).collect();
save_wav(&path1, &s1, 16_000).unwrap();
save_wav(&path2, &s2, 16_000).unwrap();
let mut scratch: Vec<f32> = Vec::with_capacity(2000);
let cap_before = scratch.capacity();
let sr1 = load_audio_into(&path1, &mut scratch).unwrap();
assert_eq!(sr1, 16_000);
assert_eq!(scratch.len(), s1.len());
for (i, (g, e)) in scratch.iter().zip(s1.iter()).enumerate() {
assert!(
(g - e).abs() < 1.1 / 32767.0,
"load_into[{i}]: got {g}, want {e}"
);
}
let sr2 = load_audio_into(&path2, &mut scratch).unwrap();
assert_eq!(sr2, 16_000);
assert_eq!(scratch.len(), s2.len());
assert!(
scratch.capacity() >= cap_before,
"buffer reuse must not shrink capacity: {} < {cap_before}",
scratch.capacity()
);
for (i, (g, e)) in scratch.iter().zip(s2.iter()).enumerate() {
assert!(
(g - e).abs() < 1.1 / 32767.0,
"load_into[{i}]: got {g}, want {e}"
);
}
let _ = fs::remove_file(&path1);
let _ = fs::remove_file(&path2);
}
#[test]
fn load_audio_with_cap_rejects_oversized_wav_at_header_stage() {
use mlxrs::{audio::io::load_audio_with_cap, error::Error};
let path = temp_wav("with_cap_oversized");
let s: Vec<f32> = (0..100).map(|i| (i as f32 / 50.0).sin() * 0.3).collect();
save_wav(&path, &s, 16_000).unwrap();
let r = load_audio_with_cap(&path, 50);
assert!(
matches!(r, Err(Error::CapExceeded(_))),
"over-cap WAV header must reject with CapExceeded, got {r:?}"
);
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_with_cap_undersized_decodes_identically() {
use mlxrs::audio::io::load_audio_with_cap;
let path = temp_wav("with_cap_undersized");
let s: Vec<f32> = (0..200).map(|i| (i as f32 / 100.0).sin() * 0.25).collect();
save_wav(&path, &s, 16_000).unwrap();
let (got_cap, sr_cap) = load_audio_with_cap(&path, 1000).unwrap();
let (got_plain, sr_plain) = load_audio(&path).unwrap();
assert_eq!(sr_cap, sr_plain);
assert_eq!(got_cap, got_plain);
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_equivalent_to_with_cap_at_max() {
use mlxrs::audio::io::{MAX_DECODED_SAMPLES, load_audio_with_cap};
let path = temp_wav("with_cap_at_max");
let s: Vec<f32> = (0..256).map(|i| (i as f32 / 128.0).cos() * 0.2).collect();
save_wav(&path, &s, 16_000).unwrap();
let (got_plain, _) = load_audio(&path).unwrap();
let (got_at_max, _) = load_audio_with_cap(&path, MAX_DECODED_SAMPLES).unwrap();
let (got_at_usize_max, _) = load_audio_with_cap(&path, usize::MAX).unwrap();
assert_eq!(got_plain, got_at_max);
assert_eq!(got_plain, got_at_usize_max);
let _ = fs::remove_file(&path);
}
#[test]
fn save_wav_into_reuses_scratch_buffer() {
use mlxrs::audio::io::save_wav_into;
let path_into = temp_wav("save_into_reuse");
let path_plain = temp_wav("save_plain_reuse");
let s1: Vec<f32> = (0..800).map(|i| (i as f32 / 50.0).sin() * 0.6).collect();
let s2: Vec<f32> = (0..400).map(|i| (i as f32 / 25.0).cos() * 0.3).collect();
let mut scratch: Vec<i16> = Vec::new();
save_wav_into(&path_into, &s1, 16_000, &mut scratch).unwrap();
let cap_after_first = scratch.capacity();
assert!(
cap_after_first >= s1.len(),
"scratch did not retain capacity"
);
save_wav_into(&path_into, &s2, 16_000, &mut scratch).unwrap();
assert!(
scratch.capacity() >= cap_after_first,
"scratch shrank on smaller write: {} < {cap_after_first}",
scratch.capacity()
);
save_wav(&path_plain, &s2, 16_000).unwrap();
let into_bytes = fs::read(&path_into).unwrap();
let plain_bytes = fs::read(&path_plain).unwrap();
assert_eq!(
into_bytes, plain_bytes,
"save_wav_into must produce byte-identical WAV vs save_wav"
);
let _ = fs::remove_file(&path_into);
let _ = fs::remove_file(&path_plain);
}
#[test]
fn i16_quantizer_matches_simd_dispatcher() {
use core::mem::MaybeUninit;
use mlxrs::audio::io::{I16Quantizer, Quantizer};
let src: Vec<f32> = vec![
-1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 0.001, ];
let mut dst_a: Vec<i16> = Vec::with_capacity(src.len());
let mut dst_b: Vec<i16> = Vec::with_capacity(src.len());
let spare_a: &mut [MaybeUninit<i16>] = dst_a.spare_capacity_mut();
I16Quantizer.quantize_into(&mut spare_a[..src.len()], &src);
unsafe { dst_a.set_len(src.len()) };
let spare_b: &mut [MaybeUninit<i16>] = dst_b.spare_capacity_mut();
mlxrs::simd::audio::quantize::f32_to_i16_quantize(&mut spare_b[..src.len()], &src);
unsafe { dst_b.set_len(src.len()) };
assert_eq!(
dst_a, dst_b,
"I16Quantizer must produce identical output to the SIMD dispatcher"
);
assert_eq!(dst_a[0], -32768, "-1.5 should clip to -32768 (i16::MIN)");
assert_eq!(
dst_a[6], 32767,
"+1.5 should clip to +32767 (i16::MAX via saturating narrow)"
);
}
#[test]
fn load_audio_with_max_seconds_unified_probe_no_toctou() {
use mlxrs::audio::io::load_audio_with_max_seconds;
let path = temp_wav("max_seconds_unified_no_toctou");
let sr = 8000_u32;
let samples: Vec<f32> = (0..8000).map(|i| (i as f32 * 0.001).sin() * 0.4).collect();
save_wav(&path, &samples, sr).unwrap();
let (got_samples, got_sr) = load_audio_with_max_seconds(&path, 1.0).unwrap();
assert_eq!(
got_sr, sr,
"returned sr must equal the file's actual sr (not a stale probe's)"
);
assert_eq!(
got_samples.len(),
samples.len(),
"exact-boundary decode (cap == header_len) must include every sample"
);
use mlxrs::error::Error;
let max_seconds_just_below = 7999.0 / 8000.0;
let r = load_audio_with_max_seconds(&path, max_seconds_just_below);
assert!(
matches!(r, Err(Error::CapExceeded(_))),
"cap one frame below header must reject; got {r:?}"
);
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_with_max_seconds_mp3_genuinely_over_cap_rejects() {
use mlxrs::{audio::io::load_audio_with_max_seconds, error::Error};
let path = temp_path("mp3_over_cap_rejects", "mp3");
write_fixture(&path, FIXTURE_MP3);
let r = load_audio_with_max_seconds(&path, 0.01);
match &r {
Err(Error::BoundedDecode(_)) => { }
Err(Error::CapExceeded(p)) => panic!(
"MP3 over-cap rejection came from the upfront \
`header_len > cap` path (CapExceeded against `{}`); estimate-count formats \
must reject mid-decode via BoundedDecode",
p.cap_name()
),
other => panic!("MP3 genuinely over cap must reject with Error::BoundedDecode; got {other:?}"),
}
let _ = fs::remove_file(&path);
}
#[test]
fn load_audio_into_unified_has_single_file_open() {
let src = include_str!("../src/audio/io.rs");
let sig = "fn load_audio_into_unified(";
let sig_pos = src
.find(sig)
.expect("io.rs must define `fn load_audio_into_unified(`");
let body_start_rel = src[sig_pos..]
.find('{')
.expect("`load_audio_into_unified` signature must be followed by `{`");
let body_start = sig_pos + body_start_rel;
let bytes = src.as_bytes();
let mut depth: i32 = 0;
let mut body_end = body_start;
for (i, &b) in bytes.iter().enumerate().skip(body_start) {
if b == b'{' {
depth += 1;
} else if b == b'}' {
depth -= 1;
if depth == 0 {
body_end = i + 1;
break;
}
}
}
assert!(
depth == 0 && body_end > body_start,
"structural test could not locate the closing `}}` of \
`load_audio_into_unified` — io.rs source layout may have changed; \
update the brace-matcher in this test."
);
let body = &src[body_start..body_end];
assert!(
body.len() > 1000,
"load_audio_into_unified body looks suspiciously short ({} bytes); \
structural test extracted the wrong region",
body.len()
);
let body_no_comments = strip_comments(body);
let file_open_count = body_no_comments
.match_indices("File::open")
.filter(|(idx, _)| {
if *idx == 0 {
return true;
}
let prev = body_no_comments.as_bytes()[*idx - 1];
!(prev.is_ascii_alphanumeric() || prev == b'_')
})
.count();
assert_eq!(
file_open_count, 1,
"STRUCTURAL regression: \
`load_audio_into_unified` body must contain EXACTLY ONE \
`File::open` (the unified probe+decode handle); found \
{file_open_count} in the comment-stripped body. A reintroduced \
second open is the TOCTOU defect this test guards against. \
Comment-stripped body was:\n{body_no_comments}"
);
assert!(
!body_no_comments.contains("probe_source_sample_rate"),
"STRUCTURAL regression: \
`load_audio_into_unified` body must NOT reference a \
`probe_source_sample_rate` helper — such a helper performs a \
SECOND `File::open` and reintroduces the TOCTOU double-open. \
Comment-stripped body was:\n{body_no_comments}"
);
let pub_sig = "pub fn load_audio_with_max_seconds(";
let pub_pos = src
.find(pub_sig)
.expect("io.rs must define `pub fn load_audio_with_max_seconds(`");
let pub_body_start_rel = src[pub_pos..]
.find('{')
.expect("`load_audio_with_max_seconds` signature must be followed by `{`");
let pub_body_start = pub_pos + pub_body_start_rel;
let mut pub_depth: i32 = 0;
let mut pub_body_end = pub_body_start;
for (i, &b) in bytes.iter().enumerate().skip(pub_body_start) {
if b == b'{' {
pub_depth += 1;
} else if b == b'}' {
pub_depth -= 1;
if pub_depth == 0 {
pub_body_end = i + 1;
break;
}
}
}
let pub_body = &src[pub_body_start..pub_body_end];
let pub_body_no_comments = strip_comments(pub_body);
let pub_file_open_count = pub_body_no_comments
.match_indices("File::open")
.filter(|(idx, _)| {
if *idx == 0 {
return true;
}
let prev = pub_body_no_comments.as_bytes()[*idx - 1];
!(prev.is_ascii_alphanumeric() || prev == b'_')
})
.count();
assert_eq!(
pub_file_open_count, 0,
"STRUCTURAL regression: \
`load_audio_with_max_seconds` must delegate to the unified \
worker without performing its own `File::open` (found \
{pub_file_open_count} direct opens). The unified worker is the \
SOLE owner of the load-path file handle. \
Comment-stripped body was:\n{pub_body_no_comments}"
);
}
fn strip_comments(src: &str) -> String {
let bytes = src.as_bytes();
let mut out = String::with_capacity(src.len());
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'/' {
i += 2;
while i < bytes.len() && bytes[i] != b'\n' {
i += 1;
}
continue;
}
if b == b'/' && i + 1 < bytes.len() && bytes[i + 1] == b'*' {
i += 2;
let mut depth: u32 = 1;
while i < bytes.len() && depth > 0 {
if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' {
depth += 1;
i += 2;
} else if i + 1 < bytes.len() && bytes[i] == b'*' && bytes[i + 1] == b'/' {
depth -= 1;
i += 2;
} else {
i += 1;
}
}
continue;
}
if b == b'"' {
out.push('"');
i += 1;
while i < bytes.len() {
let c = bytes[i];
if c == b'\\' && i + 1 < bytes.len() {
out.push(c as char);
out.push(bytes[i + 1] as char);
i += 2;
continue;
}
if c == b'"' {
out.push('"');
i += 1;
break;
}
let ch_end = utf8_char_end(bytes, i);
out.push_str(&src[i..ch_end]);
i = ch_end;
}
continue;
}
let ch_end = utf8_char_end(bytes, i);
out.push_str(&src[i..ch_end]);
i = ch_end;
}
out
}
fn utf8_char_end(bytes: &[u8], i: usize) -> usize {
let b = bytes[i];
let width = if b < 0x80 {
1
} else if b < 0xc0 {
1
} else if b < 0xe0 {
2
} else if b < 0xf0 {
3
} else {
4
};
(i + width).min(bytes.len())
}