Skip to main content

jugar_probar/av_sync/
detection.rs

1//! Audio onset detection via RMS energy analysis.
2//!
3//! Detects percussive audio events (ticks/clicks) in PCM audio streams
4//! using sliding-window RMS energy thresholding.
5
6use super::types::AudioOnset;
7
8/// Configuration for onset detection.
9#[derive(Clone, Debug)]
10pub struct DetectionConfig {
11    /// Sample rate in Hz (e.g., 48000)
12    pub sample_rate: u32,
13    /// RMS window size in milliseconds (default: 10ms)
14    pub window_ms: f64,
15    /// Energy threshold in dB (default: -40.0)
16    pub threshold_db: f64,
17    /// Minimum gap between onsets in milliseconds (default: 200ms)
18    pub min_gap_ms: f64,
19    /// Look-back for onset refinement in milliseconds (default: 5ms)
20    pub refine_lookback_ms: f64,
21}
22
23impl Default for DetectionConfig {
24    fn default() -> Self {
25        Self {
26            sample_rate: 48000,
27            window_ms: 10.0,
28            threshold_db: -40.0,
29            min_gap_ms: 200.0,
30            refine_lookback_ms: 5.0,
31        }
32    }
33}
34
35impl DetectionConfig {
36    /// Create a new detection config with the given sample rate.
37    #[must_use]
38    pub fn with_sample_rate(mut self, sample_rate: u32) -> Self {
39        self.sample_rate = sample_rate;
40        self
41    }
42
43    /// Set the energy threshold in dB.
44    #[must_use]
45    pub fn with_threshold_db(mut self, threshold_db: f64) -> Self {
46        self.threshold_db = threshold_db;
47        self
48    }
49
50    /// Set the minimum gap between onsets.
51    #[must_use]
52    pub fn with_min_gap_ms(mut self, min_gap_ms: f64) -> Self {
53        self.min_gap_ms = min_gap_ms;
54        self
55    }
56
57    /// Window size in samples.
58    fn window_samples(&self) -> usize {
59        ((self.window_ms / 1000.0) * f64::from(self.sample_rate)) as usize
60    }
61
62    /// Minimum gap in samples.
63    fn min_gap_samples(&self) -> usize {
64        ((self.min_gap_ms / 1000.0) * f64::from(self.sample_rate)) as usize
65    }
66
67    /// Lookback size in samples.
68    fn lookback_samples(&self) -> usize {
69        ((self.refine_lookback_ms / 1000.0) * f64::from(self.sample_rate)) as usize
70    }
71}
72
73/// Compute RMS energy for a window of samples.
74fn rms_energy(samples: &[f32]) -> f64 {
75    if samples.is_empty() {
76        return 0.0;
77    }
78    let sum_sq: f64 = samples.iter().map(|&s| f64::from(s) * f64::from(s)).sum();
79    (sum_sq / samples.len() as f64).sqrt()
80}
81
82/// Convert RMS to decibels.
83fn rms_to_db(rms: f64) -> f64 {
84    if rms <= 0.0 {
85        return -120.0; // floor
86    }
87    20.0 * rms.log10()
88}
89
90/// Detect audio onsets in PCM samples.
91///
92/// Uses RMS energy windowing with threshold crossing detection.
93/// Returns onsets sorted by time.
94pub fn detect_onsets(samples: &[f32], config: &DetectionConfig) -> Vec<AudioOnset> {
95    let window_size = config.window_samples();
96    let min_gap = config.min_gap_samples();
97    let lookback = config.lookback_samples();
98
99    if samples.len() < window_size || window_size == 0 {
100        return Vec::new();
101    }
102
103    let mut onsets = Vec::new();
104    let mut last_onset_sample: Option<usize> = None;
105    let mut was_below = true;
106
107    // Slide window across the signal
108    let step = window_size / 2; // 50% overlap
109    let step = if step == 0 { 1 } else { step };
110    let mut pos = 0;
111
112    while pos + window_size <= samples.len() {
113        let window = &samples[pos..pos + window_size];
114        let rms = rms_energy(window);
115        let db = rms_to_db(rms);
116
117        if db >= config.threshold_db && was_below {
118            // Threshold crossing detected
119            let onset_sample = if lookback > 0 && pos >= lookback {
120                refine_onset(
121                    samples,
122                    pos,
123                    lookback,
124                    config.threshold_db,
125                    config.sample_rate,
126                )
127            } else {
128                pos
129            };
130
131            // Enforce minimum gap
132            let gap_ok = match last_onset_sample {
133                Some(last) => onset_sample.saturating_sub(last) >= min_gap,
134                None => true,
135            };
136
137            if gap_ok {
138                let time_secs = onset_sample as f64 / f64::from(config.sample_rate);
139                onsets.push(AudioOnset {
140                    time_secs,
141                    energy_db: db,
142                    sample_index: onset_sample,
143                });
144                last_onset_sample = Some(onset_sample);
145            }
146            was_below = false;
147        } else if db < config.threshold_db {
148            was_below = true;
149        }
150
151        pos += step;
152    }
153
154    onsets
155}
156
157/// Refine onset position by looking back to find true start.
158fn refine_onset(
159    samples: &[f32],
160    detected_pos: usize,
161    lookback: usize,
162    threshold_db: f64,
163    sample_rate: u32,
164) -> usize {
165    let start = detected_pos.saturating_sub(lookback);
166    let micro_window = (sample_rate as f64 * 0.002) as usize; // 2ms micro windows
167    let micro_window = if micro_window == 0 { 1 } else { micro_window };
168
169    let mut earliest = detected_pos;
170
171    let mut pos = start;
172    while pos + micro_window <= detected_pos {
173        let window = &samples[pos..pos + micro_window];
174        let rms = rms_energy(window);
175        let db = rms_to_db(rms);
176        if db >= threshold_db {
177            earliest = pos;
178            break;
179        }
180        pos += micro_window;
181    }
182
183    earliest
184}
185
186#[cfg(test)]
187#[allow(clippy::unwrap_used, clippy::expect_used)]
188mod tests {
189    use super::*;
190
191    /// Generate a synthetic PCM signal with ticks at known positions.
192    fn synthetic_signal(sample_rate: u32, duration_secs: f64, tick_times: &[f64]) -> Vec<f32> {
193        let total_samples = (duration_secs * f64::from(sample_rate)) as usize;
194        let mut samples = vec![0.0f32; total_samples];
195        let tick_duration_samples = (0.02 * f64::from(sample_rate)) as usize; // 20ms tick
196
197        for &tick_time in tick_times {
198            let start = (tick_time * f64::from(sample_rate)) as usize;
199            for i in 0..tick_duration_samples {
200                if start + i < total_samples {
201                    // Generate a short burst at 0.5 amplitude
202                    let phase =
203                        (i as f64 / f64::from(sample_rate)) * 1000.0 * std::f64::consts::TAU;
204                    samples[start + i] = (phase.sin() * 0.5) as f32;
205                }
206            }
207        }
208
209        samples
210    }
211
212    #[test]
213    fn test_rms_energy_silence() {
214        let silence = vec![0.0f32; 480];
215        let rms = rms_energy(&silence);
216        assert!(rms < f64::EPSILON);
217    }
218
219    #[test]
220    fn test_rms_energy_constant() {
221        let signal = vec![0.5f32; 480];
222        let rms = rms_energy(&signal);
223        assert!((rms - 0.5).abs() < 0.01);
224    }
225
226    #[test]
227    fn test_rms_energy_empty() {
228        let empty: Vec<f32> = vec![];
229        let rms = rms_energy(&empty);
230        assert!(rms < f64::EPSILON);
231    }
232
233    #[test]
234    fn test_rms_to_db_unity() {
235        let db = rms_to_db(1.0);
236        assert!(db.abs() < 0.01); // 0 dB
237    }
238
239    #[test]
240    fn test_rms_to_db_half() {
241        let db = rms_to_db(0.5);
242        assert!((db - (-6.02)).abs() < 0.1); // ~-6 dB
243    }
244
245    #[test]
246    fn test_rms_to_db_zero() {
247        let db = rms_to_db(0.0);
248        assert_eq!(db, -120.0);
249    }
250
251    #[test]
252    fn test_rms_to_db_negative() {
253        let db = rms_to_db(-1.0);
254        assert_eq!(db, -120.0);
255    }
256
257    #[test]
258    fn test_detect_onsets_empty() {
259        let config = DetectionConfig::default();
260        let onsets = detect_onsets(&[], &config);
261        assert!(onsets.is_empty());
262    }
263
264    #[test]
265    fn test_detect_onsets_silence() {
266        let config = DetectionConfig::default();
267        let silence = vec![0.0f32; 48000]; // 1 second of silence
268        let onsets = detect_onsets(&silence, &config);
269        assert!(onsets.is_empty());
270    }
271
272    #[test]
273    fn test_detect_onsets_single_tick() {
274        let config = DetectionConfig::default();
275        let signal = synthetic_signal(48000, 2.0, &[1.0]);
276        let onsets = detect_onsets(&signal, &config);
277        assert_eq!(onsets.len(), 1, "expected exactly 1 onset");
278        // Allow 15ms tolerance for detection precision
279        assert!(
280            (onsets[0].time_secs - 1.0).abs() < 0.015,
281            "onset at {:.3}s, expected ~1.0s",
282            onsets[0].time_secs
283        );
284    }
285
286    #[test]
287    fn test_detect_onsets_multiple_ticks() {
288        let config = DetectionConfig::default();
289        let signal = synthetic_signal(48000, 5.0, &[1.0, 2.0, 3.0]);
290        let onsets = detect_onsets(&signal, &config);
291        assert_eq!(onsets.len(), 3, "expected 3 onsets, got {}", onsets.len());
292
293        for (i, expected_time) in [1.0, 2.0, 3.0].iter().enumerate() {
294            assert!(
295                (onsets[i].time_secs - expected_time).abs() < 0.015,
296                "onset[{}] at {:.3}s, expected ~{:.1}s",
297                i,
298                onsets[i].time_secs,
299                expected_time
300            );
301        }
302    }
303
304    #[test]
305    fn test_detect_onsets_minimum_gap_enforcement() {
306        // Two ticks 100ms apart should merge (min_gap=200ms)
307        let config = DetectionConfig::default();
308        let signal = synthetic_signal(48000, 2.0, &[1.0, 1.1]);
309        let onsets = detect_onsets(&signal, &config);
310        assert_eq!(
311            onsets.len(),
312            1,
313            "ticks 100ms apart should merge, got {} onsets",
314            onsets.len()
315        );
316    }
317
318    #[test]
319    fn test_detect_onsets_respects_threshold() {
320        let mut config = DetectionConfig::default();
321        config.threshold_db = 0.0; // Very high threshold
322        let signal = synthetic_signal(48000, 2.0, &[1.0]);
323        let onsets = detect_onsets(&signal, &config);
324        assert!(
325            onsets.is_empty(),
326            "high threshold should reject quiet ticks"
327        );
328    }
329
330    #[test]
331    fn test_detect_onsets_too_short() {
332        let config = DetectionConfig::default();
333        let short = vec![0.5f32; 10]; // Way too short for a window
334        let onsets = detect_onsets(&short, &config);
335        assert!(onsets.is_empty());
336    }
337
338    #[test]
339    fn test_detection_config_default() {
340        let config = DetectionConfig::default();
341        assert_eq!(config.sample_rate, 48000);
342        assert!((config.window_ms - 10.0).abs() < f64::EPSILON);
343        assert!((config.threshold_db - (-40.0)).abs() < f64::EPSILON);
344        assert!((config.min_gap_ms - 200.0).abs() < f64::EPSILON);
345    }
346
347    #[test]
348    fn test_detection_config_builders() {
349        let config = DetectionConfig::default()
350            .with_sample_rate(44100)
351            .with_threshold_db(-30.0)
352            .with_min_gap_ms(100.0);
353        assert_eq!(config.sample_rate, 44100);
354        assert!((config.threshold_db - (-30.0)).abs() < f64::EPSILON);
355        assert!((config.min_gap_ms - 100.0).abs() < f64::EPSILON);
356    }
357
358    #[test]
359    fn test_window_samples_calculation() {
360        let config = DetectionConfig::default(); // 48000 Hz, 10ms window
361        assert_eq!(config.window_samples(), 480);
362    }
363
364    #[test]
365    fn test_min_gap_samples_calculation() {
366        let config = DetectionConfig::default(); // 48000 Hz, 200ms gap
367        assert_eq!(config.min_gap_samples(), 9600);
368    }
369
370    #[test]
371    fn test_onset_ordering() {
372        let config = DetectionConfig::default();
373        let signal = synthetic_signal(48000, 5.0, &[3.0, 1.0, 2.0]);
374        let onsets = detect_onsets(&signal, &config);
375        // Onsets should be in chronological order
376        for pair in onsets.windows(2) {
377            assert!(pair[0].time_secs <= pair[1].time_secs);
378        }
379    }
380}