use candle_core::{Device, Result, Tensor};
use rustfft::{FftPlanner, num_complex::Complex};
#[derive(Debug, Clone)]
pub struct MelSpectrogramConfig {
pub n_fft: usize,
pub num_mels: usize,
pub sample_rate: usize,
pub hop_size: usize,
pub win_size: usize,
pub fmin: f64,
pub fmax: Option<f64>,
}
impl Default for MelSpectrogramConfig {
fn default() -> Self {
Self {
n_fft: 1024,
num_mels: 128,
sample_rate: 24000,
hop_size: 256,
win_size: 1024,
fmin: 0.0,
fmax: None,
}
}
}
pub fn dynamic_range_compression(x: &Tensor, c: f64, clip_val: f64) -> Result<Tensor> {
let clipped = x.clamp(clip_val, f64::INFINITY)?;
(clipped * c)?.log()
}
fn create_hann_window(size: usize) -> Vec<f32> {
(0..size)
.map(|n| {
let x = 2.0 * std::f32::consts::PI * n as f32 / size as f32;
0.5 * (1.0 - x.cos())
})
.collect()
}
pub fn mel_spectrogram(
audio: &Tensor,
config: &MelSpectrogramConfig,
device: &Device,
) -> Result<Tensor> {
let audio = if audio.dims().len() == 1 {
audio.unsqueeze(0)?
} else {
audio.clone()
};
let (batch_size, _num_samples) = audio.dims2()?;
let mel_filterbank = create_mel_filterbank(
config.n_fft,
config.num_mels,
config.sample_rate,
config.fmin,
config.fmax,
);
let n_freqs = config.n_fft / 2 + 1;
let mel_basis_data: Vec<f32> = mel_filterbank.into_iter().flatten().collect();
let mel_basis = Tensor::from_vec(mel_basis_data, (config.num_mels, n_freqs), device)?;
let hann_window = create_hann_window(config.win_size);
let padding = (config.n_fft - config.hop_size) / 2;
let mut mel_specs = Vec::with_capacity(batch_size);
for b in 0..batch_size {
let sample = audio.get(b)?.to_vec1::<f32>()?;
let padded = reflect_pad(&sample, padding, padding);
let stft_result = stft(
&padded,
config.n_fft,
config.hop_size,
config.win_size,
&hann_window,
);
let n_frames = stft_result.len();
let mut magnitude_data = vec![0.0f32; n_freqs * n_frames];
for (frame_idx, frame) in stft_result.iter().enumerate() {
for (freq_idx, &complex_val) in frame.iter().enumerate() {
let mag = (complex_val.re.powi(2) + complex_val.im.powi(2) + 1e-9).sqrt();
magnitude_data[freq_idx * n_frames + frame_idx] = mag;
}
}
let magnitude = Tensor::from_vec(magnitude_data, (n_freqs, n_frames), device)?;
let mel_spec = mel_basis.matmul(&magnitude)?;
mel_specs.push(mel_spec);
}
let stacked = Tensor::stack(&mel_specs.iter().collect::<Vec<_>>(), 0)?;
dynamic_range_compression(&stacked, 1.0, 1e-5)
}
fn reflect_pad(signal: &[f32], pad_left: usize, pad_right: usize) -> Vec<f32> {
let n = signal.len();
let mut padded = Vec::with_capacity(n + pad_left + pad_right);
for i in (1..=pad_left).rev() {
let idx = if i < n { i } else { n - 1 };
padded.push(signal[idx]);
}
padded.extend_from_slice(signal);
for i in 0..pad_right {
let idx = if n >= 2 + i { n - 2 - i } else { 0 };
padded.push(signal[idx]);
}
padded
}
fn stft(
signal: &[f32],
n_fft: usize,
hop_size: usize,
win_size: usize,
window: &[f32],
) -> Vec<Vec<Complex<f32>>> {
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
let n_freqs = n_fft / 2 + 1;
let num_frames = (signal.len() - n_fft) / hop_size + 1;
let mut result = Vec::with_capacity(num_frames);
for frame_idx in 0..num_frames {
let start = frame_idx * hop_size;
let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
let offset = (n_fft - win_size) / 2;
for i in 0..win_size {
if start + i < signal.len() {
buffer[offset + i] = Complex::new(signal[start + i] * window[i], 0.0);
}
}
fft.process(&mut buffer);
result.push(buffer[..n_freqs].to_vec());
}
result
}
pub fn create_mel_filterbank(
n_fft: usize,
num_mels: usize,
sample_rate: usize,
fmin: f64,
fmax: Option<f64>,
) -> Vec<Vec<f32>> {
let fmax = fmax.unwrap_or(sample_rate as f64 / 2.0);
let n_freqs = n_fft / 2 + 1;
let f_sp = 200.0 / 3.0; let min_log_hz = 1000.0;
let min_log_mel = min_log_hz / f_sp; let logstep = (6.4_f64).ln() / 27.0;
let hz_to_mel = |hz: f64| {
if hz < min_log_hz {
hz / f_sp
} else {
min_log_mel + (hz / min_log_hz).ln() / logstep
}
};
let mel_to_hz = |mel: f64| {
if mel < min_log_mel {
mel * f_sp
} else {
min_log_hz * ((mel - min_log_mel) * logstep).exp()
}
};
let mel_min = hz_to_mel(fmin);
let mel_max = hz_to_mel(fmax);
let mel_points: Vec<f64> = (0..=num_mels + 1)
.map(|i| mel_min + (mel_max - mel_min) * i as f64 / (num_mels + 1) as f64)
.collect();
let hz_points: Vec<f64> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let fft_freqs: Vec<f64> = (0..n_freqs)
.map(|i| i as f64 * sample_rate as f64 / n_fft as f64)
.collect();
let mut filterbank = vec![vec![0.0f32; n_freqs]; num_mels];
for m in 0..num_mels {
let f_left = hz_points[m];
let f_center = hz_points[m + 1];
let f_right = hz_points[m + 2];
let lower_diff = f_center - f_left;
let upper_diff = f_right - f_center;
for (bin_idx, &freq) in fft_freqs.iter().enumerate() {
let lower = if lower_diff > 0.0 {
(freq - f_left) / lower_diff
} else {
0.0
};
let upper = if upper_diff > 0.0 {
(f_right - freq) / upper_diff
} else {
0.0
};
let weight = lower.min(upper).max(0.0);
filterbank[m][bin_idx] = weight as f32;
}
let enorm = 2.0 / (f_right - f_left) as f32;
for val in &mut filterbank[m] {
*val *= enorm;
}
}
filterbank
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hann_window() {
let window = create_hann_window(4);
assert_eq!(window.len(), 4);
assert!((window[0] - 0.0).abs() < 1e-6);
assert!((window[2] - 1.0).abs() < 1e-6);
}
#[test]
fn test_reflect_pad() {
let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let padded = reflect_pad(&signal, 2, 2);
assert_eq!(padded, vec![3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 4.0, 3.0]);
}
#[test]
fn test_mel_filterbank_shape() {
let filterbank = create_mel_filterbank(1024, 128, 24000, 0.0, Some(12000.0));
assert_eq!(filterbank.len(), 128);
assert_eq!(filterbank[0].len(), 513); }
#[test]
fn test_mel_filterbank_values() {
let filterbank = create_mel_filterbank(1024, 128, 24000, 0.0, Some(12000.0));
println!("Rust filterbank[0, :10]: {:?}", &filterbank[0][..10]);
println!("Rust filterbank[1, :10]: {:?}", &filterbank[1][..10]);
println!("Rust filterbank[64, :10]: {:?}", &filterbank[64][..10]);
assert!(filterbank[0][1] > 0.0, "filterbank[0][1] should be > 0");
assert!(filterbank[0][2] > 0.0, "filterbank[0][2] should be > 0");
}
#[test]
fn test_mel_spectrogram_shape() -> Result<()> {
let device = Device::Cpu;
let config = MelSpectrogramConfig {
n_fft: 1024,
num_mels: 128,
sample_rate: 24000,
hop_size: 256,
win_size: 1024,
fmin: 0.0,
fmax: Some(12000.0),
};
let audio_data: Vec<f32> = (0..24000).map(|i| (i as f32 * 0.01).sin()).collect();
let audio = Tensor::from_vec(audio_data, 24000, &device)?;
let mel = mel_spectrogram(&audio, &config, &device)?;
let dims = mel.dims();
assert_eq!(dims[0], 1); assert_eq!(dims[1], 128);
Ok(())
}
}