use std::{num::NonZeroUsize, ops::Deref};
use ndarray::Array2;
use non_empty_slice::{NonEmptySlice, NonEmptyVec};
use crate::{SpectrogramError, SpectrogramResult, StftParams, nzu};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MfccParams {
n_mfcc: NonZeroUsize,
include_c0: bool,
lifter: usize,
}
impl Default for MfccParams {
#[inline]
fn default() -> Self {
Self {
n_mfcc: nzu!(13),
include_c0: true,
lifter: 22,
}
}
}
impl MfccParams {
#[inline]
#[must_use]
pub const fn new(n_mfcc: NonZeroUsize) -> Self {
Self {
n_mfcc,
include_c0: true,
lifter: 22,
}
}
#[inline]
#[must_use]
pub const fn speech_standard() -> Self {
Self::new(nzu!(13))
}
#[inline]
#[must_use]
pub const fn with_c0(mut self, include_c0: bool) -> Self {
self.include_c0 = include_c0;
self
}
#[inline]
#[must_use]
pub const fn with_lifter(mut self, lifter: usize) -> Self {
self.lifter = lifter;
self
}
#[inline]
#[must_use]
pub const fn n_mfcc(&self) -> NonZeroUsize {
self.n_mfcc
}
#[inline]
#[must_use]
pub const fn include_c0(&self) -> bool {
self.include_c0
}
#[inline]
#[must_use]
pub const fn lifter(&self) -> usize {
self.lifter
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Mfcc {
pub data: Array2<f64>,
params: MfccParams,
}
impl AsRef<Array2<f64>> for Mfcc {
#[inline]
fn as_ref(&self) -> &Array2<f64> {
&self.data
}
}
impl Deref for Mfcc {
type Target = Array2<f64>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl Mfcc {
#[inline]
#[must_use]
pub fn n_coefficients(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
}
#[inline]
#[must_use]
pub fn n_frames(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
}
#[inline]
#[must_use]
pub const fn params(&self) -> &MfccParams {
&self.params
}
}
#[inline]
pub fn mfcc_from_log_mel(
log_mel_spec: &Array2<f64>,
params: &MfccParams,
) -> SpectrogramResult<Mfcc> {
let n_mels = log_mel_spec.nrows();
let n_frames = log_mel_spec.ncols();
if params.n_mfcc.get() > n_mels {
return Err(SpectrogramError::invalid_input("n_mfcc must be <= n_mels"));
}
let mut mfcc_data = Array2::<f64>::zeros((params.n_mfcc.get(), n_frames));
let mut mel_frame = vec![0.0; n_mels];
for frame_idx in 0..n_frames {
for i in 0..n_mels {
mel_frame[i] = log_mel_spec[[i, frame_idx]];
}
let dct_coeffs = dct_ii(&mel_frame);
for (coeff_idx, &val) in dct_coeffs.iter().enumerate().take(params.n_mfcc.get()) {
mfcc_data[[coeff_idx, frame_idx]] = val;
}
}
if params.lifter > 0 {
apply_liftering(&mut mfcc_data, params.lifter);
}
let final_data = if !params.include_c0 && params.n_mfcc > nzu!(1) {
mfcc_data.slice(ndarray::s![1.., ..]).to_owned()
} else {
mfcc_data
};
Ok(Mfcc {
data: final_data,
params: *params,
})
}
fn dct_ii(input: &[f64]) -> NonEmptyVec<f64> {
let n = input.len();
let mut output = vec![0.0; n];
for (k, sample) in output.iter_mut().enumerate().take(n) {
*sample = input.iter().enumerate().fold(0.0, |acc, (i, &val)| {
val.mul_add(
(std::f64::consts::PI * k as f64 * (i as f64 + 0.5) / n as f64).cos(),
acc,
)
});
}
unsafe { NonEmptyVec::new_unchecked(output) }
}
fn apply_liftering(mfcc: &mut Array2<f64>, lifter: usize) {
let n_mfcc = mfcc.nrows();
let n_frames = mfcc.ncols();
let mut weights = vec![0.0; n_mfcc];
for (i, w) in weights.iter_mut().enumerate().take(n_mfcc) {
*w = (lifter as f64 / 2.0)
.mul_add((std::f64::consts::PI * i as f64 / lifter as f64).sin(), 1.0);
}
for frame_idx in 0..n_frames {
for (coeff_idx, &weight) in weights.iter().enumerate().take(n_mfcc) {
mfcc[[coeff_idx, frame_idx]] *= weight;
}
}
}
#[inline]
pub fn mfcc(
samples: &NonEmptySlice<f64>,
stft_params: &StftParams,
sample_rate: f64,
n_mels: NonZeroUsize,
mfcc_params: &MfccParams,
) -> SpectrogramResult<Mfcc> {
use crate::{LogParams, MelDbSpectrogram, MelParams, SpectrogramParams};
let params = SpectrogramParams::new(stft_params.clone(), sample_rate)?;
let mel = MelParams::new(n_mels, 0.0, sample_rate / 2.0)?;
let db = LogParams::new(-80.0)?;
let log_mel_spec = MelDbSpectrogram::compute(samples, ¶ms, &mel, Some(&db))?;
mfcc_from_log_mel(log_mel_spec.data(), mfcc_params)
}