use std::marker::PhantomData;
use std::num::NonZeroUsize;
use std::ops::{Deref, DerefMut};
use ndarray::{Array1, Array2};
use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
use num_complex::Complex;
#[cfg(feature = "python")]
use pyo3::prelude::*;
use crate::cqt::CqtKernel;
use crate::erb::ErbFilterbank;
use crate::{
CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
min_max_single_pass, nzu,
};
const EPS: f64 = 1e-12;
#[derive(Debug, Clone)]
struct SparseMatrix {
nrows: usize,
ncols: usize,
values: Vec<Vec<f64>>,
indices: Vec<Vec<usize>>,
}
impl SparseMatrix {
fn new(nrows: usize, ncols: usize) -> Self {
Self {
nrows,
ncols,
values: vec![Vec::new(); nrows],
indices: vec![Vec::new(); nrows],
}
}
fn set(&mut self, row: usize, col: usize, value: f64) {
debug_assert!(
row < self.nrows && col < self.ncols,
"SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
row,
col,
self.nrows,
self.ncols
);
if row >= self.nrows || col >= self.ncols {
return;
}
if value.abs() > 1e-10 {
self.values[row].push(value);
self.indices[row].push(col);
}
}
const fn nrows(&self) -> usize {
self.nrows
}
const fn ncols(&self) -> usize {
self.ncols
}
#[inline]
fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
debug_assert_eq!(input.len(), self.ncols);
debug_assert_eq!(out.len(), self.nrows);
for (row_idx, (row_values, row_indices)) in
self.values.iter().zip(&self.indices).enumerate()
{
let mut acc = 0.0;
for (&value, &col_idx) in row_values.iter().zip(row_indices) {
acc += value * input[col_idx];
}
out[row_idx] = acc;
}
}
}
pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
pub type LogMelSpectrogram = MelDbSpectrogram;
pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
use crate::fft_backend::r2c_output_size;
pub struct SpectrogramPlan<FreqScale, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
FreqScale: Copy + Clone + 'static,
{
params: SpectrogramParams,
stft: StftPlan,
mapping: FrequencyMapping<FreqScale>,
scaling: AmplitudeScaling<AmpScale>,
freq_axis: FrequencyAxis<FreqScale>,
workspace: Workspace,
_amp: PhantomData<AmpScale>,
}
impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
FreqScale: Copy + Clone + 'static,
{
#[inline]
#[must_use]
pub const fn params(&self) -> &SpectrogramParams {
&self.params
}
#[inline]
#[must_use]
pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
&self.freq_axis
}
#[inline]
pub fn compute(
&mut self,
samples: &NonEmptySlice<f64>,
) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
let n_frames = self.stft.frame_count(samples.len())?;
let n_bins = self.mapping.output_bins();
let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
self.workspace
.ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
for frame_idx in 0..n_frames.get() {
if self.mapping.kind.needs_unwindowed_frame() {
self.stft
.fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
} else {
self.stft
.compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
}
let Workspace {
spectrum,
mapped,
frame,
..
} = &mut self.workspace;
self.mapping.apply(spectrum, frame, mapped)?;
self.scaling.apply_in_place(mapped)?;
for (row, &val) in mapped.iter().enumerate() {
data[[row, frame_idx]] = val;
}
}
let times = build_time_axis_seconds(&self.params, n_frames);
let axes = Axes::new(self.freq_axis.clone(), times);
Ok(Spectrogram::new(data, axes, self.params.clone()))
}
#[inline]
pub fn compute_frame(
&mut self,
samples: &NonEmptySlice<f64>,
frame_idx: usize,
) -> SpectrogramResult<NonEmptyVec<f64>> {
let n_bins = self.mapping.output_bins();
self.workspace
.ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
if self.mapping.kind.needs_unwindowed_frame() {
self.stft
.fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
} else {
self.stft
.compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
}
let Workspace {
spectrum,
mapped,
frame,
..
} = &mut self.workspace;
self.mapping.apply(spectrum, frame, mapped)?;
self.scaling.apply_in_place(mapped)?;
Ok(mapped.clone())
}
#[inline]
pub fn compute_into(
&mut self,
samples: &NonEmptySlice<f64>,
output: &mut Array2<f64>,
) -> SpectrogramResult<()> {
let n_frames = self.stft.frame_count(samples.len())?;
let n_bins = self.mapping.output_bins();
if output.nrows() != n_bins.get() {
return Err(SpectrogramError::dimension_mismatch(
n_bins.get(),
output.nrows(),
));
}
if output.ncols() != n_frames.get() {
return Err(SpectrogramError::dimension_mismatch(
n_frames.get(),
output.ncols(),
));
}
self.workspace
.ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
for frame_idx in 0..n_frames.get() {
if self.mapping.kind.needs_unwindowed_frame() {
self.stft
.fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
} else {
self.stft
.compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
}
let Workspace {
spectrum,
mapped,
frame,
..
} = &mut self.workspace;
self.mapping.apply(spectrum, frame, mapped)?;
self.scaling.apply_in_place(mapped)?;
for (row, &val) in mapped.iter().enumerate() {
output[[row, frame_idx]] = val;
}
}
Ok(())
}
#[inline]
pub fn output_shape(
&self,
signal_length: NonZeroUsize,
) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
let n_frames = self.stft.frame_count(signal_length)?;
let n_bins = self.mapping.output_bins();
Ok((n_bins, n_frames))
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct StftResult {
pub data: Array2<Complex<f64>>,
pub frequencies: NonEmptyVec<f64>,
pub sample_rate: f64,
pub params: StftParams,
}
impl StftResult {
#[inline]
#[must_use]
pub fn n_bins(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
}
#[inline]
#[must_use]
pub fn n_frames(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
}
#[inline]
#[must_use]
pub fn frequency_resolution(&self) -> f64 {
self.sample_rate / self.params.n_fft().get() as f64
}
#[inline]
#[must_use]
pub fn time_resolution(&self) -> f64 {
self.params.hop_size().get() as f64 / self.sample_rate
}
#[inline]
pub fn norm(&self) -> Array2<f64> {
self.as_ref().mapv(Complex::norm)
}
}
impl AsRef<Array2<Complex<f64>>> for StftResult {
#[inline]
fn as_ref(&self) -> &Array2<Complex<f64>> {
&self.data
}
}
impl AsMut<Array2<Complex<f64>>> for StftResult {
#[inline]
fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
&mut self.data
}
}
impl Deref for StftResult {
type Target = Array2<Complex<f64>>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl DerefMut for StftResult {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
#[derive(Debug, Default)]
#[non_exhaustive]
pub struct SpectrogramPlanner;
impl SpectrogramPlanner {
#[inline]
#[must_use]
pub const fn new() -> Self {
Self
}
#[inline]
pub fn compute_stft(
&self,
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
) -> SpectrogramResult<StftResult> {
let mut plan = StftPlan::new(params)?;
plan.compute(samples, params)
}
#[inline]
pub fn compute_power_spectrum(
&self,
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
window: WindowType,
) -> SpectrogramResult<NonEmptyVec<f64>> {
if samples.len() > n_fft {
return Err(SpectrogramError::invalid_input(format!(
"Input length ({}) exceeds FFT size ({})",
samples.len(),
n_fft
)));
}
let window_samples = make_window(window, n_fft);
let out_len = r2c_output_size(n_fft.get());
#[cfg(feature = "realfft")]
let mut fft = {
let mut planner = crate::RealFftPlanner::new();
let plan = planner.get_or_create(n_fft.get());
crate::RealFftPlan::new(n_fft.get(), plan)
};
#[cfg(feature = "fftw")]
let mut fft = {
use std::sync::Arc;
let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
crate::FftwPlan::new(Arc::new(plan))
};
let mut windowed = vec![0.0; n_fft.get()];
for i in 0..samples.len().get() {
windowed[i] = samples[i] * window_samples[i];
}
let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
fft.process(&windowed, &mut fft_out)?;
let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
Ok(unsafe { NonEmptyVec::new_unchecked(power) })
}
#[inline]
pub fn compute_magnitude_spectrum(
&self,
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
window: WindowType,
) -> SpectrogramResult<NonEmptyVec<f64>> {
let power = self.compute_power_spectrum(samples, n_fft, window)?;
let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
Ok(unsafe { NonEmptyVec::new_unchecked(power) })
}
#[inline]
pub fn linear_plan<AmpScale>(
&self,
params: &SpectrogramParams,
db: Option<&LogParams>, ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
where
AmpScale: AmpScaleSpec + 'static,
{
let stft = StftPlan::new(params)?;
let mapping = FrequencyMapping::<LinearHz>::new(params)?;
let scaling = AmplitudeScaling::<AmpScale>::new(db);
let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
Ok(SpectrogramPlan {
params: params.clone(),
stft,
mapping,
scaling,
freq_axis,
workspace,
_amp: PhantomData,
})
}
#[inline]
pub fn mel_plan<AmpScale>(
&self,
params: &SpectrogramParams,
mel: &MelParams,
db: Option<&LogParams>, ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
where
AmpScale: AmpScaleSpec + 'static,
{
let nyquist = params.nyquist_hz();
if mel.f_max() > nyquist {
return Err(SpectrogramError::invalid_input(
"mel f_max must be <= Nyquist",
));
}
let stft = StftPlan::new(params)?;
let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
let scaling = AmplitudeScaling::<AmpScale>::new(db);
let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
Ok(SpectrogramPlan {
params: params.clone(),
stft,
mapping,
scaling,
freq_axis,
workspace,
_amp: PhantomData,
})
}
#[inline]
pub fn erb_plan<AmpScale>(
&self,
params: &SpectrogramParams,
erb: &ErbParams,
db: Option<&LogParams>,
) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
where
AmpScale: AmpScaleSpec + 'static,
{
let nyquist = params.nyquist_hz();
if erb.f_max() > nyquist {
return Err(SpectrogramError::invalid_input(format!(
"f_max={} exceeds Nyquist={}",
erb.f_max(),
nyquist
)));
}
let stft = StftPlan::new(params)?;
let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
let scaling = AmplitudeScaling::<AmpScale>::new(db);
let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
Ok(SpectrogramPlan {
params: params.clone(),
stft,
mapping,
scaling,
freq_axis,
workspace,
_amp: PhantomData,
})
}
#[inline]
pub fn log_hz_plan<AmpScale>(
&self,
params: &SpectrogramParams,
loghz: &LogHzParams,
db: Option<&LogParams>,
) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
where
AmpScale: AmpScaleSpec + 'static,
{
let nyquist = params.nyquist_hz();
if loghz.f_max() > nyquist {
return Err(SpectrogramError::invalid_input(format!(
"f_max={} exceeds Nyquist={}",
loghz.f_max(),
nyquist
)));
}
let stft = StftPlan::new(params)?;
let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
let scaling = AmplitudeScaling::<AmpScale>::new(db);
let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
Ok(SpectrogramPlan {
params: params.clone(),
stft,
mapping,
scaling,
freq_axis,
workspace,
_amp: PhantomData,
})
}
#[inline]
pub fn cqt_plan<AmpScale>(
&self,
params: &SpectrogramParams,
cqt: &CqtParams,
db: Option<&LogParams>, ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
where
AmpScale: AmpScaleSpec + 'static,
{
let stft = StftPlan::new(params)?;
let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
let scaling = AmplitudeScaling::<AmpScale>::new(db);
let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
Ok(SpectrogramPlan {
params: params.clone(),
stft,
mapping,
scaling,
freq_axis,
workspace,
_amp: PhantomData,
})
}
}
pub struct StftPlan {
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: NonEmptyVec<f64>,
centre: bool,
out_len: NonZeroUsize,
fft: Box<dyn R2cPlan>,
fft_out: NonEmptyVec<Complex<f64>>,
frame: NonEmptyVec<f64>,
}
impl StftPlan {
#[inline]
pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
let stft = params.stft();
let n_fft = stft.n_fft();
let hop_size = stft.hop_size();
let centre = stft.centre();
let window = make_window(stft.window(), n_fft);
let out_len = r2c_output_size(n_fft.get());
let out_len = NonZeroUsize::new(out_len)
.ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
#[cfg(feature = "realfft")]
let fft = {
let mut planner = crate::RealFftPlanner::new();
let plan = planner.get_or_create(n_fft.get());
let plan = crate::RealFftPlan::new(n_fft.get(), plan);
Box::new(plan)
};
#[cfg(feature = "fftw")]
let fft = {
use std::sync::Arc;
let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
Box::new(crate::FftwPlan::new(Arc::new(plan)))
};
Ok(Self {
n_fft,
hop_size,
window,
centre,
out_len,
fft,
fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
frame: non_empty_vec![0.0; n_fft],
})
}
fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
let padded_len = n_samples.get() + 2 * pad;
if padded_len < self.n_fft.get() {
return Ok(nzu!(1));
}
let remaining = padded_len - self.n_fft.get();
let n_frames = remaining / self.hop_size().get() + 1;
let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
SpectrogramError::invalid_input("computed number of frames must be non-zero")
})?;
Ok(n_frames)
}
fn compute_frame_fft_simple(
&mut self,
samples: &NonEmptySlice<f64>,
frame_idx: usize,
) -> SpectrogramResult<()> {
let out = self.frame.as_mut_slice();
debug_assert_eq!(out.len(), self.n_fft.get());
let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
let start = frame_idx
.checked_mul(self.hop_size.get())
.ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
let v_idx = start + i;
let s_idx = v_idx as isize - pad as isize;
let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
0.0
} else {
samples[s_idx as usize]
};
*sample = sample_val * self.window[i];
}
let fft_out = self.fft_out.as_mut_slice();
self.fft.process(out, fft_out)?;
Ok(())
}
fn compute_frame_spectrum(
&mut self,
samples: &NonEmptySlice<f64>,
frame_idx: usize,
workspace: &mut Workspace,
) -> SpectrogramResult<()> {
let out = workspace.frame.as_mut_slice();
debug_assert_eq!(out.len(), self.n_fft.get());
let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
let start = frame_idx
.checked_mul(self.hop_size().get())
.ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
let v_idx = start + i;
let s_idx = v_idx as isize - pad as isize;
let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
0.0
} else {
samples[s_idx as usize]
};
*sample = sample_val * self.window[i];
}
let fft_out = workspace.fft_out.as_mut_slice();
self.fft.process(out, fft_out)?;
for (i, c) in workspace.fft_out.iter().enumerate() {
workspace.spectrum[i] = c.norm_sqr();
}
Ok(())
}
fn fill_frame_unwindowed(
&self,
samples: &NonEmptySlice<f64>,
frame_idx: usize,
workspace: &mut Workspace,
) -> SpectrogramResult<()> {
let out = workspace.frame.as_mut_slice();
debug_assert_eq!(out.len(), self.n_fft.get());
let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
let start = frame_idx
.checked_mul(self.hop_size().get())
.ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
let v_idx = start + i;
let s_idx = v_idx as isize - pad as isize;
let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
0.0
} else {
samples[s_idx as usize]
};
*sample = sample_val;
}
Ok(())
}
#[inline]
pub fn compute(
&mut self,
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
) -> SpectrogramResult<StftResult> {
let n_frames = self.frame_count(samples.len())?;
let n_bins = self.out_len;
let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
for frame_idx in 0..n_frames.get() {
self.compute_frame_fft_simple(samples, frame_idx)?;
for (bin_idx, &value) in self.fft_out.iter().enumerate() {
data[[bin_idx, frame_idx]] = value;
}
}
let frequencies: Vec<f64> = (0..n_bins.get())
.map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
.collect();
let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
Ok(StftResult {
data,
frequencies,
sample_rate: params.sample_rate_hz(),
params: params.stft().clone(),
})
}
#[inline]
pub fn compute_frame_simple(
&mut self,
samples: &NonEmptySlice<f64>,
frame_idx: usize,
) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
self.compute_frame_fft_simple(samples, frame_idx)?;
Ok(self.fft_out.clone())
}
#[inline]
pub fn compute_into(
&mut self,
samples: &NonEmptySlice<f64>,
output: &mut Array2<Complex<f64>>,
) -> SpectrogramResult<()> {
let n_frames = self.frame_count(samples.len())?;
let n_bins = self.out_len;
if output.nrows() != n_bins.get() {
return Err(SpectrogramError::dimension_mismatch(
n_bins.get(),
output.nrows(),
));
}
if output.ncols() != n_frames.get() {
return Err(SpectrogramError::dimension_mismatch(
n_frames.get(),
output.ncols(),
));
}
for frame_idx in 0..n_frames.get() {
self.compute_frame_fft_simple(samples, frame_idx)?;
for (bin_idx, &value) in self.fft_out.iter().enumerate() {
output[[bin_idx, frame_idx]] = value;
}
}
Ok(())
}
#[inline]
pub fn output_shape(
&self,
signal_length: NonZeroUsize,
) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
let n_frames = self.frame_count(signal_length)?;
Ok((self.out_len, n_frames))
}
#[inline]
#[must_use]
pub const fn n_bins(&self) -> NonZeroUsize {
self.out_len
}
#[inline]
#[must_use]
pub const fn n_fft(&self) -> NonZeroUsize {
self.n_fft
}
#[inline]
#[must_use]
pub const fn hop_size(&self) -> NonZeroUsize {
self.hop_size
}
}
#[derive(Debug, Clone)]
enum MappingKind {
Identity {
out_len: NonZeroUsize,
},
Mel {
matrix: SparseMatrix,
}, LogHz {
matrix: SparseMatrix,
frequencies: NonEmptyVec<f64>,
}, Erb {
filterbank: ErbFilterbank,
},
Cqt {
kernel: CqtKernel,
},
}
impl MappingKind {
const fn needs_unwindowed_frame(&self) -> bool {
matches!(self, Self::Cqt { .. })
}
}
#[derive(Debug, Clone)]
struct FrequencyMapping<FreqScale> {
kind: MappingKind,
_marker: PhantomData<FreqScale>,
}
impl FrequencyMapping<LinearHz> {
fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
let out_len = r2c_output_size(params.stft().n_fft().get());
let out_len = NonZeroUsize::new(out_len)
.ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
Ok(Self {
kind: MappingKind::Identity { out_len },
_marker: PhantomData,
})
}
}
impl FrequencyMapping<Mel> {
fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
let n_fft = params.stft().n_fft();
let out_len = r2c_output_size(n_fft.get());
let out_len = NonZeroUsize::new(out_len)
.ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
if mel.n_mels() > nzu!(10_000) {
return Err(SpectrogramError::invalid_input(
"n_mels is unreasonably large",
));
}
let matrix = build_mel_filterbank_matrix(
params.sample_rate_hz(),
n_fft,
mel.n_mels(),
mel.f_min(),
mel.f_max(),
mel.norm(),
)?;
if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
return Err(SpectrogramError::invalid_input(
"mel filterbank matrix shape mismatch",
));
}
Ok(Self {
kind: MappingKind::Mel { matrix },
_marker: PhantomData,
})
}
}
impl FrequencyMapping<LogHz> {
fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
let n_fft = params.stft().n_fft();
let out_len = r2c_output_size(n_fft.get());
let out_len = NonZeroUsize::new(out_len)
.ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
if loghz.n_bins() > nzu!(10_000) {
return Err(SpectrogramError::invalid_input(
"n_bins is unreasonably large",
));
}
let (matrix, frequencies) = build_loghz_matrix(
params.sample_rate_hz(),
n_fft,
loghz.n_bins(),
loghz.f_min(),
loghz.f_max(),
)?;
if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
return Err(SpectrogramError::invalid_input(
"loghz matrix shape mismatch",
));
}
Ok(Self {
kind: MappingKind::LogHz {
matrix,
frequencies,
},
_marker: PhantomData,
})
}
}
impl FrequencyMapping<Erb> {
fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
let n_fft = params.stft().n_fft();
let sample_rate = params.sample_rate_hz();
if erb.n_filters() > nzu!(10_000) {
return Err(SpectrogramError::invalid_input(
"n_filters is unreasonably large",
));
}
let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
Ok(Self {
kind: MappingKind::Erb { filterbank },
_marker: PhantomData,
})
}
}
impl FrequencyMapping<Cqt> {
fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
let sample_rate = params.sample_rate_hz();
let n_fft = params.stft().n_fft();
let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
if f_max >= sample_rate / 2.0 {
return Err(SpectrogramError::invalid_input(
"CQT maximum frequency must be below Nyquist frequency",
));
}
let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
Ok(Self {
kind: MappingKind::Cqt { kernel },
_marker: PhantomData,
})
}
}
impl<FreqScale> FrequencyMapping<FreqScale> {
const fn output_bins(&self) -> NonZeroUsize {
match &self.kind {
MappingKind::Identity { out_len } => *out_len,
MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
NonZeroUsize::new_unchecked(matrix.nrows())
},
MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
}
}
fn apply(
&self,
spectrum: &NonEmptySlice<f64>,
frame: &NonEmptySlice<f64>,
out: &mut NonEmptySlice<f64>,
) -> SpectrogramResult<()> {
match &self.kind {
MappingKind::Identity { out_len } => {
if spectrum.len() != *out_len {
return Err(SpectrogramError::dimension_mismatch(
(*out_len).get(),
spectrum.len().get(),
));
}
if out.len() != *out_len {
return Err(SpectrogramError::dimension_mismatch(
(*out_len).get(),
out.len().get(),
));
}
out.copy_from_slice(spectrum);
Ok(())
}
MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
let out_bins = matrix.nrows();
let in_bins = matrix.ncols();
if spectrum.len().get() != in_bins {
return Err(SpectrogramError::dimension_mismatch(
in_bins,
spectrum.len().get(),
));
}
if out.len().get() != out_bins {
return Err(SpectrogramError::dimension_mismatch(
out_bins,
out.len().get(),
));
}
matrix.multiply_vec(spectrum, out);
Ok(())
}
MappingKind::Erb { filterbank } => {
let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
if out.len().get() != erb_out.len().get() {
return Err(SpectrogramError::dimension_mismatch(
erb_out.len().get(),
out.len().get(),
));
}
out.copy_from_slice(&erb_out);
Ok(())
}
MappingKind::Cqt { kernel } => {
let cqt_complex = kernel.apply(frame)?;
if out.len().get() != cqt_complex.len().get() {
return Err(SpectrogramError::dimension_mismatch(
cqt_complex.len().get(),
out.len().get(),
));
}
for (i, c) in cqt_complex.iter().enumerate() {
out[i] = c.norm_sqr();
}
Ok(())
}
}
}
fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
match &self.kind {
MappingKind::Identity { out_len } => {
let n_fft = params.stft().n_fft().get() as f64;
let sr = params.sample_rate_hz();
let df = sr / n_fft;
let mut f = Vec::with_capacity((*out_len).get());
for k in 0..(*out_len).get() {
f.push(k as f64 * df);
}
unsafe { NonEmptyVec::new_unchecked(f) }
}
MappingKind::Mel { matrix } => {
let n_mels = matrix.nrows();
let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
}
MappingKind::LogHz { frequencies, .. } => {
frequencies.clone()
}
MappingKind::Erb { filterbank, .. } => {
filterbank.center_frequencies().to_non_empty_vec()
}
MappingKind::Cqt { kernel, .. } => {
kernel.frequencies().to_non_empty_vec()
}
}
}
}
pub trait AmpScaleSpec: Sized + Send + Sync {
fn apply_from_power(power: f64) -> f64;
fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
}
impl AmpScaleSpec for Power {
#[inline]
fn apply_from_power(power: f64) -> f64 {
power
}
#[inline]
fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
Ok(())
}
}
impl AmpScaleSpec for Magnitude {
#[inline]
fn apply_from_power(power: f64) -> f64 {
power.sqrt()
}
#[inline]
fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
Ok(())
}
}
impl AmpScaleSpec for Decibels {
#[inline]
fn apply_from_power(power: f64) -> f64 {
power
}
#[inline]
fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
if !floor_db.is_finite() {
return Err(SpectrogramError::invalid_input("floor_db must be finite"));
}
let eps = 10.0_f64.powf(floor_db / 10.0);
for v in x.iter_mut() {
*v = 10.0 * v.max(eps).log10();
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct AmplitudeScaling<AmpScale> {
db_floor: Option<f64>,
_marker: PhantomData<AmpScale>,
}
impl<AmpScale> AmplitudeScaling<AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
{
fn new(db: Option<&LogParams>) -> Self {
let db_floor = db.map(LogParams::floor_db);
Self {
db_floor,
_marker: PhantomData,
}
}
pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
for v in x.iter_mut() {
*v = AmpScale::apply_from_power(*v);
}
if let Some(floor_db) = self.db_floor {
AmpScale::apply_db_in_place(x, floor_db)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
struct Workspace {
spectrum: NonEmptyVec<f64>, mapped: NonEmptyVec<f64>, frame: NonEmptyVec<f64>, fft_out: NonEmptyVec<Complex<f64>>, }
impl Workspace {
fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
Self {
spectrum: non_empty_vec![0.0; out_len],
mapped: non_empty_vec![0.0; n_bins],
frame: non_empty_vec![0.0; n_fft],
fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
}
}
fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
if self.spectrum.len() != out_len {
self.spectrum.resize(out_len, 0.0);
}
if self.mapped.len() != n_bins {
self.mapped.resize(n_bins, 0.0);
}
if self.frame.len() != n_fft {
self.frame.resize(n_fft, 0.0);
}
if self.fft_out.len() != out_len {
self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
}
}
}
fn build_frequency_axis<FreqScale>(
params: &SpectrogramParams,
mapping: &FrequencyMapping<FreqScale>,
) -> FrequencyAxis<FreqScale>
where
FreqScale: Copy + Clone + 'static,
{
let frequencies = mapping.frequencies_hz(params);
FrequencyAxis::new(frequencies)
}
fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
let dt = params.frame_period_seconds();
let mut times = Vec::with_capacity(n_frames.get());
for i in 0..n_frames.get() {
times.push(i as f64 * dt);
}
unsafe { NonEmptyVec::new_unchecked(times) }
}
#[inline]
#[must_use]
pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
let n_fft = n_fft.get();
let mut w = vec![0.0; n_fft];
match window {
WindowType::Rectangular => {
w.fill(1.0);
}
WindowType::Hanning => {
let n1 = (n_fft - 1) as f64;
for (n, v) in w.iter_mut().enumerate() {
*v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
}
}
WindowType::Hamming => {
let n1 = (n_fft - 1) as f64;
for (n, v) in w.iter_mut().enumerate() {
*v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
}
}
WindowType::Blackman => {
let n1 = (n_fft - 1) as f64;
for (n, v) in w.iter_mut().enumerate() {
let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
*v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
}
}
WindowType::Kaiser { beta } => {
if n_fft == 1 {
w[0] = 1.0;
} else {
let denom = modified_bessel_i0(beta);
let n_max = (n_fft - 1) as f64 / 2.0;
for (i, value) in w.iter_mut().enumerate() {
let n = i as f64 - n_max;
let ratio = if n_max == 0.0 {
0.0
} else {
let normalized = n / n_max;
(1.0 - normalized * normalized).max(0.0)
};
let arg = beta * ratio.sqrt();
*value = if denom == 0.0 {
0.0
} else {
modified_bessel_i0(arg) / denom
};
}
}
}
WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
let n = i as f64;
let center: f64 = (n_fft - 1) as f64 / 2.0;
let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
w[i] = exponent.exp();
}),
WindowType::Custom { coefficients, size } => {
assert!(
size.get() == n_fft,
"Custom window size mismatch: expected {}, got {}. \
Custom windows must be pre-computed with the exact FFT size.",
n_fft,
size.get()
);
w.copy_from_slice(&coefficients);
}
}
unsafe { NonEmptyVec::new_unchecked(w) }
}
fn modified_bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax <= 3.75 {
let t = x / 3.75;
let t2 = t * t;
1.0 + t2
* (3.515_622_9
+ t2 * (3.089_942_4
+ t2 * (1.206_749_2
+ t2 * (0.265_973_2 + t2 * (0.036_076_8 + t2 * 0.004_581_3)))))
} else {
let t = 3.75 / ax;
let poly = 0.398_942_28
+ t * (0.013_285_92
+ t * (0.002_253_19
+ t * (-0.001_575_65
+ t * (0.009_162_81
+ t * (-0.020_577_06
+ t * (0.026_355_37 + t * (-0.016_476_33 + t * 0.003_923_77)))))));
(ax.exp() / (ax.sqrt() * (2.0 * std::f64::consts::PI).sqrt())) * poly
}
}
fn hz_to_mel(hz: f64) -> f64 {
const F_MIN: f64 = 0.0;
const F_SP: f64 = 200.0 / 3.0; const MIN_LOG_HZ: f64 = 1000.0;
const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; const LOGSTEP: f64 = 0.068_751_777_420_949_23; if hz >= MIN_LOG_HZ {
MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
} else {
(hz - F_MIN) / F_SP
}
}
fn mel_to_hz(mel: f64) -> f64 {
const F_MIN: f64 = 0.0;
const F_SP: f64 = 200.0 / 3.0; const MIN_LOG_HZ: f64 = 1000.0;
const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; const LOGSTEP: f64 = 0.068_751_777_420_949_23;
if mel >= MIN_LOG_MEL {
MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
} else {
F_SP.mul_add(mel, F_MIN)
}
}
fn build_mel_filterbank_matrix(
sample_rate_hz: f64,
n_fft: NonZeroUsize,
n_mels: NonZeroUsize,
f_min: f64,
f_max: f64,
norm: MelNorm,
) -> SpectrogramResult<SparseMatrix> {
if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
return Err(SpectrogramError::invalid_input(
"sample_rate_hz must be finite and > 0",
));
}
if f_min < 0.0 || f_min.is_infinite() {
return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
}
if f_max <= f_min {
return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
}
if f_max > sample_rate_hz * 0.5 {
return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
}
let n_mels = n_mels.get();
let n_fft = n_fft.get();
let out_len = r2c_output_size(n_fft);
let df = sample_rate_hz / n_fft as f64;
let mel_min = hz_to_mel(f_min);
let mel_max = hz_to_mel(f_max);
let n_points = n_mels + 2;
let step = (mel_max - mel_min) / (n_points - 1) as f64;
let mut mel_points = Vec::with_capacity(n_points);
for i in 0..n_points {
mel_points.push((i as f64).mul_add(step, mel_min));
}
let mut hz_points = Vec::with_capacity(n_points);
for m in &mel_points {
hz_points.push(mel_to_hz(*m));
}
let mut fb = SparseMatrix::new(n_mels, out_len);
for m in 0..n_mels {
let freq_left = hz_points[m];
let freq_center = hz_points[m + 1];
let freq_right = hz_points[m + 2];
let fdiff_left = freq_center - freq_left;
let fdiff_right = freq_right - freq_center;
if fdiff_left == 0.0 || fdiff_right == 0.0 {
continue;
}
for k in 0..out_len {
let bin_freq = k as f64 * df;
let lower = (bin_freq - freq_left) / fdiff_left;
let upper = (freq_right - bin_freq) / fdiff_right;
let weight = lower.min(upper).clamp(0.0, 1.0);
if weight > 0.0 {
fb.set(m, k, weight);
}
}
}
match norm {
MelNorm::None => {
}
MelNorm::Slaney => {
for m in 0..n_mels {
let mel_left = mel_points[m];
let mel_right = mel_points[m + 2];
let hz_left = mel_to_hz(mel_left);
let hz_right = mel_to_hz(mel_right);
let enorm = 2.0 / (hz_right - hz_left);
for val in &mut fb.values[m] {
*val *= enorm;
}
}
}
MelNorm::L1 => {
for m in 0..n_mels {
let sum: f64 = fb.values[m].iter().sum();
if sum > 0.0 {
let normalizer = 1.0 / sum;
for val in &mut fb.values[m] {
*val *= normalizer;
}
}
}
}
MelNorm::L2 => {
for m in 0..n_mels {
let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
if norm_val > 0.0 {
let normalizer = 1.0 / norm_val;
for val in &mut fb.values[m] {
*val *= normalizer;
}
}
}
}
}
Ok(fb)
}
fn build_loghz_matrix(
sample_rate_hz: f64,
n_fft: NonZeroUsize,
n_bins: NonZeroUsize,
f_min: f64,
f_max: f64,
) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
return Err(SpectrogramError::invalid_input(
"sample_rate_hz must be finite and > 0",
));
}
if f_min <= 0.0 || f_min.is_infinite() {
return Err(SpectrogramError::invalid_input(
"f_min must be finite and > 0",
));
}
if f_max <= f_min {
return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
}
if f_max > sample_rate_hz * 0.5 {
return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
}
let n_bins = n_bins.get();
let n_fft = n_fft.get();
let out_len = r2c_output_size(n_fft);
let df = sample_rate_hz / n_fft as f64;
let log_f_min = f_min.ln();
let log_f_max = f_max.ln();
let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
let mut log_frequencies = Vec::with_capacity(n_bins);
for i in 0..n_bins {
let log_f = (i as f64).mul_add(log_step, log_f_min);
log_frequencies.push(log_f.exp());
}
let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
let mut matrix = SparseMatrix::new(n_bins, out_len);
for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
let exact_bin = target_freq / df;
let lower_bin = exact_bin.floor() as usize;
let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
if lower_bin >= out_len {
continue;
}
if lower_bin == upper_bin {
matrix.set(bin_idx, lower_bin, 1.0);
} else {
let frac = exact_bin - lower_bin as f64;
matrix.set(bin_idx, lower_bin, 1.0 - frac);
if upper_bin < out_len {
matrix.set(bin_idx, upper_bin, frac);
}
}
}
Ok((matrix, log_frequencies))
}
fn mel_band_centres_hz(
n_mels: NonZeroUsize,
sample_rate_hz: f64,
nyquist_hz: f64,
) -> NonEmptyVec<f64> {
let f_min = 0.0;
let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
let mel_min = hz_to_mel(f_min);
let mel_max = hz_to_mel(f_max);
let n_mels = n_mels.get();
let step = (mel_max - mel_min) / (n_mels + 1) as f64;
let mut centres = Vec::with_capacity(n_mels);
for i in 0..n_mels {
let mel = (i as f64 + 1.0).mul_add(step, mel_min);
centres.push(mel_to_hz(mel));
}
unsafe { NonEmptyVec::new_unchecked(centres) }
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Spectrogram<FreqScale, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
FreqScale: Copy + Clone + 'static,
{
data: Array2<f64>,
axes: Axes<FreqScale>,
params: SpectrogramParams,
#[cfg_attr(feature = "serde", serde(skip))]
_amp: PhantomData<AmpScale>,
}
impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
FreqScale: Copy + Clone + 'static,
{
#[inline]
#[must_use]
pub const fn x_axis_label() -> &'static str {
"Time (s)"
}
#[inline]
#[must_use]
pub fn y_axis_label() -> &'static str {
match std::any::TypeId::of::<FreqScale>() {
id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
_ => "Frequency",
}
}
pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
debug_assert_eq!(data.ncols(), axes.times().len().get());
Self {
data,
axes,
params,
_amp: PhantomData,
}
}
#[inline]
pub fn set_data(&mut self, data: Array2<f64>) {
self.data = data;
}
#[inline]
#[must_use]
pub const fn data(&self) -> &Array2<f64> {
&self.data
}
#[inline]
#[must_use]
pub fn into_data(self) -> Array2<f64> {
self.data
}
#[inline]
#[must_use]
pub const fn axes(&self) -> &Axes<FreqScale> {
&self.axes
}
#[inline]
#[must_use]
pub fn frequencies(&self) -> &NonEmptySlice<f64> {
self.axes.frequencies()
}
#[inline]
#[must_use]
pub const fn frequency_range(&self) -> (f64, f64) {
self.axes.frequency_range()
}
#[inline]
#[must_use]
pub fn times(&self) -> &NonEmptySlice<f64> {
self.axes.times()
}
#[inline]
#[must_use]
pub const fn params(&self) -> &SpectrogramParams {
&self.params
}
#[inline]
#[must_use]
pub fn duration(&self) -> f64 {
self.axes.duration()
}
#[inline]
#[must_use]
pub fn db_range(&self) -> Option<(f64, f64)> {
let type_self = std::any::TypeId::of::<AmpScale>();
if type_self == std::any::TypeId::of::<Decibels>() {
let (min, max) = min_max_single_pass(self.data.as_slice()?);
Some((min, max))
} else if type_self == std::any::TypeId::of::<Power>() {
let mut min_db = f64::INFINITY;
let mut max_db = f64::NEG_INFINITY;
for &v in &self.data {
let db = 10.0 * (v + EPS).log10();
if db < min_db {
min_db = db;
}
if db > max_db {
max_db = db;
}
}
Some((min_db, max_db))
} else if type_self == std::any::TypeId::of::<Magnitude>() {
let mut min_db = f64::INFINITY;
let mut max_db = f64::NEG_INFINITY;
for &v in &self.data {
let power = v * v;
let db = 10.0 * (power + EPS).log10();
if db < min_db {
min_db = db;
}
if db > max_db {
max_db = db;
}
}
Some((min_db, max_db))
} else {
None
}
}
#[inline]
#[must_use]
pub fn n_bins(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
}
#[inline]
#[must_use]
pub fn n_frames(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
}
}
impl<FreqScale, AmpScale> AsRef<Array2<f64>> for Spectrogram<FreqScale, AmpScale>
where
FreqScale: Copy + Clone + 'static,
AmpScale: AmpScaleSpec + 'static,
{
#[inline]
fn as_ref(&self) -> &Array2<f64> {
&self.data
}
}
impl<FreqScale, AmpScale> Deref for Spectrogram<FreqScale, AmpScale>
where
FreqScale: Copy + Clone + 'static,
AmpScale: AmpScaleSpec + 'static,
{
type Target = Array2<f64>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl<AmpScale> Spectrogram<LinearHz, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
{
#[inline]
pub fn compute(
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
db: Option<&LogParams>,
) -> SpectrogramResult<Self> {
let planner = SpectrogramPlanner::new();
let mut plan = planner.linear_plan(params, db)?;
plan.compute(samples)
}
}
impl<AmpScale> Spectrogram<Mel, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
{
#[inline]
pub fn compute(
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
mel: &MelParams,
db: Option<&LogParams>,
) -> SpectrogramResult<Self> {
let planner = SpectrogramPlanner::new();
let mut plan = planner.mel_plan(params, mel, db)?;
plan.compute(samples)
}
}
impl<AmpScale> Spectrogram<Erb, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
{
#[inline]
pub fn compute(
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
erb: &ErbParams,
db: Option<&LogParams>,
) -> SpectrogramResult<Self> {
let planner = SpectrogramPlanner::new();
let mut plan = planner.erb_plan(params, erb, db)?;
plan.compute(samples)
}
}
impl<AmpScale> Spectrogram<LogHz, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
{
#[inline]
pub fn compute(
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
loghz: &LogHzParams,
db: Option<&LogParams>,
) -> SpectrogramResult<Self> {
let planner = SpectrogramPlanner::new();
let mut plan = planner.log_hz_plan(params, loghz, db)?;
plan.compute(samples)
}
}
impl<AmpScale> Spectrogram<Cqt, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
{
#[inline]
pub fn compute(
samples: &NonEmptySlice<f64>,
params: &SpectrogramParams,
cqt: &CqtParams,
db: Option<&LogParams>,
) -> SpectrogramResult<Self> {
let planner = SpectrogramPlanner::new();
let mut plan = planner.cqt_plan(params, cqt, db)?;
plan.compute(samples)
}
}
fn amp_scale_name<AmpScale>() -> &'static str
where
AmpScale: AmpScaleSpec + 'static,
{
match std::any::TypeId::of::<AmpScale>() {
id if id == std::any::TypeId::of::<Power>() => "Power",
id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
_ => "Unknown",
}
}
fn freq_scale_name<FreqScale>() -> &'static str
where
FreqScale: Copy + Clone + 'static,
{
match std::any::TypeId::of::<FreqScale>() {
id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
id if id == std::any::TypeId::of::<Mel>() => "Mel",
id if id == std::any::TypeId::of::<Erb>() => "ERB",
id if id == std::any::TypeId::of::<Cqt>() => "CQT",
_ => "Unknown",
}
}
impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
where
AmpScale: AmpScaleSpec + 'static,
FreqScale: Copy + Clone + 'static,
{
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
let (freq_min, freq_max) = self.frequency_range();
let duration = self.duration();
let (rows, cols) = self.data.dim();
if f.alternate() {
writeln!(f, "Spectrogram {{")?;
writeln!(f, " Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
writeln!(f, " Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
writeln!(f, " Shape: {rows} frequency bins × {cols} time frames")?;
writeln!(f, " Frequency Range: {freq_min:.2} Hz - {freq_max:.2} Hz")?;
writeln!(f, " Duration: {duration:.3} s")?;
writeln!(f)?;
writeln!(f, " Parameters:")?;
writeln!(f, " Sample Rate: {} Hz", self.params.sample_rate_hz())?;
writeln!(f, " FFT Size: {}", self.params.stft().n_fft())?;
writeln!(f, " Hop Size: {}", self.params.stft().hop_size())?;
writeln!(f, " Window: {:?}", self.params.stft().window())?;
writeln!(f, " Centered: {}", self.params.stft().centre())?;
writeln!(f)?;
let data_slice = self.data.as_slice().unwrap_or(&[]);
if !data_slice.is_empty() {
let (min_val, max_val) = min_max_single_pass(data_slice);
let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
writeln!(f, " Data Statistics:")?;
writeln!(f, " Min: {min_val:.6}")?;
writeln!(f, " Max: {max_val:.6}")?;
writeln!(f, " Mean: {mean:.6}")?;
writeln!(f)?;
}
writeln!(f, " Data Matrix:")?;
let max_rows_to_display = 5;
let max_cols_to_display = 5;
for i in 0..rows.min(max_rows_to_display) {
write!(f, " [")?;
for j in 0..cols.min(max_cols_to_display) {
if j > 0 {
write!(f, ", ")?;
}
write!(f, "{:9.4}", self.data[[i, j]])?;
}
if cols > max_cols_to_display {
write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
}
writeln!(f, "]")?;
}
if rows > max_rows_to_display {
writeln!(f, " ... ({} more rows)", rows - max_rows_to_display)?;
}
write!(f, "}}")?;
} else {
write!(
f,
"Spectrogram<{}, {}>[{}x{}] ({:.2}-{:.2} Hz, {:.3}s)",
freq_scale_name::<FreqScale>(),
amp_scale_name::<AmpScale>(),
rows,
cols,
freq_min,
freq_max,
duration
)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct FrequencyAxis<FreqScale>
where
FreqScale: Copy + Clone + 'static,
{
frequencies: NonEmptyVec<f64>,
#[cfg_attr(feature = "serde", serde(skip))]
_marker: PhantomData<FreqScale>,
}
impl<FreqScale> FrequencyAxis<FreqScale>
where
FreqScale: Copy + Clone + 'static,
{
pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
Self {
frequencies,
_marker: PhantomData,
}
}
#[inline]
#[must_use]
pub fn frequencies(&self) -> &NonEmptySlice<f64> {
&self.frequencies
}
#[inline]
#[must_use]
pub const fn frequency_range(&self) -> (f64, f64) {
let data = self.frequencies.as_slice();
let min = data[0];
let max_idx = data.len().saturating_sub(1); let max = data[max_idx];
(min, max)
}
#[inline]
#[must_use]
pub const fn len(&self) -> NonZeroUsize {
self.frequencies.len()
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Axes<FreqScale>
where
FreqScale: Copy + Clone + 'static,
{
freq: FrequencyAxis<FreqScale>,
times: NonEmptyVec<f64>,
}
impl<FreqScale> Axes<FreqScale>
where
FreqScale: Copy + Clone + 'static,
{
pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
Self { freq, times }
}
#[inline]
#[must_use]
pub fn frequencies(&self) -> &NonEmptySlice<f64> {
self.freq.frequencies()
}
#[inline]
#[must_use]
pub fn times(&self) -> &NonEmptySlice<f64> {
&self.times
}
#[inline]
#[must_use]
pub const fn frequency_range(&self) -> (f64, f64) {
self.freq.frequency_range()
}
#[inline]
#[must_use]
pub fn duration(&self) -> f64 {
*self.times.last()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum LinearHz {
_Phantom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum LogHz {
_Phantom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Mel {
_Phantom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Erb {
_Phantom,
}
pub type Gammatone = Erb;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Cqt {
_Phantom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Power {
_Phantom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Decibels {
_Phantom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "python", pyclass(from_py_object))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum Magnitude {
_Phantom,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct StftParams {
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: WindowType,
centre: bool,
}
impl StftParams {
#[inline]
pub fn new(
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: WindowType,
centre: bool,
) -> SpectrogramResult<Self> {
if hop_size.get() > n_fft.get() {
return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
}
if let WindowType::Custom { size, .. } = &window {
if size.get() != n_fft.get() {
return Err(SpectrogramError::invalid_input(format!(
"Custom window size ({}) must match n_fft ({})",
size.get(),
n_fft.get()
)));
}
}
Ok(Self {
n_fft,
hop_size,
window,
centre,
})
}
const unsafe fn new_unchecked(
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: WindowType,
centre: bool,
) -> Self {
Self {
n_fft,
hop_size,
window,
centre,
}
}
#[inline]
#[must_use]
pub const fn n_fft(&self) -> NonZeroUsize {
self.n_fft
}
#[inline]
#[must_use]
pub const fn hop_size(&self) -> NonZeroUsize {
self.hop_size
}
#[inline]
#[must_use]
pub fn window(&self) -> WindowType {
self.window.clone()
}
#[inline]
#[must_use]
pub const fn centre(&self) -> bool {
self.centre
}
#[inline]
#[must_use]
pub fn builder() -> StftParamsBuilder {
StftParamsBuilder::default()
}
}
#[derive(Debug, Clone)]
pub struct StftParamsBuilder {
n_fft: Option<NonZeroUsize>,
hop_size: Option<NonZeroUsize>,
window: WindowType,
centre: bool,
}
impl Default for StftParamsBuilder {
#[inline]
fn default() -> Self {
Self {
n_fft: None,
hop_size: None,
window: WindowType::Hanning,
centre: true,
}
}
}
impl StftParamsBuilder {
#[inline]
#[must_use]
pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
self.n_fft = Some(n_fft);
self
}
#[inline]
#[must_use]
pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
self.hop_size = Some(hop_size);
self
}
#[inline]
#[must_use]
pub fn window(mut self, window: WindowType) -> Self {
self.window = window;
self
}
#[inline]
#[must_use]
pub const fn centre(mut self, centre: bool) -> Self {
self.centre = centre;
self
}
#[inline]
pub fn build(self) -> SpectrogramResult<StftParams> {
let n_fft = self
.n_fft
.ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
let hop_size = self
.hop_size
.ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
StftParams::new(n_fft, hop_size, self.window, self.centre)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
#[derive(Default)]
pub enum MelNorm {
#[default]
None,
Slaney,
L1,
L2,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MelParams {
n_mels: NonZeroUsize,
f_min: f64,
f_max: f64,
norm: MelNorm,
}
impl MelParams {
#[inline]
pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
}
#[inline]
pub fn with_norm(
n_mels: NonZeroUsize,
f_min: f64,
f_max: f64,
norm: MelNorm,
) -> SpectrogramResult<Self> {
if f_min < 0.0 {
return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
}
if f_max <= f_min {
return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
}
Ok(Self {
n_mels,
f_min,
f_max,
norm,
})
}
pub const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
Self {
n_mels,
f_min,
f_max,
norm: MelNorm::None,
}
}
#[inline]
#[must_use]
pub const fn n_mels(&self) -> NonZeroUsize {
self.n_mels
}
#[inline]
#[must_use]
pub const fn f_min(&self) -> f64 {
self.f_min
}
#[inline]
#[must_use]
pub const fn f_max(&self) -> f64 {
self.f_max
}
#[inline]
#[must_use]
pub const fn norm(&self) -> MelNorm {
self.norm
}
#[inline]
#[must_use]
pub const fn standard(sample_rate: f64) -> Self {
assert!(sample_rate > 0.0);
unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
}
#[inline]
#[must_use]
pub const fn speech_standard() -> Self {
unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct LogHzParams {
n_bins: NonZeroUsize,
f_min: f64,
f_max: f64,
}
impl LogHzParams {
#[inline]
pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
if !(f_min > 0.0 && f_min.is_finite()) {
return Err(SpectrogramError::invalid_input(
"f_min must be finite and > 0",
));
}
if f_max <= f_min {
return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
}
Ok(Self {
n_bins,
f_min,
f_max,
})
}
const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
Self {
n_bins,
f_min,
f_max,
}
}
#[inline]
#[must_use]
pub const fn n_bins(&self) -> NonZeroUsize {
self.n_bins
}
#[inline]
#[must_use]
pub const fn f_min(&self) -> f64 {
self.f_min
}
#[inline]
#[must_use]
pub const fn f_max(&self) -> f64 {
self.f_max
}
#[inline]
#[must_use]
pub fn standard(sample_rate: f64) -> Self {
unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
}
#[inline]
#[must_use]
pub const fn music_standard() -> Self {
unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct LogParams {
floor_db: f64,
}
impl LogParams {
#[inline]
pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
if !floor_db.is_finite() {
return Err(SpectrogramError::invalid_input("floor_db must be finite"));
}
Ok(Self { floor_db })
}
#[inline]
pub fn new_unchecked(floor_db: f64) -> Self {
Self { floor_db }
}
#[inline]
#[must_use]
pub const fn floor_db(&self) -> f64 {
self.floor_db
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SpectrogramParams {
pub(crate) stft: StftParams,
pub(crate) sample_rate_hz: f64,
}
impl SpectrogramParams {
#[inline]
pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
return Err(SpectrogramError::invalid_input(
"sample_rate_hz must be finite and > 0",
));
}
Ok(Self {
stft,
sample_rate_hz,
})
}
#[inline]
pub fn new_unchecked(stft: StftParams, sample_rate_hz: f64) -> Self {
Self {
stft,
sample_rate_hz,
}
}
#[inline]
#[must_use]
pub fn builder() -> SpectrogramParamsBuilder {
SpectrogramParamsBuilder::default()
}
#[inline]
pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
let stft =
unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
Self::new(stft, sample_rate_hz)
}
#[inline]
pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
let stft =
unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
Self::new(stft, sample_rate_hz)
}
#[inline]
#[must_use]
pub const fn stft(&self) -> &StftParams {
&self.stft
}
#[inline]
#[must_use]
pub const fn sample_rate_hz(&self) -> f64 {
self.sample_rate_hz
}
#[inline]
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn frame_period_seconds(&self) -> f64 {
self.stft.hop_size().get() as f64 / self.sample_rate_hz
}
#[inline]
#[must_use]
pub fn nyquist_hz(&self) -> f64 {
self.sample_rate_hz * 0.5
}
}
#[derive(Debug, Clone)]
pub struct SpectrogramParamsBuilder {
sample_rate: Option<f64>,
n_fft: Option<NonZeroUsize>,
hop_size: Option<NonZeroUsize>,
window: WindowType,
centre: bool,
}
impl Default for SpectrogramParamsBuilder {
#[inline]
fn default() -> Self {
Self {
sample_rate: None,
n_fft: None,
hop_size: None,
window: WindowType::Hanning,
centre: true,
}
}
}
impl SpectrogramParamsBuilder {
#[inline]
#[must_use]
pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
self.sample_rate = Some(sample_rate);
self
}
#[inline]
#[must_use]
pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
self.n_fft = Some(n_fft);
self
}
#[inline]
#[must_use]
pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
self.hop_size = Some(hop_size);
self
}
#[inline]
#[must_use]
pub fn window(mut self, window: WindowType) -> Self {
self.window = window;
self
}
#[inline]
#[must_use]
pub const fn centre(mut self, centre: bool) -> Self {
self.centre = centre;
self
}
#[inline]
pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
let sample_rate = self
.sample_rate
.ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
let n_fft = self
.n_fft
.ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
let hop_size = self
.hop_size
.ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
SpectrogramParams::new(stft, sample_rate)
}
#[inline]
pub unsafe fn build_unchecked(self) -> SpectrogramParams {
unsafe {
let n_fft = self.n_fft.unwrap_unchecked();
let hop_size = self.hop_size.unwrap_unchecked();
let stft = StftParams::new_unchecked(n_fft, hop_size, self.window, self.centre);
let sample_rate = self.sample_rate.unwrap_unchecked();
SpectrogramParams::new_unchecked(stft, sample_rate)
}
}
}
#[inline]
pub fn fft(
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
) -> SpectrogramResult<Array1<Complex<f64>>> {
if samples.len() > n_fft {
return Err(SpectrogramError::invalid_input(format!(
"Input length ({}) exceeds FFT size ({})",
samples.len(),
n_fft
)));
}
let out_len = r2c_output_size(n_fft.get());
#[cfg(feature = "realfft")]
let mut fft = {
use crate::fft_backend::get_or_create_r2c_plan;
let plan = get_or_create_r2c_plan(n_fft.get())?;
(*plan).clone()
};
#[cfg(feature = "fftw")]
let mut fft = {
use std::sync::Arc;
let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
crate::FftwPlan::new(Arc::new(plan))
};
let input = if samples.len() < n_fft {
let mut padded = vec![0.0; n_fft.get()];
padded[..samples.len().get()].copy_from_slice(samples);
unsafe { NonEmptyVec::new_unchecked(padded) }
} else {
samples.to_non_empty_vec()
};
let mut output = vec![Complex::new(0.0, 0.0); out_len];
fft.process(&input, &mut output)?;
let output = Array1::from_vec(output);
Ok(output)
}
#[inline]
pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
Ok(fft(samples, n_fft)?.mapv(Complex::norm))
}
#[inline]
pub fn power_spectrum(
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
window: Option<WindowType>,
) -> SpectrogramResult<NonEmptyVec<f64>> {
if samples.len() > n_fft {
return Err(SpectrogramError::invalid_input(format!(
"Input length ({}) exceeds FFT size ({})",
samples.len(),
n_fft
)));
}
let mut windowed = vec![0.0; n_fft.get()];
windowed[..samples.len().get()].copy_from_slice(samples);
if let Some(win_type) = window {
let window_samples = make_window(win_type, n_fft);
for i in 0..n_fft.get() {
windowed[i] *= window_samples[i];
}
}
let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
let fft_result = fft(windowed, n_fft)?;
let fft_result = fft_result
.iter()
.map(num_complex::Complex::norm_sqr)
.collect();
Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
}
#[inline]
pub fn magnitude_spectrum(
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
window: Option<WindowType>,
) -> SpectrogramResult<NonEmptyVec<f64>> {
let power = power_spectrum(samples, n_fft, window)?;
let power = power.iter().map(|&p| p.sqrt()).collect();
Ok(unsafe { NonEmptyVec::new_unchecked(power) })
}
#[inline]
pub fn stft(
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: WindowType,
center: bool,
) -> SpectrogramResult<Array2<Complex<f64>>> {
let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
let params = SpectrogramParams::new(stft_params, 1.0)?;
let planner = SpectrogramPlanner::new();
let result = planner.compute_stft(samples, ¶ms)?;
Ok(result.data)
}
#[inline]
pub fn irfft(
spectrum: &NonEmptySlice<Complex<f64>>,
n_fft: NonZeroUsize,
) -> SpectrogramResult<NonEmptyVec<f64>> {
use crate::fft_backend::{C2rPlan, r2c_output_size};
let n_fft = n_fft.get();
let expected_len = r2c_output_size(n_fft);
if spectrum.len().get() != expected_len {
return Err(SpectrogramError::dimension_mismatch(
expected_len,
spectrum.len().get(),
));
}
#[cfg(feature = "realfft")]
let mut ifft = {
use crate::fft_backend::get_or_create_c2r_plan;
let plan = get_or_create_c2r_plan(n_fft)?;
(*plan).clone()
};
#[cfg(feature = "fftw")]
let mut ifft = {
use crate::fft_backend::C2rPlanner;
let mut planner = crate::FftwPlanner::new();
planner.plan_c2r(n_fft)?
};
let mut output = vec![0.0; n_fft];
ifft.process(spectrum.as_slice(), &mut output)?;
Ok(unsafe { NonEmptyVec::new_unchecked(output) })
}
#[inline]
pub fn istft(
stft_matrix: &Array2<Complex<f64>>,
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: WindowType,
center: bool,
) -> SpectrogramResult<NonEmptyVec<f64>> {
use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
let n_bins = stft_matrix.nrows();
let n_frames = stft_matrix.ncols();
let expected_bins = r2c_output_size(n_fft.get());
if n_bins != expected_bins {
return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
}
if hop_size.get() > n_fft.get() {
return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
}
#[cfg(feature = "realfft")]
let mut ifft = {
let mut planner = crate::RealFftPlanner::new();
planner.plan_c2r(n_fft.get())?
};
#[cfg(feature = "fftw")]
let mut ifft = {
let mut planner = crate::FftwPlanner::new();
planner.plan_c2r(n_fft.get())?
};
let window_samples = make_window(window, n_fft);
let n_fft = n_fft.get();
let hop_size = hop_size.get();
let pad = if center { n_fft / 2 } else { 0 };
let output_len = (n_frames - 1) * hop_size + n_fft;
let output_len = unsafe { NonZeroUsize::new_unchecked(output_len) };
let unpadded_len = output_len.get().saturating_sub(2 * pad);
let mut output = non_empty_vec![0.0; output_len];
let mut norm = non_empty_vec![0.0; output_len];
let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
let mut time_frame = vec![0.0; n_fft];
for frame_idx in 0..n_frames {
for bin_idx in 0..n_bins {
frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
}
ifft.process(&frame_buffer, &mut time_frame)?;
for i in 0..n_fft {
time_frame[i] *= window_samples[i];
}
let start = frame_idx * hop_size;
for i in 0..n_fft {
let pos = start + i;
if pos < output_len.get() {
output[pos] += time_frame[i];
norm[pos] += window_samples[i] * window_samples[i];
}
}
}
for i in 0..output_len.get() {
if norm[i] > 1e-10 {
output[i] /= norm[i];
}
}
if center && unpadded_len > 0 {
let start = pad;
let end = start + unpadded_len;
output = unsafe {
NonEmptySlice::new_unchecked(&output[start..end.min(output_len.get())])
.to_non_empty_vec()
};
}
Ok(output)
}
pub struct FftPlanner {
#[cfg(feature = "realfft")]
inner: crate::RealFftPlanner,
#[cfg(feature = "fftw")]
inner: crate::FftwPlanner,
}
impl FftPlanner {
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
#[cfg(feature = "realfft")]
inner: crate::RealFftPlanner::new(),
#[cfg(feature = "fftw")]
inner: crate::FftwPlanner::new(),
}
}
#[inline]
pub fn fft(
&mut self,
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
) -> SpectrogramResult<Array1<Complex<f64>>> {
use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
if samples.len() > n_fft {
return Err(SpectrogramError::invalid_input(format!(
"Input length ({}) exceeds FFT size ({})",
samples.len(),
n_fft
)));
}
let out_len = r2c_output_size(n_fft.get());
let mut plan = self.inner.plan_r2c(n_fft.get())?;
let input = if samples.len() < n_fft {
let mut padded = vec![0.0; n_fft.get()];
padded[..samples.len().get()].copy_from_slice(samples);
unsafe { NonEmptyVec::new_unchecked(padded) }
} else {
samples.to_non_empty_vec()
};
let mut output = vec![Complex::new(0.0, 0.0); out_len];
plan.process(&input, &mut output)?;
let output = Array1::from_vec(output);
Ok(output)
}
#[inline]
pub fn rfft(
&mut self,
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
) -> SpectrogramResult<Array1<f64>> {
let fft_with_complex = fft(samples, n_fft)?;
Ok(fft_with_complex.mapv(Complex::norm))
}
#[inline]
pub fn irfft(
&mut self,
spectrum: &NonEmptySlice<Complex<f64>>,
n_fft: NonZeroUsize,
) -> SpectrogramResult<NonEmptyVec<f64>> {
use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
let expected_len = r2c_output_size(n_fft.get());
if spectrum.len().get() != expected_len {
return Err(SpectrogramError::dimension_mismatch(
expected_len,
spectrum.len().get(),
));
}
let mut plan = self.inner.plan_c2r(n_fft.get())?;
let mut output = vec![0.0; n_fft.get()];
plan.process(spectrum, &mut output)?;
let output = unsafe { NonEmptyVec::new_unchecked(output) };
Ok(output)
}
#[inline]
pub fn power_spectrum(
&mut self,
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
window: Option<WindowType>,
) -> SpectrogramResult<NonEmptyVec<f64>> {
if samples.len() > n_fft {
return Err(SpectrogramError::invalid_input(format!(
"Input length ({}) exceeds FFT size ({})",
samples.len(),
n_fft
)));
}
let mut windowed = vec![0.0; n_fft.get()];
windowed[..samples.len().get()].copy_from_slice(samples);
if let Some(win_type) = window {
let window_samples = make_window(win_type, n_fft);
for i in 0..n_fft.get() {
windowed[i] *= window_samples[i];
}
}
let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
let fft_result = self.fft(windowed, n_fft)?;
let f = fft_result
.iter()
.map(num_complex::Complex::norm_sqr)
.collect();
Ok(unsafe { NonEmptyVec::new_unchecked(f) })
}
#[inline]
pub fn magnitude_spectrum(
&mut self,
samples: &NonEmptySlice<f64>,
n_fft: NonZeroUsize,
window: Option<WindowType>,
) -> SpectrogramResult<NonEmptyVec<f64>> {
let power = self.power_spectrum(samples, n_fft, window)?;
let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
Ok(unsafe { NonEmptyVec::new_unchecked(power) })
}
}
impl Default for FftPlanner {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sparse_matrix_basic() {
let mut sparse = SparseMatrix::new(3, 5);
sparse.set(0, 1, 2.0);
sparse.set(1, 2, 0.5);
sparse.set(1, 3, 1.5);
sparse.set(2, 0, 3.0);
sparse.set(2, 4, 1.0);
let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut output = vec![0.0; 3];
sparse.multiply_vec(&input, &mut output);
assert_eq!(output[0], 4.0);
assert_eq!(output[1], 7.5);
assert_eq!(output[2], 8.0);
}
#[test]
fn test_sparse_matrix_zeros_ignored() {
let mut sparse = SparseMatrix::new(2, 3);
sparse.set(0, 0, 1.0);
sparse.set(0, 1, 0.0); sparse.set(0, 2, 2.0);
assert_eq!(sparse.values[0].len(), 2);
assert_eq!(sparse.indices[0].len(), 2);
assert_eq!(sparse.indices[0], vec![0, 2]);
assert_eq!(sparse.values[0], vec![1.0, 2.0]);
}
#[test]
fn test_loghz_matrix_sparsity() {
let sample_rate = 16000.0;
let n_fft = nzu!(512);
let n_bins = nzu!(128);
let f_min = 20.0;
let f_max = sample_rate / 2.0;
let (matrix, _freqs) =
build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
for row_idx in 0..matrix.nrows() {
let nnz = matrix.values[row_idx].len();
assert!(
nnz <= 2,
"Row {} has {} non-zeros, expected at most 2",
row_idx,
nnz
);
assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
}
let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
assert!(total_nnz <= n_bins.get() * 2);
assert!(total_nnz >= n_bins.get()); }
#[test]
fn test_mel_matrix_sparsity() {
let sample_rate = 16000.0;
let n_fft = nzu!(512);
let n_mels = nzu!(40);
let f_min = 0.0;
let f_max = sample_rate / 2.0;
let matrix =
build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
.unwrap();
let n_fft_bins = r2c_output_size(n_fft.get());
let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
let total_elements = n_mels.get() * n_fft_bins;
let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
assert!(
sparsity > 0.8,
"Mel matrix sparsity is only {:.1}%, expected >80%",
sparsity * 100.0
);
for row_idx in 0..matrix.nrows() {
let nnz = matrix.values[row_idx].len();
assert!(
nnz < n_fft_bins / 2,
"Mel filter {} is not sparse enough",
row_idx
);
}
}
}