Skip to main content

oximedia_align/
beat_align.rs

1//! Beat-grid alignment for music and rhythmic media synchronisation.
2//!
3//! Detects downbeats in an audio signal and aligns it to a target beat grid,
4//! producing a sample-accurate offset.
5
6#![allow(dead_code)]
7
8/// A regular beat grid defined by a tempo.
9#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct BeatGrid {
11    /// Tempo in beats per minute.
12    pub bpm: f64,
13    /// Phase offset of the first beat in milliseconds from the stream start.
14    pub phase_offset_ms: f64,
15}
16
17impl BeatGrid {
18    /// Creates a new beat grid at the given BPM with zero phase offset.
19    #[must_use]
20    pub fn new(bpm: f64) -> Self {
21        Self {
22            bpm,
23            phase_offset_ms: 0.0,
24        }
25    }
26
27    /// Creates a beat grid with an explicit phase offset.
28    #[must_use]
29    pub fn with_phase(bpm: f64, phase_offset_ms: f64) -> Self {
30        Self {
31            bpm,
32            phase_offset_ms,
33        }
34    }
35
36    /// Returns the interval between consecutive beats in milliseconds.
37    #[must_use]
38    pub fn interval_ms(&self) -> f64 {
39        if self.bpm <= 0.0 {
40            f64::INFINITY
41        } else {
42            60_000.0 / self.bpm
43        }
44    }
45
46    /// Returns the timestamp (ms from stream start) of the n-th beat (0-indexed).
47    #[must_use]
48    pub fn beat_time_ms(&self, beat_index: u32) -> f64 {
49        self.phase_offset_ms + f64::from(beat_index) * self.interval_ms()
50    }
51
52    /// Returns the nearest beat index for a given timestamp (ms).
53    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
54    #[must_use]
55    pub fn nearest_beat(&self, time_ms: f64) -> u32 {
56        if self.bpm <= 0.0 {
57            return 0;
58        }
59        let offset = time_ms - self.phase_offset_ms;
60        let beat_f = offset / self.interval_ms();
61        beat_f.round().max(0.0) as u32
62    }
63}
64
65/// Configuration for the beat-alignment algorithm.
66#[derive(Debug, Clone)]
67pub struct BeatAlignConfig {
68    /// Target beat grid to align to.
69    pub grid: BeatGrid,
70    /// Maximum allowed alignment error before it is considered a non-match (ms).
71    pub tolerance: f64,
72    /// Sample rate of the input audio.
73    pub sample_rate: u32,
74}
75
76impl BeatAlignConfig {
77    /// Creates a new config.
78    #[must_use]
79    pub fn new(grid: BeatGrid, sample_rate: u32) -> Self {
80        Self {
81            grid,
82            tolerance: 20.0,
83            sample_rate,
84        }
85    }
86
87    /// Returns the alignment tolerance in milliseconds.
88    #[must_use]
89    pub fn tolerance_ms(&self) -> f64 {
90        self.tolerance
91    }
92}
93
94/// Result of a beat-alignment operation.
95#[derive(Debug, Clone, Copy)]
96pub struct BeatAlignResult {
97    /// Time offset that should be applied to the signal to align it (ms).
98    pub offset: f64,
99    /// Confidence that the downbeat was correctly detected (0.0–1.0).
100    pub confidence: f64,
101    /// Beat index within the target grid that the detected downbeat maps to.
102    pub matched_beat_index: u32,
103}
104
105impl BeatAlignResult {
106    /// Returns the offset in milliseconds.
107    #[must_use]
108    pub fn offset_ms(&self) -> f64 {
109        self.offset
110    }
111}
112
113/// Performs beat-grid alignment on an audio signal.
114#[derive(Debug)]
115pub struct BeatAligner {
116    config: BeatAlignConfig,
117}
118
119impl BeatAligner {
120    /// Creates a new aligner with the given configuration.
121    #[must_use]
122    pub fn new(config: BeatAlignConfig) -> Self {
123        Self { config }
124    }
125
126    /// Returns a reference to the current configuration.
127    #[must_use]
128    pub fn config(&self) -> &BeatAlignConfig {
129        &self.config
130    }
131
132    /// Detects the approximate position of the first downbeat in `samples`.
133    ///
134    /// Uses a simple energy-onset heuristic: returns the sample index of the
135    /// frame with the highest short-window RMS energy.
136    ///
137    /// Returns `None` when the signal is empty.
138    #[allow(clippy::cast_precision_loss)]
139    #[must_use]
140    pub fn detect_downbeat(&self, samples: &[f32]) -> Option<usize> {
141        if samples.is_empty() {
142            return None;
143        }
144        let window = (self.config.sample_rate / 100) as usize; // 10 ms window
145        let window = window.max(1);
146        let mut best_idx = 0usize;
147        let mut best_rms = 0.0f64;
148
149        let mut i = 0usize;
150        while i + window <= samples.len() {
151            let rms: f64 = samples[i..i + window]
152                .iter()
153                .map(|&s| f64::from(s) * f64::from(s))
154                .sum::<f64>()
155                / window as f64;
156            if rms > best_rms {
157                best_rms = rms;
158                best_idx = i;
159            }
160            i += window;
161        }
162        Some(best_idx)
163    }
164
165    /// Aligns `samples` to the configured beat grid.
166    ///
167    /// Returns `None` when no reliable downbeat is found or the confidence is
168    /// below an acceptable threshold.
169    #[allow(clippy::cast_precision_loss)]
170    #[must_use]
171    pub fn align_to_grid(&self, samples: &[f32]) -> Option<BeatAlignResult> {
172        let downbeat_sample = self.detect_downbeat(samples)?;
173        let downbeat_ms = (downbeat_sample as f64 / f64::from(self.config.sample_rate)) * 1000.0;
174
175        // Find the grid beat nearest to the detected downbeat.
176        let beat_idx = self.config.grid.nearest_beat(downbeat_ms);
177        let grid_beat_ms = self.config.grid.beat_time_ms(beat_idx);
178        let offset_ms = grid_beat_ms - downbeat_ms;
179
180        // Simple confidence: full confidence when error is zero.
181        let error = offset_ms.abs();
182        let tolerance = self.config.tolerance_ms();
183        let confidence = if error > tolerance {
184            0.0
185        } else {
186            1.0 - error / tolerance
187        };
188
189        if confidence < 0.1 {
190            return None;
191        }
192
193        Some(BeatAlignResult {
194            offset: offset_ms,
195            confidence,
196            matched_beat_index: beat_idx,
197        })
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    fn make_config(bpm: f64) -> BeatAlignConfig {
206        BeatAlignConfig::new(BeatGrid::new(bpm), 48_000)
207    }
208
209    #[test]
210    fn test_beat_grid_interval_120bpm() {
211        let grid = BeatGrid::new(120.0);
212        assert!((grid.interval_ms() - 500.0).abs() < 1e-9);
213    }
214
215    #[test]
216    fn test_beat_grid_interval_60bpm() {
217        let grid = BeatGrid::new(60.0);
218        assert!((grid.interval_ms() - 1000.0).abs() < 1e-9);
219    }
220
221    #[test]
222    fn test_beat_grid_interval_zero_bpm() {
223        let grid = BeatGrid::new(0.0);
224        assert!(grid.interval_ms().is_infinite());
225    }
226
227    #[test]
228    fn test_beat_grid_beat_time_ms() {
229        let grid = BeatGrid::new(120.0); // 500 ms per beat
230        assert!((grid.beat_time_ms(0) - 0.0).abs() < 1e-9);
231        assert!((grid.beat_time_ms(1) - 500.0).abs() < 1e-9);
232        assert!((grid.beat_time_ms(4) - 2000.0).abs() < 1e-9);
233    }
234
235    #[test]
236    fn test_beat_grid_with_phase() {
237        let grid = BeatGrid::with_phase(120.0, 250.0);
238        assert!((grid.beat_time_ms(0) - 250.0).abs() < 1e-9);
239        assert!((grid.beat_time_ms(1) - 750.0).abs() < 1e-9);
240    }
241
242    #[test]
243    fn test_beat_grid_nearest_beat() {
244        let grid = BeatGrid::new(120.0); // 500 ms / beat
245        assert_eq!(grid.nearest_beat(0.0), 0);
246        assert_eq!(grid.nearest_beat(499.0), 1);
247        assert_eq!(grid.nearest_beat(1000.0), 2);
248    }
249
250    #[test]
251    fn test_config_tolerance_ms() {
252        let cfg = make_config(120.0);
253        assert!((cfg.tolerance_ms() - 20.0).abs() < 1e-9);
254    }
255
256    #[test]
257    fn test_beat_align_result_offset_ms() {
258        let r = BeatAlignResult {
259            offset: 12.5,
260            confidence: 0.9,
261            matched_beat_index: 3,
262        };
263        assert!((r.offset_ms() - 12.5).abs() < 1e-9);
264    }
265
266    #[test]
267    fn test_detect_downbeat_empty() {
268        let aligner = BeatAligner::new(make_config(120.0));
269        assert!(aligner.detect_downbeat(&[]).is_none());
270    }
271
272    #[test]
273    fn test_detect_downbeat_finds_loudest_region() {
274        let aligner = BeatAligner::new(make_config(120.0));
275        // Quiet signal with a loud burst at sample 4800
276        let mut samples = vec![0.01f32; 9600];
277        for i in 4800..5280 {
278            samples[i] = 1.0;
279        }
280        let idx = aligner
281            .detect_downbeat(&samples)
282            .expect("idx should be valid");
283        // Should be somewhere near 4800
284        assert!(idx >= 4320 && idx <= 5280);
285    }
286
287    #[test]
288    fn test_align_to_grid_empty() {
289        let aligner = BeatAligner::new(make_config(120.0));
290        assert!(aligner.align_to_grid(&[]).is_none());
291    }
292
293    #[test]
294    fn test_align_to_grid_returns_result() {
295        let aligner = BeatAligner::new(make_config(120.0));
296        // Non-trivial signal with energy at sample 0
297        let mut samples = vec![0.0f32; 48_000];
298        for s in &mut samples[0..480] {
299            *s = 1.0;
300        }
301        let result = aligner.align_to_grid(&samples);
302        // Should produce a result (or None if offset exceeds tolerance)
303        // — we just verify it doesn't panic.
304        let _ = result;
305    }
306
307    #[test]
308    fn test_aligner_config_accessor() {
309        let cfg = make_config(100.0);
310        let aligner = BeatAligner::new(cfg);
311        assert!((aligner.config().grid.bpm - 100.0).abs() < 1e-9);
312    }
313}