#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
use num_traits::Float as NumFloat;
use crate::kernel::Float;
use super::stft::{power_spectrogram, stft as compute_stft};
use super::window::WindowFunction;
#[derive(Clone, Debug)]
pub struct MelConfig {
pub sample_rate: f64,
pub fft_size: usize,
pub hop_size: usize,
pub n_mels: usize,
pub f_min: f64,
pub f_max: f64,
}
impl MelConfig {
pub fn new(sample_rate: f64, fft_size: usize, hop_size: usize, n_mels: usize) -> Self {
Self {
sample_rate,
fft_size,
hop_size,
n_mels,
f_min: 0.0,
f_max: sample_rate / 2.0,
}
}
}
#[inline]
pub fn hz_to_mel<T: Float>(hz: T) -> T {
let base = T::from_f64(700.0);
let factor = T::from_f64(2595.0);
let ln10 = T::from_f64(10.0_f64.ln());
factor * NumFloat::ln(T::ONE + hz / base) / ln10
}
#[inline]
pub fn mel_to_hz<T: Float>(mel: T) -> T {
let base = T::from_f64(700.0);
let factor = T::from_f64(2595.0);
let ten = T::from_f64(10.0);
base * (NumFloat::powf(ten, mel / factor) - T::ONE)
}
pub fn build_mel_filterbank<T: Float>(config: &MelConfig) -> Vec<Vec<T>> {
let n_freq_bins = config.fft_size / 2 + 1;
let n_mels = config.n_mels;
let mel_min = hz_to_mel::<T>(T::from_f64(config.f_min));
let mel_max = hz_to_mel::<T>(T::from_f64(config.f_max));
let n_points = n_mels + 2;
let mel_points: Vec<T> = (0..n_points)
.map(|i| mel_min + (mel_max - mel_min) * T::from_f64(i as f64 / (n_points - 1) as f64))
.collect();
let hz_points: Vec<T> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let fft_size_plus_one = T::from_f64((config.fft_size + 1) as f64);
let sample_rate = T::from_f64(config.sample_rate);
let bin_indices: Vec<usize> = hz_points
.iter()
.map(|&hz| {
let raw = NumFloat::floor(hz * fft_size_plus_one / sample_rate);
let idx = raw.to_usize().unwrap_or(0);
idx.min(n_freq_bins.saturating_sub(1))
})
.collect();
let mut filterbank = vec![vec![T::ZERO; n_freq_bins]; n_mels];
for m in 0..n_mels {
let left = bin_indices[m];
let center = bin_indices[m + 1];
let right = bin_indices[m + 2];
let width = right.saturating_sub(left);
if width == 0 {
continue;
}
let width_t = T::from_f64(width as f64);
if center > left {
let rise_width = center - left;
for k in left..center.min(n_freq_bins) {
let numerator = T::from_f64((k - left) as f64);
filterbank[m][k] = numerator / width_t;
}
let _ = rise_width; }
if right > center {
for k in center..right.min(n_freq_bins) {
let numerator = T::from_f64((right - k) as f64);
filterbank[m][k] = numerator / width_t;
}
}
}
filterbank
}
pub fn mel_spectrogram<T: Float>(signal: &[T], config: &MelConfig) -> Vec<Vec<T>> {
let n_freq_bins = config.fft_size / 2 + 1;
let filterbank = build_mel_filterbank::<T>(config);
let complex_spectrogram = compute_stft(
signal,
config.fft_size,
config.hop_size,
WindowFunction::Hann,
);
let power_spec = power_spectrogram(&complex_spectrogram);
let epsilon = T::from_f64(1e-10);
power_spec
.iter()
.map(|power_frame| {
let n_bins = power_frame.len().min(n_freq_bins);
(0..config.n_mels)
.map(|m| {
let energy: T = filterbank[m][..n_bins]
.iter()
.zip(power_frame[..n_bins].iter())
.fold(T::ZERO, |acc, (&filt, &pwr)| acc + filt * pwr);
NumFloat::ln(if energy > epsilon { energy } else { epsilon })
})
.collect()
})
.collect()
}
pub fn mfcc<T: Float>(signal: &[T], config: &MelConfig, n_coeffs: usize) -> Vec<Vec<T>> {
let log_mel = mel_spectrogram(signal, config);
let n_mels = config.n_mels;
let n_mels_f = T::from_f64(n_mels as f64);
let pi = <T as Float>::PI;
log_mel
.iter()
.map(|frame| {
(0..n_coeffs)
.map(|k| {
let k_t = T::from_f64(k as f64);
frame.iter().enumerate().fold(T::ZERO, |acc, (m, &val)| {
let m_plus_half = T::from_f64(m as f64) + T::from_f64(0.5);
let angle = pi * k_t * m_plus_half / n_mels_f;
acc + val * NumFloat::cos(angle)
})
})
.collect()
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hz_to_mel_and_back() {
let hz_values = [100.0f64, 440.0, 1000.0, 4000.0, 8000.0];
for &hz in &hz_values {
let mel = hz_to_mel(hz);
let recovered = mel_to_hz(mel);
assert!(
(recovered - hz).abs() < 0.001,
"Round-trip failed for {hz}: got {recovered}"
);
}
}
#[test]
fn test_mel_filterbank_shape() {
let config = MelConfig::new(8000.0, 256, 128, 40);
let fb = build_mel_filterbank::<f64>(&config);
assert_eq!(fb.len(), 40); assert_eq!(fb[0].len(), 256 / 2 + 1); }
#[test]
fn test_mel_filterbank_nonnegative() {
let config = MelConfig::new(8000.0, 256, 128, 40);
let fb = build_mel_filterbank::<f64>(&config);
for row in &fb {
for &v in row {
assert!(v >= 0.0, "Filterbank value should be non-negative: {v}");
}
}
}
#[test]
fn test_mel_spectrogram_shape() {
let sample_rate = 8000.0f64;
let signal: Vec<f64> = (0..8000).map(|i| (f64::from(i) * 0.1).sin()).collect();
let config = MelConfig::new(sample_rate, 256, 128, 40);
let mel_spec = mel_spectrogram(&signal, &config);
assert!(!mel_spec.is_empty(), "Should have at least one frame");
assert_eq!(mel_spec[0].len(), 40);
}
#[test]
fn test_mfcc_shape() {
let sample_rate = 8000.0f64;
let signal: Vec<f64> = (0..8000).map(|i| (f64::from(i) * 0.1).sin()).collect();
let config = MelConfig::new(sample_rate, 256, 128, 40);
let coefficients = mfcc(&signal, &config, 13);
assert!(!coefficients.is_empty());
assert_eq!(coefficients[0].len(), 13);
}
#[test]
fn test_mel_filterbank_sum_of_weights() {
let config = MelConfig::new(16000.0, 512, 256, 40);
let fb = build_mel_filterbank::<f64>(&config);
let nonzero_filters = fb.iter().filter(|row| row.iter().any(|&v| v > 0.0)).count();
assert!(
nonzero_filters >= config.n_mels * 3 / 4,
"Expected most filters to be non-zero, got {nonzero_filters} out of {}",
config.n_mels
);
}
#[test]
fn test_mel_spectrogram_values_finite() {
let sample_rate = 8000.0f64;
let signal: Vec<f64> = (0..4096).map(|i| (f64::from(i) * 0.05).sin()).collect();
let config = MelConfig::new(sample_rate, 256, 128, 40);
let mel_spec = mel_spectrogram(&signal, &config);
for (frame_idx, frame) in mel_spec.iter().enumerate() {
for (mel_idx, &val) in frame.iter().enumerate() {
assert!(
val.is_finite(),
"Non-finite value at frame {frame_idx} mel bin {mel_idx}: {val}"
);
}
}
}
#[test]
fn test_mfcc_values_finite() {
let sample_rate = 8000.0f64;
let signal: Vec<f64> = (0..4096).map(|i| (f64::from(i) * 0.05).sin()).collect();
let config = MelConfig::new(sample_rate, 256, 128, 40);
let coefficients = mfcc(&signal, &config, 13);
for (frame_idx, frame) in coefficients.iter().enumerate() {
for (coeff_idx, &val) in frame.iter().enumerate() {
assert!(
val.is_finite(),
"Non-finite MFCC at frame {frame_idx} coeff {coeff_idx}: {val}"
);
}
}
}
#[test]
fn test_mel_config_defaults() {
let config = MelConfig::new(22050.0, 1024, 512, 80);
assert_eq!(config.f_min, 0.0);
assert!((config.f_max - 11025.0).abs() < 1e-6);
assert_eq!(config.n_mels, 80);
assert_eq!(config.fft_size, 1024);
assert_eq!(config.hop_size, 512);
}
}