use smol_str::format_smolstr;
use crate::{
Array,
audio::dsp::{hann_window, mel_filter_bank},
error::{ArithmeticOverflowPayload, Error, InvariantViolationPayload, OutOfRangePayload, Result},
ops::{
arithmetic::{abs, add, divide, log10, maximum, multiply, square},
fft::{FftNorm, rfft},
linalg_basic::matmul,
reduction,
shape::as_strided,
},
};
#[derive(Debug)]
pub struct IncrementalMelSpectrogram {
n_fft: usize,
hop_length: usize,
overlap_size: usize,
window: Array,
filters: Array,
overlap_buffer: Vec<f32>,
is_first_chunk: bool,
running_log_max: f32,
total_frames: usize,
#[cfg(test)]
pub(crate) flush_err_inject_count: usize,
}
impl IncrementalMelSpectrogram {
pub fn new(sample_rate: u32, n_fft: usize, hop_length: usize, n_mels: usize) -> Result<Self> {
if n_fft == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"IncrementalMelSpectrogram::new: n_fft",
"must be > 0",
)));
}
if !n_fft.is_multiple_of(2) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"IncrementalMelSpectrogram::new: n_fft",
"must be even (odd n_fft is unsupported because the one-sided rfft is not invertible)",
format_smolstr!("{n_fft}"),
)));
}
if hop_length == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"IncrementalMelSpectrogram::new: hop_length",
"must be > 0",
)));
}
if hop_length > n_fft {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"IncrementalMelSpectrogram::new: hop_length",
"must be <= n_fft (overlap-save framing requires `n_fft - hop_length >= 0`)",
format_smolstr!("hop_length={hop_length}, n_fft={n_fft}"),
)));
}
let window = hann_window(n_fft)?;
let filters = mel_filter_bank(n_mels, n_fft, sample_rate, 0.0, None)?;
Ok(Self {
n_fft,
hop_length,
overlap_size: n_fft - hop_length,
window,
filters,
overlap_buffer: Vec::new(),
is_first_chunk: true,
running_log_max: f32::NEG_INFINITY,
total_frames: 0,
#[cfg(test)]
flush_err_inject_count: 0,
})
}
pub fn total_frames(&self) -> usize {
self.total_frames
}
#[doc(hidden)]
#[cfg(test)]
pub(crate) fn overlap_buffer_len(&self) -> usize {
self.overlap_buffer.len()
}
pub fn process(&mut self, samples: &[f32]) -> Result<Option<Array>> {
if samples.is_empty() {
return Ok(None);
}
let signal = self.build_signal_for_chunk(samples);
let num_frames = if signal.len() >= self.n_fft {
(signal.len() - self.n_fft) / self.hop_length + 1
} else {
0
};
if num_frames == 0 {
self.overlap_buffer = signal;
return Ok(None);
}
let consumed = (num_frames - 1) * self.hop_length + self.n_fft;
self.overlap_buffer = if consumed < signal.len() {
let start = consumed.saturating_sub(self.overlap_size);
signal[start..].to_vec()
} else {
let tail_start = signal.len().saturating_sub(self.overlap_size);
signal[tail_start..].to_vec()
};
let mel = self.compute_mel(&signal, num_frames)?;
self.total_frames += num_frames;
Ok(Some(mel))
}
pub fn flush(&mut self) -> Result<Option<Array>> {
#[cfg(test)]
if self.flush_err_inject_count > 0 {
self.flush_err_inject_count -= 1;
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"IncrementalMelSpectrogram::flush",
"scripted test injection",
)));
}
if self.overlap_buffer.is_empty() {
return Ok(None);
}
let mut signal: Vec<f32> = self.overlap_buffer.clone();
if signal.len() < self.n_fft {
signal.resize(self.n_fft, 0.0);
}
let pad_size = self.n_fft / 2;
let signal_len = signal.len();
let reflect_len = pad_size.min(signal_len.saturating_sub(1));
if reflect_len > 0 {
let lower = signal_len - 1 - reflect_len;
let upper = signal_len - 1;
let mut suffix: Vec<f32> = signal[lower..upper].iter().copied().rev().collect();
signal.append(&mut suffix);
}
let num_frames = if signal.len() >= self.n_fft {
(signal.len() - self.n_fft) / self.hop_length + 1
} else {
0
};
if num_frames == 0 {
self.overlap_buffer.clear();
return Ok(None);
}
let mel = self.compute_mel(&signal, num_frames)?;
self.overlap_buffer.clear();
self.total_frames += num_frames;
Ok(Some(mel))
}
pub fn reset(&mut self) {
self.overlap_buffer.clear();
self.is_first_chunk = true;
self.running_log_max = f32::NEG_INFINITY;
self.total_frames = 0;
}
fn build_signal_for_chunk(&mut self, samples: &[f32]) -> Vec<f32> {
if self.is_first_chunk {
let pad_size = self.n_fft / 2;
let mut prefix: Vec<f32> = Vec::with_capacity(pad_size);
if samples.len() > 1 {
let reflect_len = pad_size.min(samples.len() - 1);
if reflect_len > 0 {
prefix.extend(samples[1..=reflect_len].iter().rev().copied());
}
}
if prefix.is_empty() {
let fill = samples.first().copied().unwrap_or(0.0);
prefix.resize(pad_size, fill);
} else if prefix.len() < pad_size {
while prefix.len() < pad_size {
let needed = pad_size - prefix.len();
let snapshot: Vec<f32> = prefix.iter().copied().take(needed).collect();
prefix.extend(snapshot);
}
}
let mut signal = prefix;
signal.extend_from_slice(samples);
self.is_first_chunk = false;
signal
} else {
let mut signal = std::mem::take(&mut self.overlap_buffer);
signal.extend_from_slice(samples);
signal
}
}
fn compute_mel(&mut self, signal: &[f32], num_frames: usize) -> Result<Array> {
let n_fft_i32 = i32::try_from(self.n_fft).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"IncrementalMelSpectrogram: n_fft does not fit i32",
"i32",
[("n_fft", self.n_fft as u64)],
))
})?;
let num_frames_i32 = i32::try_from(num_frames).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"IncrementalMelSpectrogram: num_frames does not fit i32",
"i32",
[("num_frames", num_frames as u64)],
))
})?;
let signal_array = Array::from_slice::<f32>(signal, &[signal.len() as i32])?;
let frames_stacked = unsafe {
as_strided(
&signal_array,
&[num_frames_i32, n_fft_i32],
&[self.hop_length as i64, 1],
0,
)?
};
let windowed = multiply(&frames_stacked, &self.window)?;
let fft = rfft(&windowed, n_fft_i32, 1, FftNorm::Backward)?;
let magnitudes = square(&abs(&fft)?)?;
let filters_t = self.filters.transpose()?;
let mut mel = matmul(&magnitudes, &filters_t)?;
let floor = Array::from_slice::<f32>(&[1e-10], &[1i32])?;
mel = maximum(&mel, &floor)?;
mel = log10(&mel)?;
let mut chunk_max_arr = reduction::max(&mel, false)?;
let chunk_max: f32 = chunk_max_arr.item::<f32>()?;
if chunk_max > self.running_log_max {
self.running_log_max = chunk_max;
}
let floor_log = self.running_log_max - 8.0;
let floor_log_arr = Array::from_slice::<f32>(&[floor_log], &[1i32])?;
mel = maximum(&mel, &floor_log_arr)?;
let four = Array::from_slice::<f32>(&[4.0_f32], &[1i32])?;
mel = divide(&add(&mel, &four)?, &four)?;
Ok(mel)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_extractor() -> IncrementalMelSpectrogram {
IncrementalMelSpectrogram::new(16_000, 32, 16, 8).unwrap()
}
#[test]
fn new_rejects_zero_n_fft() {
let err = IncrementalMelSpectrogram::new(16_000, 0, 16, 8).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(ref p)
if p.context().contains("n_fft") && p.requirement().contains("must be > 0")));
}
#[test]
fn new_rejects_odd_n_fft() {
let err = IncrementalMelSpectrogram::new(16_000, 33, 16, 8).unwrap_err();
assert!(matches!(err, Error::OutOfRange(ref p)
if p.context().contains("n_fft") && p.requirement().contains("must be even")));
}
#[test]
fn new_rejects_zero_hop_length() {
let err = IncrementalMelSpectrogram::new(16_000, 32, 0, 8).unwrap_err();
assert!(matches!(err, Error::InvariantViolation(ref p)
if p.context().contains("hop_length") && p.requirement().contains("must be > 0")));
}
#[test]
fn new_rejects_hop_larger_than_n_fft() {
let err = IncrementalMelSpectrogram::new(16_000, 32, 64, 8).unwrap_err();
assert!(matches!(err, Error::OutOfRange(ref p)
if p.context().contains("hop_length") && p.requirement().contains("<= n_fft")));
}
#[test]
fn process_empty_input_returns_none() {
let mut mel = make_extractor();
let out = mel.process(&[]).unwrap();
assert!(out.is_none());
assert_eq!(mel.total_frames(), 0);
}
#[test]
fn process_emits_mel_frames_with_expected_shape() {
let mut mel = make_extractor();
let samples: Vec<f32> = (0..128).map(|i| (i as f32 * 0.01).sin()).collect();
let mut out = mel.process(&samples).unwrap().expect("expected frames");
let shape = out.shape();
assert_eq!(shape.len(), 2, "expected 2-D output, got shape {shape:?}");
assert_eq!(shape[1], 8, "n_mels axis should be 8, got {shape:?}");
assert!(shape[0] > 0, "num_frames axis must be > 0, got {shape:?}");
assert_eq!(mel.total_frames(), shape[0]);
let vals = out.to_vec::<f32>().unwrap();
for v in &vals {
assert!(v.is_finite(), "non-finite mel value: {v}");
}
}
#[test]
fn streaming_then_flush_consumes_all_samples_deterministically() {
let mut mel = make_extractor();
let samples: Vec<f32> = (0..256).map(|i| (i as f32 * 0.005).sin()).collect();
let _ = mel.process(&samples[..128]).unwrap();
let _ = mel.process(&samples[128..]).unwrap();
let after = mel.total_frames();
let flushed = mel.flush().unwrap();
let final_total = mel.total_frames();
if let Some(_arr) = flushed {
assert!(final_total > after, "flush should emit at least 1 frame");
} else {
assert_eq!(final_total, after);
}
}
#[test]
fn reset_clears_state_so_second_session_starts_fresh() {
let mut mel = make_extractor();
let samples: Vec<f32> = (0..128).map(|i| (i as f32 * 0.01).sin()).collect();
let _ = mel.process(&samples).unwrap();
let total_before_reset = mel.total_frames();
assert!(total_before_reset > 0);
mel.reset();
assert_eq!(mel.total_frames(), 0);
let _ = mel.process(&samples).unwrap();
let total_after_reset = mel.total_frames();
assert_eq!(total_after_reset, total_before_reset);
}
}