use std::f32::consts::PI;
use std::sync::Arc;
use rustfft::{Fft, FftPlanner, num_complex::Complex32};
const NUM_MEL_FILTERS: usize = 26;
const PRE_EMPHASIS: f32 = 0.97;
const LOG_FLOOR: f32 = 1e-10;
pub struct MfccExtractor {
sample_rate: u32,
frame_len: usize,
hop: usize,
num_coeffs: usize,
fft_size: usize,
fft: Arc<dyn Fft<f32>>,
window: Vec<f32>,
filters: Vec<(usize, Vec<f32>)>,
}
impl MfccExtractor {
pub fn new(sample_rate: u32, frame_len: usize, hop: usize, num_coeffs: usize) -> Self {
let fft_size = next_pow2(frame_len);
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(fft_size);
let window = hamming_window(frame_len);
let filters = build_mel_filterbank(sample_rate, fft_size, NUM_MEL_FILTERS);
Self {
sample_rate,
frame_len,
hop,
num_coeffs,
fft_size,
fft,
window,
filters,
}
}
pub fn sample_rate(&self) -> u32 {
self.sample_rate
}
pub fn frame_len(&self) -> usize {
self.frame_len
}
pub fn hop(&self) -> usize {
self.hop
}
pub fn num_coeffs(&self) -> usize {
self.num_coeffs
}
pub fn extract(&self, samples: &[i16]) -> Vec<Vec<f32>> {
if samples.len() < self.frame_len {
return Vec::new();
}
let pre = pre_emphasize(samples);
let n_frames = 1 + (samples.len() - self.frame_len) / self.hop;
let mut out = Vec::with_capacity(n_frames);
let mut buf = vec![Complex32::new(0.0, 0.0); self.fft_size];
let mut power = vec![0.0f32; self.fft_size / 2 + 1];
for f in 0..n_frames {
let start = f * self.hop;
let end = start + self.frame_len;
buf.iter_mut().for_each(|c| {
c.re = 0.0;
c.im = 0.0;
});
for (i, &x) in pre[start..end].iter().enumerate() {
buf[i].re = x * self.window[i];
}
self.fft.process(&mut buf);
for k in 0..power.len() {
let c = buf[k];
power[k] = c.re * c.re + c.im * c.im;
}
let mut log_energies = Vec::with_capacity(self.filters.len());
for (start_bin, weights) in &self.filters {
let mut e = 0.0f32;
for (offset, w) in weights.iter().enumerate() {
let bin = start_bin + offset;
if bin < power.len() {
e += power[bin] * w;
}
}
log_energies.push((e + LOG_FLOOR).ln());
}
let mut coeffs = Vec::with_capacity(self.num_coeffs);
let n_filt = log_energies.len();
for k in 1..=self.num_coeffs {
let mut s = 0.0f32;
for (n, &lg) in log_energies.iter().enumerate() {
s += lg * ((PI * k as f32 * (n as f32 + 0.5)) / n_filt as f32).cos();
}
coeffs.push(s);
}
out.push(coeffs);
}
out
}
}
fn next_pow2(n: usize) -> usize {
let mut p = 1usize;
while p < n {
p <<= 1;
}
p
}
fn pre_emphasize(samples: &[i16]) -> Vec<f32> {
let mut out = Vec::with_capacity(samples.len());
let mut prev = 0i16;
for &s in samples {
let y = s as f32 - PRE_EMPHASIS * prev as f32;
out.push(y);
prev = s;
}
out
}
fn hamming_window(n: usize) -> Vec<f32> {
if n <= 1 {
return vec![1.0; n];
}
let denom = (n - 1) as f32;
(0..n)
.map(|i| 0.54 - 0.46 * (2.0 * PI * i as f32 / denom).cos())
.collect()
}
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10f32.powf(mel / 2595.0) - 1.0)
}
fn build_mel_filterbank(
sample_rate: u32,
fft_size: usize,
num_filters: usize,
) -> Vec<(usize, Vec<f32>)> {
let nyquist = sample_rate as f32 / 2.0;
let low_mel = hz_to_mel(0.0);
let high_mel = hz_to_mel(nyquist);
let edges: Vec<f32> = (0..num_filters + 2)
.map(|i| low_mel + (high_mel - low_mel) * (i as f32) / (num_filters + 1) as f32)
.map(mel_to_hz)
.collect();
let n_bins = fft_size / 2 + 1;
let bin_freq = |bin: usize| (bin as f32) * (sample_rate as f32) / (fft_size as f32);
let mut filters = Vec::with_capacity(num_filters);
for i in 0..num_filters {
let left = edges[i];
let center = edges[i + 1];
let right = edges[i + 2];
let mut found: Option<(usize, Vec<f32>)> = None;
for bin in 0..n_bins {
let f = bin_freq(bin);
let w = if f <= left || f >= right {
0.0
} else if f <= center {
(f - left) / (center - left)
} else {
(right - f) / (right - center)
};
match (&mut found, w > 0.0) {
(None, true) => found = Some((bin, vec![w])),
(Some((_, weights)), true) => weights.push(w),
(Some(_), false) => break, (None, false) => continue,
}
}
let (start_bin, weights) = found.unwrap_or_else(|| {
let bin = ((center * fft_size as f32) / (sample_rate as f32)) as usize;
(bin.min(n_bins.saturating_sub(1)), vec![1.0])
});
filters.push((start_bin, weights));
}
filters
}
#[cfg(test)]
mod tests {
use super::*;
fn sine_16k(samples: usize, freq_hz: f32) -> Vec<i16> {
(0..samples)
.map(|n| {
let t = n as f32 / 16_000.0;
((2.0 * PI * freq_hz * t).sin() * 10_000.0) as i16
})
.collect()
}
#[test]
fn mfcc_extract_silence_yields_low_energy_coefficients() {
let extractor = MfccExtractor::new(16_000, 400, 160, 12);
let silence = vec![0i16; 16_000]; let frames = extractor.extract(&silence);
assert!(!frames.is_empty(), "silence should still yield frames");
for (i, frame) in frames.iter().enumerate() {
assert_eq!(frame.len(), 12);
assert!(
frame[0].abs() < 5.0,
"silence frame {i} 1st coeff = {} should be near zero",
frame[0]
);
}
}
#[test]
fn mfcc_extract_frame_count_matches_hop_arithmetic() {
let extractor = MfccExtractor::new(16_000, 400, 160, 12);
let n_samples = 400 + 160 * 9; let pcm = vec![1i16; n_samples];
let frames = extractor.extract(&pcm);
let expected = 1 + (n_samples - 400) / 160;
assert_eq!(frames.len(), expected, "framing arithmetic mismatched");
let short = vec![1i16; 100];
assert_eq!(extractor.extract(&short).len(), 0);
}
#[test]
fn mfcc_extract_deterministic() {
let extractor = MfccExtractor::new(16_000, 400, 160, 12);
let pcm = sine_16k(8_000, 440.0); let a = extractor.extract(&pcm);
let b = extractor.extract(&pcm);
assert_eq!(a.len(), b.len());
for (i, (fa, fb)) in a.iter().zip(b.iter()).enumerate() {
for (j, (xa, xb)) in fa.iter().zip(fb.iter()).enumerate() {
assert!(
(xa - xb).abs() < 1e-5,
"frame {i} coeff {j} drifted: {xa} vs {xb}"
);
}
}
}
#[test]
fn mfcc_sine_distinguishes_from_silence() {
let extractor = MfccExtractor::new(16_000, 400, 160, 12);
let silence = vec![0i16; 8_000];
let tone = sine_16k(8_000, 1000.0);
let s = extractor.extract(&silence);
let t = extractor.extract(&tone);
assert_eq!(s.len(), t.len());
let mut d = 0.0f32;
for (xs, xt) in s[0].iter().zip(t[0].iter()) {
d += (xs - xt).powi(2);
}
d = d.sqrt();
assert!(
d > 1.0,
"silence vs 1 kHz tone should differ in MFCC; got L2 = {d}"
);
}
}