Skip to main content

ferrum_models/
mel.rs

1//! Mel spectrogram computation matching Python whisper exactly.
2//!
3//! Uses rustfft for STFT (matching torch.stft with center=True),
4//! pre-computed mel filterbank, and identical normalization.
5
6use rustfft::{num_complex::Complex, FftPlanner};
7
8const N_FFT: usize = 400;
9const HOP_LENGTH: usize = 160;
10
11/// Compute log-mel spectrogram matching Python whisper.audio.log_mel_spectrogram.
12///
13/// - `pcm`: audio samples (f32, 16kHz mono)
14/// - `n_mels`: 80 or 128
15/// - `mel_filters`: pre-loaded filter bank, shape [n_mels, N_FFT/2 + 1], row-major
16///
17/// Returns flat Vec<f32> in [n_mels, n_frames] layout (row-major per mel bin).
18pub fn log_mel_spectrogram(pcm: &[f32], n_mels: usize, mel_filters: &[f32]) -> Vec<f32> {
19    let n_fft_half = N_FFT / 2 + 1; // 201
20
21    // Step 1: Reflect-pad (center=True in torch.stft)
22    let padded = reflect_pad(pcm, N_FFT / 2);
23
24    // Step 2: STFT → magnitudes squared
25    let magnitudes = stft_magnitudes_squared(&padded);
26    // magnitudes: [n_fft_half, n_frames_raw]
27    let n_frames_raw = magnitudes.len() / n_fft_half;
28    // Drop last frame (matching Python: stft[..., :-1])
29    let n_frames = n_frames_raw - 1;
30
31    // Step 3: mel_filters[n_mels, n_fft_half] @ magnitudes[n_fft_half, n_frames]
32    let mut mel_spec = vec![0f32; n_mels * n_frames];
33    for m in 0..n_mels {
34        for t in 0..n_frames {
35            let mut sum = 0f32;
36            for f in 0..n_fft_half {
37                sum += mel_filters[m * n_fft_half + f] * magnitudes[f * n_frames_raw + t];
38            }
39            mel_spec[m * n_frames + t] = sum;
40        }
41    }
42
43    // Step 4: log10(clamp(x, 1e-10))
44    for v in &mut mel_spec {
45        *v = v.max(1e-10).log10();
46    }
47
48    // Step 5: max(x, global_max - 8.0)
49    let global_max = mel_spec.iter().copied().fold(f32::NEG_INFINITY, f32::max);
50    let clamp_min = global_max - 8.0;
51    for v in &mut mel_spec {
52        *v = v.max(clamp_min);
53    }
54
55    // Step 6: (x + 4.0) / 4.0
56    for v in &mut mel_spec {
57        *v = (*v + 4.0) / 4.0;
58    }
59
60    mel_spec
61}
62
63/// Reflect-pad signal on both sides (matching torch center=True).
64fn reflect_pad(signal: &[f32], pad: usize) -> Vec<f32> {
65    let n = signal.len();
66    let mut out = Vec::with_capacity(n + 2 * pad);
67    // Left reflect: signal[pad], signal[pad-1], ..., signal[1]
68    for i in (1..=pad).rev() {
69        out.push(signal[i.min(n - 1)]);
70    }
71    out.extend_from_slice(signal);
72    // Right reflect: signal[n-2], signal[n-3], ..., signal[n-1-pad]
73    for i in 1..=pad {
74        out.push(signal[(n - 1).saturating_sub(i)]);
75    }
76    out
77}
78
79/// STFT with Hann window, returning magnitudes squared.
80/// Returns [n_fft_half, n_frames] in column-major (frequency × time).
81fn stft_magnitudes_squared(padded: &[f32]) -> Vec<f32> {
82    let n_fft_half = N_FFT / 2 + 1;
83    let n_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1;
84
85    // Hann window
86    let hann: Vec<f32> = (0..N_FFT)
87        .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / N_FFT as f32).cos()))
88        .collect();
89
90    let mut planner = FftPlanner::<f32>::new();
91    let fft = planner.plan_fft_forward(N_FFT);
92
93    // Output: [n_fft_half, n_frames] column-major
94    let mut magnitudes = vec![0f32; n_fft_half * n_frames];
95
96    let mut buffer = vec![Complex::new(0f32, 0f32); N_FFT];
97
98    for t in 0..n_frames {
99        let offset = t * HOP_LENGTH;
100        // Apply window
101        for i in 0..N_FFT {
102            buffer[i] = Complex::new(padded[offset + i] * hann[i], 0.0);
103        }
104        // FFT
105        fft.process(&mut buffer);
106        // Magnitude squared for first n_fft_half bins
107        for f in 0..n_fft_half {
108            magnitudes[f * n_frames + t] = buffer[f].norm_sqr();
109        }
110    }
111
112    magnitudes
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_reflect_pad() {
121        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
122        let padded = reflect_pad(&signal, 2);
123        // Left: signal[2], signal[1] = 3.0, 2.0
124        // Right: signal[3], signal[2] = 4.0, 3.0
125        assert_eq!(padded, vec![3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0]);
126    }
127
128    #[test]
129    fn test_mel_shape() {
130        // 1 second of silence at 16kHz
131        let pcm = vec![0.0f32; 16000];
132        let filters = vec![0.0f32; 80 * 201];
133        let mel = log_mel_spectrogram(&pcm, 80, &filters);
134        let n_frames = mel.len() / 80;
135        assert_eq!(n_frames, 100); // 16000 / 160 = 100 frames
136    }
137}