use super::*;
use crate::Dtype;
const F32_TOL: f32 = 1e-5;
fn to_vec(a: &Array) -> Vec<f32> {
crate::ops::shape::contiguous(a, false)
.unwrap()
.to_vec::<f32>()
.unwrap()
}
#[test]
fn mel_scale_kaldi_matches_reference_formula() {
assert!((mel_scale_kaldi(0.0)).abs() < F32_TOL);
let v_700 = mel_scale_kaldi(700.0);
let want_700 = 1127.0 * 2.0_f32.ln();
assert!(
(v_700 - want_700).abs() < 1e-3,
"mel(700): got {v_700}, want {want_700}"
);
let v_1000 = mel_scale_kaldi(1000.0);
let want_1000 = 1127.0 * (17.0_f32 / 7.0).ln();
assert!(
(v_1000 - want_1000).abs() < 1e-3,
"mel(1000): got {v_1000}, want {want_1000}"
);
}
#[test]
fn mel_scale_kaldi_inverse_round_trips() {
for hz in [0.0_f32, 100.0, 700.0, 1000.0, 4000.0, 8000.0, 16000.0] {
let mel = mel_scale_kaldi(hz);
let back = inverse_mel_scale_kaldi(mel);
assert!(
(back - hz).abs() < (hz.abs() + 1.0) * 1e-5,
"round-trip(hz={hz}): mel={mel}, back={back}"
);
}
}
#[test]
fn mel_banks_kaldi_shape() {
let (bins, centers) = get_mel_banks_kaldi(80, 512, 16_000.0, 20.0, 0.0).unwrap();
assert_eq!(bins.shape(), vec![80, 256]);
assert_eq!(centers.shape(), vec![80]);
assert_eq!(bins.dtype().unwrap(), Dtype::F32);
}
#[test]
fn mel_banks_kaldi_rows_sum_positive() {
let (bins, _) = get_mel_banks_kaldi(40, 512, 16_000.0, 0.0, 0.0).unwrap();
let v = to_vec(&bins);
let cols = 256;
for m in 0..40 {
let row_sum: f32 = v[m * cols..(m + 1) * cols].iter().sum();
assert!(
row_sum > 0.0,
"mel bin {m} integrates to {row_sum}, expected > 0"
);
}
}
#[test]
fn mel_banks_kaldi_center_freqs_monotone_increasing() {
let (_, centers) = get_mel_banks_kaldi(40, 512, 16_000.0, 20.0, 0.0).unwrap();
let c = to_vec(¢ers);
for w in c.windows(2) {
assert!(
w[1] > w[0],
"center freqs must be monotone increasing: {} not > {}",
w[1],
w[0]
);
}
assert!(c[0] > 20.0, "first center {} should exceed low_freq", c[0]);
assert!(
c[c.len() - 1] < 8000.0,
"last center {} should be under Nyquist 8000",
c[c.len() - 1]
);
}
#[test]
fn mel_banks_kaldi_rejects_invalid_args() {
assert!(matches!(
get_mel_banks_kaldi(3, 512, 16_000.0, 0.0, 0.0),
Err(Error::OutOfRange(_))
));
assert!(matches!(
get_mel_banks_kaldi(40, 513, 16_000.0, 0.0, 0.0),
Err(Error::OutOfRange(_))
));
assert!(matches!(
get_mel_banks_kaldi(40, 512, 0.0, 0.0, 0.0),
Err(Error::OutOfRange(_))
));
assert!(matches!(
get_mel_banks_kaldi(40, 512, 16_000.0, 9000.0, 0.0),
Err(Error::OutOfRange(_))
));
assert!(matches!(
get_mel_banks_kaldi(40, 512, 16_000.0, 9000.0, -100.0),
Err(Error::OutOfRange(_))
));
}
#[test]
fn next_power_of_2_smoke() {
assert_eq!(next_power_of_2(0), 1);
assert_eq!(next_power_of_2(1), 1);
assert_eq!(next_power_of_2(2), 2);
assert_eq!(next_power_of_2(3), 4);
assert_eq!(next_power_of_2(400), 512);
assert_eq!(next_power_of_2(1920), 2048);
}
fn sine_wave(freq: f32, sample_rate: u32, n_samples: usize) -> Vec<f32> {
(0..n_samples)
.map(|n| (2.0 * PI * freq * (n as f32) / (sample_rate as f32)).sin())
.collect()
}
#[test]
fn compute_fbank_kaldi_output_shape() {
let samples = sine_wave(1000.0, 16_000, 16_000);
let x = Array::from_slice::<f32>(&samples, &[16_000_i32]).unwrap();
let out = compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap();
assert_eq!(out.shape(), vec![98, 40]);
assert_eq!(out.dtype().unwrap(), Dtype::F32);
}
#[test]
fn compute_fbank_kaldi_snip_edges_false_frame_count_and_finite() {
let samples = sine_wave(1000.0, 16_000, 16_000);
let x = Array::from_slice::<f32>(&samples, &[16_000_i32]).unwrap();
let out_false = compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
false, 20.0,
0.0,
None,
)
.unwrap();
let m_false: usize = (16_000 + 160 / 2) / 160; assert_eq!(
out_false.shape(),
vec![m_false, 40],
"snip_edges=false frame count"
);
assert_eq!(m_false, 100);
let v = to_vec(&out_false);
assert!(
v.iter().all(|x| x.is_finite()),
"snip_edges=false features must all be finite"
);
}
#[test]
fn compute_fbank_kaldi_known_signal_peaks_near_1khz() {
let samples = sine_wave(1000.0, 16_000, 16_000);
let x = Array::from_slice::<f32>(&samples, &[16_000_i32]).unwrap();
let out = compute_fbank_kaldi(
&x,
16_000,
400,
160,
80,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap();
let shape = out.shape();
assert_eq!(shape.len(), 2);
let num_frames = shape[0] as usize;
let num_mels = shape[1] as usize;
let v = to_vec(&out);
let (_, centers) = get_mel_banks_kaldi(80, 512, 16_000.0, 20.0, 0.0).unwrap();
let c = to_vec(¢ers);
let (closest_bin, _) = c
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| {
(*a - 1000.0)
.abs()
.partial_cmp(&(*b - 1000.0).abs())
.unwrap()
})
.unwrap();
let mut hits = 0;
let mut tries = 0;
for f in 2..(num_frames.saturating_sub(2)) {
let row = &v[f * num_mels..(f + 1) * num_mels];
let (argmax_bin, _) = row
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap();
if (argmax_bin as i32 - closest_bin as i32).abs() <= 1 {
hits += 1;
}
tries += 1;
}
assert!(
hits >= (tries * 9) / 10,
"expected >= 90% of steady-state frames to peak near 1 kHz mel bin {closest_bin}: \
got {hits}/{tries}"
);
}
#[test]
fn compute_fbank_kaldi_silence_is_finite() {
let zeros = vec![0.0_f32; 4_000];
let x = Array::from_slice::<f32>(&zeros, &[4_000_i32]).unwrap();
let out = compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0, true,
20.0,
0.0,
None,
)
.unwrap();
let v = to_vec(&out);
assert!(!v.is_empty());
let want = (1e-8_f32).ln();
for (i, &x) in v.iter().enumerate() {
assert!(x.is_finite(), "silence[{i}] = {x}: must be finite");
assert!(
(x - want).abs() < 1e-3,
"silence[{i}] = {x}: must be log(1e-8) = {want}"
);
}
}
#[test]
fn compute_fbank_kaldi_short_input_returns_empty() {
let short = vec![0.0_f32; 100];
let x = Array::from_slice::<f32>(&short, &[100_i32]).unwrap();
let out = compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap();
assert_eq!(out.shape(), vec![0, 40]);
}
#[test]
fn compute_fbank_kaldi_window_variants_differ() {
let samples = sine_wave(1000.0, 16_000, 4_000);
let x = Array::from_slice::<f32>(&samples, &[4_000_i32]).unwrap();
let mut feats = Vec::new();
for wt in [
KaldiWindow::Hamming,
KaldiWindow::Hanning,
KaldiWindow::Povey,
KaldiWindow::Rectangular,
] {
let f = compute_fbank_kaldi(
&x, 16_000, 400, 160, 40, wt, 0.97, 0.0, true, 20.0, 0.0, None,
)
.unwrap();
feats.push(to_vec(&f));
}
for i in 0..feats.len() {
for j in (i + 1)..feats.len() {
let max_diff = feats[i]
.iter()
.zip(feats[j].iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
assert!(
max_diff > 1e-4,
"window variants {i} and {j} produced identical fbank features (max diff {max_diff})"
);
}
}
}
#[test]
fn compute_fbank_kaldi_rejects_invalid_args() {
let zeros = vec![0.0_f32; 4_000];
let x = Array::from_slice::<f32>(&zeros, &[4_000_i32]).unwrap();
let two_d = Array::zeros::<f32>(&[2_i32, 100_i32]).unwrap();
assert!(matches!(
compute_fbank_kaldi(
&two_d,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None
),
Err(Error::RankMismatch(_))
));
assert!(matches!(
compute_fbank_kaldi(
&x,
0,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None
),
Err(Error::InvariantViolation(_))
));
assert!(matches!(
compute_fbank_kaldi(
&x,
16_000,
400,
0,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None
),
Err(Error::InvariantViolation(_))
));
assert!(matches!(
compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
-1.0,
true,
20.0,
0.0,
None
),
Err(Error::OutOfRange(_))
));
assert!(matches!(
compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.5,
true,
20.0,
0.0,
None
),
Err(Error::InvariantViolation(_))
));
assert!(matches!(
compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
1.5,
0.0,
true,
20.0,
0.0,
None
),
Err(Error::OutOfRange(_))
));
}
#[test]
fn compute_fbank_kaldi_preemphasis_is_applied() {
let samples: Vec<f32> = (0..4_000).map(|i| (i as f32) / 4_000.0).collect();
let x = Array::from_slice::<f32>(&samples, &[4_000_i32]).unwrap();
let no_pe = to_vec(
&compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.0,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap(),
);
let with_pe = to_vec(
&compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap(),
);
let max_diff = no_pe
.iter()
.zip(with_pe.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
assert!(
max_diff > 1e-2,
"preemphasis=0.97 must change the fbank features vs preemphasis=0.0 (max diff {max_diff})"
);
}
#[test]
fn compute_fbank_kaldi_dither_keyed_is_deterministic() {
let samples = sine_wave(440.0, 16_000, 4_000);
let x = Array::from_slice::<f32>(&samples, &[4_000_i32]).unwrap();
let key_a = ops::random::key(0xA5A5_A5A5).unwrap();
let key_b = ops::random::key(0x5A5A_5A5A).unwrap();
let key_a_again = ops::random::key(0xA5A5_A5A5).unwrap();
let feats_a = to_vec(
&compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.1,
true,
20.0,
0.0,
Some(&key_a),
)
.unwrap(),
);
let feats_a2 = to_vec(
&compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.1,
true,
20.0,
0.0,
Some(&key_a_again),
)
.unwrap(),
);
let feats_b = to_vec(
&compute_fbank_kaldi(
&x,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.1,
true,
20.0,
0.0,
Some(&key_b),
)
.unwrap(),
);
assert_eq!(feats_a.len(), feats_a2.len());
for (i, (a, a2)) in feats_a.iter().zip(feats_a2.iter()).enumerate() {
assert!(
(a - a2).abs() < 1e-5,
"same key must produce identical output at [{i}]: {a} vs {a2}"
);
}
let max_diff = feats_a
.iter()
.zip(feats_b.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
assert!(
max_diff > 1e-3,
"different keys must produce different features (max diff {max_diff})"
);
}
#[test]
fn compute_fbank_kaldi_output_element_cap_rejects_large_matmul() {
let samples = vec![0.0_f32; 128];
let x = Array::from_slice::<f32>(&samples, &[128_i32]).unwrap();
let err = compute_fbank_kaldi(
&x,
16_000,
2,
1,
1 << 20, KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err("expected output-element cap to reject pathological num_mels");
let msg = format!("{err:?}");
assert!(
msg.contains("output element count"),
"expected error to mention the output-element cap, got: {msg}"
);
}
#[test]
fn compute_fbank_kaldi_sliced_waveform_matches_contiguous() {
let full = sine_wave(1_000.0, 16_000, 18_000);
let full_arr = Array::from_slice::<f32>(&full, &[18_000_i32]).unwrap();
let sliced = full_arr.slice(&[1_000], &[17_000], &[1]).unwrap();
assert_eq!(sliced.shape(), vec![16_000]);
let contig = Array::from_slice::<f32>(&full[1_000..17_000], &[16_000_i32]).unwrap();
let from_sliced = to_vec(
&compute_fbank_kaldi(
&sliced,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap(),
);
let from_contig = to_vec(
&compute_fbank_kaldi(
&contig,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap(),
);
assert_eq!(from_sliced.len(), from_contig.len());
for (i, (a, b)) in from_sliced.iter().zip(from_contig.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-3,
"sliced[{i}] = {a} vs contig[{i}] = {b}: must match within 1e-3"
);
}
}
#[test]
fn compute_fbank_kaldi_broadcasted_scalar_waveform_matches_contiguous() {
let one = Array::from_slice::<f32>(&[0.5_f32], &[1_i32]).unwrap();
let bcast = crate::ops::shape::broadcast_to(&one, &[4_000_i32]).unwrap();
assert_eq!(bcast.shape(), vec![4_000]);
let constant_buf = vec![0.5_f32; 4_000];
let contig = Array::from_slice::<f32>(&constant_buf, &[4_000_i32]).unwrap();
let from_bcast = to_vec(
&compute_fbank_kaldi(
&bcast,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap(),
);
let from_contig = to_vec(
&compute_fbank_kaldi(
&contig,
16_000,
400,
160,
40,
KaldiWindow::Hamming,
0.97,
0.0,
true,
20.0,
0.0,
None,
)
.unwrap(),
);
assert_eq!(from_bcast.len(), from_contig.len());
for (i, (a, b)) in from_bcast.iter().zip(from_contig.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-3,
"bcast[{i}] = {a} vs contig[{i}] = {b}: must match within 1e-3"
);
}
}
#[test]
fn compute_fbank_kaldi_preemphasis_first_sample_matches_kaldi() {
let input = [2.0_f32, 1.0, 0.5, 0.25];
let mean = (input[0] + input[1] + input[2] + input[3]) / 4.0;
let centered: Vec<f32> = input.iter().map(|x| x - mean).collect();
let p = 0.5_f32;
let mut kaldi = [0.0_f32; 4];
kaldi[0] = centered[0] * (1.0 - p);
for n in 1..4 {
kaldi[n] = centered[n] - p * centered[n - 1];
}
let kaldi_sum: f32 = kaldi.iter().sum();
assert!(
(kaldi_sum - (-0.875)).abs() < 1e-5,
"Kaldi closed-form check: y-sum = {kaldi_sum}, want -0.875"
);
let mut mlx_audio = [0.0_f32; 4];
mlx_audio[0] = centered[0];
for n in 1..4 {
mlx_audio[n] = centered[n] - p * centered[n - 1];
}
let mlx_audio_sum: f32 = mlx_audio.iter().sum();
assert!(
(mlx_audio_sum - (-0.34375)).abs() < 1e-5,
"mlx-audio closed-form sentinel: y-sum = {mlx_audio_sum}, want -0.34375 \
(this assertion exists to prove the Kaldi vs mlx-audio distinction is observable)"
);
let x = Array::from_slice::<f32>(&input, &[4_i32]).unwrap();
let out = compute_fbank_kaldi(
&x,
16_000,
4, 4, 4, KaldiWindow::Rectangular,
p,
0.0,
true,
0.0,
0.0,
None,
)
.unwrap();
assert_eq!(out.shape(), vec![1, 4]);
let v = to_vec(&out);
for (i, &val) in v.iter().enumerate() {
assert!(
val.is_finite(),
"compute_fbank_kaldi[{i}] = {val}: must be finite under Kaldi preemphasis"
);
}
}
#[test]
fn compute_fbank_kaldi_samples_len_cap_rejects_huge_broadcast() {
let one = Array::from_slice::<f32>(&[0.5_f32], &[1_i32]).unwrap();
let bcast = crate::ops::shape::broadcast_to(&one, &[100_000_000_i32]).unwrap();
assert_eq!(bcast.shape(), vec![100_000_000]);
let err = compute_fbank_kaldi(
&bcast,
16_000,
2, 50_000_000, 1, KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err(
"expected samples_len cap to reject a 100 Mi broadcasted waveform \
BEFORE `contiguous` would materialize the logical extent",
);
let msg = format!("{err:?}");
assert!(
msg.contains("samples_len") && msg.contains("MAX_DECODED_SAMPLES"),
"expected error to mention samples_len cap + MAX_DECODED_SAMPLES, got: {msg}"
);
}
#[test]
fn compute_fbank_kaldi_padded_mel_bank_cap_rejects_doubled_operand() {
let samples = vec![0.0_f32; 2];
let x = Array::from_slice::<f32>(&samples, &[2_i32]).unwrap();
let num_mels = 64 * 1024 * 1024; let err = compute_fbank_kaldi(
&x,
16_000,
2, 1, num_mels,
KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err(
"expected padded-mel-bank cap to reject 64 Mi mels with n_fft_padded=2 \
(unpadded bank passes at-cap, padded operand doubles to 128 Mi)",
);
let msg = format!("{err:?}");
assert!(
msg.contains("padded mel-bank element count"),
"expected error to mention the padded mel-bank cap, got: {msg}"
);
}
#[test]
fn compute_fbank_kaldi_snip_edges_false_reflect_buffer_cap_rejects_doubled_waveform() {
let samples_len = MAX_FBANK_WORK; let len_i32 = i32::try_from(samples_len).unwrap();
let x = Array::zeros::<f32>(&[len_i32]).unwrap();
let err = compute_fbank_kaldi(
&x,
16_000,
2, 4, 4, KaldiWindow::Rectangular,
0.0,
0.0,
false, 0.0,
0.0,
None,
)
.expect_err(
"expected the reflect-buffer cap to reject a 64 Mi snip_edges=false \
waveform BEFORE the reflect bookends double it to ~128 Mi",
);
let msg = format!("{err:?}");
assert!(
msg.contains("reflect-padded buffer length") && msg.contains("work cap"),
"expected error to mention the reflect-padded buffer cap, got: {msg}"
);
}
#[test]
fn compute_fbank_kaldi_snip_edges_false_reflect_buffer_cap_rejects_pad_one_undercount() {
let samples_len = MAX_FBANK_WORK - 2; let len_i32 = i32::try_from(samples_len).unwrap();
let x = Array::zeros::<f32>(&[len_i32]).unwrap();
let err = compute_fbank_kaldi(
&x,
16_000,
1_048_576, 1_048_574, 4, KaldiWindow::Rectangular,
0.0,
0.0,
false, 0.0,
0.0,
None,
)
.expect_err(
"expected the per-branch reflect-buffer cap to reject a pad==1 waveform \
whose true 2*n reflected buffer exceeds the cap (n + 2 is within it)",
);
let msg = format!("{err:?}");
assert!(
msg.contains("reflect-padded buffer length") && msg.contains("work cap"),
"expected error to mention the reflect-padded buffer cap, got: {msg}"
);
}
#[test]
fn strided_no_snip_edges_rejects_oversized_reflect_buffer() {
let n = MAX_FBANK_WORK; let n_i32 = i32::try_from(n).unwrap();
let huge = Array::zeros::<f32>(&[n_i32]).unwrap();
let num_frames = (n + 4 / 2) / 4;
let err = strided_frames_no_snip_edges(&huge, 2, 4, num_frames)
.expect_err("expected the reflect-buffer cap to reject a doubled 64 Mi waveform");
let msg = format!("{err:?}");
assert!(
msg.contains("reflect-padded buffer length"),
"expected a reflect-padded buffer cap error, got: {msg}"
);
let wf: Vec<f32> = (0..10).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10]).unwrap();
let m = (10 + 2 / 2) / 2; let ok = strided_frames_no_snip_edges(&x, 4, 2, m).unwrap();
assert_eq!(
ok.shape(),
vec![5, 4],
"normal snip_edges=false framing still works"
);
}
#[test]
fn strided_no_snip_edges_pad_one_rejects_undercounted_reflect_buffer() {
let n = MAX_FBANK_WORK - 2; assert!(
n + 2 <= MAX_FBANK_WORK,
"the bug's `n + 2*pad` estimate must be within the cap"
);
assert!(
n.checked_mul(2).unwrap() > MAX_FBANK_WORK,
"the actual `2*n` pad==1 reflected buffer must exceed the cap"
);
let n_i32 = i32::try_from(n).unwrap();
let huge = Array::zeros::<f32>(&[n_i32]).unwrap();
let num_frames = (n + 2 / 2) / 2;
let err = strided_frames_no_snip_edges(&huge, 4, 2, num_frames).expect_err(
"expected the per-branch cap to reject a pad==1 waveform whose true 2*n \
reflected buffer exceeds the cap (even though n + 2 is within it)",
);
let msg = format!("{err:?}");
assert!(
msg.contains("reflect-padded buffer length") && msg.contains("work cap"),
"expected a reflect-padded buffer cap error, got: {msg}"
);
}
#[test]
fn strided_no_snip_edges_pad_one_small_input_correct_frames() {
let wf: Vec<f32> = (0..8).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[8]).unwrap();
let m = (8 + 2 / 2) / 2; let frames = strided_frames_no_snip_edges(&x, 4, 2, m).unwrap();
assert_eq!(frames.shape(), vec![4, 4]);
let got = to_vec_2d(&frames, 4, 4);
let want = [
[1.0_f32, 0.0, 1.0, 2.0],
[1.0, 2.0, 3.0, 4.0],
[3.0, 4.0, 5.0, 6.0],
[5.0, 6.0, 7.0, 7.0],
];
assert_eq!(
got, want,
"pad==1 small-input snip_edges=false frames mismatch"
);
}
fn to_vec_2d(a: &Array, rows: usize, cols: usize) -> Vec<Vec<f32>> {
let contig = ops::shape::contiguous(a, false).unwrap();
let flat = to_vec(&contig);
assert_eq!(flat.len(), rows * cols, "to_vec_2d shape mismatch");
(0..rows)
.map(|r| flat[r * cols..(r + 1) * cols].to_vec())
.collect()
}
#[test]
fn compute_deltas_kaldi_win5_edge_matches_reference() {
let x =
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 5]).unwrap();
let d = compute_deltas_kaldi(&x, 5, DeltaPadMode::Edge).unwrap();
assert_eq!(d.shape(), vec![2, 5]);
let got = to_vec_2d(&d, 2, 5);
let want = [[0.5_f32, 0.8, 1.0, 0.8, 0.5], [0.2, 0.1, 0.0, -0.1, -0.2]];
for r in 0..2 {
for c in 0..5 {
assert!(
(got[r][c] - want[r][c]).abs() < F32_TOL,
"delta[{r}][{c}]: got {}, want {}",
got[r][c],
want[r][c]
);
}
}
}
#[test]
fn compute_deltas_kaldi_win3_constant_matches_reference() {
let x =
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 5]).unwrap();
let d = compute_deltas_kaldi(&x, 3, DeltaPadMode::Constant).unwrap();
let got = to_vec_2d(&d, 2, 5);
let want = [[1.0_f32, 1.0, 1.0, 1.0, -2.0], [0.0, 0.5, 0.0, -0.5, 0.0]];
for r in 0..2 {
for c in 0..5 {
assert!(
(got[r][c] - want[r][c]).abs() < F32_TOL,
"delta[{r}][{c}]: got {}, want {}",
got[r][c],
want[r][c]
);
}
}
}
#[test]
fn compute_deltas_kaldi_1d_ramp_interior_is_unit_slope() {
let ramp: Vec<f32> = (0..8).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&ramp, &[8]).unwrap();
let d = compute_deltas_kaldi(&x, 5, DeltaPadMode::Edge).unwrap();
assert_eq!(d.shape(), vec![8]);
let got = to_vec(&d);
for (i, &g) in got.iter().enumerate().take(6).skip(2) {
assert!(
(g - 1.0).abs() < F32_TOL,
"ramp delta[{i}]: got {g}, want 1.0"
);
}
}
#[test]
fn compute_deltas_kaldi_delta_delta_is_zero_for_ramp_interior() {
let ramp: Vec<f32> = (0..12).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&ramp, &[12]).unwrap();
let d = compute_deltas_kaldi(&x, 3, DeltaPadMode::Edge).unwrap();
let dd = compute_deltas_kaldi(&d, 3, DeltaPadMode::Edge).unwrap();
let got = to_vec(&dd);
for (i, &g) in got.iter().enumerate().take(10).skip(2) {
assert!(g.abs() < F32_TOL, "ramp delta-delta[{i}]: got {g}, want ~0");
}
}
#[test]
fn compute_deltas_kaldi_rejects_invalid_win_length() {
let x = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &[1, 4]).unwrap();
assert!(matches!(
compute_deltas_kaldi(&x, 2, DeltaPadMode::Edge),
Err(Error::OutOfRange(_))
));
assert!(matches!(
compute_deltas_kaldi(&x, 4, DeltaPadMode::Edge),
Err(Error::OutOfRange(_))
));
}
#[test]
fn compute_deltas_kaldi_rejects_huge_win_length_on_tiny_input() {
let x = Array::from_slice::<f32>(&[1.0], &[1]).unwrap();
let huge = 4_000_001_usize; assert!(!huge.is_multiple_of(2), "win_length must be odd");
for mode in [DeltaPadMode::Edge, DeltaPadMode::Constant] {
assert!(
matches!(
compute_deltas_kaldi(&x, huge, mode),
Err(Error::CapExceeded(_))
),
"huge win_length on a 1-element input must be rejected ({mode:?})"
);
}
let ok = compute_deltas_kaldi(&x, 5, DeltaPadMode::Edge).unwrap();
assert_eq!(ok.shape(), vec![1], "shape preserved for the tiny input");
let got = to_vec(&ok);
assert!(
got[0].abs() < F32_TOL,
"single-value edge-padded delta should be 0, got {}",
got[0]
);
}
#[test]
fn compute_deltas_kaldi_rejects_padded_work_over_cap() {
let num_features = MAX_FBANK_WORK; let nf_i32 = i32::try_from(num_features).unwrap();
let x = Array::zeros::<f32>(&[nf_i32, 1]).unwrap();
assert!(
matches!(
compute_deltas_kaldi(&x, 5, DeltaPadMode::Edge),
Err(Error::CapExceeded(_))
),
"padded work exceeding the cap must be rejected before allocating"
);
}
#[test]
fn compute_deltas_kaldi_rejects_cumulative_work_over_cap() {
let win_length = 1023_usize; let n = (win_length - 1) / 2; let len = MAX_FBANK_WORK - 2 * n; assert!(!win_length.is_multiple_of(2), "win_length must be odd");
assert!(
win_length <= MAX_DELTA_WIN_LENGTH,
"win_length must clear the win_length cap so the cumulative cap is reached"
);
assert!(len <= MAX_FBANK_WORK, "total must pass the total cap");
assert_eq!(
len + 2 * n,
MAX_FBANK_WORK,
"padded_work is at-cap (passes)"
);
assert!(
len.checked_mul(win_length - 1).unwrap() > MAX_DELTA_WORK,
"delta_work must exceed the cumulative-work cap"
);
let len_i32 = i32::try_from(len).unwrap();
let x = Array::zeros::<f32>(&[len_i32]).unwrap();
let err = compute_deltas_kaldi(&x, win_length, DeltaPadMode::Edge).expect_err(
"expected the cumulative-work cap to reject total * (win_length - 1) \
BEFORE the per-offset accumulation loop",
);
let msg = format!("{err:?}");
assert!(
msg.contains("accumulation work") && msg.contains("work cap"),
"expected the cumulative accumulation-work cap error, got: {msg}"
);
let small =
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0, 5.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 5]).unwrap();
let ok = compute_deltas_kaldi(&small, 5, DeltaPadMode::Edge)
.expect("a normal win_length=5 must pass the cumulative-work cap");
assert_eq!(
ok.shape(),
vec![2, 5],
"normal win_length=5 deltas still work"
);
}
#[test]
fn strided_no_snip_edges_win4_shift2_boundary_values() {
let wf: Vec<f32> = (0..10).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10]).unwrap();
let m = (10 + 2 / 2) / 2; let frames = strided_frames_no_snip_edges(&x, 4, 2, m).unwrap();
assert_eq!(frames.shape(), vec![5, 4]);
let got = to_vec_2d(&frames, 5, 4);
let want = [
[1.0_f32, 0.0, 1.0, 2.0],
[1.0, 2.0, 3.0, 4.0],
[3.0, 4.0, 5.0, 6.0],
[5.0, 6.0, 7.0, 8.0],
[7.0, 8.0, 9.0, 9.0],
];
assert_eq!(got, want, "snip_edges=false win4 shift2 frames mismatch");
}
#[test]
fn strided_no_snip_edges_win6_shift2_left_reflect_bookend() {
let wf: Vec<f32> = (0..10).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10]).unwrap();
let frames = strided_frames_no_snip_edges(&x, 6, 2, 5).unwrap();
assert_eq!(frames.shape(), vec![5, 6]);
let got = to_vec_2d(&frames, 5, 6);
assert_eq!(
got[0],
vec![2.0, 1.0, 0.0, 1.0, 2.0, 3.0],
"left reflect bookend (pad=2) mismatch"
);
assert_eq!(
got[4],
vec![6.0, 7.0, 8.0, 9.0, 9.0, 8.0],
"right reflect bookend (pad=2) mismatch"
);
}
#[test]
fn strided_no_snip_edges_pad_zero_path() {
let wf: Vec<f32> = (0..10).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10]).unwrap();
let m = (10 + 4 / 2) / 4; let frames = strided_frames_no_snip_edges(&x, 4, 4, m).unwrap();
assert_eq!(frames.shape(), vec![3, 4]);
let got = to_vec_2d(&frames, 3, 4);
let want = [
[0.0_f32, 1.0, 2.0, 3.0],
[4.0, 5.0, 6.0, 7.0],
[8.0, 9.0, 9.0, 8.0],
];
assert_eq!(got, want, "snip_edges=false pad<=0 path frames mismatch");
}
#[test]
fn strided_no_snip_edges_produces_extra_frame_vs_snip_true() {
let wf: Vec<f32> = (0..10).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10]).unwrap();
let m_true = 1 + (10 - 4) / 2; let m_false = (10 + 2 / 2) / 2; assert_eq!(
m_false,
m_true + 1,
"snip=false should yield one extra frame"
);
let f_true = strided_frames_snip_edges(&x, 4, 2, m_true).unwrap();
let f_false = strided_frames_no_snip_edges(&x, 4, 2, m_false).unwrap();
assert_eq!(f_true.shape(), vec![4, 4]);
assert_eq!(f_false.shape(), vec![5, 4]);
}
#[test]
fn strided_no_snip_edges_rejects_degenerate_overread() {
let wf: Vec<f32> = (0..5).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[5]).unwrap();
let err = strided_frames_no_snip_edges(&x, 8, 2, 3)
.expect_err("expected degenerate overread to be rejected");
let Error::OutOfRange(payload) = &err else {
panic!("expected OutOfRange overread/short-signal error, got: {err:?}");
};
assert!(
payload.context().contains("reflect-pad") || payload.requirement().contains("reflect-pad"),
"expected an overread/short-signal error referencing reflect-pad, got: context={:?}, requirement={:?}",
payload.context(),
payload.requirement()
);
}
#[test]
fn kaldi_window_as_str_and_display_are_canonical() {
for (w, s) in [
(KaldiWindow::Hamming, "hamming"),
(KaldiWindow::Hanning, "hanning"),
(KaldiWindow::Povey, "povey"),
(KaldiWindow::Rectangular, "rectangular"),
] {
assert_eq!(w.as_str(), s, "{w:?}.as_str()");
assert_eq!(w.to_string(), s, "{w:?} Display");
}
assert_eq!(KaldiWindow::default(), KaldiWindow::Hamming);
}
#[test]
fn kaldi_window_variant_predicates() {
assert!(KaldiWindow::Hamming.is_hamming());
assert!(!KaldiWindow::Hamming.is_hanning());
assert!(KaldiWindow::Hanning.is_hanning());
assert!(KaldiWindow::Povey.is_povey());
assert!(KaldiWindow::Rectangular.is_rectangular());
assert!(!KaldiWindow::Rectangular.is_povey());
}
#[test]
fn delta_pad_mode_as_str_and_display_and_predicates() {
assert_eq!(DeltaPadMode::Edge.as_str(), "edge");
assert_eq!(DeltaPadMode::Constant.as_str(), "constant");
assert_eq!(DeltaPadMode::Edge.to_string(), "edge");
assert_eq!(DeltaPadMode::Constant.to_string(), "constant");
assert!(DeltaPadMode::Edge.is_edge());
assert!(!DeltaPadMode::Edge.is_constant());
assert!(DeltaPadMode::Constant.is_constant());
assert_eq!(DeltaPadMode::default(), DeltaPadMode::Edge);
}
#[test]
fn mel_banks_kaldi_positive_high_freq_used_verbatim() {
let (_, centers) = get_mel_banks_kaldi(40, 512, 16_000.0, 20.0, 4000.0).unwrap();
let c = to_vec(¢ers);
assert!(
c[c.len() - 1] < 4000.0,
"explicit high_freq=4000 must cap the top center below 4000, got {}",
c[c.len() - 1]
);
assert!(
c[c.len() - 1] > 1500.0,
"top center {} should approach the explicit high_freq",
c[c.len() - 1]
);
}
#[test]
fn mel_banks_kaldi_rejects_high_freq_above_nyquist() {
let err = get_mel_banks_kaldi(40, 512, 16_000.0, 20.0, 9000.0)
.expect_err("high_freq above nyquist must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(
p.context(),
"get_mel_banks_kaldi: high_freq",
"expected the high_freq range branch, got context {:?}",
p.context()
);
}
#[test]
fn mel_banks_kaldi_rejects_low_freq_ge_high_freq_both_valid() {
let err = get_mel_banks_kaldi(40, 512, 16_000.0, 4000.0, 2000.0)
.expect_err("low_freq >= high_freq must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(
p.context(),
"get_mel_banks_kaldi: low_freq",
"expected the low>=high branch context"
);
assert_eq!(p.requirement(), "must be < high_freq");
}
#[test]
fn mel_banks_kaldi_rejects_zero_n_fft_padded() {
let err = get_mel_banks_kaldi(40, 0, 16_000.0, 20.0, 0.0)
.expect_err("zero n_fft_padded must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "get_mel_banks_kaldi: n_fft_padded");
}
#[test]
fn mel_banks_kaldi_rejects_nonfinite_sample_freq() {
for bad in [f32::NAN, f32::INFINITY, -1.0_f32] {
let err = get_mel_banks_kaldi(40, 512, bad, 20.0, 0.0)
.expect_err("non-finite/negative sample_freq must be rejected");
assert!(
matches!(err, Error::OutOfRange(_)),
"sample_freq={bad} expected OutOfRange, got {err:?}"
);
}
}
#[test]
fn mel_banks_kaldi_cap_rejects_oversized_bank() {
let err = get_mel_banks_kaldi(4096, 65_536, 16_000.0, 20.0, 0.0)
.expect_err("oversized bank_len must be capped");
let Error::CapExceeded(p) = &err else {
panic!("expected CapExceeded, got {err:?}");
};
assert_eq!(p.cap_name(), "MAX_FBANK_WORK");
assert!(
p.observed() > p.cap(),
"observed {} must exceed cap {}",
p.observed(),
p.cap()
);
}
#[test]
fn build_kaldi_window_rejects_too_small() {
for bad in [0_usize, 1] {
let err =
build_kaldi_window(KaldiWindow::Hamming, bad).expect_err("win_size < 2 must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "build_kaldi_window: win_size");
}
}
#[test]
fn build_kaldi_window_povey_closed_form() {
let w = build_kaldi_window(KaldiWindow::Povey, 4).unwrap();
assert_eq!(w.shape(), vec![4]);
let v = to_vec(&w);
let mid = 0.75_f32.powf(0.85);
let want = [0.0_f32, mid, mid, 0.0];
for (i, (&g, &e)) in v.iter().zip(want.iter()).enumerate() {
assert!((g - e).abs() < F32_TOL, "povey[{i}]: got {g}, want {e}");
}
}
#[test]
fn build_kaldi_window_rectangular_is_all_ones() {
let w = build_kaldi_window(KaldiWindow::Rectangular, 5).unwrap();
let v = to_vec(&w);
assert_eq!(v, vec![1.0_f32; 5]);
}
#[test]
fn build_kaldi_window_hanning_closed_form() {
let w = build_kaldi_window(KaldiWindow::Hanning, 5).unwrap();
let v = to_vec(&w);
let want = [0.0_f32, 0.5, 1.0, 0.5, 0.0];
for (i, (&g, &e)) in v.iter().zip(want.iter()).enumerate() {
assert!((g - e).abs() < F32_TOL, "hanning[{i}]: got {g}, want {e}");
}
}
#[test]
fn reverse_1d_basic_reverses_including_index_zero() {
let a = Array::from_slice::<f32>(&[10.0, 20.0, 30.0, 40.0], &[4]).unwrap();
let r = reverse_1d(&a).unwrap();
assert_eq!(r.shape(), vec![4]);
assert_eq!(to_vec(&r), vec![40.0, 30.0, 20.0, 10.0]);
}
#[test]
fn reverse_1d_single_element() {
let a = Array::from_slice::<f32>(&[7.0], &[1]).unwrap();
let r = reverse_1d(&a).unwrap();
assert_eq!(to_vec(&r), vec![7.0]);
}
#[test]
fn reverse_1d_rejects_non_1d() {
let a = Array::zeros::<f32>(&[2_i32, 3_i32]).unwrap();
let err = reverse_1d(&a).expect_err("2-D input must be rejected");
let Error::RankMismatch(p) = &err else {
panic!("expected RankMismatch, got {err:?}");
};
assert_eq!(p.actual(), 2, "observed rank");
assert_eq!(p.context(), "reverse_1d: expected 1-D input");
}
#[test]
fn reverse_1d_rejects_empty() {
let a = Array::zeros::<f32>(&[0_i32]).unwrap();
let err = reverse_1d(&a).expect_err("empty input must be rejected");
assert!(
matches!(err, Error::EmptyInput(_)),
"expected EmptyInput, got {err:?}"
);
}
#[test]
fn strided_frames_snip_edges_basic_frames() {
let wf: Vec<f32> = (0..8).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[8]).unwrap();
let frames = strided_frames_snip_edges(&x, 4, 2, 3).unwrap();
assert_eq!(frames.shape(), vec![3, 4]);
let got = to_vec_2d(&frames, 3, 4);
let want = [
[0.0_f32, 1.0, 2.0, 3.0],
[2.0, 3.0, 4.0, 5.0],
[4.0, 5.0, 6.0, 7.0],
];
assert_eq!(got, want, "snip_edges=true framing mismatch");
}
#[test]
fn strided_frames_snip_edges_rejects_overread() {
let wf: Vec<f32> = (0..5).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[5]).unwrap();
let err =
strided_frames_snip_edges(&x, 4, 2, 3).expect_err("overreading the waveform must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert!(
p.value().contains("last_index=8") && p.value().contains("waveform_len=5"),
"expected the overread diagnostic, got value {:?}",
p.value()
);
}
#[test]
fn strided_no_snip_edges_rejects_non_1d() {
let a = Array::zeros::<f32>(&[2_i32, 4_i32]).unwrap();
let err = strided_frames_no_snip_edges(&a, 4, 2, 2).expect_err("2-D waveform must be rejected");
let Error::RankMismatch(p) = &err else {
panic!("expected RankMismatch, got {err:?}");
};
assert_eq!(
p.context(),
"strided_frames_no_snip_edges: expected 1-D waveform"
);
assert_eq!(p.actual(), 2);
}
#[test]
fn strided_no_snip_edges_zero_frames_returns_empty() {
let wf: Vec<f32> = (0..4).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[4]).unwrap();
let frames = strided_frames_no_snip_edges(&x, 2, 2, 0).unwrap();
assert_eq!(frames.shape(), vec![0, 0]);
}
#[test]
fn strided_no_snip_edges_pad_gt1_rejects_signal_shorter_than_pad() {
let wf: Vec<f32> = (0..3).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[3]).unwrap();
let err = strided_frames_no_snip_edges(&x, 8, 2, 2)
.expect_err("signal shorter than the reflect pad must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.requirement(), "must be >= pad + 1");
assert!(
p.value().contains("n=3") && p.value().contains("pad=3"),
"expected n/pad diagnostic, got {:?}",
p.value()
);
}
#[test]
fn strided_no_snip_edges_rejects_abs_pad_gt_n() {
let wf: Vec<f32> = (0..5).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[5]).unwrap();
let err = strided_frames_no_snip_edges(&x, 2, 200, 1)
.expect_err("|pad| larger than the signal must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.requirement(), "must be <= waveform len");
assert!(
p.value().contains("abs_pad=99") && p.value().contains("n=5"),
"expected abs_pad/n diagnostic, got {:?}",
p.value()
);
}
#[test]
fn compute_fbank_kaldi_rejects_win_len_below_two() {
let x = Array::from_slice::<f32>(&[0.0_f32; 8], &[8_i32]).unwrap();
let err = compute_fbank_kaldi(
&x,
16_000,
1,
1,
4,
KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err("win_len < 2 must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "compute_fbank_kaldi: win_len");
}
#[test]
fn compute_fbank_kaldi_rejects_win_len_over_decoded_samples_cap() {
let cap = crate::audio::io::MAX_DECODED_SAMPLES;
let x = Array::from_slice::<f32>(&[0.0_f32; 8], &[8_i32]).unwrap();
let err = compute_fbank_kaldi(
&x,
16_000,
cap + 1, 160,
4,
KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err("win_len over MAX_DECODED_SAMPLES must be rejected");
let Error::CapExceeded(p) = &err else {
panic!("expected CapExceeded, got {err:?}");
};
assert_eq!(p.cap_name(), "MAX_DECODED_SAMPLES");
assert_eq!(p.context(), "compute_fbank_kaldi: win_len exceeds cap");
}
#[test]
fn compute_fbank_kaldi_snip_edges_false_zero_frames_returns_empty() {
let x = Array::from_slice::<f32>(&[0.5_f32], &[1_i32]).unwrap();
let out = compute_fbank_kaldi(
&x,
16_000,
2, 4, 7, KaldiWindow::Rectangular,
0.0,
0.0,
false, 0.0,
0.0,
None,
)
.unwrap();
assert_eq!(out.shape(), vec![0, 7]);
}
#[test]
fn compute_deltas_kaldi_rejects_rank_zero_scalar() {
let scalar = Array::full::<f32>(&[0_i32; 0], 3.0).unwrap();
assert!(scalar.shape().is_empty(), "scalar must be rank-0");
let err = compute_deltas_kaldi(&scalar, 5, DeltaPadMode::Edge)
.expect_err("rank-0 specgram must be rejected");
let Error::RankMismatch(p) = &err else {
panic!("expected RankMismatch, got {err:?}");
};
assert_eq!(p.actual(), 0, "rank-0 reported");
assert_eq!(
p.context(),
"compute_deltas_kaldi: specgram must have rank >= 1 (a time axis)"
);
}
#[test]
fn compute_deltas_kaldi_time_zero_returns_empty_same_shape() {
let x = Array::zeros::<f32>(&[2_i32, 0_i32]).unwrap();
let out = compute_deltas_kaldi(&x, 5, DeltaPadMode::Edge).unwrap();
assert_eq!(out.shape(), vec![2, 0], "time==0 preserves the empty shape");
}
#[test]
fn compute_deltas_kaldi_rejects_total_over_cap() {
let total = MAX_FBANK_WORK + 1;
let total_i32 = i32::try_from(total).unwrap();
let x = Array::zeros::<f32>(&[total_i32]).unwrap();
let err = compute_deltas_kaldi(&x, 5, DeltaPadMode::Edge)
.expect_err("total over MAX_FBANK_WORK must be rejected");
let Error::CapExceeded(p) = &err else {
panic!("expected CapExceeded, got {err:?}");
};
assert_eq!(p.cap_name(), "MAX_FBANK_WORK");
assert_eq!(
p.context(),
"compute_deltas_kaldi: element count exceeds work cap"
);
}
#[test]
fn compute_deltas_kaldi_constant_mode_1d_ramp_interior_unit_slope() {
let ramp: Vec<f32> = (0..9).map(|n| n as f32).collect();
let x = Array::from_slice::<f32>(&ramp, &[9]).unwrap();
let d = compute_deltas_kaldi(&x, 5, DeltaPadMode::Constant).unwrap();
assert_eq!(d.shape(), vec![9]);
let got = to_vec(&d);
for (i, &g) in got.iter().enumerate().take(7).skip(2) {
assert!(
(g - 1.0).abs() < F32_TOL,
"constant-pad ramp delta[{i}]: got {g}, want 1.0"
);
}
}
#[test]
fn mel_banks_kaldi_num_bins_times_fft_bins_overflows_usize() {
let num_bins = 1_usize << 40;
let n_fft_padded = 1_usize << 26; let num_fft_bins = n_fft_padded / 2; assert!(
num_bins.checked_mul(num_fft_bins).is_none(),
"test premise: num_bins * num_fft_bins must overflow usize"
);
let err = get_mel_banks_kaldi(num_bins, n_fft_padded, 16_000.0, 20.0, 0.0)
.expect_err("num_bins * num_fft_bins overflow must be rejected");
let Error::ArithmeticOverflow(p) = &err else {
panic!("expected ArithmeticOverflow, got {err:?}");
};
assert_eq!(p.op_type(), "usize");
assert!(
p.context()
.contains("get_mel_banks_kaldi: num_bins * num_fft_bins"),
"expected the num_bins*num_fft_bins overflow context, got {:?}",
p.context()
);
let names: Vec<&str> = p.operands().iter().map(|(n, _)| *n).collect();
assert!(
names.contains(&"num_bins") && names.contains(&"num_fft_bins"),
"expected num_bins + num_fft_bins operands, got {names:?}"
);
}
#[test]
fn build_kaldi_window_rejects_win_size_over_i32() {
let win_size = (i32::MAX as usize) + 1;
assert!(i32::try_from(win_size).is_err(), "test premise: > i32::MAX");
let err = build_kaldi_window(KaldiWindow::Hamming, win_size)
.expect_err("win_size over i32::MAX must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "build_kaldi_window: win_size");
assert!(
p.requirement().contains("fit in i32"),
"expected the i32-fit requirement, got {:?}",
p.requirement()
);
}
#[test]
fn strided_frames_snip_edges_last_index_overflows_usize() {
let x = Array::from_slice::<f32>(&[0.0_f32; 4], &[4_i32]).unwrap();
assert!(
1_usize
.checked_mul(usize::MAX)
.and_then(|v| v.checked_add(10))
.is_none(),
"test premise: (num_frames-1)*win_inc + win_size must overflow usize"
);
let err = strided_frames_snip_edges(&x, 10, usize::MAX, 2)
.expect_err("reachable-index overflow must be rejected");
let Error::ArithmeticOverflow(p) = &err else {
panic!("expected ArithmeticOverflow, got {err:?}");
};
assert_eq!(p.op_type(), "usize");
assert!(
p.context()
.contains("strided_frames_snip_edges: reachable element range"),
"expected the reachable-range overflow context, got {:?}",
p.context()
);
}
#[test]
fn strided_frames_snip_edges_rejects_num_frames_over_i32() {
let x = Array::from_slice::<f32>(&[0.0_f32; 5], &[5_i32]).unwrap();
let num_frames = (i32::MAX as usize) + 1;
let err = strided_frames_snip_edges(&x, 4, 0, num_frames)
.expect_err("num_frames over i32::MAX must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "strided_frames_snip_edges: num_frames");
}
#[test]
fn strided_frames_snip_edges_rejects_win_inc_over_i64() {
let x = Array::from_slice::<f32>(&[0.0_f32; 8], &[8_i32]).unwrap();
let win_inc = (i64::MAX as usize) + 1;
assert!(i64::try_from(win_inc).is_err(), "test premise: > i64::MAX");
let err = strided_frames_snip_edges(&x, 4, win_inc, 1)
.expect_err("win_inc over i64::MAX must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "strided_frames_snip_edges: win_inc");
assert!(
p.requirement().contains("fit in i64"),
"expected the i64-fit requirement, got {:?}",
p.requirement()
);
}
#[test]
fn strided_no_snip_edges_last_index_overflows_usize() {
let wf: Vec<f32> = (0..10).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10_i32]).unwrap();
assert!(
(usize::MAX - 1).checked_mul(2).is_none(),
"test premise: (num_frames-1)*win_inc must overflow usize"
);
let err = strided_frames_no_snip_edges(&x, 4, 2, usize::MAX)
.expect_err("reachable-index overflow must be rejected");
let Error::ArithmeticOverflow(p) = &err else {
panic!("expected ArithmeticOverflow, got {err:?}");
};
assert_eq!(p.op_type(), "usize");
assert!(
p.context()
.contains("strided_frames_no_snip_edges: reachable element range"),
"expected the reachable-range overflow context, got {:?}",
p.context()
);
}
#[test]
fn strided_no_snip_edges_rejects_num_frames_over_i32() {
let wf: Vec<f32> = (0..10).map(|v| v as f32).collect();
let x = Array::from_slice::<f32>(&wf, &[10_i32]).unwrap();
let num_frames = (i32::MAX as usize) + 1;
let err = strided_frames_no_snip_edges(&x, 4, 0, num_frames)
.expect_err("num_frames over i32::MAX must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "strided_frames_no_snip_edges: num_frames");
}
#[test]
fn compute_fbank_kaldi_rejects_num_mels_over_i32() {
let x = Array::from_slice::<f32>(&[0.0_f32; 8], &[8_i32]).unwrap();
let num_mels = (i32::MAX as usize) + 1;
let err = compute_fbank_kaldi(
&x,
16_000,
4,
2,
num_mels,
KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err("num_mels over i32::MAX must be rejected");
let Error::OutOfRange(p) = &err else {
panic!("expected OutOfRange, got {err:?}");
};
assert_eq!(p.context(), "compute_fbank_kaldi: num_mels");
}
#[test]
fn compute_fbank_kaldi_frame_work_cap_rejects_large_framed_matrix() {
let samples_len = 200_000;
let n_fft_padded = 1024_usize; let num_frames = 1 + (samples_len - 1000); assert!(
num_frames * n_fft_padded > MAX_FBANK_WORK,
"test premise: frame_work must exceed the cap"
);
assert!(
num_frames * 4 <= MAX_FBANK_WORK,
"test premise: output_elems must stay under the cap so frame_work is binding"
);
let x = Array::zeros::<f32>(&[i32::try_from(samples_len).unwrap()]).unwrap();
let err = compute_fbank_kaldi(
&x,
16_000,
1000, 1, 4, KaldiWindow::Rectangular,
0.0,
0.0,
true,
0.0,
0.0,
None,
)
.expect_err("frame_work over MAX_FBANK_WORK must be rejected");
let Error::CapExceeded(p) = &err else {
panic!("expected CapExceeded, got {err:?}");
};
assert_eq!(p.cap_name(), "MAX_FBANK_WORK");
assert!(
p.context().contains("frame work"),
"expected the frame-work cap context, got {:?}",
p.context()
);
assert!(
p.observed() > p.cap(),
"observed {} must exceed cap {}",
p.observed(),
p.cap()
);
}