use std::{num::NonZeroUsize, ops::Deref};
use ndarray::Array2;
use non_empty_slice::{NonEmptySlice, NonEmptyVec};
use num_complex::Complex;
use crate::{SpectrogramError, SpectrogramResult, WindowType, nzu};
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct CqtParams {
pub bins_per_octave: NonZeroUsize,
pub n_octaves: NonZeroUsize,
pub f_min: f64,
pub q_factor: f64,
pub window: WindowType,
pub sparsity_threshold: f64,
pub normalize: bool,
}
impl CqtParams {
#[inline]
pub fn new(
bins_per_octave: NonZeroUsize,
n_octaves: NonZeroUsize,
f_min: f64,
) -> SpectrogramResult<Self> {
if f_min <= 0.0 || f_min.is_infinite() {
return Err(SpectrogramError::invalid_input(
"f_min must be finite and > 0",
));
}
Ok(Self {
bins_per_octave,
n_octaves,
f_min,
q_factor: 1.0 / ((1.0 / bins_per_octave.get() as f64).exp2() - 1.0),
window: WindowType::Hanning,
sparsity_threshold: 0.01,
normalize: true,
})
}
#[inline]
#[must_use]
pub unsafe fn new_unchecked(
bins_per_octave: NonZeroUsize,
n_octaves: NonZeroUsize,
f_min: f64,
) -> Self {
Self {
bins_per_octave,
n_octaves,
f_min,
q_factor: 1.0 / ((1.0 / bins_per_octave.get() as f64).exp2() - 1.0),
window: WindowType::Hanning,
sparsity_threshold: 0.01,
normalize: true,
}
}
#[inline]
pub fn with_q_factor(mut self, q_factor: f64) -> SpectrogramResult<Self> {
if !(q_factor > 0.0 && q_factor.is_finite()) {
return Err(SpectrogramError::invalid_input(
"q_factor must be finite and > 0",
));
}
self.q_factor = q_factor;
Ok(self)
}
#[inline]
#[must_use]
pub fn with_window(mut self, window: WindowType) -> Self {
self.window = window;
self
}
#[inline]
#[must_use]
pub const fn with_sparsity(mut self, threshold: f64) -> Self {
self.sparsity_threshold = threshold.max(0.0);
self
}
#[inline]
#[must_use]
pub const fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
#[inline]
#[must_use]
pub const fn num_bins(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.bins_per_octave.get() * self.n_octaves.get()) }
}
#[inline]
#[must_use]
pub fn bin_frequency(&self, bin_idx: usize) -> f64 {
self.f_min * (bin_idx as f64 / self.bins_per_octave.get() as f64).exp2()
}
#[inline]
#[must_use]
pub fn bin_bandwidth(&self, bin_idx: usize) -> f64 {
self.bin_frequency(bin_idx) / self.q_factor
}
#[inline]
#[must_use]
pub fn frequencies(&self) -> NonEmptyVec<f64> {
let freqs = (0..self.num_bins().get())
.map(|i| self.bin_frequency(i))
.collect();
unsafe { NonEmptyVec::new_unchecked(freqs) }
}
#[inline]
#[must_use]
pub fn percussive() -> Self {
unsafe { Self::new_unchecked(nzu!(12), nzu!(7), 32.7) }
}
#[inline]
#[must_use]
pub fn onset_detection() -> Self {
let mut params = unsafe { Self::new_unchecked(nzu!(24), nzu!(6), 55.0) };
params.q_factor = 0.5; params.sparsity_threshold = 0.02;
params.normalize = true;
params
}
#[inline]
#[must_use]
pub fn chord_detection() -> Self {
let mut params = unsafe { Self::new_unchecked(nzu!(36), nzu!(5), 82.4) };
params.q_factor = 0.8; params.sparsity_threshold = 0.02;
params.normalize = true;
params
}
#[inline]
#[must_use]
pub fn harmonic() -> Self {
let mut params = unsafe { Self::new_unchecked(nzu!(24), nzu!(7), 55.0) };
params.q_factor = 1.0; params.sparsity_threshold = 0.005;
params.normalize = true;
params
}
#[inline]
#[must_use]
pub fn musical() -> Self {
let mut params = unsafe { Self::new_unchecked(nzu!(12), nzu!(7), 32.7) };
params.q_factor = 1.0; params.sparsity_threshold = 0.01;
params.normalize = true;
params
}
}
#[derive(Debug, Clone)]
pub struct CqtKernel {
kernels: NonEmptyVec<NonEmptyVec<Complex<f64>>>,
kernel_lengths: Vec<NonZeroUsize>,
_fft_size: usize,
frequencies: NonEmptyVec<f64>,
}
impl CqtKernel {
pub(crate) fn generate(
params: &CqtParams,
sample_rate: f64,
signal_length: NonZeroUsize,
) -> Self {
let num_bins = params.num_bins().get();
let mut kernels = Vec::with_capacity(num_bins);
let mut frequencies = Vec::with_capacity(num_bins);
let mut kernel_lengths = Vec::with_capacity(num_bins);
let fft_size = (signal_length.get() * 2).next_power_of_two();
for bin_idx in 0..num_bins {
let center_freq = params.bin_frequency(bin_idx);
if center_freq >= sample_rate / 2.0 {
break;
}
let kernel_length = ((params.q_factor * sample_rate / center_freq).round() as usize)
.max(1)
.min(signal_length.get());
let kernel_length = unsafe { NonZeroUsize::new_unchecked(kernel_length) };
let mut kernel = Self::generate_kernel_bin(
center_freq,
kernel_length,
sample_rate,
params.window.clone(),
);
Self::apply_sparsity_threshold(&mut kernel, params.sparsity_threshold);
if params.normalize {
Self::normalize_kernel(&mut kernel);
}
kernels.push(kernel);
frequencies.push(center_freq);
kernel_lengths.push(kernel_length);
}
let kernels = unsafe { NonEmptyVec::new_unchecked(kernels) };
let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
Self {
kernels,
kernel_lengths,
_fft_size: fft_size,
frequencies,
}
}
fn generate_kernel_bin(
center_freq: f64,
kernel_length: NonZeroUsize,
sample_rate: f64,
window_type: WindowType,
) -> NonEmptyVec<Complex<f64>> {
let mut kernel = Vec::with_capacity(kernel_length.get());
let window = crate::spectrogram::make_window(window_type, kernel_length);
for (n, w) in window.iter().enumerate().take(kernel_length.get()) {
let t = n as f64 / sample_rate;
let phase = 2.0 * std::f64::consts::PI * center_freq * t;
let exponential = Complex::new(phase.cos(), phase.sin());
let windowed = exponential * w;
kernel.push(windowed);
}
unsafe { NonEmptyVec::new_unchecked(kernel) }
}
fn apply_sparsity_threshold(kernel: &mut [Complex<f64>], threshold: f64) {
if threshold <= 0.0 {
return;
}
let max_magnitude = kernel.iter().map(|c| c.norm()).fold(0.0, f64::max);
if max_magnitude == 0.0 {
return;
}
let absolute_threshold = max_magnitude * threshold;
for coefficient in kernel.iter_mut() {
if coefficient.norm() < absolute_threshold {
*coefficient = Complex::new(0.0, 0.0);
}
}
}
fn normalize_kernel(kernel: &mut NonEmptySlice<Complex<f64>>) {
let energy: f64 = kernel.iter().map(num_complex::Complex::norm_sqr).sum();
if energy > 0.0 {
let norm_factor = 1.0 / energy.sqrt();
for coefficient in kernel.iter_mut() {
*coefficient *= norm_factor;
}
}
}
#[inline]
#[must_use]
pub fn frequencies(&self) -> &NonEmptySlice<f64> {
&self.frequencies
}
#[inline]
#[must_use]
pub const fn num_bins(&self) -> NonZeroUsize {
self.kernels.len()
}
#[inline]
pub fn apply(
&self,
samples: &NonEmptySlice<f64>,
) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
let mut cqt_result = Vec::with_capacity(self.kernels.len().get());
for (bin_idx, kernel) in self.kernels.iter().enumerate() {
let kernel_length = self.kernel_lengths[bin_idx];
let mut correlation = Complex::new(0.0, 0.0);
let start_idx = samples.len().get().saturating_sub(kernel_length.get());
for (k_idx, &k) in kernel.iter().enumerate() {
let sample_idx = start_idx + k_idx;
if sample_idx < samples.len().get() {
let sample = samples[sample_idx];
correlation += k.conj() * sample;
}
}
cqt_result.push(correlation);
}
Ok(unsafe { NonEmptyVec::new_unchecked(cqt_result) })
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct CqtResult {
pub data: Array2<Complex<f64>>,
pub frequencies: NonEmptyVec<f64>,
pub sample_rate: f64,
pub hop_size: NonZeroUsize,
}
impl AsRef<Array2<Complex<f64>>> for CqtResult {
#[inline]
fn as_ref(&self) -> &Array2<Complex<f64>> {
&self.data
}
}
impl Deref for CqtResult {
type Target = Array2<Complex<f64>>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl CqtResult {
#[inline]
#[must_use]
pub fn n_bins(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
}
#[inline]
#[must_use]
pub fn n_frames(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
}
#[inline]
#[must_use]
pub fn time_resolution(&self) -> f64 {
self.hop_size.get() as f64 / self.sample_rate
}
#[inline]
#[must_use]
pub fn to_magnitude(&self) -> Array2<f64> {
let mut magnitude = Array2::<f64>::zeros(self.data.dim());
for ((i, j), val) in self.data.indexed_iter() {
magnitude[[i, j]] = val.norm();
}
magnitude
}
#[inline]
#[must_use]
pub fn to_power(&self) -> Array2<f64> {
let mut power = Array2::<f64>::zeros(self.data.dim());
for ((i, j), val) in self.data.indexed_iter() {
power[[i, j]] = val.norm_sqr();
}
power
}
}
#[inline]
pub fn cqt(
samples: &NonEmptySlice<f64>,
sample_rate: f64,
params: &CqtParams,
hop_size: NonZeroUsize,
) -> SpectrogramResult<CqtResult> {
let kernel_length = samples.len().min(nzu!(16384));
let kernel = CqtKernel::generate(params, sample_rate, kernel_length);
let n_bins = kernel.num_bins();
let frequencies = kernel.frequencies().to_non_empty_vec();
let n_frames = if samples.len() < kernel_length {
1
} else {
(samples.len().get() - kernel_length.get()) / hop_size.get() + 1
};
let mut cqt_data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames));
for frame_idx in 0..n_frames {
let start = frame_idx * hop_size.get();
let end = (start + kernel_length.get()).min(samples.len().get());
if end <= start {
break;
}
let frame = non_empty_slice::non_empty_slice!(&samples[start..end]);
let cqt_frame = kernel.apply(frame)?;
for (bin_idx, &val) in cqt_frame.iter().enumerate() {
if bin_idx < n_bins.get() {
cqt_data[[bin_idx, frame_idx]] = val;
}
}
}
Ok(CqtResult {
data: cqt_data,
frequencies,
sample_rate,
hop_size,
})
}