use std::{num::NonZeroUsize, ops::Deref};
use ndarray::Array2;
use non_empty_slice::NonEmptySlice;
use crate::{SpectrogramError, SpectrogramResult, StftParams, nzu};
pub const N_CHROMA: usize = 12;
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ChromaParams {
tuning: f64,
n_octaves: NonZeroUsize,
f_min: f64,
f_max: f64,
norm: ChromaNorm,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum ChromaNorm {
None,
L1,
#[default]
L2,
Max,
}
impl Default for ChromaParams {
#[inline]
fn default() -> Self {
Self {
tuning: 440.0,
n_octaves: nzu!(7),
f_min: 32.7, f_max: 4186.0, norm: ChromaNorm::L2,
}
}
}
impl ChromaParams {
#[inline]
pub fn new(tuning: f64, f_min: f64, f_max: f64, norm: ChromaNorm) -> SpectrogramResult<Self> {
if !(tuning > 0.0 && tuning.is_finite()) {
return Err(SpectrogramError::invalid_input(
"tuning must be finite and > 0",
));
}
if !(f_min > 0.0 && f_min.is_finite()) {
return Err(SpectrogramError::invalid_input(
"f_min must be finite and > 0",
));
}
if f_max <= f_min {
return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
}
let n_octaves = ((f_max / f_min).log2().ceil() as usize).max(1);
let n_octaves = unsafe { NonZeroUsize::new_unchecked(n_octaves) };
Ok(Self {
tuning,
n_octaves,
f_min,
f_max,
norm,
})
}
#[inline]
#[must_use]
pub const fn music_standard() -> Self {
Self {
tuning: 440.0,
n_octaves: nzu!(7),
f_min: 32.7, f_max: 4186.0, norm: ChromaNorm::L2,
}
}
#[inline]
#[must_use]
pub const fn with_norm(mut self, norm: ChromaNorm) -> Self {
self.norm = norm;
self
}
#[inline]
#[must_use]
pub const fn tuning(&self) -> f64 {
self.tuning
}
#[inline]
#[must_use]
pub const fn f_min(&self) -> f64 {
self.f_min
}
#[inline]
#[must_use]
pub const fn f_max(&self) -> f64 {
self.f_max
}
#[inline]
#[must_use]
pub const fn n_octaves(&self) -> NonZeroUsize {
self.n_octaves
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Chromagram {
pub data: Array2<f64>,
params: ChromaParams,
}
impl Chromagram {
#[inline]
#[must_use]
pub fn n_frames(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
}
#[inline]
#[must_use]
pub fn n_bins(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
}
#[inline]
#[must_use]
pub const fn params(&self) -> &ChromaParams {
&self.params
}
#[inline]
#[must_use]
pub const fn labels() -> [&'static str; 12] {
[
"C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B",
]
}
}
impl AsRef<Array2<f64>> for Chromagram {
#[inline]
fn as_ref(&self) -> &Array2<f64> {
&self.data
}
}
impl Deref for Chromagram {
type Target = Array2<f64>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.data
}
}
#[inline]
pub fn build_chroma_filterbank(
sample_rate: f64,
n_fft: NonZeroUsize,
params: &ChromaParams,
) -> SpectrogramResult<Array2<f64>> {
use std::f64::consts::LN_2;
if sample_rate <= 0.0 || !sample_rate.is_finite() {
return Err(SpectrogramError::invalid_input(
"sample_rate must be finite and > 0",
));
}
let n_bins = n_fft.get() / 2 + 1;
let freq_resolution = sample_rate / n_fft.get() as f64;
let fft_freqs: Vec<f64> = (0..n_bins).map(|k| k as f64 * freq_resolution).collect();
let mut filterbank = Array2::<f64>::zeros((N_CHROMA, n_bins));
let a4_midi = 69.0;
let a4_freq = params.tuning;
for (bin_idx, &freq) in fft_freqs.iter().enumerate() {
if freq < params.f_min || freq > params.f_max || freq <= 0.0 {
continue;
}
let midi_note = a4_midi + 12.0 * (freq / a4_freq).ln() / LN_2;
let pitch_class = midi_note.rem_euclid(12.0);
for chroma_idx in 0..N_CHROMA {
let chroma_center = chroma_idx as f64;
let dist = (pitch_class - chroma_center).abs();
let circular_dist = dist.min(12.0 - dist);
let sigma = 1.0;
let weight = (-0.5 * (circular_dist / sigma).powi(2)).exp();
filterbank[[chroma_idx, bin_idx]] = weight;
}
}
for chroma_idx in 0..N_CHROMA {
let row_sum: f64 = (0..n_bins).map(|i| filterbank[[chroma_idx, i]]).sum();
if row_sum > 0.0 {
for bin_idx in 0..n_bins {
filterbank[[chroma_idx, bin_idx]] /= row_sum;
}
}
}
Ok(filterbank)
}
#[inline]
pub fn chromagram_from_spectrogram(
spectrogram: &Array2<f64>,
sample_rate: f64,
n_fft: NonZeroUsize,
params: &ChromaParams,
) -> SpectrogramResult<Chromagram> {
let n_bins = spectrogram.nrows();
let n_frames = spectrogram.ncols();
let expected_bins = n_fft.get() / 2 + 1;
if n_bins != expected_bins {
return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
}
let filterbank = build_chroma_filterbank(sample_rate, n_fft, params)?;
let mut chroma_data = Array2::<f64>::zeros((N_CHROMA, n_frames));
for frame_idx in 0..n_frames {
for chroma_idx in 0..N_CHROMA {
let mut sum = 0.0;
for bin_idx in 0..n_bins {
sum += filterbank[[chroma_idx, bin_idx]] * spectrogram[[bin_idx, frame_idx]];
}
chroma_data[[chroma_idx, frame_idx]] = sum;
}
}
apply_chroma_normalization(&mut chroma_data, params.norm);
Ok(Chromagram {
data: chroma_data,
params: *params,
})
}
fn apply_chroma_normalization(chroma: &mut Array2<f64>, norm: ChromaNorm) {
let n_frames = chroma.ncols();
match norm {
ChromaNorm::None => {}
ChromaNorm::L1 => {
for frame_idx in 0..n_frames {
let sum: f64 = (0..N_CHROMA).map(|i| chroma[[i, frame_idx]]).sum();
if sum > 0.0 {
for chroma_idx in 0..N_CHROMA {
chroma[[chroma_idx, frame_idx]] /= sum;
}
}
}
}
ChromaNorm::L2 => {
for frame_idx in 0..n_frames {
let sum_sq: f64 = (0..N_CHROMA).map(|i| chroma[[i, frame_idx]].powi(2)).sum();
let norm = sum_sq.sqrt();
if norm > 0.0 {
for chroma_idx in 0..N_CHROMA {
chroma[[chroma_idx, frame_idx]] /= norm;
}
}
}
}
ChromaNorm::Max => {
for frame_idx in 0..n_frames {
let max_val = (0..N_CHROMA)
.map(|i| chroma[[i, frame_idx]])
.fold(0.0, f64::max);
if max_val > 0.0 {
for chroma_idx in 0..N_CHROMA {
chroma[[chroma_idx, frame_idx]] /= max_val;
}
}
}
}
}
}
#[inline]
pub fn chromagram(
samples: &NonEmptySlice<f64>,
stft_params: &StftParams,
sample_rate: f64,
chroma_params: &ChromaParams,
) -> SpectrogramResult<Chromagram> {
use crate::{SpectrogramParams, SpectrogramPlanner};
let params = SpectrogramParams::new(stft_params.clone(), sample_rate)?;
let planner = SpectrogramPlanner::new();
let stft_result = planner.compute_stft(samples, ¶ms)?;
let mut magnitude_spec = Array2::<f64>::zeros(stft_result.data.dim());
for ((i, j), val) in stft_result.data.indexed_iter() {
magnitude_spec[[i, j]] = val.norm();
}
chromagram_from_spectrogram(
&magnitude_spec,
sample_rate,
stft_result.params.n_fft(),
chroma_params,
)
}