#![cfg(feature = "audio")]
use std::f32::consts::PI;
use mlxrs::{
Array, Dtype,
audio::dsp::{
LogFloor, MelPrecision, WindowPad, hann_window, log_mel_spectrogram, log_mel_spectrogram_with,
mel_filter_bank, mel_filter_bank_with, mel_spectrogram, stft,
},
};
fn sine_1khz_16samples() -> Array {
let sr = 16_000.0_f32;
let f = 1_000.0_f32;
let buf: Vec<f32> = (0..16)
.map(|n| (2.0 * PI * f * n as f32 / sr).sin())
.collect();
Array::from_slice::<f32>(&buf, &[16i32]).unwrap()
}
#[test]
fn hann_window_endpoints_are_zero() {
let mut w = hann_window(8).unwrap();
let v = w.to_vec::<f32>().unwrap();
assert!(
(v[0]).abs() < 1e-6,
"first sample should be 0, got {}",
v[0]
);
assert!((v[7]).abs() < 1e-6, "last sample should be 0, got {}", v[7]);
}
#[test]
fn hann_window_is_symmetric() {
let mut w = hann_window(9).unwrap();
let v = w.to_vec::<f32>().unwrap();
for k in 0..v.len() / 2 {
let mirror = v.len() - 1 - k;
assert!(
(v[k] - v[mirror]).abs() < 1e-6,
"asymmetric: v[{k}]={} vs v[{mirror}]={}",
v[k],
v[mirror]
);
}
}
#[test]
fn hann_window_rejects_n_lt_2() {
assert!(matches!(hann_window(0), Err(mlxrs::Error::OutOfRange(_))));
assert!(matches!(hann_window(1), Err(mlxrs::Error::OutOfRange(_))));
}
#[test]
fn hann_window_center_value_is_one() {
let mut w = hann_window(9).unwrap();
let v = w.to_vec::<f32>().unwrap();
assert!(
(v[4] - 1.0).abs() < 1e-5,
"center should be ~1, got {}",
v[4]
);
}
#[test]
fn stft_shape_matches_formula() {
let x = sine_1khz_16samples();
let s = stft(&x, 8, 4, None, WindowPad::Center).unwrap();
assert_eq!(s.data_ref().shape(), vec![5, 5]); assert_eq!(s.data_ref().dtype().unwrap(), Dtype::Complex64);
assert_eq!(s.n_fft(), 8);
assert_eq!(s.hop_length(), 4);
assert_eq!(s.win_length(), 8); assert!(s.center());
}
#[test]
fn stft_rejects_zero_n_fft() {
let x = sine_1khz_16samples();
let r = stft(&x, 0, 4, None, WindowPad::Center);
assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}
#[test]
fn stft_rejects_zero_hop_length() {
let x = sine_1khz_16samples();
let r = stft(&x, 8, 0, None, WindowPad::Center);
assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}
#[test]
fn stft_rejects_win_length_greater_than_n_fft() {
let x = sine_1khz_16samples();
let r = stft(&x, 8, 4, Some(16), WindowPad::Center);
assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
}
#[test]
fn stft_minimum_valid_input_boundary_padding_to_index_zero() {
let buf: Vec<f32> = (0..5).map(|i| i as f32).collect();
let x = Array::from_slice::<f32>(&buf, &[5i32]).unwrap();
let s = stft(&x, 8, 4, None, WindowPad::Center).unwrap();
assert_eq!(s.data_ref().shape(), vec![2, 5]); assert_eq!(s.data_ref().dtype().unwrap(), Dtype::Complex64);
}
#[test]
fn stft_rejects_input_too_short_for_reflect_pad() {
let buf = vec![0.0_f32, 0.1, 0.2, 0.3];
let x = Array::from_slice::<f32>(&buf, &[4i32]).unwrap();
let r = stft(&x, 16, 8, None, WindowPad::Center);
assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
}
#[test]
fn stft_win_length_shorter_than_n_fft_zero_pads_window() {
let x = sine_1khz_16samples();
let s = stft(&x, 8, 4, Some(4), WindowPad::Right).unwrap();
assert_eq!(s.data_ref().shape(), vec![5, 5]);
assert_eq!(s.win_length(), 4); }
#[test]
fn mel_filter_bank_shape_matches_n_mels_x_n_freqs() {
let bank = mel_filter_bank(80, 400, 16_000, 0.0, None).unwrap();
assert_eq!(bank.shape(), vec![80, 201]);
}
#[test]
fn mel_filter_bank_rejects_zero_n_mels() {
let r = mel_filter_bank(0, 400, 16_000, 0.0, None);
assert!(matches!(r, Err(mlxrs::Error::InvariantViolation(_))));
}
#[test]
fn mel_filter_bank_rejects_invalid_freq_range() {
let r = mel_filter_bank(40, 400, 16_000, 1000.0, Some(500.0));
assert!(matches!(r, Err(mlxrs::Error::OutOfRange(_))));
}
#[test]
fn mel_filter_bank_rejects_usize_overflow_inputs() {
let r = mel_filter_bank(usize::MAX, 400, 16_000, 0.0, None);
assert!(matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))));
let big_n_mels = 1usize << 33;
let big_n_fft = 1usize << 34;
let r = mel_filter_bank(big_n_mels, big_n_fft, 16_000, 0.0, Some(8_000.0));
assert!(matches!(r, Err(mlxrs::Error::ArithmeticOverflow(_))));
}
#[test]
fn mel_filter_bank_values_are_nonneg() {
let mut bank = mel_filter_bank(8, 64, 16_000, 0.0, None).unwrap();
for v in bank.to_vec::<f32>().unwrap() {
assert!(v >= 0.0, "negative mel weight: {v}");
}
}
#[test]
fn mel_filter_bank_precise_shape_matches() {
let std_bank = mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Standard).unwrap();
let precise = mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Precise).unwrap();
assert_eq!(std_bank.shape(), vec![80, 201]);
assert_eq!(precise.shape(), vec![80, 201]);
}
#[test]
fn mel_filter_bank_with_standard_matches_shorthand() {
let mut shorthand = mel_filter_bank(80, 400, 16_000, 0.0, None).unwrap();
let mut with_std =
mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Standard).unwrap();
assert_eq!(
shorthand.to_vec::<f32>().unwrap(),
with_std.to_vec::<f32>().unwrap(),
"Standard precision must match the f32 shorthand bit-for-bit"
);
}
#[test]
fn mel_filter_bank_precise_differs_but_close() {
let mut std_bank =
mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Standard).unwrap();
let mut precise =
mel_filter_bank_with(80, 400, 16_000, 0.0, None, MelPrecision::Precise).unwrap();
let s = std_bank.to_vec::<f32>().unwrap();
let p = precise.to_vec::<f32>().unwrap();
assert_eq!(s.len(), p.len(), "shape mismatch between precisions");
assert!(
s.iter().zip(&p).any(|(a, b)| a.to_bits() != b.to_bits()),
"precise bank must differ from the f32 bank (otherwise the f64 path is a no-op)"
);
let max_abs = s
.iter()
.zip(&p)
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
assert!(
max_abs < 1e-4,
"precise vs f32 max abs diff {max_abs} exceeds 1e-4"
);
}
#[test]
fn mel_filter_bank_precise_values_are_nonneg() {
let mut bank = mel_filter_bank_with(8, 64, 16_000, 0.0, None, MelPrecision::Precise).unwrap();
for v in bank.to_vec::<f32>().unwrap() {
assert!(v >= 0.0, "negative precise mel weight: {v}");
}
}
#[test]
fn mel_filter_bank_precise_rejects_invalid_inputs() {
assert!(matches!(
mel_filter_bank_with(0, 400, 16_000, 0.0, None, MelPrecision::Precise),
Err(mlxrs::Error::InvariantViolation(_))
));
assert!(matches!(
mel_filter_bank_with(40, 400, 16_000, 1000.0, Some(500.0), MelPrecision::Precise),
Err(mlxrs::Error::OutOfRange(_))
));
}
#[test]
fn mel_spectrogram_is_nonneg_for_real_input() {
let x = sine_1khz_16samples();
let mut m = mel_spectrogram(&x, 8, 4, None, 4, 16_000, 0.0, None).unwrap();
assert_eq!(m.shape(), vec![4, 5]);
for v in m.to_vec::<f32>().unwrap() {
assert!(v >= 0.0, "mel spec must be non-negative, got {v}");
}
}
#[test]
fn log_mel_spectrogram_is_finite_for_silence() {
let zeros = Array::zeros::<f32>(&(64usize,)).unwrap();
let mut m = log_mel_spectrogram(&zeros, 16, 8, None, 4, 16_000, 0.0, None).unwrap();
let v = m.to_vec::<f32>().unwrap();
for x in &v {
assert!(x.is_finite(), "log-mel must be finite (eps floor), got {x}");
}
let expected = (1e-10_f32).ln();
for x in &v {
assert!(
(x - expected).abs() < 1e-3,
"silence log-mel should equal ln(eps)={expected}, got {x}"
);
}
}
#[test]
fn log_mel_spectrogram_is_finite_for_sine_input() {
let x = sine_1khz_16samples();
let mut m = log_mel_spectrogram(&x, 8, 4, None, 4, 16_000, 0.0, None).unwrap();
for v in m.to_vec::<f32>().unwrap() {
assert!(v.is_finite(), "log-mel must be finite, got {v}");
}
}
#[test]
fn log_floor_whisper_matches_1e_10() {
assert_eq!(LogFloor::Whisper.value(), 1e-10_f32);
assert_eq!(LogFloor::default().value(), 1e-10_f32);
}
#[test]
fn log_floor_kaldi_matches_mlx_audio_1e_8() {
assert_eq!(LogFloor::Kaldi.value(), 1e-8_f32);
}
#[test]
fn log_floor_custom_clamps_nonpositive_and_nonfinite_to_min_positive() {
assert_eq!(LogFloor::Custom(-1.0).value(), f32::MIN_POSITIVE);
assert_eq!(LogFloor::Custom(0.0).value(), f32::MIN_POSITIVE);
assert_eq!(LogFloor::Custom(-0.0).value(), f32::MIN_POSITIVE);
assert_eq!(LogFloor::Custom(f32::NAN).value(), f32::MIN_POSITIVE);
assert_eq!(LogFloor::Custom(f32::INFINITY).value(), f32::MIN_POSITIVE);
assert_eq!(
LogFloor::Custom(f32::NEG_INFINITY).value(),
f32::MIN_POSITIVE
);
let v = LogFloor::Custom(1e-7).value();
assert!((v - 1e-7).abs() < f32::EPSILON);
}
#[test]
fn log_mel_spectrogram_whisper_vs_kaldi_differ_at_silence() {
let zeros = Array::zeros::<f32>(&(64usize,)).unwrap();
let mut whisper =
log_mel_spectrogram_with(&zeros, 16, 8, None, 4, 16_000, 0.0, None, LogFloor::Whisper).unwrap();
let mut kaldi =
log_mel_spectrogram_with(&zeros, 16, 8, None, 4, 16_000, 0.0, None, LogFloor::Kaldi).unwrap();
let w = whisper.to_vec::<f32>().unwrap();
let k = kaldi.to_vec::<f32>().unwrap();
assert_eq!(w.len(), k.len(), "shape mismatch between floors");
let expected_w = (1e-10_f32).ln();
let expected_k = (1e-8_f32).ln();
let expected_delta = expected_k - expected_w;
for (wi, ki) in w.iter().zip(k.iter()) {
assert!(
(wi - expected_w).abs() < 1e-3,
"whisper silence entry should equal ln(1e-10)={expected_w}, got {wi}"
);
assert!(
(ki - expected_k).abs() < 1e-3,
"kaldi silence entry should equal ln(1e-8)={expected_k}, got {ki}"
);
assert!(*wi < *ki, "whisper floor must be more negative than kaldi");
assert!(
((ki - wi) - expected_delta).abs() < 1e-3,
"delta whisper-kaldi should be ~ln(100)={expected_delta}, got {}",
ki - wi
);
}
}
#[test]
fn log_mel_spectrogram_default_matches_explicit_whisper() {
let x = sine_1khz_16samples();
let mut a = log_mel_spectrogram(&x, 8, 4, None, 4, 16_000, 0.0, None).unwrap();
let mut b =
log_mel_spectrogram_with(&x, 8, 4, None, 4, 16_000, 0.0, None, LogFloor::Whisper).unwrap();
let va = a.to_vec::<f32>().unwrap();
let vb = b.to_vec::<f32>().unwrap();
assert_eq!(va.len(), vb.len());
for (i, (x, y)) in va.iter().zip(vb.iter()).enumerate() {
assert_eq!(
x.to_bits(),
y.to_bits(),
"bit-mismatch at {i}: default={x:?} explicit_whisper={y:?}"
);
}
}
#[test]
fn mel_spectrogram_uses_cached_filter_bank() {
let src = include_str!("../src/audio/dsp.rs");
let body_start = src
.find("pub fn mel_spectrogram(")
.expect("dsp.rs must define `pub fn mel_spectrogram(`");
let body_tail = &src[body_start..];
let body_end_rel = body_tail
.find("ops::linalg_basic::matmul(&mel, &power_t)")
.expect("mel_spectrogram body must terminate with the canonical matmul-return");
let body = &body_tail[..body_end_rel];
assert!(
body.contains("mel_filter_bank_cached("),
"regression: `mel_spectrogram` must invoke \
`mel_filter_bank_cached(...)` (per-thread LRU cache), not the \
uncached `mel_filter_bank(...)`. Function body was:\n{body}"
);
let uncached_calls = body
.match_indices("mel_filter_bank(")
.filter(|(idx, _)| {
if *idx == 0 {
return true;
}
let prev = body.as_bytes()[*idx - 1];
!(prev.is_ascii_alphanumeric() || prev == b'_')
})
.count();
assert_eq!(
uncached_calls, 0,
"regression: `mel_spectrogram` body must NOT \
contain any direct `mel_filter_bank(` call; only the cached \
variant `mel_filter_bank_cached(` is allowed. Found {uncached_calls} \
uncached call(s).\nBody:\n{body}"
);
}