Skip to main content

oximedia_mir/
streaming.rs

1//! Streaming / incremental MIR analysis for real-time use.
2//!
3//! [`StreamingAnalyzer`] processes audio chunk-by-chunk, maintaining an internal
4//! ring-buffer and emitting updated analysis estimates as data accumulates.
5//! It is designed for latency-sensitive pipelines where the full signal is not
6//! available in advance (e.g. live broadcast, real-time DJ monitoring).
7//!
8//! # Design
9//!
10//! * A fixed-size internal overlap-save buffer is maintained.
11//! * When the buffer has accumulated at least `min_analysis_samples` samples,
12//!   lightweight feature estimates (spectral centroid, ZCR, onset strength)
13//!   are updated.
14//! * Full analysis (tempo, key, chord) is triggered only when
15//!   `full_analysis_samples` have accumulated since the last full run.
16//! * All state is stored as plain `Vec<f32>` — no ndarray.
17
18use crate::utils::{hann_window, magnitude_spectrum, mean};
19use crate::MirResult;
20
21/// Lightweight per-chunk features updated every `min_analysis_samples`.
22#[derive(Debug, Clone, Default)]
23pub struct StreamingFrameFeatures {
24    /// Spectral centroid estimate (normalised 0–1 relative to Nyquist).
25    pub spectral_centroid: f32,
26    /// Zero-crossing rate.
27    pub zero_crossing_rate: f32,
28    /// Onset strength (normalised).
29    pub onset_strength: f32,
30    /// RMS energy level.
31    pub rms_energy: f32,
32    /// Number of audio samples analysed so far.
33    pub samples_processed: usize,
34}
35
36/// Full analysis summary emitted after enough audio has accumulated.
37#[derive(Debug, Clone, Default)]
38pub struct StreamingAnalysisSummary {
39    /// Estimated BPM (0.0 if not yet determined).
40    pub bpm: f32,
41    /// BPM confidence (0–1).
42    pub bpm_confidence: f32,
43    /// Whether the performance appears to be rubato.
44    pub is_rubato: bool,
45    /// Dominant pitch class (0 = C … 11 = B), or 255 if unknown.
46    pub dominant_pitch_class: u8,
47    /// Onset times (seconds) detected so far.
48    pub onset_times: Vec<f32>,
49    /// Per-chunk spectral history (centroid, one value per chunk).
50    pub centroid_history: Vec<f32>,
51    /// Per-chunk RMS history.
52    pub rms_history: Vec<f32>,
53    /// Total duration analysed (seconds).
54    pub duration_secs: f32,
55}
56
57/// Configuration for the streaming analyzer.
58#[derive(Debug, Clone)]
59pub struct StreamingConfig {
60    /// Sample rate in Hz.
61    pub sample_rate: f32,
62    /// Minimum samples to accumulate before computing frame-level features.
63    pub min_analysis_samples: usize,
64    /// Samples to accumulate before running a full tempo/key analysis.
65    pub full_analysis_samples: usize,
66    /// FFT window size for spectral analysis.
67    pub window_size: usize,
68    /// Hop size.
69    pub hop_size: usize,
70    /// Minimum BPM for tempo estimation.
71    pub min_bpm: f32,
72    /// Maximum BPM for tempo estimation.
73    pub max_bpm: f32,
74}
75
76impl Default for StreamingConfig {
77    fn default() -> Self {
78        Self {
79            sample_rate: 44100.0,
80            // ~93 ms chunks at 44.1 kHz
81            min_analysis_samples: 4096,
82            // ~3 seconds worth of audio before full analysis
83            full_analysis_samples: 44100 * 3,
84            window_size: 2048,
85            hop_size: 512,
86            min_bpm: 60.0,
87            max_bpm: 200.0,
88        }
89    }
90}
91
92/// Incremental streaming MIR analyzer.
93///
94/// Call [`StreamingAnalyzer::push_chunk`] repeatedly with new audio blocks.
95/// After each call, retrieve the lightweight [`StreamingFrameFeatures`] via
96/// [`StreamingAnalyzer::frame_features`].  The heavier
97/// [`StreamingAnalysisSummary`] is refreshed lazily via
98/// [`StreamingAnalyzer::summary`] (it only re-runs when enough new audio
99/// has arrived since the last full analysis).
100pub struct StreamingAnalyzer {
101    config: StreamingConfig,
102    /// Internal ring-buffer holding the most recent samples.
103    buffer: Vec<f32>,
104    /// Total samples pushed.
105    total_samples: usize,
106    /// Samples at the time of the last full analysis run.
107    last_full_analysis_at: usize,
108    /// Latest per-frame feature estimates.
109    frame_features: StreamingFrameFeatures,
110    /// Latest full-analysis summary.
111    summary: StreamingAnalysisSummary,
112    /// Previous magnitude spectrum (for onset detection).
113    prev_magnitude: Vec<f32>,
114    /// Accumulated onset samples (for tempo estimation).
115    onset_history: Vec<f32>,
116    /// Centroid history per processed chunk.
117    centroid_history: Vec<f32>,
118    /// RMS history per processed chunk.
119    rms_history: Vec<f32>,
120}
121
122impl StreamingAnalyzer {
123    /// Create a new streaming analyzer with the given configuration.
124    #[must_use]
125    pub fn new(config: StreamingConfig) -> Self {
126        let window_size = config.window_size;
127        Self {
128            config,
129            buffer: Vec::with_capacity(window_size * 4),
130            total_samples: 0,
131            last_full_analysis_at: 0,
132            frame_features: StreamingFrameFeatures::default(),
133            summary: StreamingAnalysisSummary::default(),
134            prev_magnitude: vec![0.0; window_size / 2 + 1],
135            onset_history: Vec::new(),
136            centroid_history: Vec::new(),
137            rms_history: Vec::new(),
138        }
139    }
140
141    /// Create a streaming analyzer with default config for the given sample rate.
142    #[must_use]
143    pub fn with_sample_rate(sample_rate: f32) -> Self {
144        Self::new(StreamingConfig {
145            sample_rate,
146            ..StreamingConfig::default()
147        })
148    }
149
150    /// Push a new chunk of mono audio samples into the analyzer.
151    ///
152    /// Frame-level features are recomputed on every call.  Full analysis is
153    /// triggered automatically when enough samples have accumulated.
154    ///
155    /// # Errors
156    ///
157    /// Returns error if internal analysis fails.
158    pub fn push_chunk(&mut self, chunk: &[f32]) -> MirResult<()> {
159        if chunk.is_empty() {
160            return Ok(());
161        }
162
163        self.buffer.extend_from_slice(chunk);
164        self.total_samples += chunk.len();
165
166        // Keep buffer bounded: retain the most recent `full_analysis_samples`
167        // samples plus one extra window.
168        let max_buffer = self.config.full_analysis_samples + self.config.window_size;
169        if self.buffer.len() > max_buffer {
170            let drop = self.buffer.len() - max_buffer;
171            self.buffer.drain(..drop);
172        }
173
174        // Compute frame-level features on this chunk (even tiny chunks use
175        // the RMS path; spectral path requires at least one full window).
176        self.update_frame_features(chunk)?;
177
178        // Run full analysis when enough new audio has accumulated.
179        let new_samples_since_full = self.total_samples - self.last_full_analysis_at;
180        if new_samples_since_full >= self.config.full_analysis_samples {
181            self.run_full_analysis()?;
182            self.last_full_analysis_at = self.total_samples;
183        }
184
185        Ok(())
186    }
187
188    /// Return the latest lightweight per-chunk features (updated every call to
189    /// `push_chunk`).
190    #[must_use]
191    pub fn frame_features(&self) -> &StreamingFrameFeatures {
192        &self.frame_features
193    }
194
195    /// Return the latest full-analysis summary.
196    ///
197    /// This is recomputed automatically inside `push_chunk` when sufficient
198    /// audio has accumulated.
199    #[must_use]
200    pub fn summary(&self) -> &StreamingAnalysisSummary {
201        &self.summary
202    }
203
204    /// Total number of samples pushed so far.
205    #[must_use]
206    pub fn samples_processed(&self) -> usize {
207        self.total_samples
208    }
209
210    /// Duration analysed so far, in seconds.
211    #[must_use]
212    pub fn duration_secs(&self) -> f32 {
213        self.total_samples as f32 / self.config.sample_rate
214    }
215
216    /// Reset all internal state.
217    pub fn reset(&mut self) {
218        self.buffer.clear();
219        self.total_samples = 0;
220        self.last_full_analysis_at = 0;
221        self.frame_features = StreamingFrameFeatures::default();
222        self.summary = StreamingAnalysisSummary::default();
223        self.prev_magnitude = vec![0.0; self.config.window_size / 2 + 1];
224        self.onset_history.clear();
225        self.centroid_history.clear();
226        self.rms_history.clear();
227    }
228
229    // ── Private helpers ───────────────────────────────────────────────────────
230
231    /// Update lightweight per-chunk features from the latest `chunk`.
232    #[allow(clippy::cast_precision_loss)]
233    fn update_frame_features(&mut self, chunk: &[f32]) -> MirResult<()> {
234        // RMS energy
235        let rms = {
236            let sq_sum: f32 = chunk.iter().map(|&s| s * s).sum();
237            (sq_sum / chunk.len() as f32).sqrt()
238        };
239
240        // Zero-crossing rate
241        let zcr = if chunk.len() >= 2 {
242            let crossings = chunk
243                .windows(2)
244                .filter(|w| (w[0] >= 0.0) != (w[1] >= 0.0))
245                .count();
246            crossings as f32 / chunk.len() as f32
247        } else {
248            0.0
249        };
250
251        // Spectral centroid and onset strength — only when chunk is large enough
252        let (centroid, onset_strength) = if chunk.len() >= self.config.window_size {
253            self.compute_spectral_features(chunk)?
254        } else {
255            // For tiny chunks use previous values
256            (self.frame_features.spectral_centroid, 0.0)
257        };
258
259        self.rms_history.push(rms);
260        self.centroid_history.push(centroid);
261
262        self.frame_features = StreamingFrameFeatures {
263            spectral_centroid: centroid,
264            zero_crossing_rate: zcr,
265            onset_strength,
266            rms_energy: rms,
267            samples_processed: self.total_samples,
268        };
269
270        Ok(())
271    }
272
273    /// Compute spectral centroid and onset strength for a window of audio.
274    #[allow(clippy::cast_precision_loss)]
275    fn compute_spectral_features(&mut self, chunk: &[f32]) -> MirResult<(f32, f32)> {
276        let win = self.config.window_size;
277        let hop = self.config.hop_size;
278
279        // Use the last `win` samples of the chunk (or pad with zeros if too short)
280        let start = if chunk.len() >= win {
281            chunk.len() - win
282        } else {
283            0
284        };
285        let frame_slice = &chunk[start..];
286
287        // Apply Hann window
288        let window = hann_window(win);
289        let windowed: Vec<f32> = frame_slice
290            .iter()
291            .zip(window.iter().take(frame_slice.len()))
292            .map(|(&s, &w)| s * w)
293            .chain(std::iter::repeat(0.0_f32).take(win.saturating_sub(frame_slice.len())))
294            .take(win)
295            .collect();
296
297        let fft_input: Vec<oxifft::Complex<f32>> = windowed
298            .iter()
299            .map(|&s| oxifft::Complex::new(s, 0.0))
300            .collect();
301
302        let spectrum = oxifft::fft(&fft_input);
303        let mag = magnitude_spectrum(&spectrum);
304        let n_bins = mag.len().min(win / 2 + 1);
305
306        let sr = self.config.sample_rate;
307        let freq_per_bin = sr / win as f32;
308
309        // Spectral centroid (normalised to Nyquist)
310        let (weighted_sum, total_mag) = mag[..n_bins]
311            .iter()
312            .enumerate()
313            .fold((0.0_f32, 0.0_f32), |(ws, tm), (k, &m)| {
314                (ws + k as f32 * freq_per_bin * m, tm + m)
315            });
316        let centroid_hz = if total_mag > 1e-9 {
317            weighted_sum / total_mag
318        } else {
319            0.0
320        };
321        let centroid_norm = (centroid_hz / (sr * 0.5)).clamp(0.0, 1.0);
322
323        // Onset strength: sum of positive spectral flux
324        let prev = &self.prev_magnitude;
325        let onset: f32 = mag[..n_bins]
326            .iter()
327            .zip(prev.iter())
328            .map(|(&m, &p)| (m - p).max(0.0))
329            .sum();
330        let onset_norm = (onset / (n_bins as f32)).clamp(0.0, 1.0);
331
332        // Update previous magnitude
333        self.prev_magnitude = mag[..n_bins].to_vec();
334        // Pad to expected length if needed
335        if self.prev_magnitude.len() < win / 2 + 1 {
336            self.prev_magnitude.resize(win / 2 + 1, 0.0);
337        }
338
339        // Accumulate onset for tempo estimation (use onset as scalar per frame)
340        self.onset_history.push(onset_norm);
341
342        // Keep onset history bounded to full_analysis_samples / hop frames
343        let max_frames = self.config.full_analysis_samples / hop + 1;
344        if self.onset_history.len() > max_frames {
345            let drop = self.onset_history.len() - max_frames;
346            self.onset_history.drain(..drop);
347        }
348
349        Ok((centroid_norm, onset_norm))
350    }
351
352    /// Run a full (heavyweight) tempo + chromagram analysis on the buffered audio.
353    #[allow(clippy::cast_precision_loss)]
354    fn run_full_analysis(&mut self) -> MirResult<()> {
355        let sr = self.config.sample_rate;
356        let buf_len = self.buffer.len();
357
358        if buf_len < (sr as usize) {
359            // Not enough audio yet for a meaningful full analysis
360            return Ok(());
361        }
362
363        // ── Tempo from onset autocorrelation ──────────────────────────────
364        let (bpm, bpm_confidence, is_rubato) = self.estimate_tempo()?;
365
366        // ── Dominant pitch class from chromagram ──────────────────────────
367        let dominant_pitch = self.estimate_dominant_pitch();
368
369        // ── Onset times from onset history ────────────────────────────────
370        let hop = self.config.hop_size;
371        let onset_times: Vec<f32> = self
372            .onset_history
373            .iter()
374            .enumerate()
375            .filter(|(_, &v)| v > 0.1)
376            .map(|(i, _)| {
377                // Approximate onset time: buffer offset in seconds
378                let sample_offset = (buf_len as isize
379                    - (self.onset_history.len() as isize - i as isize) * hop as isize)
380                    .max(0) as usize;
381                (self.total_samples.saturating_sub(buf_len) + sample_offset) as f32 / sr
382            })
383            .collect();
384
385        self.summary = StreamingAnalysisSummary {
386            bpm,
387            bpm_confidence,
388            is_rubato,
389            dominant_pitch_class: dominant_pitch,
390            onset_times,
391            centroid_history: self.centroid_history.clone(),
392            rms_history: self.rms_history.clone(),
393            duration_secs: self.total_samples as f32 / sr,
394        };
395
396        Ok(())
397    }
398
399    /// Estimate tempo from the onset envelope using autocorrelation.
400    #[allow(clippy::cast_precision_loss)]
401    fn estimate_tempo(&self) -> MirResult<(f32, f32, bool)> {
402        if self.onset_history.len() < 16 {
403            return Ok((0.0, 0.0, false));
404        }
405
406        let acf = crate::utils::autocorrelation(&self.onset_history)
407            .unwrap_or_else(|_| vec![0.0; self.onset_history.len()]);
408
409        let sr = self.config.sample_rate;
410        let hop = self.config.hop_size as f32;
411        let fps = sr / hop; // frames per second
412
413        // Convert BPM range to lag range in frames
414        let min_lag = ((fps * 60.0 / self.config.max_bpm) as usize).max(1);
415        let max_lag =
416            ((fps * 60.0 / self.config.min_bpm) as usize).min(acf.len().saturating_sub(1));
417
418        if min_lag >= max_lag {
419            return Ok((0.0, 0.0, false));
420        }
421
422        let peaks = crate::utils::find_peaks(&acf[min_lag..max_lag], 3);
423        if peaks.is_empty() {
424            return Ok((0.0, 0.0, false));
425        }
426
427        // Best peak
428        let best_peak = peaks
429            .iter()
430            .copied()
431            .max_by(|&a, &b| {
432                acf[a + min_lag]
433                    .partial_cmp(&acf[b + min_lag])
434                    .unwrap_or(std::cmp::Ordering::Equal)
435            })
436            .unwrap_or(0);
437
438        let lag = best_peak + min_lag;
439        let bpm = fps * 60.0 / lag as f32;
440
441        let max_acf = acf[min_lag..max_lag]
442            .iter()
443            .fold(f32::NEG_INFINITY, |a, &b| a.max(b));
444        let confidence = if max_acf > 0.0 {
445            (acf[lag] / max_acf).clamp(0.0, 1.0)
446        } else {
447            0.0
448        };
449
450        // Stability: measure CV of inter-onset intervals
451        let stability = self.measure_onset_stability(lag);
452        let is_rubato = stability < 0.45;
453
454        Ok((bpm, confidence, is_rubato))
455    }
456
457    /// Measure onset stability as inverse coefficient-of-variation at the detected period.
458    #[allow(clippy::cast_precision_loss)]
459    fn measure_onset_stability(&self, period_frames: usize) -> f32 {
460        if period_frames == 0 || self.onset_history.len() < period_frames * 2 {
461            return 0.5;
462        }
463        let samples: Vec<f32> = (period_frames..self.onset_history.len())
464            .step_by(period_frames)
465            .map(|i| self.onset_history[i])
466            .collect();
467        if samples.is_empty() {
468            return 0.5;
469        }
470        let m = mean(&samples);
471        if m < 1e-9 {
472            return 0.5;
473        }
474        let variance: f32 =
475            samples.iter().map(|v| (v - m).powi(2)).sum::<f32>() / samples.len() as f32;
476        let cv = variance.sqrt() / m;
477        (1.0 - cv.min(1.0)).clamp(0.0, 1.0)
478    }
479
480    /// Estimate dominant pitch class from the buffered audio chromagram.
481    #[allow(clippy::cast_precision_loss)]
482    fn estimate_dominant_pitch(&self) -> u8 {
483        if self.buffer.len() < self.config.window_size {
484            return 255;
485        }
486
487        // Use at most the last 2 × full_analysis_samples worth of audio
488        let buf = &self.buffer;
489        let win = self.config.window_size;
490        let hop = self.config.hop_size;
491        let sr = self.config.sample_rate as f64;
492
493        // Accumulate chroma bins
494        let mut chroma = [0.0_f64; 12];
495        let n_frames = (buf.len().saturating_sub(win)) / hop + 1;
496
497        for frame_idx in 0..n_frames {
498            let start = frame_idx * hop;
499            let end = start + win;
500            if end > buf.len() {
501                break;
502            }
503            let frame = &buf[start..end];
504
505            for k in 1..(win / 2) {
506                let freq = k as f64 * sr / win as f64;
507                if !(65.0..=2093.0).contains(&freq) {
508                    continue;
509                }
510                // Goertzel magnitude estimate
511                let omega = 2.0 * std::f64::consts::PI * k as f64 / win as f64;
512                let coeff = 2.0 * omega.cos();
513                let (mut s1, mut s2) = (0.0_f64, 0.0_f64);
514                for &sample in frame {
515                    let s0 = f64::from(sample) + coeff * s1 - s2;
516                    s2 = s1;
517                    s1 = s0;
518                }
519                let mag = (s1 * s1 + s2 * s2 - coeff * s1 * s2).abs().sqrt();
520
521                // Map to chroma bin
522                let midi = 12.0 * (freq / 440.0).log2() + 69.0;
523                let pc = (midi.round() as i64).rem_euclid(12) as usize;
524                chroma[pc] += mag;
525            }
526        }
527
528        // Find dominant pitch class
529        chroma
530            .iter()
531            .enumerate()
532            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
533            .map_or(255, |(i, _)| i as u8)
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540    use std::f32::consts::TAU;
541
542    fn make_sine(freq: f32, sr: f32, seconds: f32) -> Vec<f32> {
543        let n = (sr * seconds) as usize;
544        (0..n).map(|i| (TAU * freq * i as f32 / sr).sin()).collect()
545    }
546
547    #[test]
548    fn test_streaming_analyzer_default() {
549        let analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
550        assert_eq!(analyzer.samples_processed(), 0);
551        assert!((analyzer.duration_secs() - 0.0).abs() < f32::EPSILON);
552    }
553
554    #[test]
555    fn test_push_empty_chunk() {
556        let mut analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
557        let result = analyzer.push_chunk(&[]);
558        assert!(result.is_ok());
559        assert_eq!(analyzer.samples_processed(), 0);
560    }
561
562    #[test]
563    fn test_push_small_chunk_accumulates() {
564        let mut analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
565        let chunk = vec![0.0f32; 512];
566        let result = analyzer.push_chunk(&chunk);
567        assert!(result.is_ok());
568        assert_eq!(analyzer.samples_processed(), 512);
569    }
570
571    #[test]
572    fn test_push_large_chunk_updates_features() {
573        let mut analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
574        let sine = make_sine(440.0, 44100.0, 1.0);
575        let result = analyzer.push_chunk(&sine);
576        assert!(result.is_ok());
577        assert_eq!(analyzer.samples_processed(), 44100);
578        // After a full second of sine at 440 Hz, centroid should be non-zero
579        assert!(analyzer.frame_features().spectral_centroid > 0.0);
580    }
581
582    #[test]
583    fn test_reset_clears_state() {
584        let mut analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
585        let sine = make_sine(440.0, 44100.0, 0.1);
586        let _ = analyzer.push_chunk(&sine);
587        assert!(analyzer.samples_processed() > 0);
588        analyzer.reset();
589        assert_eq!(analyzer.samples_processed(), 0);
590        assert_eq!(analyzer.frame_features().samples_processed, 0);
591    }
592
593    #[test]
594    fn test_streaming_multiple_chunks() {
595        let mut analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
596        let chunk_size = 4096_usize;
597        // Push 20 chunks × 4096 samples = ~80k samples of sine
598        let sine = make_sine(220.0, 44100.0, 8.0);
599        let mut offset = 0;
600        while offset + chunk_size <= sine.len() {
601            analyzer
602                .push_chunk(&sine[offset..offset + chunk_size])
603                .expect("push failed");
604            offset += chunk_size;
605        }
606        assert!(analyzer.samples_processed() >= 20 * chunk_size);
607        // Centroid history should have accumulated entries
608        assert!(!analyzer.summary().centroid_history.is_empty());
609    }
610
611    #[test]
612    fn test_full_analysis_triggers_on_threshold() {
613        let config = StreamingConfig {
614            sample_rate: 44100.0,
615            // Require only 2 seconds of audio before full analysis
616            full_analysis_samples: 44100 * 2,
617            min_analysis_samples: 4096,
618            window_size: 2048,
619            hop_size: 512,
620            min_bpm: 60.0,
621            max_bpm: 200.0,
622        };
623        let mut analyzer = StreamingAnalyzer::new(config);
624        let sine = make_sine(440.0, 44100.0, 3.0);
625        analyzer.push_chunk(&sine).expect("push failed");
626        // After 3 s the full analysis should have run and duration is set
627        assert!(analyzer.summary().duration_secs > 0.0);
628    }
629
630    #[test]
631    fn test_zcr_silent_signal() {
632        let mut analyzer = StreamingAnalyzer::with_sample_rate(44100.0);
633        let silence = vec![0.0f32; 8192];
634        analyzer.push_chunk(&silence).expect("push failed");
635        // ZCR for DC silence is 0
636        assert!((analyzer.frame_features().zero_crossing_rate - 0.0).abs() < 1e-4);
637    }
638}