#![cfg(feature = "onnx")]
use realfft::RealFftPlanner;
const SAMPLE_RATE: f32 = 16000.0;
const N_FFT: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;
const FMIN: f32 = 0.0;
const FMAX: f32 = 8000.0;
pub fn whisper_mel(samples: &[f32]) -> Vec<f32> {
if samples.is_empty() {
return vec![0.0; N_MELS];
}
let pad = N_FFT / 2;
let padded_len = samples.len() + 2 * pad;
let mut padded = vec![0.0f32; padded_len];
for i in 0..pad {
padded[pad - 1 - i] = samples[(i + 1).min(samples.len() - 1)];
}
padded[pad..pad + samples.len()].copy_from_slice(samples);
for i in 0..pad {
let src = samples.len().saturating_sub(2 + i);
padded[pad + samples.len() + i] = samples[src];
}
let n_frames = (padded_len - N_FFT) / HOP_LENGTH + 1;
let window: Vec<f32> = (0..N_FFT)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / N_FFT as f32).cos()))
.collect();
let filters = mel_filterbank(N_MELS, N_FFT, SAMPLE_RATE, FMIN, FMAX);
let mut planner = RealFftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
let n_freqs = N_FFT / 2 + 1;
let mut mel_spec = vec![0.0f32; N_MELS * n_frames];
let mut fft_input = vec![0.0f32; N_FFT];
let mut fft_output = fft.make_output_vec();
for frame_idx in 0..n_frames {
let start = frame_idx * HOP_LENGTH;
for i in 0..N_FFT {
fft_input[i] = padded[start + i] * window[i];
}
fft.process(&mut fft_input, &mut fft_output).ok();
let power: Vec<f32> = fft_output
.iter()
.map(|c| c.re * c.re + c.im * c.im)
.collect();
for mel_bin in 0..N_MELS {
let filter_start = mel_bin * n_freqs;
let mut energy = 0.0f32;
for freq_bin in 0..n_freqs {
energy += filters[filter_start + freq_bin] * power[freq_bin];
}
mel_spec[mel_bin * n_frames + frame_idx] = energy;
}
}
let max_val = mel_spec.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let log_offset = 1e-10_f32;
for val in &mut mel_spec {
*val = (*val).max(log_offset).log10();
}
let log_max = max_val.max(log_offset).log10();
let clamp_min = log_max - 8.0;
for val in &mut mel_spec {
*val = (*val).max(clamp_min);
}
for val in &mut mel_spec {
*val = (*val + 4.0) / 4.0;
}
if n_frames > 1 {
let trimmed_frames = n_frames - 1;
let mut trimmed = vec![0.0f32; N_MELS * trimmed_frames];
for mel_bin in 0..N_MELS {
let src_start = mel_bin * n_frames;
let dst_start = mel_bin * trimmed_frames;
trimmed[dst_start..dst_start + trimmed_frames]
.copy_from_slice(&mel_spec[src_start..src_start + trimmed_frames]);
}
return trimmed;
}
mel_spec
}
pub fn mel_frame_count(num_samples: usize) -> usize {
if num_samples == 0 {
return 1;
}
let padded_len = num_samples + N_FFT; let raw_frames = (padded_len - N_FFT) / HOP_LENGTH + 1;
if raw_frames > 1 {
raw_frames - 1
} else {
raw_frames
}
}
fn mel_filterbank(n_mels: usize, n_fft: usize, sr: f32, fmin: f32, fmax: f32) -> Vec<f32> {
let n_freqs = n_fft / 2 + 1;
let mel_min = hz_to_mel(fmin);
let mel_max = hz_to_mel(fmax);
let mel_points: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_points: Vec<f32> = hz_points.iter().map(|&hz| hz * n_fft as f32 / sr).collect();
let mut filters = vec![0.0f32; n_mels * n_freqs];
for m in 0..n_mels {
let left = bin_points[m];
let center = bin_points[m + 1];
let right = bin_points[m + 2];
for k in 0..n_freqs {
let freq = k as f32;
if freq >= left && freq <= center && center > left {
filters[m * n_freqs + k] = (freq - left) / (center - left);
} else if freq > center && freq <= right && right > center {
filters[m * n_freqs + k] = (right - freq) / (right - center);
}
}
let width = hz_points[m + 2] - hz_points[m];
if width > 0.0 {
let norm = 2.0 / width;
for k in 0..n_freqs {
filters[m * n_freqs + k] *= norm;
}
}
}
filters
}
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10.0_f32.powf(mel / 2595.0) - 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mel_output_shape() {
let samples = vec![0.0f32; 16000]; let mel = whisper_mel(&samples);
let n_frames = mel_frame_count(16000);
assert_eq!(mel.len(), N_MELS * n_frames);
}
#[test]
fn mel_empty_input() {
let mel = whisper_mel(&[]);
assert_eq!(mel.len(), N_MELS);
}
#[test]
fn mel_values_finite() {
let samples: Vec<f32> = (0..16000).map(|i| (i as f32 * 0.01).sin()).collect();
let mel = whisper_mel(&samples);
assert!(
mel.iter().all(|v| v.is_finite()),
"mel contains non-finite values"
);
}
#[test]
fn mel_filterbank_shape() {
let filters = mel_filterbank(N_MELS, N_FFT, SAMPLE_RATE, FMIN, FMAX);
assert_eq!(filters.len(), N_MELS * (N_FFT / 2 + 1));
}
#[test]
fn hz_mel_roundtrip() {
for hz in [0.0, 100.0, 1000.0, 4000.0, 8000.0] {
let rt = mel_to_hz(hz_to_mel(hz));
assert!((rt - hz).abs() < 0.1, "roundtrip failed for {hz}: got {rt}");
}
}
}