use kaldi_native_fbank::fbank::{FbankComputer, FbankOptions};
use kaldi_native_fbank::online::{FeatureComputer, OnlineFeature};
const SAMPLE_RATE: f32 = 16000.0;
pub struct MelSpectrogram {
opts: FbankOptions,
n_mels: usize,
}
impl MelSpectrogram {
pub fn new() -> Self {
let opts = phostt_fbank_options();
let probe = FbankComputer::new(opts.clone()).expect("FBANK options valid");
let n_mels = probe.dim();
Self { opts, n_mels }
}
pub fn compute(&self, samples: &[f32]) -> (Vec<f32>, usize) {
if samples.is_empty() {
return (vec![0.0; self.n_mels], 1);
}
let computer = FbankComputer::new(self.opts.clone()).expect("FBANK options valid");
let mut online = OnlineFeature::new(FeatureComputer::Fbank(computer));
online.accept_waveform(SAMPLE_RATE, samples);
online.input_finished();
let num_frames = online.num_frames_ready();
if num_frames == 0 {
return (vec![0.0; self.n_mels], 1);
}
let mut out = vec![0f32; num_frames * self.n_mels];
for f in 0..num_frames {
let frame = online
.get_frame(f)
.expect("frame index < num_frames_ready must be retrievable");
out[f * self.n_mels..(f + 1) * self.n_mels].copy_from_slice(&frame[..self.n_mels]);
}
(out, num_frames)
}
}
pub(crate) fn phostt_fbank_options() -> FbankOptions {
let mut opts = FbankOptions::default();
opts.mel_opts.num_bins = 80;
opts.use_energy = false;
opts.frame_opts.dither = 0.0;
opts
}
pub(crate) fn extract_online_frames(
online: &OnlineFeature,
start_frame: usize,
num_frames: usize,
) -> Vec<f32> {
let mut out = vec![0.0; num_frames * super::N_MELS];
for f in 0..num_frames {
let frame = online
.get_frame(start_frame + f)
.expect("frame index < num_frames_ready");
out[f * super::N_MELS..(f + 1) * super::N_MELS].copy_from_slice(&frame[..super::N_MELS]);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::N_MELS;
#[test]
fn test_default_dim_matches_const() {
let mel = MelSpectrogram::new();
assert_eq!(
mel.n_mels, N_MELS,
"n_mels must agree with the public constant"
);
}
#[test]
fn test_silence_returns_finite_features() {
let mel = MelSpectrogram::new();
let silence = vec![0.0_f32; 16000];
let (features, num_frames) = mel.compute(&silence);
assert!(num_frames > 0, "silence must still produce frames");
assert_eq!(features.len(), N_MELS * num_frames);
for &v in &features {
assert!(v.is_finite(), "expected finite mel value, got {v}");
}
}
#[test]
fn test_too_short_returns_single_zero_frame() {
let mel = MelSpectrogram::new();
let samples = vec![0.0_f32; 100];
let (features, num_frames) = mel.compute(&samples);
assert_eq!(num_frames, 1);
assert_eq!(features.len(), N_MELS);
}
#[test]
fn test_sine_wave_has_dynamic_range() {
let mel = MelSpectrogram::new();
let samples: Vec<f32> = (0..16000)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 16000.0).sin())
.collect();
let (features, num_frames) = mel.compute(&samples);
assert!(num_frames > 0);
let max = features.iter().copied().fold(f32::MIN, f32::max);
let min = features.iter().copied().fold(f32::MAX, f32::min);
assert!(
max - min > 1.0,
"expected wide log-mel range for a 440 Hz tone, got max={max} min={min}"
);
}
#[test]
fn test_one_second_yields_about_one_hundred_frames() {
let mel = MelSpectrogram::new();
let samples = vec![0.0_f32; 16000];
let (_, num_frames) = mel.compute(&samples);
assert!(
(96..=100).contains(&num_frames),
"expected ~98 frames for 1 s of audio, got {num_frames}"
);
}
}