use realfft::{RealFftPlanner, RealToComplex};
use std::sync::Arc;
const SAMPLE_RATE: u32 = 16000;
const FRAME_LENGTH: usize = 400;
pub(crate) const FRAME_SHIFT: usize = 160;
const FFT_SIZE: usize = 512;
const N_BINS: usize = FFT_SIZE / 2 + 1;
const N_MEL: usize = 80;
const LOW_FREQ: f32 = 20.0;
const HIGH_FREQ: f32 = 8000.0;
const PREEMPH_COEFF: f32 = 0.97;
#[inline]
fn mel_scale(freq: f32) -> f32 {
1127.0 * (1.0 + freq / 700.0).ln()
}
#[allow(dead_code)]
#[inline]
fn inverse_mel_scale(mel: f32) -> f32 {
700.0 * ((mel / 1127.0).exp() - 1.0)
}
struct MelFilter {
start_bin: usize,
weights: Vec<f32>,
}
pub(crate) struct FbankExtractor {
fft: Arc<dyn RealToComplex<f32>>,
window: Vec<f32>,
mel_filters: Vec<MelFilter>,
overlap_buffer: Vec<f32>,
first_frame: bool,
fft_input: Vec<f32>,
fft_scratch: Vec<realfft::num_complex::Complex<f32>>,
fft_output: Vec<realfft::num_complex::Complex<f32>>,
power_spectrum: Vec<f32>,
frame_count: usize,
}
impl FbankExtractor {
pub fn new() -> Self {
let mut planner = RealFftPlanner::new();
let fft = planner.plan_fft_forward(FFT_SIZE);
let scratch_len = fft.get_scratch_len();
let window = Self::povey_window();
let mel_filters = Self::compute_mel_filterbank();
Self {
fft,
window,
mel_filters,
overlap_buffer: vec![0.0; FRAME_LENGTH - FRAME_SHIFT], first_frame: true,
fft_input: vec![0.0; FFT_SIZE],
fft_scratch: vec![realfft::num_complex::Complex::new(0.0, 0.0); scratch_len],
fft_output: vec![realfft::num_complex::Complex::new(0.0, 0.0); N_BINS],
power_spectrum: vec![0.0; N_BINS],
frame_count: 0,
}
}
fn povey_window() -> Vec<f32> {
(0..FRAME_LENGTH)
.map(|i| {
let hann = 0.5
- 0.5
* (2.0 * std::f64::consts::PI * i as f64 / (FRAME_LENGTH - 1) as f64).cos();
hann.powf(0.85) as f32
})
.collect()
}
fn compute_mel_filterbank() -> Vec<MelFilter> {
let mel_low = mel_scale(LOW_FREQ);
let mel_high = mel_scale(HIGH_FREQ);
let mel_delta = (mel_high - mel_low) / (N_MEL as f32 + 1.0);
let fft_bin_width = SAMPLE_RATE as f32 / FFT_SIZE as f32;
let mut filters = Vec::with_capacity(N_MEL);
for m in 0..N_MEL {
let left_mel = mel_low + m as f32 * mel_delta;
let center_mel = mel_low + (m + 1) as f32 * mel_delta;
let right_mel = mel_low + (m + 2) as f32 * mel_delta;
let mut start_bin = N_BINS; let mut weights = Vec::new();
for i in 0..N_BINS {
let freq = fft_bin_width * i as f32;
let mel = mel_scale(freq);
if mel > left_mel && mel < right_mel {
let weight = if mel <= center_mel {
(mel - left_mel) / (center_mel - left_mel)
} else {
(right_mel - mel) / (right_mel - center_mel)
};
if start_bin == N_BINS {
start_bin = i;
}
let expected_idx = i - start_bin;
while weights.len() < expected_idx {
weights.push(0.0);
}
weights.push(weight);
}
}
if start_bin == N_BINS {
start_bin = 0;
}
filters.push(MelFilter { start_bin, weights });
}
filters
}
pub fn extract_frame(&mut self, samples: &[f32], output: &mut [f32; N_MEL]) {
debug_assert_eq!(samples.len(), FRAME_SHIFT);
debug_assert!(
!self.first_frame,
"must call extract_frame_full for the first frame"
);
let overlap_len = FRAME_LENGTH - FRAME_SHIFT;
let mut frame = [0.0f32; FRAME_LENGTH];
frame[..overlap_len].copy_from_slice(&self.overlap_buffer);
frame[overlap_len..].copy_from_slice(samples);
self.overlap_buffer.copy_from_slice(&frame[FRAME_SHIFT..]);
self.process_frame(&frame, output);
self.frame_count += 1;
}
pub fn extract_frame_full(
&mut self,
frame_samples: &[f32; FRAME_LENGTH],
output: &mut [f32; N_MEL],
) {
self.overlap_buffer
.copy_from_slice(&frame_samples[FRAME_SHIFT..]);
self.first_frame = false;
self.process_frame(frame_samples, output);
self.frame_count += 1;
}
fn process_frame(&mut self, frame: &[f32], output: &mut [f32; N_MEL]) {
let mut work = [0.0f32; FRAME_LENGTH];
work.copy_from_slice(&frame[..FRAME_LENGTH]);
let mean = work.iter().sum::<f32>() / FRAME_LENGTH as f32;
for s in work.iter_mut() {
*s -= mean;
}
for i in (1..FRAME_LENGTH).rev() {
work[i] -= PREEMPH_COEFF * work[i - 1];
}
work[0] -= PREEMPH_COEFF * work[0];
for (s, &w) in work.iter_mut().zip(self.window.iter()) {
*s *= w;
}
self.fft_input[..FRAME_LENGTH].copy_from_slice(&work);
self.fft_input[FRAME_LENGTH..].fill(0.0);
self.fft
.process_with_scratch(
&mut self.fft_input,
&mut self.fft_output,
&mut self.fft_scratch,
)
.expect("FFT failed");
for (pow, c) in self.power_spectrum.iter_mut().zip(self.fft_output.iter()) {
*pow = c.re * c.re + c.im * c.im;
}
let epsilon = f32::EPSILON; for (m, filter) in self.mel_filters.iter().enumerate() {
let mut energy = 0.0f32;
let spectrum_slice = &self.power_spectrum[filter.start_bin..];
for (w, &p) in filter.weights.iter().zip(spectrum_slice.iter()) {
energy += w * p;
}
output[m] = energy.max(epsilon).ln();
}
}
pub fn reset(&mut self) {
self.overlap_buffer.fill(0.0);
self.first_frame = true;
self.fft_input.fill(0.0);
self.power_spectrum.fill(0.0);
self.frame_count = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn povey_window_shape() {
let window = FbankExtractor::povey_window();
assert_eq!(window.len(), FRAME_LENGTH);
assert!((window[0]).abs() < 1e-10);
assert!((window[FRAME_LENGTH - 1]).abs() < 1e-10);
let mid = window[FRAME_LENGTH / 2];
assert!(mid > 0.9, "window midpoint = {mid}, expected > 0.9");
for i in 0..FRAME_LENGTH / 2 {
let diff = (window[i] - window[FRAME_LENGTH - 1 - i]).abs();
assert!(diff < 1e-6, "asymmetry at {i}: {diff}");
}
}
#[test]
fn mel_filterbank_structure() {
let filters = FbankExtractor::compute_mel_filterbank();
assert_eq!(filters.len(), N_MEL);
for (i, f) in filters.iter().enumerate() {
assert!(!f.weights.is_empty(), "filter {i} has no weights");
}
for i in 1..N_MEL {
assert!(
filters[i].start_bin >= filters[i - 1].start_bin,
"filter {i} start_bin {} < filter {} start_bin {}",
filters[i].start_bin,
i - 1,
filters[i - 1].start_bin
);
}
}
#[test]
fn fbank_matches_python_reference() {
let ref_json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/firered_reference/ref_fbank.json"
));
let ref_data: serde_json::Value = serde_json::from_str(ref_json).unwrap();
let ref_fbank: Vec<Vec<f64>> = serde_json::from_value(ref_data["data"].clone()).unwrap();
let ref_shape: Vec<usize> = serde_json::from_value(ref_data["shape"].clone()).unwrap();
assert_eq!(ref_shape[1], N_MEL);
let samples_json = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../testdata/firered_reference/ref_samples.json"
));
let samples_data: serde_json::Value = serde_json::from_str(samples_json).unwrap();
let samples: Vec<i16> = serde_json::from_value(samples_data["samples"].clone()).unwrap();
let mut extractor = FbankExtractor::new();
let num_frames = (samples.len() - FRAME_LENGTH) / FRAME_SHIFT + 1;
assert_eq!(num_frames, ref_shape[0]);
let mut max_diff: f64 = 0.0;
for frame_idx in 0..num_frames {
let start = frame_idx * FRAME_SHIFT;
let end = start + FRAME_LENGTH;
let frame_samples: Vec<f32> = samples[start..end].iter().map(|&s| s as f32).collect();
let frame_arr: &[f32; FRAME_LENGTH] = frame_samples.as_slice().try_into().unwrap();
let mut output = [0.0f32; N_MEL];
extractor.extract_frame_full(frame_arr, &mut output);
for bin in 0..N_MEL {
let diff = (output[bin] as f64 - ref_fbank[frame_idx][bin]).abs();
if diff > max_diff {
max_diff = diff;
}
}
}
assert!(
max_diff < 2e-3,
"FBank max diff vs Python reference: {max_diff:.8} (tolerance: 2e-3)"
);
eprintln!("FBank max diff vs Python reference: {max_diff:.8}");
}
}