Skip to main content

oximedia_align/
tempo_align.rs

1#![allow(dead_code)]
2//! Tempo-based audio alignment for music synchronization.
3//!
4//! This module aligns audio streams by detecting and matching musical tempo,
5//! beat positions, and rhythmic structures. It is particularly useful for
6//! aligning multiple recordings of the same musical performance.
7
8/// Tempo detection configuration.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct TempoConfig {
11    /// Sample rate of the audio signal in Hz.
12    pub sample_rate: u32,
13    /// Minimum detectable BPM.
14    pub min_bpm: f64,
15    /// Maximum detectable BPM.
16    pub max_bpm: f64,
17    /// Analysis hop size in samples.
18    pub hop_size: usize,
19    /// Number of onset frames to accumulate before estimation.
20    pub accumulation_frames: usize,
21}
22
23impl Default for TempoConfig {
24    fn default() -> Self {
25        Self {
26            sample_rate: 44100,
27            min_bpm: 40.0,
28            max_bpm: 240.0,
29            hop_size: 512,
30            accumulation_frames: 256,
31        }
32    }
33}
34
35/// A detected beat position in an audio stream.
36#[derive(Debug, Clone, Copy, PartialEq)]
37pub struct BeatPosition {
38    /// Time of the beat in seconds.
39    pub time_secs: f64,
40    /// Strength of the detected beat (0.0..1.0).
41    pub strength: f64,
42    /// Beat index (sequential, starting from 0).
43    pub index: u32,
44}
45
46impl BeatPosition {
47    /// Create a new beat position.
48    #[must_use]
49    pub fn new(time_secs: f64, strength: f64, index: u32) -> Self {
50        Self {
51            time_secs,
52            strength: strength.clamp(0.0, 1.0),
53            index,
54        }
55    }
56
57    /// Compute the interval to the next beat.
58    #[must_use]
59    pub fn interval_to(&self, next: &Self) -> f64 {
60        next.time_secs - self.time_secs
61    }
62}
63
64/// Result of tempo estimation on an audio segment.
65#[derive(Debug, Clone, PartialEq)]
66pub struct TempoEstimate {
67    /// Estimated tempo in BPM.
68    pub bpm: f64,
69    /// Confidence of the estimate (0.0..1.0).
70    pub confidence: f64,
71    /// Detected beat positions.
72    pub beats: Vec<BeatPosition>,
73    /// Alternative tempo candidates (e.g., half/double time).
74    pub alternatives: Vec<f64>,
75}
76
77impl TempoEstimate {
78    /// Create a new tempo estimate.
79    #[must_use]
80    pub fn new(bpm: f64, confidence: f64) -> Self {
81        Self {
82            bpm,
83            confidence: confidence.clamp(0.0, 1.0),
84            beats: Vec::new(),
85            alternatives: Vec::new(),
86        }
87    }
88
89    /// Compute the beat period in seconds from the detected BPM.
90    #[must_use]
91    pub fn beat_period_secs(&self) -> f64 {
92        if self.bpm > 0.0 {
93            60.0 / self.bpm
94        } else {
95            0.0
96        }
97    }
98
99    /// Compute the mean inter-beat interval from detected beats.
100    #[allow(clippy::cast_precision_loss)]
101    #[must_use]
102    pub fn mean_ibi(&self) -> f64 {
103        if self.beats.len() < 2 {
104            return 0.0;
105        }
106        let total: f64 = self
107            .beats
108            .windows(2)
109            .map(|w| w[1].time_secs - w[0].time_secs)
110            .sum();
111        total / (self.beats.len() - 1) as f64
112    }
113
114    /// Check whether this tempo is harmonically related to another tempo.
115    #[must_use]
116    pub fn is_harmonic_of(&self, other_bpm: f64) -> bool {
117        if other_bpm <= 0.0 || self.bpm <= 0.0 {
118            return false;
119        }
120        let ratio = self.bpm / other_bpm;
121        let rounded = ratio.round();
122        if rounded < 1.0 {
123            return false;
124        }
125        (ratio - rounded).abs() < 0.05
126    }
127}
128
129/// Onset detection function type.
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum OnsetFunction {
132    /// Energy-based onset detection.
133    Energy,
134    /// Spectral flux onset detection.
135    SpectralFlux,
136    /// High-frequency content onset detection.
137    HighFrequencyContent,
138    /// Complex domain onset detection.
139    ComplexDomain,
140}
141
142/// Onset envelope analyzer for beat tracking.
143#[derive(Debug)]
144pub struct OnsetAnalyzer {
145    /// Configuration.
146    config: TempoConfig,
147    /// Type of onset function to use.
148    onset_fn: OnsetFunction,
149    /// Accumulated onset envelope.
150    envelope: Vec<f64>,
151}
152
153impl OnsetAnalyzer {
154    /// Create a new onset analyzer.
155    #[must_use]
156    pub fn new(config: TempoConfig, onset_fn: OnsetFunction) -> Self {
157        Self {
158            config,
159            onset_fn,
160            envelope: Vec::new(),
161        }
162    }
163
164    /// Compute the onset envelope from audio samples.
165    #[allow(clippy::cast_precision_loss)]
166    pub fn compute_envelope(&mut self, samples: &[f32]) {
167        self.envelope.clear();
168        if samples.is_empty() || self.config.hop_size == 0 {
169            return;
170        }
171        let hop = self.config.hop_size;
172        let num_frames = samples.len() / hop;
173
174        for i in 0..num_frames {
175            let start = i * hop;
176            let end = (start + hop).min(samples.len());
177            let frame = &samples[start..end];
178
179            let value = match self.onset_fn {
180                OnsetFunction::Energy => {
181                    frame.iter().map(|&s| f64::from(s).powi(2)).sum::<f64>() / frame.len() as f64
182                }
183                OnsetFunction::SpectralFlux => {
184                    // Simplified: use absolute differences between consecutive samples
185                    if frame.len() < 2 {
186                        0.0
187                    } else {
188                        frame
189                            .windows(2)
190                            .map(|w| f64::from(w[1] - w[0]).abs())
191                            .sum::<f64>()
192                            / (frame.len() - 1) as f64
193                    }
194                }
195                OnsetFunction::HighFrequencyContent => {
196                    // Simplified: weight by sample position within frame
197                    frame
198                        .iter()
199                        .enumerate()
200                        .map(|(j, &s)| (j as f64 + 1.0) * f64::from(s).abs())
201                        .sum::<f64>()
202                        / frame.len() as f64
203                }
204                OnsetFunction::ComplexDomain => {
205                    // Simplified: combination of energy and flux
206                    let energy: f64 = frame.iter().map(|&s| f64::from(s).powi(2)).sum::<f64>()
207                        / frame.len() as f64;
208                    let flux: f64 = if frame.len() < 2 {
209                        0.0
210                    } else {
211                        frame
212                            .windows(2)
213                            .map(|w| f64::from(w[1] - w[0]).abs())
214                            .sum::<f64>()
215                            / (frame.len() - 1) as f64
216                    };
217                    (energy + flux) / 2.0
218                }
219            };
220            self.envelope.push(value);
221        }
222    }
223
224    /// Return a reference to the computed onset envelope.
225    #[must_use]
226    pub fn envelope(&self) -> &[f64] {
227        &self.envelope
228    }
229
230    /// Pick peaks in the onset envelope above a threshold.
231    #[must_use]
232    pub fn pick_peaks(&self, threshold: f64) -> Vec<usize> {
233        let mut peaks = Vec::new();
234        if self.envelope.len() < 3 {
235            return peaks;
236        }
237        for i in 1..self.envelope.len() - 1 {
238            if self.envelope[i] > threshold
239                && self.envelope[i] > self.envelope[i - 1]
240                && self.envelope[i] >= self.envelope[i + 1]
241            {
242                peaks.push(i);
243            }
244        }
245        peaks
246    }
247}
248
249/// Tempo-based alignment result between two audio streams.
250#[derive(Debug, Clone, PartialEq)]
251pub struct TempoAlignResult {
252    /// Estimated offset in seconds (stream B relative to stream A).
253    pub offset_secs: f64,
254    /// Tempo of stream A in BPM.
255    pub tempo_a: f64,
256    /// Tempo of stream B in BPM.
257    pub tempo_b: f64,
258    /// Confidence of the alignment (0.0..1.0).
259    pub confidence: f64,
260    /// Number of matched beat pairs.
261    pub matched_beats: usize,
262}
263
264impl TempoAlignResult {
265    /// Create a new tempo alignment result.
266    #[must_use]
267    pub fn new(
268        offset_secs: f64,
269        tempo_a: f64,
270        tempo_b: f64,
271        confidence: f64,
272        matched_beats: usize,
273    ) -> Self {
274        Self {
275            offset_secs,
276            tempo_a,
277            tempo_b,
278            confidence: confidence.clamp(0.0, 1.0),
279            matched_beats,
280        }
281    }
282
283    /// Return the tempo ratio between the two streams.
284    #[must_use]
285    pub fn tempo_ratio(&self) -> f64 {
286        if self.tempo_b > 0.0 {
287            self.tempo_a / self.tempo_b
288        } else {
289            0.0
290        }
291    }
292
293    /// Check whether the two tempos are approximately equal.
294    #[must_use]
295    pub fn tempos_match(&self, tolerance_bpm: f64) -> bool {
296        (self.tempo_a - self.tempo_b).abs() < tolerance_bpm
297    }
298}
299
300/// Align two sets of beat positions by finding the best offset.
301#[allow(clippy::cast_precision_loss)]
302#[must_use]
303pub fn align_beats(
304    beats_a: &[BeatPosition],
305    beats_b: &[BeatPosition],
306    tolerance_secs: f64,
307) -> TempoAlignResult {
308    if beats_a.is_empty() || beats_b.is_empty() {
309        return TempoAlignResult::new(0.0, 0.0, 0.0, 0.0, 0);
310    }
311
312    // Estimate tempos from beat intervals
313    let tempo_a = estimate_bpm_from_beats(beats_a);
314    let tempo_b = estimate_bpm_from_beats(beats_b);
315
316    // Try each possible offset by pairing first beats
317    let mut best_offset = 0.0;
318    let mut best_count = 0_usize;
319
320    for a_beat in beats_a.iter().take(beats_a.len().min(8)) {
321        for b_beat in beats_b.iter().take(beats_b.len().min(8)) {
322            let candidate_offset = a_beat.time_secs - b_beat.time_secs;
323            let count = count_matched_beats(beats_a, beats_b, candidate_offset, tolerance_secs);
324            if count > best_count {
325                best_count = count;
326                best_offset = candidate_offset;
327            }
328        }
329    }
330
331    let max_possible = beats_a.len().min(beats_b.len());
332    let confidence = if max_possible > 0 {
333        (best_count as f64 / max_possible as f64).clamp(0.0, 1.0)
334    } else {
335        0.0
336    };
337
338    TempoAlignResult::new(best_offset, tempo_a, tempo_b, confidence, best_count)
339}
340
341/// Count how many beats match between two beat sequences given an offset.
342fn count_matched_beats(
343    beats_a: &[BeatPosition],
344    beats_b: &[BeatPosition],
345    offset_secs: f64,
346    tolerance_secs: f64,
347) -> usize {
348    let mut count = 0;
349    for a in beats_a {
350        let shifted = a.time_secs - offset_secs;
351        for b in beats_b {
352            if (shifted - b.time_secs).abs() < tolerance_secs {
353                count += 1;
354                break;
355            }
356        }
357    }
358    count
359}
360
361/// Estimate BPM from a series of beat positions.
362#[allow(clippy::cast_precision_loss)]
363fn estimate_bpm_from_beats(beats: &[BeatPosition]) -> f64 {
364    if beats.len() < 2 {
365        return 0.0;
366    }
367    let total_time = beats
368        .last()
369        .expect("beats non-empty: len < 2 check returned above")
370        .time_secs
371        - beats
372            .first()
373            .expect("beats non-empty: len < 2 check returned above")
374            .time_secs;
375    if total_time <= 0.0 {
376        return 0.0;
377    }
378    let intervals = (beats.len() - 1) as f64;
379    let avg_interval = total_time / intervals;
380    if avg_interval > 0.0 {
381        60.0 / avg_interval
382    } else {
383        0.0
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_tempo_config_default() {
393        let cfg = TempoConfig::default();
394        assert_eq!(cfg.sample_rate, 44100);
395        assert!((cfg.min_bpm - 40.0).abs() < f64::EPSILON);
396        assert!((cfg.max_bpm - 240.0).abs() < f64::EPSILON);
397    }
398
399    #[test]
400    fn test_beat_position_interval() {
401        let a = BeatPosition::new(1.0, 0.9, 0);
402        let b = BeatPosition::new(1.5, 0.8, 1);
403        assert!((a.interval_to(&b) - 0.5).abs() < 1e-10);
404    }
405
406    #[test]
407    fn test_beat_position_strength_clamped() {
408        let bp = BeatPosition::new(0.0, 2.0, 0);
409        assert!((bp.strength - 1.0).abs() < f64::EPSILON);
410    }
411
412    #[test]
413    fn test_tempo_estimate_beat_period() {
414        let te = TempoEstimate::new(120.0, 0.9);
415        assert!((te.beat_period_secs() - 0.5).abs() < 1e-10);
416    }
417
418    #[test]
419    fn test_tempo_estimate_zero_bpm() {
420        let te = TempoEstimate::new(0.0, 0.0);
421        assert!((te.beat_period_secs()).abs() < f64::EPSILON);
422    }
423
424    #[test]
425    fn test_tempo_estimate_mean_ibi() {
426        let mut te = TempoEstimate::new(120.0, 0.9);
427        te.beats.push(BeatPosition::new(0.0, 1.0, 0));
428        te.beats.push(BeatPosition::new(0.5, 1.0, 1));
429        te.beats.push(BeatPosition::new(1.0, 1.0, 2));
430        assert!((te.mean_ibi() - 0.5).abs() < 1e-10);
431    }
432
433    #[test]
434    fn test_tempo_estimate_mean_ibi_single() {
435        let mut te = TempoEstimate::new(120.0, 0.9);
436        te.beats.push(BeatPosition::new(0.0, 1.0, 0));
437        assert!((te.mean_ibi()).abs() < f64::EPSILON);
438    }
439
440    #[test]
441    fn test_is_harmonic_double_time() {
442        let te = TempoEstimate::new(120.0, 0.9);
443        assert!(te.is_harmonic_of(60.0));
444        assert!(te.is_harmonic_of(120.0));
445    }
446
447    #[test]
448    fn test_is_harmonic_not_related() {
449        let te = TempoEstimate::new(120.0, 0.9);
450        assert!(!te.is_harmonic_of(73.0));
451    }
452
453    #[test]
454    fn test_onset_analyzer_energy() {
455        let config = TempoConfig {
456            sample_rate: 44100,
457            hop_size: 4,
458            ..TempoConfig::default()
459        };
460        let mut analyzer = OnsetAnalyzer::new(config, OnsetFunction::Energy);
461        let samples = vec![0.5_f32, 0.3, 0.1, 0.0, 0.8, 0.6, 0.4, 0.2];
462        analyzer.compute_envelope(&samples);
463        assert_eq!(analyzer.envelope().len(), 2);
464        assert!(analyzer.envelope()[0] > 0.0);
465    }
466
467    #[test]
468    fn test_onset_analyzer_empty() {
469        let config = TempoConfig::default();
470        let mut analyzer = OnsetAnalyzer::new(config, OnsetFunction::Energy);
471        analyzer.compute_envelope(&[]);
472        assert!(analyzer.envelope().is_empty());
473    }
474
475    #[test]
476    fn test_pick_peaks() {
477        let config = TempoConfig {
478            sample_rate: 44100,
479            hop_size: 1,
480            ..TempoConfig::default()
481        };
482        let mut analyzer = OnsetAnalyzer::new(config, OnsetFunction::Energy);
483        // Manually set envelope with a clear peak
484        let samples: Vec<f32> = vec![0.0, 0.1, 0.5, 0.9, 0.5, 0.1, 0.0];
485        analyzer.compute_envelope(&samples);
486        let peaks = analyzer.pick_peaks(0.01);
487        assert!(!peaks.is_empty());
488    }
489
490    #[test]
491    fn test_align_beats_exact_match() {
492        let beats_a: Vec<BeatPosition> = (0..4)
493            .map(|i| BeatPosition::new(i as f64 * 0.5, 1.0, i))
494            .collect();
495        let beats_b: Vec<BeatPosition> = (0..4)
496            .map(|i| BeatPosition::new(i as f64 * 0.5, 1.0, i))
497            .collect();
498        let result = align_beats(&beats_a, &beats_b, 0.05);
499        assert!(result.offset_secs.abs() < 0.06);
500        assert!(result.matched_beats >= 3);
501    }
502
503    #[test]
504    fn test_align_beats_with_offset() {
505        let beats_a: Vec<BeatPosition> = (0..4)
506            .map(|i| BeatPosition::new(i as f64 * 0.5 + 1.0, 1.0, i))
507            .collect();
508        let beats_b: Vec<BeatPosition> = (0..4)
509            .map(|i| BeatPosition::new(i as f64 * 0.5, 1.0, i))
510            .collect();
511        let result = align_beats(&beats_a, &beats_b, 0.05);
512        assert!((result.offset_secs - 1.0).abs() < 0.1);
513    }
514
515    #[test]
516    fn test_align_beats_empty() {
517        let result = align_beats(&[], &[], 0.05);
518        assert_eq!(result.matched_beats, 0);
519        assert!((result.confidence).abs() < f64::EPSILON);
520    }
521
522    #[test]
523    fn test_tempo_align_result_ratio() {
524        let r = TempoAlignResult::new(0.0, 120.0, 60.0, 0.9, 8);
525        assert!((r.tempo_ratio() - 2.0).abs() < 1e-10);
526    }
527
528    #[test]
529    fn test_tempo_align_result_match() {
530        let r = TempoAlignResult::new(0.0, 120.0, 120.5, 0.9, 8);
531        assert!(r.tempos_match(1.0));
532        assert!(!r.tempos_match(0.1));
533    }
534}