use ndarray::Array2;
use rustfft::{num_complex::Complex, FftPlanner};
use std::f32::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WindowType {
Hamming,
Hann,
}
#[derive(Debug, Clone)]
pub struct MelConfig {
pub sample_rate: u32,
pub num_mels: usize,
pub n_fft: usize,
pub hop_length: usize,
pub window: WindowType,
pub f_min: f32,
pub f_max: Option<f32>,
pub pre_emphasis: Option<f32>,
pub snip_edges: bool,
pub normalize_samples: bool,
}
impl Default for MelConfig {
fn default() -> Self {
Self {
sample_rate: 16000,
num_mels: 80,
n_fft: 400,
hop_length: 160,
window: WindowType::Hamming,
f_min: 20.0,
f_max: None,
pre_emphasis: Some(0.97),
snip_edges: true,
normalize_samples: true,
}
}
}
pub fn compute_mel(samples: &[f32], config: &MelConfig) -> Array2<f32> {
let sr = config.sample_rate as f32;
let f_max = config.f_max.unwrap_or(sr / 2.0);
if config.pre_emphasis.is_some() {
compute_fbank(samples, config, sr, f_max)
} else {
compute_mel_spectrogram(samples, config, sr, f_max)
}
}
fn compute_fbank(samples: &[f32], config: &MelConfig, sr: f32, f_max: f32) -> Array2<f32> {
let frame_length = config.n_fft;
let frame_shift = config.hop_length;
let pre_emphasis_coeff = config.pre_emphasis.unwrap_or(0.97);
let samples: Vec<f32> = if !config.normalize_samples {
samples.iter().map(|&s| s * 32768.0).collect()
} else {
samples.to_vec()
};
let num_frames = if config.snip_edges {
if samples.len() < frame_length {
0
} else {
1 + (samples.len() - frame_length) / frame_shift
}
} else {
(samples.len() + frame_shift - 1) / frame_shift
};
if num_frames == 0 {
return Array2::zeros((0, config.num_mels));
}
let fft_size = frame_length.next_power_of_two();
let num_fft_bins = fft_size / 2 + 1;
let window = make_window(config.window, frame_length);
let mel_banks = mel_filterbank(config.num_mels, fft_size, sr, config.f_min, f_max);
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
let mut features = Array2::zeros((num_frames, config.num_mels));
for i in 0..num_frames {
let start = i * frame_shift;
let mut frame = vec![0.0f32; frame_length];
let copy_len = frame_length.min(samples.len().saturating_sub(start));
frame[..copy_len].copy_from_slice(&samples[start..start + copy_len]);
for j in (1..frame_length).rev() {
frame[j] -= pre_emphasis_coeff * frame[j - 1];
}
frame[0] *= 1.0 - pre_emphasis_coeff;
for j in 0..frame_length {
frame[j] *= window[j];
}
let mut fft_input: Vec<Complex<f32>> =
frame.iter().map(|&x| Complex::new(x, 0.0)).collect();
fft_input.resize(fft_size, Complex::new(0.0, 0.0));
fft.process(&mut fft_input);
let power_spectrum: Vec<f32> = fft_input[..num_fft_bins]
.iter()
.map(|c| c.norm_sqr())
.collect();
for m in 0..config.num_mels {
let mut energy: f32 = mel_banks
.row(m)
.iter()
.zip(power_spectrum.iter())
.map(|(&w, &p)| w * p)
.sum();
if energy < 1.0e-10 {
energy = 1.0e-10;
}
features[[i, m]] = energy.ln();
}
}
features
}
fn compute_mel_spectrogram(
samples: &[f32],
config: &MelConfig,
sr: f32,
f_max: f32,
) -> Array2<f32> {
let n_fft = config.n_fft;
let hop_length = config.hop_length;
if samples.len() < n_fft {
return Array2::zeros((0, config.num_mels));
}
let n_frames = (samples.len() - n_fft) / hop_length + 1;
let freq_bins = n_fft / 2 + 1;
let window = make_window(config.window, n_fft);
let filterbank = mel_filterbank(config.num_mels, n_fft, sr, config.f_min, f_max);
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(n_fft);
let mut power_spec = Array2::<f32>::zeros((freq_bins, n_frames));
for frame_idx in 0..n_frames {
let start = frame_idx * hop_length;
let mut fft_buf: Vec<Complex<f32>> = (0..n_fft)
.map(|i| Complex::new(samples[start + i] * window[i], 0.0))
.collect();
fft.process(&mut fft_buf);
for (bin, val) in fft_buf.iter().enumerate().take(freq_bins) {
power_spec[[bin, frame_idx]] = val.norm_sqr();
}
}
let mel = filterbank.dot(&power_spec);
mel.mapv(|v| v.clamp(1e-9, 1e9).ln()).t().to_owned()
}
fn make_window(window_type: WindowType, length: usize) -> Vec<f32> {
match window_type {
WindowType::Hamming => (0..length)
.map(|i| 0.54 - 0.46 * (2.0 * PI * i as f32 / (length as f32 - 1.0)).cos())
.collect(),
WindowType::Hann => (0..length)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / length as f32).cos()))
.collect(),
}
}
fn mel_filterbank(
num_mels: usize,
fft_size: usize,
sample_rate: f32,
low_freq: f32,
high_freq: f32,
) -> Array2<f32> {
let num_fft_bins = fft_size / 2 + 1;
let mel_low = hz_to_mel(low_freq);
let mel_high = hz_to_mel(high_freq);
let num_points = num_mels + 2;
let mel_points: Vec<f32> = (0..num_points)
.map(|i| mel_low + (mel_high - mel_low) * i as f32 / (num_points - 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_points: Vec<f32> = hz_points
.iter()
.map(|&f| f * fft_size as f32 / sample_rate)
.collect();
let mut banks = Array2::zeros((num_mels, num_fft_bins));
for m in 0..num_mels {
let left = bin_points[m];
let center = bin_points[m + 1];
let right = bin_points[m + 2];
for k in 0..num_fft_bins {
let kf = k as f32;
if kf > left && kf < center {
banks[[m, k]] = (kf - left) / (center - left);
} else if kf >= center && kf < right {
banks[[m, k]] = (right - kf) / (right - center);
}
}
}
banks
}
fn hz_to_mel(hz: f32) -> f32 {
1127.0 * (1.0 + hz / 700.0).ln()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * ((mel / 1127.0).exp() - 1.0)
}