#[cfg(not(feature = "std"))]
use alloc::{vec, vec::Vec};
#[cfg(feature = "std")]
use std::vec;
use core::fmt;
use num_traits::Float;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum MelScale {
Htk,
#[default]
Slaney,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum MelNorm {
None,
#[default]
Slaney,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MelConfig<T: Float> {
pub n_mels: usize,
pub fmin: T,
pub fmax: Option<T>,
pub mel_scale: MelScale,
pub norm: MelNorm,
pub use_power: bool,
}
impl<T: Float> Default for MelConfig<T> {
fn default() -> Self {
MelConfig {
n_mels: 80,
fmin: T::zero(),
fmax: None,
mel_scale: MelScale::default(),
norm: MelNorm::default(),
use_power: true,
}
}
}
pub fn hz_to_mel_htk<T: Float>(hz: T) -> T {
let factor = T::from(2595.0).unwrap();
let divisor = T::from(700.0).unwrap();
factor * (T::one() + hz / divisor).log10()
}
pub fn mel_to_hz_htk<T: Float>(mel: T) -> T {
let factor = T::from(700.0).unwrap();
let divisor = T::from(2595.0).unwrap();
factor * (T::from(10.0).unwrap().powf(mel / divisor) - T::one())
}
pub fn hz_to_mel_slaney<T: Float>(hz: T) -> T {
let f_min = T::zero();
let f_sp = T::from(200.0 / 3.0).unwrap();
let min_log_hz = T::from(1000.0).unwrap();
let min_log_mel = (min_log_hz - f_min) / f_sp;
let logstep = (T::from(6.4).unwrap()).ln() / T::from(27.0).unwrap();
if hz >= min_log_hz {
min_log_mel + ((hz / min_log_hz).ln() / logstep)
} else {
(hz - f_min) / f_sp
}
}
pub fn mel_to_hz_slaney<T: Float>(mel: T) -> T {
let f_min = T::zero();
let f_sp = T::from(200.0 / 3.0).unwrap();
let min_log_hz = T::from(1000.0).unwrap();
let min_log_mel = (min_log_hz - f_min) / f_sp;
let logstep = (T::from(6.4).unwrap()).ln() / T::from(27.0).unwrap();
if mel >= min_log_mel {
min_log_hz * (logstep * (mel - min_log_mel)).exp()
} else {
f_min + f_sp * mel
}
}
pub fn hz_to_mel<T: Float>(hz: T, scale: MelScale) -> T {
match scale {
MelScale::Htk => hz_to_mel_htk(hz),
MelScale::Slaney => hz_to_mel_slaney(hz),
}
}
pub fn mel_to_hz<T: Float>(mel: T, scale: MelScale) -> T {
match scale {
MelScale::Htk => mel_to_hz_htk(mel),
MelScale::Slaney => mel_to_hz_slaney(mel),
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct MelFilterbank<T: Float> {
pub n_mels: usize,
pub n_freqs: usize,
pub sample_rate: T,
pub weights: Vec<Vec<(usize, T)>>,
}
impl<T: Float + fmt::Debug> MelFilterbank<T> {
pub fn new(sample_rate: T, n_fft: usize, config: &MelConfig<T>) -> Self {
let n_freqs = n_fft / 2 + 1;
let fmax = config.fmax.unwrap_or(sample_rate / T::from(2.0).unwrap());
let mel_min = hz_to_mel(config.fmin, config.mel_scale);
let mel_max = hz_to_mel(fmax, config.mel_scale);
let n_mels_plus_2 = config.n_mels + 2;
let mel_step = (mel_max - mel_min) / T::from(n_mels_plus_2 - 1).unwrap();
let mel_points: Vec<T> = (0..n_mels_plus_2)
.map(|i| mel_min + T::from(i).unwrap() * mel_step)
.collect();
let hz_points: Vec<T> = mel_points
.iter()
.map(|&mel| mel_to_hz(mel, config.mel_scale))
.collect();
let fft_bins: Vec<T> = hz_points
.iter()
.map(|&hz| hz * T::from(n_fft).unwrap() / sample_rate)
.collect();
let mut weights = Vec::with_capacity(config.n_mels);
for i in 0..config.n_mels {
let left = fft_bins[i];
let center = fft_bins[i + 1];
let right = fft_bins[i + 2];
let mut filter_weights = Vec::new();
let start_bin = left.floor().to_usize().unwrap_or(0);
let end_bin = (right.ceil().to_usize().unwrap_or(n_freqs)).min(n_freqs);
for bin in start_bin..end_bin {
let freq = T::from(bin).unwrap();
let weight = if freq < center {
if center > left {
(freq - left) / (center - left)
} else {
T::zero()
}
} else {
if right > center {
(right - freq) / (right - center)
} else {
T::zero()
}
};
if weight > T::zero() {
filter_weights.push((bin, weight));
}
}
if config.norm == MelNorm::Slaney {
let enorm = T::from(2.0).unwrap() / (hz_points[i + 2] - hz_points[i]);
for (_, w) in &mut filter_weights {
*w = *w * enorm;
}
}
weights.push(filter_weights);
}
Self {
n_mels: config.n_mels,
n_freqs,
sample_rate,
weights,
}
}
pub fn apply(&self, magnitudes: &[T]) -> Vec<T> {
assert_eq!(
magnitudes.len(),
self.n_freqs,
"Magnitude spectrum length mismatch"
);
let mut mel_mags = vec![T::zero(); self.n_mels];
for (mel_idx, filter) in self.weights.iter().enumerate() {
let mut sum = T::zero();
for &(bin, weight) in filter {
sum = sum + magnitudes[bin] * weight;
}
mel_mags[mel_idx] = sum;
}
mel_mags
}
#[inline]
pub fn apply_power(&self, power: &[T]) -> Vec<T> {
self.apply(power)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct MelSpectrum<T: Float> {
pub num_frames: usize,
pub n_mels: usize,
pub data: Vec<T>,
}
impl<T: Float> MelSpectrum<T> {
pub fn new(num_frames: usize, n_mels: usize) -> Self {
Self {
num_frames,
n_mels,
data: vec![T::zero(); num_frames * n_mels],
}
}
#[inline]
pub fn get(&self, frame: usize, mel_bin: usize) -> T {
self.data[frame * self.n_mels + mel_bin]
}
#[inline]
pub fn set(&mut self, frame: usize, mel_bin: usize, value: T) {
self.data[frame * self.n_mels + mel_bin] = value;
}
pub fn frame(&self, frame: usize) -> &[T] {
let start = frame * self.n_mels;
&self.data[start..start + self.n_mels]
}
pub fn frame_mut(&mut self, frame: usize) -> &mut [T] {
let start = frame * self.n_mels;
&mut self.data[start..start + self.n_mels]
}
pub fn to_db(&self, amin: Option<T>, top_db: Option<T>) -> Self {
let amin = amin.unwrap_or(T::from(1e-10).unwrap());
let top_db = top_db.unwrap_or(T::from(80.0).unwrap());
let log10_factor = T::from(10.0).unwrap();
let mut result = self.clone();
for val in &mut result.data {
let clamped = if *val < amin { amin } else { *val };
*val = log10_factor * clamped.log10();
}
let max_db = result
.data
.iter()
.copied()
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(T::zero());
let threshold = max_db - top_db;
for val in &mut result.data {
if *val < threshold {
*val = threshold;
}
}
result
}
pub fn apply<F>(&mut self, mut f: F)
where
F: FnMut(usize, usize, T) -> T,
{
for frame in 0..self.num_frames {
for mel_bin in 0..self.n_mels {
let val = self.get(frame, mel_bin);
self.set(frame, mel_bin, f(frame, mel_bin, val));
}
}
}
pub fn delta(&self, width: Option<usize>) -> Self {
let width = width.unwrap_or(2);
let mut result = MelSpectrum::new(self.num_frames, self.n_mels);
let denom =
T::from(2).unwrap() * T::from((1..=width).map(|n| n * n).sum::<usize>()).unwrap();
for t in 0..self.num_frames {
for mel_bin in 0..self.n_mels {
let mut delta_val = T::zero();
for n in 1..=width {
let t_plus = (t + n).min(self.num_frames - 1);
let t_minus = t.saturating_sub(n);
let val_plus = self.get(t_plus, mel_bin);
let val_minus = self.get(t_minus, mel_bin);
delta_val = delta_val + T::from(n).unwrap() * (val_plus - val_minus);
}
result.set(t, mel_bin, delta_val / denom);
}
}
result
}
pub fn delta_delta(&self, width: Option<usize>) -> Self {
let delta = self.delta(width);
delta.delta(width)
}
pub fn with_deltas(&self, width: Option<usize>) -> Self {
let delta = self.delta(width);
let delta_delta = delta.delta(width);
let mut result = MelSpectrum::new(self.num_frames, self.n_mels * 3);
for t in 0..self.num_frames {
for mel_bin in 0..self.n_mels {
result.set(t, mel_bin, self.get(t, mel_bin));
result.set(t, self.n_mels + mel_bin, delta.get(t, mel_bin));
result.set(t, self.n_mels * 2 + mel_bin, delta_delta.get(t, mel_bin));
}
}
result
}
}
#[derive(Debug, Clone)]
pub struct BatchMelSpectrogram<T: Float> {
filterbank: MelFilterbank<T>,
use_power: bool,
}
impl<T: Float + fmt::Debug> BatchMelSpectrogram<T> {
pub fn new(sample_rate: T, n_fft: usize, config: &MelConfig<T>) -> Self {
let filterbank = MelFilterbank::new(sample_rate, n_fft, config);
let use_power = config.use_power;
Self {
filterbank,
use_power,
}
}
pub fn process(&self, spectrum: &crate::Spectrum<T>) -> MelSpectrum<T> {
let mut mel_spec = MelSpectrum::new(spectrum.num_frames, self.filterbank.n_mels);
for frame_idx in 0..spectrum.num_frames {
let frame_mags: Vec<T> = if self.use_power {
(0..spectrum.freq_bins)
.map(|bin| {
let re = spectrum.real(frame_idx, bin);
let im = spectrum.imag(frame_idx, bin);
re * re + im * im
})
.collect()
} else {
(0..spectrum.freq_bins)
.map(|bin| spectrum.magnitude(frame_idx, bin))
.collect()
};
let mel_frame = self.filterbank.apply(&frame_mags);
mel_spec.frame_mut(frame_idx).copy_from_slice(&mel_frame);
}
mel_spec
}
pub fn process_db(
&self,
spectrum: &crate::Spectrum<T>,
amin: Option<T>,
top_db: Option<T>,
) -> MelSpectrum<T> {
let mel_spec = self.process(spectrum);
mel_spec.to_db(amin, top_db)
}
pub fn n_mels(&self) -> usize {
self.filterbank.n_mels
}
}
#[derive(Debug, Clone)]
pub struct StreamingMelSpectrogram<T: Float> {
filterbank: MelFilterbank<T>,
use_power: bool,
}
impl<T: Float + fmt::Debug> StreamingMelSpectrogram<T> {
pub fn new(sample_rate: T, n_fft: usize, config: &MelConfig<T>) -> Self {
let filterbank = MelFilterbank::new(sample_rate, n_fft, config);
let use_power = config.use_power;
Self {
filterbank,
use_power,
}
}
pub fn process_frame(&self, frame: &crate::SpectrumFrame<T>) -> Vec<T> {
assert_eq!(
frame.freq_bins, self.filterbank.n_freqs,
"Frequency bins mismatch"
);
let frame_mags: Vec<T> = if self.use_power {
frame
.data
.iter()
.map(|c| c.re * c.re + c.im * c.im)
.collect()
} else {
frame.magnitudes()
};
self.filterbank.apply(&frame_mags)
}
pub fn process_frame_into(&self, frame: &crate::SpectrumFrame<T>, output: &mut [T]) -> usize {
assert_eq!(
frame.freq_bins, self.filterbank.n_freqs,
"Frequency bins mismatch"
);
assert!(
output.len() >= self.filterbank.n_mels,
"Output buffer too small"
);
let frame_mags: Vec<T> = if self.use_power {
frame
.data
.iter()
.map(|c| c.re * c.re + c.im * c.im)
.collect()
} else {
frame.magnitudes()
};
let mel_frame = self.filterbank.apply(&frame_mags);
output[..self.filterbank.n_mels].copy_from_slice(&mel_frame);
self.filterbank.n_mels
}
pub fn n_mels(&self) -> usize {
self.filterbank.n_mels
}
}
pub type MelConfigF32 = MelConfig<f32>;
pub type MelConfigF64 = MelConfig<f64>;
pub type MelFilterbankF32 = MelFilterbank<f32>;
pub type MelFilterbankF64 = MelFilterbank<f64>;
pub type MelSpectrumF32 = MelSpectrum<f32>;
pub type MelSpectrumF64 = MelSpectrum<f64>;
pub type BatchMelSpectrogramF32 = BatchMelSpectrogram<f32>;
pub type BatchMelSpectrogramF64 = BatchMelSpectrogram<f64>;
pub type StreamingMelSpectrogramF32 = StreamingMelSpectrogram<f32>;
pub type StreamingMelSpectrogramF64 = StreamingMelSpectrogram<f64>;