svod-model 0.1.0-alpha.3

Pretrained models inference abstraction.
Documentation
use crate::audio::mel::hann_window;
use crate::audio::{MelConfig, MelSpectrogram};

struct MelOutput {
    data: ndarray::Array3<f32>,
}

impl MelOutput {
    fn shape(&self) -> (usize, usize, usize) {
        let s = self.data.shape();
        (s[0], s[1], s[2])
    }

    fn as_slice(&self) -> &[f32] {
        self.data.as_slice().expect("contiguous mel buffer")
    }
}

fn run_mel(config: &MelConfig, waveform: &[f32]) -> MelOutput {
    let mel = MelSpectrogram::new(config);
    let n_mels = mel.n_mels();
    let n_frames = mel.num_frames(waveform.len());
    let mut data = ndarray::Array3::<f32>::zeros((1, n_mels, n_frames));
    let mut view = data.view_mut().into_dyn();
    mel.forward_into(waveform, &mut view);
    MelOutput { data }
}

#[test]
fn test_mel_spectrogram_shape_center_true() {
    let config =
        MelConfig { sample_rate: 16000, n_fft: 400, hop_length: 160, win_length: 400, n_mels: 64, center: true };

    let waveform: Vec<f32> =
        (0..16000).map(|i| (i as f32 * 440.0 * 2.0 * std::f32::consts::PI / 16000.0).sin()).collect();
    let output = run_mel(&config, &waveform);

    assert_eq!(output.shape(), (1, 64, 101));
}

#[test]
fn test_mel_spectrogram_shape_center_false() {
    let config =
        MelConfig { sample_rate: 16000, n_fft: 320, hop_length: 160, win_length: 320, n_mels: 64, center: false };

    let waveform: Vec<f32> =
        (0..16000).map(|i| (i as f32 * 440.0 * 2.0 * std::f32::consts::PI / 16000.0).sin()).collect();
    let output = run_mel(&config, &waveform);

    assert_eq!(output.shape(), (1, 64, 99));
}

#[test]
fn test_mel_spectrogram_values_finite() {
    let config =
        MelConfig { sample_rate: 16000, n_fft: 400, hop_length: 160, win_length: 400, n_mels: 64, center: true };

    let waveform: Vec<f32> = vec![0.0; 1600];
    let output = run_mel(&config, &waveform);

    for v in output.as_slice() {
        assert!(v.is_finite(), "mel output contains non-finite value: {v}");
    }
}

#[test]
fn test_mel_spectrogram_sine_wave() {
    let config =
        MelConfig { sample_rate: 16000, n_fft: 400, hop_length: 160, win_length: 400, n_mels: 64, center: true };

    let waveform: Vec<f32> =
        (0..16000).map(|i| (i as f32 * 440.0 * 2.0 * std::f32::consts::PI / 16000.0).sin()).collect();
    let output = run_mel(&config, &waveform);

    let vals = output.as_slice();
    let (_, n_mels, n_frames) = output.shape();

    let mut avg_energy: Vec<f32> = vec![0.0; n_mels];
    for mel_idx in 0..n_mels {
        for frame in 0..n_frames {
            avg_energy[mel_idx] += vals[mel_idx * n_frames + frame];
        }
        avg_energy[mel_idx] /= n_frames as f32;
    }

    let lower_avg: f32 = avg_energy[..20].iter().sum::<f32>() / 20.0;
    let upper_avg: f32 = avg_energy[40..].iter().sum::<f32>() / 24.0;
    assert!(
        lower_avg > upper_avg,
        "Expected lower mel bins to have more energy for 440Hz sine: lower={lower_avg:.2}, upper={upper_avg:.2}"
    );
}

#[test]
fn test_hann_window_matches_torch_periodic_default() {
    let window = hann_window(8, 8);
    let expected = [0.0, 0.14644662, 0.5, 0.8535534, 1.0, 0.8535533, 0.5, 0.1464465];

    for (got, want) in window.iter().zip(expected) {
        assert!((got - want).abs() < 1e-6, "got {got}, want {want}");
    }
}