Skip to main content

omni_dev/voice/
features.rs

1//! Kaldi-style FBANK (log-mel filterbank) feature extraction.
2//!
3//! Produces the 80-dim features that the wespeaker speaker-embedding
4//! ONNX model in [`crate::voice::speaker`] consumes. Parameters match
5//! sherpa-onnx's `kaldi-native-fbank` defaults for that model:
6//!
7//! - 16 kHz mono input
8//! - 25 ms window (400 samples) / 10 ms hop (160 samples)
9//! - Pre-emphasis `0.97`
10//! - Hamming window
11//! - 80 mel bins, low-freq 20 Hz, high-freq 8000 Hz
12//! - Slaney/Kaldi mel scale: `1127 · ln(1 + f/700)`
13//! - Log applied to mel energies (`ln(power + 1e-10)`)
14//! - Cepstral mean normalisation across all frames of the supplied window
15//!
16//! The implementation is pure-Rust DSP; the only non-std dep is
17//! `rustfft` for the 512-point real FFT.
18
19use std::f32::consts::PI;
20
21use anyhow::{bail, Result};
22use rustfft::num_complex::Complex32;
23use rustfft::FftPlanner;
24
25/// Target sample rate: wespeaker is trained on 16 kHz audio.
26pub const SAMPLE_RATE: u32 = 16_000;
27
28/// Frame length in milliseconds.
29pub const FRAME_LENGTH_MS: f32 = 25.0;
30
31/// Frame shift (hop) in milliseconds.
32pub const FRAME_SHIFT_MS: f32 = 10.0;
33
34/// Number of mel bins (also the feature dimension per frame).
35pub const NUM_MEL_BINS: usize = 80;
36
37/// FFT size: smallest power of two ≥ frame length (400 → 512).
38pub const FFT_SIZE: usize = 512;
39
40/// Mel-filter low-frequency cutoff (Hz). Kaldi default.
41pub const LOW_FREQ_HZ: f32 = 20.0;
42
43/// Mel-filter high-frequency cutoff (Hz). Half the sample rate.
44pub const HIGH_FREQ_HZ: f32 = (SAMPLE_RATE / 2) as f32;
45
46/// Pre-emphasis coefficient — boost high frequencies before windowing.
47pub const PREEMPHASIS: f32 = 0.97;
48
49/// Numerical floor added before `ln` to keep log-mel finite when mel
50/// energies are near zero.
51const EPSILON: f32 = 1e-10;
52
53/// Builds the triangular mel filterbank as a dense `[num_bins][fft_size/2 + 1]`
54/// matrix indexed `[mel_bin][fft_bin]`.
55///
56/// The same matrix is reused across many windows by
57/// [`crate::voice::speaker::WespeakerEmbedder`]; construction cost is
58/// paid once per process.
59pub fn build_mel_filterbank(
60    num_bins: usize,
61    fft_size: usize,
62    sample_rate: u32,
63) -> Result<Vec<Vec<f32>>> {
64    if num_bins != NUM_MEL_BINS {
65        bail!("wespeaker FBANK requires {NUM_MEL_BINS} mel bins, got {num_bins}");
66    }
67    let mel_low = hz_to_mel(LOW_FREQ_HZ);
68    let mel_high = hz_to_mel(HIGH_FREQ_HZ);
69    let num_points = num_bins + 2;
70    let mel_points: Vec<f32> = (0..num_points)
71        .map(|i| (mel_high - mel_low).mul_add(i as f32 / (num_points - 1) as f32, mel_low))
72        .collect();
73    let hz_points: Vec<f32> = mel_points.iter().copied().map(mel_to_hz).collect();
74
75    let num_bins_fft = fft_size / 2 + 1;
76    let bin_hz: Vec<f32> = (0..num_bins_fft)
77        .map(|k| k as f32 * sample_rate as f32 / fft_size as f32)
78        .collect();
79
80    let mut filters = vec![vec![0f32; num_bins_fft]; num_bins];
81    for m in 0..num_bins {
82        let left = hz_points[m];
83        let centre = hz_points[m + 1];
84        let right = hz_points[m + 2];
85        for (k, &f) in bin_hz.iter().enumerate() {
86            let w = if f < left || f > right {
87                0.0
88            } else if f <= centre {
89                (f - left) / (centre - left).max(EPSILON)
90            } else {
91                (right - f) / (right - centre).max(EPSILON)
92            };
93            filters[m][k] = w.max(0.0);
94        }
95    }
96    Ok(filters)
97}
98
99/// Hz → mel via the Slaney/Kaldi formula `1127 · ln(1 + f/700)`.
100fn hz_to_mel(hz: f32) -> f32 {
101    1127.0 * (hz / 700.0).ln_1p()
102}
103
104/// Inverse of [`hz_to_mel`].
105fn mel_to_hz(mel: f32) -> f32 {
106    700.0 * (mel / 1127.0).exp_m1()
107}
108
109/// Hamming window of length `n`.
110fn hamming(n: usize) -> Vec<f32> {
111    (0..n)
112        .map(|i| (-0.46_f32).mul_add((2.0 * PI * i as f32 / (n - 1) as f32).cos(), 0.54))
113        .collect()
114}
115
116/// Computes 80-dim Kaldi-style FBANK features for the supplied 16 kHz
117/// mono floating-point PCM window. Returns one feature row per frame,
118/// already cepstral-mean-normalised across the window.
119///
120/// Errors if `pcm` is shorter than one frame (400 samples ≈ 25 ms).
121pub fn compute_fbank(pcm: &[f32], mel_filters: &[Vec<f32>]) -> Result<Vec<Vec<f32>>> {
122    let frame_length = ((FRAME_LENGTH_MS / 1000.0) * SAMPLE_RATE as f32) as usize;
123    let frame_shift = ((FRAME_SHIFT_MS / 1000.0) * SAMPLE_RATE as f32) as usize;
124    if pcm.len() < frame_length {
125        bail!(
126            "PCM window has {} samples; need at least {} (one 25 ms frame at 16 kHz)",
127            pcm.len(),
128            frame_length
129        );
130    }
131    let num_frames = 1 + (pcm.len() - frame_length) / frame_shift;
132    let window = hamming(frame_length);
133    let mut planner = FftPlanner::<f32>::new();
134    let fft = planner.plan_fft_forward(FFT_SIZE);
135
136    let mut feats: Vec<Vec<f32>> = Vec::with_capacity(num_frames);
137    let mut scratch = vec![Complex32::new(0.0, 0.0); FFT_SIZE];
138    let mut emph = vec![0f32; frame_length];
139
140    for i in 0..num_frames {
141        let start = i * frame_shift;
142        let frame = &pcm[start..start + frame_length];
143
144        // Pre-emphasis: x[n] - 0.97 · x[n-1]. The kaldi-native-fbank
145        // default uses x[0] itself for the n=0 history (no buffer of
146        // prior frames carried forward), and that's what we mirror
147        // here.
148        emph[0] = (-PREEMPHASIS).mul_add(frame[0], frame[0]);
149        for n in 1..frame_length {
150            emph[n] = (-PREEMPHASIS).mul_add(frame[n - 1], frame[n]);
151        }
152
153        // Hamming window, zero-pad to FFT_SIZE.
154        for n in 0..FFT_SIZE {
155            let v = if n < frame_length {
156                emph[n] * window[n]
157            } else {
158                0.0
159            };
160            scratch[n] = Complex32::new(v, 0.0);
161        }
162        fft.process(&mut scratch);
163
164        // Power spectrum, mel projection, log.
165        let num_bins_fft = FFT_SIZE / 2 + 1;
166        let mut mel = vec![0f32; mel_filters.len()];
167        for (m, filter) in mel_filters.iter().enumerate() {
168            let mut energy = 0f32;
169            for (k, &w) in filter.iter().enumerate().take(num_bins_fft) {
170                let c = scratch[k];
171                let power = c.re.mul_add(c.re, c.im * c.im);
172                energy = w.mul_add(power, energy);
173            }
174            mel[m] = (energy + EPSILON).ln();
175        }
176        feats.push(mel);
177    }
178
179    // Cepstral mean normalisation across all frames of this window.
180    let n = feats.len() as f32;
181    let mut mean = vec![0f32; mel_filters.len()];
182    for frame in &feats {
183        for (m, &v) in frame.iter().enumerate() {
184            mean[m] += v;
185        }
186    }
187    for v in &mut mean {
188        *v /= n;
189    }
190    for frame in &mut feats {
191        for (m, v) in frame.iter_mut().enumerate() {
192            *v -= mean[m];
193        }
194    }
195    Ok(feats)
196}
197
198#[cfg(test)]
199#[allow(clippy::unwrap_used, clippy::expect_used)]
200mod tests {
201    use super::*;
202
203    /// Generates `secs` seconds of a `freq_hz` sine wave at unit amplitude.
204    fn sine(freq_hz: f32, secs: f32) -> Vec<f32> {
205        let n = (secs * SAMPLE_RATE as f32) as usize;
206        (0..n)
207            .map(|i| (2.0 * PI * freq_hz * i as f32 / SAMPLE_RATE as f32).sin())
208            .collect()
209    }
210
211    #[test]
212    fn build_mel_filterbank_returns_80_by_257_matrix() {
213        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
214        assert_eq!(filters.len(), NUM_MEL_BINS);
215        assert_eq!(filters[0].len(), FFT_SIZE / 2 + 1);
216    }
217
218    #[test]
219    fn build_mel_filterbank_rejects_non_80_bins() {
220        let err = build_mel_filterbank(64, FFT_SIZE, SAMPLE_RATE).unwrap_err();
221        assert!(err.to_string().contains("80 mel bins"), "got: {err}");
222    }
223
224    #[test]
225    fn build_mel_filterbank_filters_are_non_negative() {
226        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
227        for (m, filter) in filters.iter().enumerate() {
228            for (k, &w) in filter.iter().enumerate() {
229                assert!(w >= 0.0, "filter[{m}][{k}] = {w} is negative");
230            }
231        }
232    }
233
234    #[test]
235    fn compute_fbank_frame_count_matches_formula() {
236        let pcm = sine(1_000.0, 0.5); // 0.5 s → 8000 samples
237        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
238        let feats = compute_fbank(&pcm, &filters).unwrap();
239        let frame_length = 400;
240        let frame_shift = 160;
241        let expected = 1 + (pcm.len() - frame_length) / frame_shift;
242        assert_eq!(feats.len(), expected);
243    }
244
245    #[test]
246    fn compute_fbank_emits_80_dim_frames() {
247        let pcm = sine(1_000.0, 0.5);
248        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
249        let feats = compute_fbank(&pcm, &filters).unwrap();
250        for (i, frame) in feats.iter().enumerate() {
251            assert_eq!(frame.len(), NUM_MEL_BINS, "frame {i}: {}", frame.len());
252        }
253    }
254
255    #[test]
256    fn compute_fbank_errors_on_too_short_pcm() {
257        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
258        let err = compute_fbank(&vec![0.0; 100], &filters).unwrap_err();
259        assert!(err.to_string().contains("at least"), "got: {err}");
260    }
261
262    #[test]
263    fn mel_filter_centres_are_monotonically_increasing_in_hz() {
264        // Structural sanity check on the filterbank rather than on the
265        // post-CMN feature output: filter m+1's centre frequency must
266        // be higher than filter m's. (Post-CMN argmax over a stationary
267        // tone is dominated by noise — CMN subtracts the mean-per-bin
268        // so all bins are close to zero, and the integration test in
269        // #805's `voice_enroll_speaker_test` is the real validation
270        // against a trained model.)
271        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
272        let mut prev_centre_bin = 0usize;
273        for (m, filter) in filters.iter().enumerate() {
274            let centre_bin = filter
275                .iter()
276                .enumerate()
277                .max_by(|(_, a), (_, b)| a.total_cmp(b))
278                .map(|(i, _)| i)
279                .unwrap();
280            assert!(
281                centre_bin >= prev_centre_bin,
282                "mel filter {m} centre bin {centre_bin} < previous {prev_centre_bin}"
283            );
284            prev_centre_bin = centre_bin;
285        }
286    }
287
288    #[test]
289    fn compute_fbank_cmn_zeros_mean_per_bin() {
290        // Cepstral mean normalisation subtracts the per-bin mean, so the
291        // post-normalisation mean of each bin should be ≈ 0 across all
292        // frames. Pin this invariant — it's what wespeaker's input
293        // distribution assumes.
294        let pcm = sine(1_000.0, 0.5);
295        let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
296        let feats = compute_fbank(&pcm, &filters).unwrap();
297        let nf = feats.len() as f32;
298        for m in 0..NUM_MEL_BINS {
299            let mean: f32 = feats.iter().map(|f| f[m]).sum::<f32>() / nf;
300            assert!(
301                mean.abs() < 1e-3,
302                "bin {m} post-CMN mean = {mean} (expected ~0)"
303            );
304        }
305    }
306
307    #[test]
308    fn hz_mel_round_trip() {
309        for &hz in &[20.0_f32, 200.0, 1000.0, 4000.0, 8000.0] {
310            let back = mel_to_hz(hz_to_mel(hz));
311            assert!((back - hz).abs() < 1e-2, "{hz} -> {back}");
312        }
313    }
314}