oximedia_align/
beat_align.rs1#![allow(dead_code)]
7
8#[derive(Debug, Clone, Copy, PartialEq)]
10pub struct BeatGrid {
11 pub bpm: f64,
13 pub phase_offset_ms: f64,
15}
16
17impl BeatGrid {
18 #[must_use]
20 pub fn new(bpm: f64) -> Self {
21 Self {
22 bpm,
23 phase_offset_ms: 0.0,
24 }
25 }
26
27 #[must_use]
29 pub fn with_phase(bpm: f64, phase_offset_ms: f64) -> Self {
30 Self {
31 bpm,
32 phase_offset_ms,
33 }
34 }
35
36 #[must_use]
38 pub fn interval_ms(&self) -> f64 {
39 if self.bpm <= 0.0 {
40 f64::INFINITY
41 } else {
42 60_000.0 / self.bpm
43 }
44 }
45
46 #[must_use]
48 pub fn beat_time_ms(&self, beat_index: u32) -> f64 {
49 self.phase_offset_ms + f64::from(beat_index) * self.interval_ms()
50 }
51
52 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
54 #[must_use]
55 pub fn nearest_beat(&self, time_ms: f64) -> u32 {
56 if self.bpm <= 0.0 {
57 return 0;
58 }
59 let offset = time_ms - self.phase_offset_ms;
60 let beat_f = offset / self.interval_ms();
61 beat_f.round().max(0.0) as u32
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct BeatAlignConfig {
68 pub grid: BeatGrid,
70 pub tolerance: f64,
72 pub sample_rate: u32,
74}
75
76impl BeatAlignConfig {
77 #[must_use]
79 pub fn new(grid: BeatGrid, sample_rate: u32) -> Self {
80 Self {
81 grid,
82 tolerance: 20.0,
83 sample_rate,
84 }
85 }
86
87 #[must_use]
89 pub fn tolerance_ms(&self) -> f64 {
90 self.tolerance
91 }
92}
93
94#[derive(Debug, Clone, Copy)]
96pub struct BeatAlignResult {
97 pub offset: f64,
99 pub confidence: f64,
101 pub matched_beat_index: u32,
103}
104
105impl BeatAlignResult {
106 #[must_use]
108 pub fn offset_ms(&self) -> f64 {
109 self.offset
110 }
111}
112
113#[derive(Debug)]
115pub struct BeatAligner {
116 config: BeatAlignConfig,
117}
118
119impl BeatAligner {
120 #[must_use]
122 pub fn new(config: BeatAlignConfig) -> Self {
123 Self { config }
124 }
125
126 #[must_use]
128 pub fn config(&self) -> &BeatAlignConfig {
129 &self.config
130 }
131
132 #[allow(clippy::cast_precision_loss)]
139 #[must_use]
140 pub fn detect_downbeat(&self, samples: &[f32]) -> Option<usize> {
141 if samples.is_empty() {
142 return None;
143 }
144 let window = (self.config.sample_rate / 100) as usize; let window = window.max(1);
146 let mut best_idx = 0usize;
147 let mut best_rms = 0.0f64;
148
149 let mut i = 0usize;
150 while i + window <= samples.len() {
151 let rms: f64 = samples[i..i + window]
152 .iter()
153 .map(|&s| f64::from(s) * f64::from(s))
154 .sum::<f64>()
155 / window as f64;
156 if rms > best_rms {
157 best_rms = rms;
158 best_idx = i;
159 }
160 i += window;
161 }
162 Some(best_idx)
163 }
164
165 #[allow(clippy::cast_precision_loss)]
170 #[must_use]
171 pub fn align_to_grid(&self, samples: &[f32]) -> Option<BeatAlignResult> {
172 let downbeat_sample = self.detect_downbeat(samples)?;
173 let downbeat_ms = (downbeat_sample as f64 / f64::from(self.config.sample_rate)) * 1000.0;
174
175 let beat_idx = self.config.grid.nearest_beat(downbeat_ms);
177 let grid_beat_ms = self.config.grid.beat_time_ms(beat_idx);
178 let offset_ms = grid_beat_ms - downbeat_ms;
179
180 let error = offset_ms.abs();
182 let tolerance = self.config.tolerance_ms();
183 let confidence = if error > tolerance {
184 0.0
185 } else {
186 1.0 - error / tolerance
187 };
188
189 if confidence < 0.1 {
190 return None;
191 }
192
193 Some(BeatAlignResult {
194 offset: offset_ms,
195 confidence,
196 matched_beat_index: beat_idx,
197 })
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 fn make_config(bpm: f64) -> BeatAlignConfig {
206 BeatAlignConfig::new(BeatGrid::new(bpm), 48_000)
207 }
208
209 #[test]
210 fn test_beat_grid_interval_120bpm() {
211 let grid = BeatGrid::new(120.0);
212 assert!((grid.interval_ms() - 500.0).abs() < 1e-9);
213 }
214
215 #[test]
216 fn test_beat_grid_interval_60bpm() {
217 let grid = BeatGrid::new(60.0);
218 assert!((grid.interval_ms() - 1000.0).abs() < 1e-9);
219 }
220
221 #[test]
222 fn test_beat_grid_interval_zero_bpm() {
223 let grid = BeatGrid::new(0.0);
224 assert!(grid.interval_ms().is_infinite());
225 }
226
227 #[test]
228 fn test_beat_grid_beat_time_ms() {
229 let grid = BeatGrid::new(120.0); assert!((grid.beat_time_ms(0) - 0.0).abs() < 1e-9);
231 assert!((grid.beat_time_ms(1) - 500.0).abs() < 1e-9);
232 assert!((grid.beat_time_ms(4) - 2000.0).abs() < 1e-9);
233 }
234
235 #[test]
236 fn test_beat_grid_with_phase() {
237 let grid = BeatGrid::with_phase(120.0, 250.0);
238 assert!((grid.beat_time_ms(0) - 250.0).abs() < 1e-9);
239 assert!((grid.beat_time_ms(1) - 750.0).abs() < 1e-9);
240 }
241
242 #[test]
243 fn test_beat_grid_nearest_beat() {
244 let grid = BeatGrid::new(120.0); assert_eq!(grid.nearest_beat(0.0), 0);
246 assert_eq!(grid.nearest_beat(499.0), 1);
247 assert_eq!(grid.nearest_beat(1000.0), 2);
248 }
249
250 #[test]
251 fn test_config_tolerance_ms() {
252 let cfg = make_config(120.0);
253 assert!((cfg.tolerance_ms() - 20.0).abs() < 1e-9);
254 }
255
256 #[test]
257 fn test_beat_align_result_offset_ms() {
258 let r = BeatAlignResult {
259 offset: 12.5,
260 confidence: 0.9,
261 matched_beat_index: 3,
262 };
263 assert!((r.offset_ms() - 12.5).abs() < 1e-9);
264 }
265
266 #[test]
267 fn test_detect_downbeat_empty() {
268 let aligner = BeatAligner::new(make_config(120.0));
269 assert!(aligner.detect_downbeat(&[]).is_none());
270 }
271
272 #[test]
273 fn test_detect_downbeat_finds_loudest_region() {
274 let aligner = BeatAligner::new(make_config(120.0));
275 let mut samples = vec![0.01f32; 9600];
277 for i in 4800..5280 {
278 samples[i] = 1.0;
279 }
280 let idx = aligner
281 .detect_downbeat(&samples)
282 .expect("idx should be valid");
283 assert!(idx >= 4320 && idx <= 5280);
285 }
286
287 #[test]
288 fn test_align_to_grid_empty() {
289 let aligner = BeatAligner::new(make_config(120.0));
290 assert!(aligner.align_to_grid(&[]).is_none());
291 }
292
293 #[test]
294 fn test_align_to_grid_returns_result() {
295 let aligner = BeatAligner::new(make_config(120.0));
296 let mut samples = vec![0.0f32; 48_000];
298 for s in &mut samples[0..480] {
299 *s = 1.0;
300 }
301 let result = aligner.align_to_grid(&samples);
302 let _ = result;
305 }
306
307 #[test]
308 fn test_aligner_config_accessor() {
309 let cfg = make_config(100.0);
310 let aligner = BeatAligner::new(cfg);
311 assert!((aligner.config().grid.bpm - 100.0).abs() < 1e-9);
312 }
313}