use crate::chromagram::{ChromaVector, ChromagramAnalyzer, ChromagramConfig};
#[derive(Debug, Clone)]
pub struct CachedChromagram {
data: Vec<f64>,
n_frames: usize,
pub sample_rate: f32,
pub window_size: usize,
pub hop_size: usize,
}
impl CachedChromagram {
#[must_use]
pub fn from_chroma_vectors(
vectors: &[ChromaVector],
sample_rate: f32,
window_size: usize,
hop_size: usize,
) -> Self {
let n_frames = vectors.len();
let mut data = vec![0.0_f64; n_frames * 12];
for (i, cv) in vectors.iter().enumerate() {
for j in 0..12 {
data[i * 12 + j] = cv.bins[j];
}
}
Self {
data,
n_frames,
sample_rate,
window_size,
hop_size,
}
}
#[must_use]
pub fn n_frames(&self) -> usize {
self.n_frames
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.n_frames == 0
}
#[must_use]
pub fn frame(&self, frame_idx: usize) -> Option<[f64; 12]> {
if frame_idx >= self.n_frames {
return None;
}
let base = frame_idx * 12;
let mut bins = [0.0_f64; 12];
bins.copy_from_slice(&self.data[base..base + 12]);
Some(bins)
}
#[must_use]
pub fn mean_chroma(&self) -> [f64; 12] {
if self.n_frames == 0 {
return [0.0; 12];
}
let mut sum = [0.0_f64; 12];
for frame in 0..self.n_frames {
let base = frame * 12;
for j in 0..12 {
sum[j] += self.data[base + j];
}
}
let n = self.n_frames as f64;
sum.iter_mut().for_each(|v| *v /= n);
sum
}
#[must_use]
pub fn frame_energy(&self, frame_idx: usize) -> f64 {
if frame_idx >= self.n_frames {
return 0.0;
}
let base = frame_idx * 12;
self.data[base..base + 12].iter().sum()
}
#[must_use]
pub fn frames_range(&self, start_frame: usize, n: usize) -> Vec<[f64; 12]> {
let end = (start_frame + n).min(self.n_frames);
(start_frame..end).filter_map(|i| self.frame(i)).collect()
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn aggregate_window(&self, start_secs: f32, end_secs: f32) -> [f64; 12] {
let hop_secs = self.hop_size as f32 / self.sample_rate;
let mut sum = [0.0_f64; 12];
let mut count = 0_usize;
for frame in 0..self.n_frames {
let frame_time = frame as f32 * hop_secs;
if frame_time >= start_secs && frame_time < end_secs {
let base = frame * 12;
for j in 0..12 {
sum[j] += self.data[base + j];
}
count += 1;
}
}
if count > 0 {
let total: f64 = sum.iter().sum();
if total > 1e-12 {
sum.iter_mut().for_each(|v| *v /= total);
}
}
sum
}
}
pub struct ChromaCache {
config: ChromagramConfig,
cache: Option<CachedChromagram>,
}
impl ChromaCache {
#[must_use]
pub fn new(config: ChromagramConfig) -> Self {
Self {
config,
cache: None,
}
}
#[must_use]
pub fn with_sample_rate(sample_rate: f32) -> Self {
Self::new(ChromagramConfig {
sample_rate,
..ChromagramConfig::default()
})
}
pub fn get(&mut self, samples: &[f32]) -> &CachedChromagram {
if self.cache.is_none() {
let analyzer = ChromagramAnalyzer::new(self.config.clone());
let vectors = analyzer.compute(samples);
let cached = CachedChromagram::from_chroma_vectors(
&vectors,
self.config.sample_rate,
self.config.window_size,
self.config.hop_size,
);
self.cache = Some(cached);
}
if let Some(ref c) = self.cache {
c
} else {
unreachable!("cache was just populated in the branch above")
}
}
#[must_use]
pub fn is_populated(&self) -> bool {
self.cache.is_some()
}
pub fn invalidate(&mut self) {
self.cache = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f32::consts::TAU;
fn make_sine(freq: f32, sr: f32, seconds: f32) -> Vec<f32> {
let n = (sr * seconds) as usize;
(0..n).map(|i| (TAU * freq * i as f32 / sr).sin()).collect()
}
#[test]
fn test_cached_chromagram_empty() {
let cache = CachedChromagram::from_chroma_vectors(&[], 44100.0, 4096, 512);
assert!(cache.is_empty());
assert_eq!(cache.n_frames(), 0);
}
#[test]
fn test_cached_chromagram_frame_access() {
let cv = ChromaVector {
bins: [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
};
let cache = CachedChromagram::from_chroma_vectors(&[cv], 44100.0, 4096, 512);
let frame = cache.frame(0).expect("frame 0 must exist");
assert!((frame[0] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_cached_chromagram_out_of_range() {
let cv = ChromaVector { bins: [0.0; 12] };
let cache = CachedChromagram::from_chroma_vectors(&[cv], 44100.0, 4096, 512);
assert!(cache.frame(1).is_none());
}
#[test]
fn test_mean_chroma_all_equal() {
let cv = ChromaVector { bins: [2.0; 12] };
let cache =
CachedChromagram::from_chroma_vectors(&[cv.clone(), cv.clone()], 44100.0, 4096, 512);
let mean = cache.mean_chroma();
for &v in &mean {
assert!((v - 2.0).abs() < 1e-9);
}
}
#[test]
fn test_frame_energy() {
let mut cv = ChromaVector { bins: [0.0; 12] };
cv.bins[0] = 3.0;
cv.bins[6] = 2.0;
let cache = CachedChromagram::from_chroma_vectors(&[cv], 44100.0, 4096, 512);
let energy = cache.frame_energy(0);
assert!((energy - 5.0).abs() < 1e-9);
}
#[test]
fn test_frames_range() {
let frames: Vec<ChromaVector> = (0..5)
.map(|i| {
let mut cv = ChromaVector { bins: [0.0; 12] };
cv.bins[0] = i as f64;
cv
})
.collect();
let cache = CachedChromagram::from_chroma_vectors(&frames, 44100.0, 4096, 512);
let range = cache.frames_range(1, 3);
assert_eq!(range.len(), 3);
assert!((range[0][0] - 1.0).abs() < 1e-9);
assert!((range[2][0] - 3.0).abs() < 1e-9);
}
#[test]
fn test_chroma_cache_lazy_computation() {
let mut chroma_cache = ChromaCache::with_sample_rate(44100.0);
assert!(!chroma_cache.is_populated());
let signal = make_sine(440.0, 44100.0, 0.5);
let _ = chroma_cache.get(&signal);
assert!(chroma_cache.is_populated());
let result1_n = chroma_cache.get(&signal).n_frames();
let result2_n = chroma_cache.get(&signal).n_frames();
assert_eq!(result1_n, result2_n);
}
#[test]
fn test_chroma_cache_invalidate() {
let mut chroma_cache = ChromaCache::with_sample_rate(44100.0);
let signal = make_sine(440.0, 44100.0, 0.5);
let _ = chroma_cache.get(&signal);
assert!(chroma_cache.is_populated());
chroma_cache.invalidate();
assert!(!chroma_cache.is_populated());
}
#[test]
fn test_aggregate_window_silence() {
let cv = ChromaVector { bins: [0.0; 12] };
let frames = vec![cv; 10];
let cache = CachedChromagram::from_chroma_vectors(&frames, 44100.0, 512, 512);
let agg = cache.aggregate_window(0.0, 1.0);
let total: f64 = agg.iter().sum();
assert!(total < 1e-12);
}
}