use crate::error::Result;
use mlx_rs::Array;
#[derive(Debug, Clone)]
pub struct MelConfig {
pub sr: i32,
pub n_fft: i32,
pub n_mels: i32,
pub fmin: f32,
pub fmax: Option<f32>,
pub htk: bool,
pub norm: MelNorm,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum MelNorm {
None,
#[default]
Slaney,
}
impl Default for MelConfig {
fn default() -> Self {
Self {
sr: 24000,
n_fft: 1024,
n_mels: 100,
fmin: 0.0,
fmax: None,
htk: false,
norm: MelNorm::Slaney,
}
}
}
pub fn hz_to_mel(freq: f32, htk: bool) -> f32 {
if htk {
2595.0 * (1.0 + freq / 700.0).log10()
} else {
let f_min = 0.0;
let f_sp = 200.0 / 3.0;
let min_log_hz = 1000.0;
let min_log_mel = (min_log_hz - f_min) / f_sp;
let logstep = (6.4f32).ln() / 27.0;
if freq >= min_log_hz {
min_log_mel + (freq / min_log_hz).ln() / logstep
} else {
(freq - f_min) / f_sp
}
}
}
pub fn mel_to_hz(mel: f32, htk: bool) -> f32 {
if htk {
700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
} else {
let f_min = 0.0;
let f_sp = 200.0 / 3.0;
let min_log_hz = 1000.0;
let min_log_mel = (min_log_hz - f_min) / f_sp;
let logstep = (6.4f32).ln() / 27.0;
if mel >= min_log_mel {
min_log_hz * ((mel - min_log_mel) * logstep).exp()
} else {
f_min + f_sp * mel
}
}
}
pub fn mel_filterbank(config: &MelConfig) -> Result<Array> {
let fmax = config.fmax.unwrap_or(config.sr as f32 / 2.0);
let n_freqs = config.n_fft / 2 + 1;
let mel_min = hz_to_mel(config.fmin, config.htk);
let mel_max = hz_to_mel(fmax, config.htk);
let n_mels_plus_2 = config.n_mels + 2;
let mel_points: Vec<f32> = (0..n_mels_plus_2)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels_plus_2 - 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points
.iter()
.map(|&m| mel_to_hz(m, config.htk))
.collect();
let bin_points: Vec<f32> = hz_points
.iter()
.map(|&f| config.n_fft as f32 * f / config.sr as f32)
.collect();
let mut filterbank = vec![0.0f32; (config.n_mels * n_freqs) as usize];
for m in 0..config.n_mels as usize {
let f_left = bin_points[m];
let f_center = bin_points[m + 1];
let f_right = bin_points[m + 2];
for k in 0..n_freqs as usize {
let k_f = k as f32;
let weight = if k_f >= f_left && k_f <= f_center {
(k_f - f_left) / (f_center - f_left + 1e-10)
} else if k_f >= f_center && k_f <= f_right {
(f_right - k_f) / (f_right - f_center + 1e-10)
} else {
0.0
};
filterbank[m * n_freqs as usize + k] = weight;
}
if matches!(config.norm, MelNorm::Slaney) {
let enorm = 2.0 / (hz_points[m + 2] - hz_points[m] + 1e-10);
for k in 0..n_freqs as usize {
filterbank[m * n_freqs as usize + k] *= enorm;
}
}
}
Ok(Array::from_slice(&filterbank, &[config.n_mels, n_freqs]))
}
pub fn mel_spectrogram(
audio: &Array,
config: &MelConfig,
stft_config: &super::StftConfig,
) -> Result<Array> {
let stft_out = super::stft(audio, stft_config)?;
let magnitude = super::stft_magnitude(&stft_out)?;
let mel_fb = mel_filterbank(config)?;
let magnitude = if magnitude.ndim() == 2 {
magnitude
} else {
magnitude
};
if magnitude.ndim() == 2 {
Ok(mel_fb.matmul(&magnitude)?)
} else {
let _batch_size = magnitude.dim(0);
let _frames = magnitude.dim(2);
let mag_t = magnitude.transpose_axes(&[0, 2, 1])?; let mel_fb_t = mel_fb.transpose_axes(&[1, 0])?; let mel_spec = mag_t.matmul(&mel_fb_t)?; Ok(mel_spec.transpose_axes(&[0, 2, 1])?) }
}
pub fn log_mel_spectrogram(mel_spec: &Array, clip_val: Option<f32>) -> Result<Array> {
let clip = Array::from_f32(clip_val.unwrap_or(1e-5));
let clipped = mlx_rs::ops::maximum(mel_spec, &clip)?;
Ok(clipped.log()?)
}
pub fn dynamic_range_compression(
mel_spec: &Array,
c: Option<f32>,
clip_val: Option<f32>,
) -> Result<Array> {
let c = c.unwrap_or(1.0);
let clip = Array::from_f32(clip_val.unwrap_or(1e-5));
let c_arr = Array::from_f32(c);
let clipped = mlx_rs::ops::maximum(mel_spec, &clip)?;
let one = Array::from_f32(1.0);
let numerator = one.add(&clipped.multiply(&c_arr)?)?.log()?;
let denominator = (1.0 + c).ln();
Ok(numerator.divide(Array::from_f32(denominator))?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hz_to_mel_htk() {
let mel = hz_to_mel(1000.0, true);
assert!((mel - 1000.0).abs() < 1.0); }
#[test]
fn test_hz_to_mel_slaney() {
let mel = hz_to_mel(1000.0, false);
assert!(mel > 10.0 && mel < 20.0);
}
#[test]
fn test_mel_to_hz_roundtrip() {
let freq = 2000.0;
let mel = hz_to_mel(freq, false);
let freq_back = mel_to_hz(mel, false);
assert!((freq - freq_back).abs() < 1.0);
}
#[test]
fn test_mel_filterbank_shape() {
let config = MelConfig {
sr: 24000,
n_fft: 1024,
n_mels: 80,
fmin: 0.0,
fmax: Some(12000.0),
htk: false,
norm: MelNorm::Slaney,
};
let fb = mel_filterbank(&config).unwrap();
fb.eval().unwrap();
assert_eq!(fb.shape(), &[80, 513]);
}
#[test]
fn test_mel_filterbank_values() {
let config = MelConfig {
sr: 16000,
n_fft: 512,
n_mels: 40,
fmin: 0.0,
fmax: Some(8000.0),
htk: false,
norm: MelNorm::Slaney,
};
let fb = mel_filterbank(&config).unwrap();
fb.eval().unwrap();
let row_sums = fb.sum_axis(1, None).unwrap();
row_sums.eval().unwrap();
}
}