omni_dev/voice/
features.rs1use std::f32::consts::PI;
20
21use anyhow::{bail, Result};
22use rustfft::num_complex::Complex32;
23use rustfft::FftPlanner;
24
25pub const SAMPLE_RATE: u32 = 16_000;
27
28pub const FRAME_LENGTH_MS: f32 = 25.0;
30
31pub const FRAME_SHIFT_MS: f32 = 10.0;
33
34pub const NUM_MEL_BINS: usize = 80;
36
37pub const FFT_SIZE: usize = 512;
39
40pub const LOW_FREQ_HZ: f32 = 20.0;
42
43pub const HIGH_FREQ_HZ: f32 = (SAMPLE_RATE / 2) as f32;
45
46pub const PREEMPHASIS: f32 = 0.97;
48
49const EPSILON: f32 = 1e-10;
52
53pub fn build_mel_filterbank(
60 num_bins: usize,
61 fft_size: usize,
62 sample_rate: u32,
63) -> Result<Vec<Vec<f32>>> {
64 if num_bins != NUM_MEL_BINS {
65 bail!("wespeaker FBANK requires {NUM_MEL_BINS} mel bins, got {num_bins}");
66 }
67 let mel_low = hz_to_mel(LOW_FREQ_HZ);
68 let mel_high = hz_to_mel(HIGH_FREQ_HZ);
69 let num_points = num_bins + 2;
70 let mel_points: Vec<f32> = (0..num_points)
71 .map(|i| (mel_high - mel_low).mul_add(i as f32 / (num_points - 1) as f32, mel_low))
72 .collect();
73 let hz_points: Vec<f32> = mel_points.iter().copied().map(mel_to_hz).collect();
74
75 let num_bins_fft = fft_size / 2 + 1;
76 let bin_hz: Vec<f32> = (0..num_bins_fft)
77 .map(|k| k as f32 * sample_rate as f32 / fft_size as f32)
78 .collect();
79
80 let mut filters = vec![vec![0f32; num_bins_fft]; num_bins];
81 for m in 0..num_bins {
82 let left = hz_points[m];
83 let centre = hz_points[m + 1];
84 let right = hz_points[m + 2];
85 for (k, &f) in bin_hz.iter().enumerate() {
86 let w = if f < left || f > right {
87 0.0
88 } else if f <= centre {
89 (f - left) / (centre - left).max(EPSILON)
90 } else {
91 (right - f) / (right - centre).max(EPSILON)
92 };
93 filters[m][k] = w.max(0.0);
94 }
95 }
96 Ok(filters)
97}
98
99fn hz_to_mel(hz: f32) -> f32 {
101 1127.0 * (hz / 700.0).ln_1p()
102}
103
104fn mel_to_hz(mel: f32) -> f32 {
106 700.0 * (mel / 1127.0).exp_m1()
107}
108
109fn hamming(n: usize) -> Vec<f32> {
111 (0..n)
112 .map(|i| (-0.46_f32).mul_add((2.0 * PI * i as f32 / (n - 1) as f32).cos(), 0.54))
113 .collect()
114}
115
116pub fn compute_fbank(pcm: &[f32], mel_filters: &[Vec<f32>]) -> Result<Vec<Vec<f32>>> {
122 let frame_length = ((FRAME_LENGTH_MS / 1000.0) * SAMPLE_RATE as f32) as usize;
123 let frame_shift = ((FRAME_SHIFT_MS / 1000.0) * SAMPLE_RATE as f32) as usize;
124 if pcm.len() < frame_length {
125 bail!(
126 "PCM window has {} samples; need at least {} (one 25 ms frame at 16 kHz)",
127 pcm.len(),
128 frame_length
129 );
130 }
131 let num_frames = 1 + (pcm.len() - frame_length) / frame_shift;
132 let window = hamming(frame_length);
133 let mut planner = FftPlanner::<f32>::new();
134 let fft = planner.plan_fft_forward(FFT_SIZE);
135
136 let mut feats: Vec<Vec<f32>> = Vec::with_capacity(num_frames);
137 let mut scratch = vec![Complex32::new(0.0, 0.0); FFT_SIZE];
138 let mut emph = vec![0f32; frame_length];
139
140 for i in 0..num_frames {
141 let start = i * frame_shift;
142 let frame = &pcm[start..start + frame_length];
143
144 emph[0] = (-PREEMPHASIS).mul_add(frame[0], frame[0]);
149 for n in 1..frame_length {
150 emph[n] = (-PREEMPHASIS).mul_add(frame[n - 1], frame[n]);
151 }
152
153 for n in 0..FFT_SIZE {
155 let v = if n < frame_length {
156 emph[n] * window[n]
157 } else {
158 0.0
159 };
160 scratch[n] = Complex32::new(v, 0.0);
161 }
162 fft.process(&mut scratch);
163
164 let num_bins_fft = FFT_SIZE / 2 + 1;
166 let mut mel = vec![0f32; mel_filters.len()];
167 for (m, filter) in mel_filters.iter().enumerate() {
168 let mut energy = 0f32;
169 for (k, &w) in filter.iter().enumerate().take(num_bins_fft) {
170 let c = scratch[k];
171 let power = c.re.mul_add(c.re, c.im * c.im);
172 energy = w.mul_add(power, energy);
173 }
174 mel[m] = (energy + EPSILON).ln();
175 }
176 feats.push(mel);
177 }
178
179 let n = feats.len() as f32;
181 let mut mean = vec![0f32; mel_filters.len()];
182 for frame in &feats {
183 for (m, &v) in frame.iter().enumerate() {
184 mean[m] += v;
185 }
186 }
187 for v in &mut mean {
188 *v /= n;
189 }
190 for frame in &mut feats {
191 for (m, v) in frame.iter_mut().enumerate() {
192 *v -= mean[m];
193 }
194 }
195 Ok(feats)
196}
197
198#[cfg(test)]
199#[allow(clippy::unwrap_used, clippy::expect_used)]
200mod tests {
201 use super::*;
202
203 fn sine(freq_hz: f32, secs: f32) -> Vec<f32> {
205 let n = (secs * SAMPLE_RATE as f32) as usize;
206 (0..n)
207 .map(|i| (2.0 * PI * freq_hz * i as f32 / SAMPLE_RATE as f32).sin())
208 .collect()
209 }
210
211 #[test]
212 fn build_mel_filterbank_returns_80_by_257_matrix() {
213 let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
214 assert_eq!(filters.len(), NUM_MEL_BINS);
215 assert_eq!(filters[0].len(), FFT_SIZE / 2 + 1);
216 }
217
218 #[test]
219 fn build_mel_filterbank_rejects_non_80_bins() {
220 let err = build_mel_filterbank(64, FFT_SIZE, SAMPLE_RATE).unwrap_err();
221 assert!(err.to_string().contains("80 mel bins"), "got: {err}");
222 }
223
224 #[test]
225 fn build_mel_filterbank_filters_are_non_negative() {
226 let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
227 for (m, filter) in filters.iter().enumerate() {
228 for (k, &w) in filter.iter().enumerate() {
229 assert!(w >= 0.0, "filter[{m}][{k}] = {w} is negative");
230 }
231 }
232 }
233
234 #[test]
235 fn compute_fbank_frame_count_matches_formula() {
236 let pcm = sine(1_000.0, 0.5); let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
238 let feats = compute_fbank(&pcm, &filters).unwrap();
239 let frame_length = 400;
240 let frame_shift = 160;
241 let expected = 1 + (pcm.len() - frame_length) / frame_shift;
242 assert_eq!(feats.len(), expected);
243 }
244
245 #[test]
246 fn compute_fbank_emits_80_dim_frames() {
247 let pcm = sine(1_000.0, 0.5);
248 let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
249 let feats = compute_fbank(&pcm, &filters).unwrap();
250 for (i, frame) in feats.iter().enumerate() {
251 assert_eq!(frame.len(), NUM_MEL_BINS, "frame {i}: {}", frame.len());
252 }
253 }
254
255 #[test]
256 fn compute_fbank_errors_on_too_short_pcm() {
257 let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
258 let err = compute_fbank(&vec![0.0; 100], &filters).unwrap_err();
259 assert!(err.to_string().contains("at least"), "got: {err}");
260 }
261
262 #[test]
263 fn mel_filter_centres_are_monotonically_increasing_in_hz() {
264 let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
272 let mut prev_centre_bin = 0usize;
273 for (m, filter) in filters.iter().enumerate() {
274 let centre_bin = filter
275 .iter()
276 .enumerate()
277 .max_by(|(_, a), (_, b)| a.total_cmp(b))
278 .map(|(i, _)| i)
279 .unwrap();
280 assert!(
281 centre_bin >= prev_centre_bin,
282 "mel filter {m} centre bin {centre_bin} < previous {prev_centre_bin}"
283 );
284 prev_centre_bin = centre_bin;
285 }
286 }
287
288 #[test]
289 fn compute_fbank_cmn_zeros_mean_per_bin() {
290 let pcm = sine(1_000.0, 0.5);
295 let filters = build_mel_filterbank(NUM_MEL_BINS, FFT_SIZE, SAMPLE_RATE).unwrap();
296 let feats = compute_fbank(&pcm, &filters).unwrap();
297 let nf = feats.len() as f32;
298 for m in 0..NUM_MEL_BINS {
299 let mean: f32 = feats.iter().map(|f| f[m]).sum::<f32>() / nf;
300 assert!(
301 mean.abs() < 1e-3,
302 "bin {m} post-CMN mean = {mean} (expected ~0)"
303 );
304 }
305 }
306
307 #[test]
308 fn hz_mel_round_trip() {
309 for &hz in &[20.0_f32, 200.0, 1000.0, 4000.0, 8000.0] {
310 let back = mel_to_hz(hz_to_mel(hz));
311 assert!((back - hz).abs() < 1e-2, "{hz} -> {back}");
312 }
313 }
314}