use crate::sample::{cast, Sample};
use alloc::vec;
use alloc::vec::Vec;
use core::f64::consts::PI;
#[cfg(not(feature = "std"))]
use num_traits::Float;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum MelScale {
#[default]
Slaney,
Htk,
}
const SLANEY_F_SP: f64 = 200.0 / 3.0;
const SLANEY_MIN_LOG_HZ: f64 = 1000.0;
#[must_use]
pub fn hz_to_mel(hz: f64, scale: MelScale) -> f64 {
match scale {
MelScale::Htk => 2595.0 * (1.0 + hz / 700.0).log10(),
MelScale::Slaney => {
let min_log_mel = SLANEY_MIN_LOG_HZ / SLANEY_F_SP;
let logstep = 6.4f64.ln() / 27.0;
if hz >= SLANEY_MIN_LOG_HZ {
min_log_mel + (hz / SLANEY_MIN_LOG_HZ).ln() / logstep
} else {
hz / SLANEY_F_SP
}
}
}
}
#[must_use]
pub fn mel_to_hz(mel: f64, scale: MelScale) -> f64 {
match scale {
MelScale::Htk => 700.0 * (10.0f64.powf(mel / 2595.0) - 1.0),
MelScale::Slaney => {
let min_log_mel = SLANEY_MIN_LOG_HZ / SLANEY_F_SP;
let logstep = 6.4f64.ln() / 27.0;
if mel >= min_log_mel {
SLANEY_MIN_LOG_HZ * (logstep * (mel - min_log_mel)).exp()
} else {
mel * SLANEY_F_SP
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct MelFilterBank<T> {
weights: Vec<T>,
n_mels: usize,
n_freqs: usize,
}
impl<T: Sample> MelFilterBank<T> {
#[must_use]
pub fn new(
n_mels: usize,
n_fft: usize,
sample_rate: f64,
fmin: f64,
fmax: f64,
scale: MelScale,
) -> Self {
let n_freqs = n_fft / 2 + 1;
let mut weights = vec![T::zero(); n_mels * n_freqs];
let fft_freqs: Vec<f64> = (0..n_freqs)
.map(|k| k as f64 * sample_rate / n_fft as f64)
.collect();
let mel_min = hz_to_mel(fmin, scale);
let mel_max = hz_to_mel(fmax, scale);
let hz_points: Vec<f64> = (0..n_mels + 2)
.map(|i| {
let mel = mel_min + (mel_max - mel_min) * i as f64 / (n_mels + 1) as f64;
mel_to_hz(mel, scale)
})
.collect();
for m in 0..n_mels {
let lower_edge = hz_points[m];
let center = hz_points[m + 1];
let upper_edge = hz_points[m + 2];
let lower_width = center - lower_edge;
let upper_width = upper_edge - center;
let enorm = 2.0 / (upper_edge - lower_edge);
for (k, &f) in fft_freqs.iter().enumerate() {
let lower = if lower_width > 0.0 {
(f - lower_edge) / lower_width
} else {
0.0
};
let upper = if upper_width > 0.0 {
(upper_edge - f) / upper_width
} else {
0.0
};
let w = lower.min(upper).max(0.0) * enorm;
weights[m * n_freqs + k] = cast(w);
}
}
Self {
weights,
n_mels,
n_freqs,
}
}
#[must_use]
pub fn n_mels(&self) -> usize {
self.n_mels
}
#[must_use]
pub fn n_freqs(&self) -> usize {
self.n_freqs
}
#[must_use]
pub fn weights(&self) -> &[T] {
&self.weights
}
pub fn transform_into(&self, power: &[T], out: &mut [T]) {
assert_eq!(power.len(), self.n_freqs, "mel input length mismatch");
assert_eq!(out.len(), self.n_mels, "mel output length mismatch");
for (m, slot) in out.iter_mut().enumerate() {
let row = &self.weights[m * self.n_freqs..(m + 1) * self.n_freqs];
let mut acc = T::zero();
for (&w, &p) in row.iter().zip(power) {
acc = acc + w * p;
}
*slot = acc;
}
}
#[must_use]
pub fn transform(&self, power: &[T]) -> Vec<T> {
let mut out = vec![T::zero(); self.n_mels];
self.transform_into(power, &mut out);
out
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DctII<T> {
basis: Vec<T>,
n_in: usize,
n_out: usize,
}
impl<T: Sample> DctII<T> {
#[must_use]
pub fn new(n_in: usize, n_out: usize) -> Self {
let mut basis = vec![T::zero(); n_out * n_in];
let n = n_in as f64;
for k in 0..n_out {
let f = if k == 0 {
(1.0 / (4.0 * n)).sqrt()
} else {
(1.0 / (2.0 * n)).sqrt()
};
for m in 0..n_in {
let v = 2.0 * f * (PI * k as f64 * (2.0 * m as f64 + 1.0) / (2.0 * n)).cos();
basis[k * n_in + m] = cast(v);
}
}
Self { basis, n_in, n_out }
}
#[must_use]
pub fn n_in(&self) -> usize {
self.n_in
}
#[must_use]
pub fn n_out(&self) -> usize {
self.n_out
}
pub fn transform_into(&self, input: &[T], out: &mut [T]) {
assert_eq!(input.len(), self.n_in, "DCT input length mismatch");
assert_eq!(out.len(), self.n_out, "DCT output length mismatch");
for (k, slot) in out.iter_mut().enumerate() {
let row = &self.basis[k * self.n_in..(k + 1) * self.n_in];
let mut acc = T::zero();
for (&b, &x) in row.iter().zip(input) {
acc = acc + b * x;
}
*slot = acc;
}
}
#[must_use]
pub fn transform(&self, input: &[T]) -> Vec<T> {
let mut out = vec![T::zero(); self.n_out];
self.transform_into(input, &mut out);
out
}
}