use std::num::NonZeroUsize;
use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
use num_complex::Complex;
use crate::{SpectrogramError, SpectrogramResult, nzu};
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ErbParams {
n_filters: NonZeroUsize,
f_min: f64,
f_max: f64,
}
pub type GammatoneParams = ErbParams;
impl ErbParams {
#[inline]
pub fn new(n_filters: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
if n_filters < nzu!(2) {
return Err(SpectrogramError::invalid_input(
"n_filters must be >= 2 (single filter would cause division by zero)",
));
}
if f_min < 0.0 || f_min.is_infinite() {
return Err(SpectrogramError::invalid_input(
"f_min must be finite and >= 0",
));
}
if f_max <= f_min {
return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
}
Ok(Self {
n_filters,
f_min,
f_max,
})
}
pub(crate) const unsafe fn new_unchecked(
n_filters: NonZeroUsize,
f_min: f64,
f_max: f64,
) -> Self {
Self {
n_filters,
f_min,
f_max,
}
}
#[inline]
#[must_use]
pub const fn n_filters(&self) -> NonZeroUsize {
self.n_filters
}
#[inline]
#[must_use]
pub const fn f_min(&self) -> f64 {
self.f_min
}
#[inline]
#[must_use]
pub const fn f_max(&self) -> f64 {
self.f_max
}
#[inline]
#[must_use]
pub const fn speech_standard() -> Self {
unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
}
#[inline]
pub fn music_standard(sample_rate: f64) -> SpectrogramResult<Self> {
Self::new(nzu!(64), 0.0, sample_rate / 2.0)
}
}
#[inline]
#[must_use]
pub const fn hz_to_erb(hz: f64) -> f64 {
24.7 * (4.37 * hz / 1000.0 + 1.0)
}
#[inline]
#[must_use]
pub const fn erb_to_hz(erb: f64) -> f64 {
(erb / 24.7 - 1.0) * 1000.0 / 4.37
}
#[derive(Debug, Clone)]
pub struct ErbFilterbank {
center_freqs: NonEmptyVec<f64>,
response_matrix: NonEmptyVec<NonEmptyVec<f64>>,
}
impl ErbFilterbank {
pub(crate) fn generate(
params: &ErbParams,
sample_rate: f64,
n_fft: NonZeroUsize,
) -> SpectrogramResult<Self> {
if sample_rate <= 0.0 {
return Err(SpectrogramError::invalid_input("sample_rate must be > 0"));
}
let erb_min = hz_to_erb(params.f_min);
let erb_max = hz_to_erb(params.f_max);
let erb_step = (erb_max - erb_min) / (params.n_filters.get() - 1) as f64;
let center_freqs: Vec<f64> = (0..params.n_filters.get())
.map(|i| (i as f64).mul_add(erb_step, erb_min))
.map(erb_to_hz)
.collect::<Vec<f64>>();
let center_freqs = unsafe { NonEmptyVec::new_unchecked(center_freqs) };
let n_bins = n_fft.get() / 2 + 1; let freq_resolution = sample_rate / n_fft.get() as f64;
let mut response_matrix = Vec::with_capacity(params.n_filters.get());
for ¢er_freq in ¢er_freqs {
let erb_bandwidth = 24.7 * (4.37 * center_freq / 1000.0 + 1.0);
let bandwidth = 1.019 * erb_bandwidth;
let mut filter_response = Vec::with_capacity(n_bins);
for bin_idx in 0..n_bins {
let freq = bin_idx as f64 * freq_resolution;
let denom = Complex::new(1.0, (freq - center_freq) / bandwidth);
let denom_squared = denom * denom;
let denom_fourth = denom_squared * denom_squared;
let response_power = 1.0 / denom_fourth.norm_sqr();
filter_response.push(response_power);
}
let filter_response = unsafe { NonEmptyVec::new_unchecked(filter_response) };
response_matrix.push(filter_response);
}
let response_matrix = unsafe { NonEmptyVec::new_unchecked(response_matrix) };
Ok(Self {
center_freqs,
response_matrix,
})
}
#[inline]
#[must_use]
pub fn center_frequencies(&self) -> &NonEmptySlice<f64> {
&self.center_freqs
}
#[inline]
#[must_use]
pub const fn num_filters(&self) -> NonZeroUsize {
self.response_matrix.len()
}
#[inline]
pub fn apply_to_power_spectrum(
&self,
power_spectrum: &NonEmptySlice<f64>,
) -> SpectrogramResult<NonEmptyVec<f64>> {
let n_bins = power_spectrum.len();
let mut output = non_empty_vec![0.0; self.response_matrix.len()];
for (filter_idx, filter_response) in self.response_matrix.iter().enumerate() {
if filter_response.len() != n_bins {
return Err(SpectrogramError::dimension_mismatch(
n_bins.get(),
filter_response.len().get(),
));
}
let mut sum = 0.0;
for (bin_idx, &response_power) in filter_response.iter().enumerate() {
sum += response_power * power_spectrum[bin_idx];
}
output[filter_idx] = sum;
}
Ok(output)
}
}