audio_samples 1.0.5

A typed audio processing library for Rust that treats audio as a first-class, invariant-preserving object rather than an unstructured numeric buffer.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
//! Beat tracking operations for [`AudioSamples`].
//!
//! This module implements the [`AudioBeatTracking`] trait and provides
//! supporting types ([`BeatTrackingConfig`], [`BeatTrackingData`]) and
//! lower-level helpers ([`onset_strength_envelope`], [`track_beats_core`])
//! for tempo-aware beat detection.
//!
//! Rhythmic analysis — tempo estimation, beat alignment, and
//! synchronisation — is central to music production, DJ software, and
//! audio feature extraction.  Encapsulating the detection pipeline
//! behind a single trait keeps callers isolated from the onset
//! detection internals.
//!
//! Build a [`BeatTrackingConfig`] with the target tempo and onset
//! detection parameters, then call [`AudioBeatTracking::detect_beats`]
//! on any [`AudioSamples`] value.  The lower-level [`track_beats_core`]
//! function is available when you already have a pre-computed onset
//! envelope.
//!
//! ```
//! use audio_samples::operations::beat::track_beats_core;
//! use non_empty_slice::NonEmptyVec;
//! use std::num::NonZeroUsize;
//!
//! let onset = NonEmptyVec::new(vec![0.0f64; 100]).unwrap();
//! let beats = track_beats_core(
//!     &onset,
//!     120.0,
//!     44100.0,
//!     NonZeroUsize::new(512).unwrap(),
//!     None,
//! ).unwrap();
//! assert!(!beats.is_empty());
//! ```

use std::num::NonZeroUsize;

use non_empty_slice::{NonEmptySlice, NonEmptyVec};

use crate::{
    AudioOnsetDetection, AudioSampleError, AudioSampleResult, AudioSamples, ParameterError,
    operations::{onset::OnsetDetectionConfig, traits::AudioBeatTracking},
    traits::StandardSample,
};

/// Results of a beat tracking analysis.
///
/// Holds the estimated tempo and the detected beat timestamps returned
/// by [`AudioBeatTracking::detect_beats`].
///
/// # Ordering
///
/// The elements of `beat_times` are **not** sorted chronologically.
/// The first element is the timestamp of the global onset peak
/// (highest-energy onset in the signal).  Forward beats follow in
/// causal order; backward beats are appended afterwards in reverse
/// temporal order.  Sort `beat_times` if you need chronological
/// order.
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct BeatTrackingData {
    /// Estimated tempo in beats per minute.
    pub tempo_bpm: f64,
    /// Beat timestamps in seconds, in detection order.
    ///
    /// The first element is the global onset peak; forward beats
    /// follow in causal order; backward beats are appended last in
    /// reverse temporal order.  See the struct-level documentation.
    pub beat_times: Vec<f64>,
    /// The configuration used to produce this result.
    pub config: BeatTrackingConfig,
}

impl BeatTrackingData {
    /// Create a `BeatTrackingData` from pre-computed components.
    ///
    /// # Arguments
    /// - `tempo_bpm` – Estimated tempo in beats per minute.
    /// - `beat_times` – Beat timestamps in seconds, in detection order.
    /// - `config` – The [`BeatTrackingConfig`] used to produce this
    ///   result.
    ///
    /// # Returns
    /// A new [`BeatTrackingData`].
    #[inline]
    #[must_use]
    pub const fn new(tempo_bpm: f64, beat_times: Vec<f64>, config: BeatTrackingConfig) -> Self {
        Self {
            tempo_bpm,
            beat_times,
            config,
        }
    }
}

impl core::fmt::Display for BeatTrackingData {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        writeln!(f, "Estimated Tempo: {:.2} BPM", self.tempo_bpm)?;
        writeln!(f, "Detected Beats (s):")?;
        for &time in &self.beat_times {
            writeln!(f, "{time:.3}")?;
        }
        Ok(())
    }
}

/// Configuration for beat detection.
///
/// Controls the target tempo, timing tolerance, and the underlying
/// onset detection pipeline.  Pass to
/// [`AudioBeatTracking::detect_beats`].
///
/// # Invariants
/// - `tempo_bpm` must be positive (> 0).
/// - `tolerance`, when `Some`, should be positive and smaller than
///   the inter-beat interval.
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct BeatTrackingConfig {
    /// Target tempo in beats per minute; must be > 0.
    pub tempo_bpm: f64,
    /// Beat timing tolerance in seconds.
    ///
    /// The beat tracker searches for a local onset peak within
    /// ±`tolerance` seconds of each expected beat position.
    /// When `None`, defaults to 10 % of the inter-beat interval.
    pub tolerance: Option<f64>,
    /// Configuration forwarded to the onset detection pipeline.
    pub onset_config: OnsetDetectionConfig,
}

impl BeatTrackingConfig {
    /// Create a new [`BeatTrackingConfig`].
    ///
    /// # Arguments
    /// - `tempo_bpm` – Target tempo in beats per minute; must be > 0.
    /// - `tolerance` – Beat timing tolerance in seconds.  When `None`,
    ///   the tracker defaults to 10 % of the inter-beat interval.
    /// - `onset_config` – Configuration for the onset detection
    ///   pipeline.
    ///
    /// # Returns
    /// A new [`BeatTrackingConfig`].
    #[inline]
    #[must_use]
    pub const fn new(
        tempo_bpm: f64,
        tolerance: Option<f64>,
        onset_config: OnsetDetectionConfig,
    ) -> Self {
        Self {
            tempo_bpm,
            tolerance,
            onset_config,
        }
    }
}

impl<T> AudioBeatTracking for AudioSamples<'_, T>
where
    T: StandardSample,
{
    /// Detect beat positions in the audio signal at the target tempo.
    ///
    /// Computes an onset strength envelope from the signal using the
    /// onset detection configuration in `config`, then locates beat
    /// frames by walking forward and backward from the global onset
    /// peak in steps of one inter-beat interval.
    ///
    /// # Arguments
    /// - `config` – Beat tracking configuration: target tempo,
    ///   optional timing tolerance, and onset detection parameters.
    ///
    /// # Returns
    /// A [`BeatTrackingData`] containing the target tempo and the
    /// detected beat timestamps in seconds.  Beat times are in
    /// detection order (global peak first, then forward beats, then
    /// backward beats in reverse); sort `beat_times` for
    /// chronological order.
    ///
    /// # Errors
    /// - [crate::AudioSampleError::Parameter] – if `config.tempo_bpm` is
    ///   ≤ 0 or if the inter-beat interval is too small relative to
    ///   the hop size.
    ///
    /// # Examples
    /// ```no_run
    /// use audio_samples::{AudioSamples, sample_rate};
    /// use audio_samples::operations::beat::{BeatTrackingConfig, BeatTrackingData};
    /// use audio_samples::operations::traits::AudioBeatTracking;
    ///
    /// # fn example(audio: AudioSamples<'_, f32>, config: BeatTrackingConfig) {
    /// let result = audio.detect_beats(&config).unwrap();
    /// println!("Tempo: {:.1} BPM", result.tempo_bpm);
    /// for &t in &result.beat_times {
    ///     println!("  beat at {:.3} s", t);
    /// }
    /// # }
    /// ```
    fn detect_beats(&self, config: &BeatTrackingConfig) -> AudioSampleResult<BeatTrackingData> {
        let sr = self.sample_rate_hz();

        // Decide channel strategy explicitly
        let onset = onset_strength_envelope(self, &config.onset_config, None)?;

        let beats = track_beats_core(
            &onset,
            config.tempo_bpm,
            sr,
            config.onset_config.hop_size,
            config.tolerance,
        )?;

        Ok(BeatTrackingData {
            tempo_bpm: config.tempo_bpm,
            beat_times: beats,
            config: config.clone(),
        })
    }
}

/// Compute a smoothed, log-compressed onset strength envelope.
///
/// Runs the onset detection function on `audio`, applies a symmetric
/// moving-average smoothing window of `config.window_size` frames
/// (defaulting to 3 when `None`), then maps each smoothed value
/// through `ln(1 + μ × x)` to compress the dynamic range.
///
/// # Arguments
/// - `audio` – Input audio signal.
/// - `config` – Onset detection configuration.
/// - `log_compression` – Compression factor μ in `ln(1 + μ × x)`.
///   Defaults to `0.5` when `None`.
///
/// # Returns
/// A non-empty vector of onset strength values, one per analysis
/// frame.
///
/// # Errors
/// - Propagates any error from the underlying onset detection
///   function.
pub fn onset_strength_envelope<T>(
    audio: &AudioSamples<'_, T>,
    config: &OnsetDetectionConfig,
    log_compression: Option<f64>,
) -> AudioSampleResult<NonEmptyVec<f64>>
where
    T: StandardSample,
{
    let (_times, odf) = audio.onset_detection_function(config)?;
    let odf = odf.to_vec();
    // Simple moving average smoothing
    let window = config.window_size.unwrap_or(crate::nzu!(3)).get();
    let mut smoothed = vec![0.0; odf.len()];
    for (i, _) in odf.iter().enumerate() {
        let start = i.saturating_sub(window);
        let end: usize = (i + window + 1).min(odf.len());
        let acc: f64 = odf
            .iter()
            .skip(start)
            .take(end - start)
            .fold(0.0, |acc, x| acc + *x);
        smoothed[i] = acc / (end - start) as f64;
    }

    let compression = log_compression.unwrap_or(0.5);

    let env: Vec<f64> = smoothed
        .iter()
        .map(|&x| (compression * x).ln_1p())
        .collect();

    // safety: odf is non-empty, so env is non-empty
    let env = unsafe { NonEmptyVec::new_unchecked(env) };
    Ok(env)
}

#[inline]
fn peak_index(slice: &[f64]) -> usize {
    let mut best_i = 0usize;
    let mut best_v = f64::NEG_INFINITY;

    // Manual loop beats iterator chains for branch predictability and inlining
    for i in 0..slice.len() {
        // safety: i is in bounds of slice
        let v = unsafe { *slice.get_unchecked(i) };
        if v > best_v {
            best_v = v;
            best_i = i;
        }
    }

    best_i
}

/// Core beat tracking kernel operating on a pre-computed onset envelope.
///
/// Finds the global peak of `onset`, then walks forward and backward
/// from that peak in steps of one inter-beat interval.  At each
/// expected beat position a local peak search within a tolerance
/// window selects the actual frame.
///
/// Beat times are returned in detection order: the first element is
/// the global onset peak; forward beats follow in causal order;
/// backward beats are appended last in reverse temporal order.
/// Sort the result if chronological order is needed.
///
/// # Arguments
/// - `onset` – Pre-computed onset strength envelope, one value per
///   analysis frame.
/// - `tempo_bpm` – Target tempo in beats per minute; must be > 0.
/// - `sample_rate` – Sample rate of the original audio in Hz.
/// - `hop_size` – Number of audio samples per analysis frame.
/// - `tolerance_seconds` – Half-width of the local search window in
///   seconds.  When `None`, defaults to 10 % of the inter-beat
///   interval (minimum 1 frame).
///
/// # Returns
/// Beat timestamps in seconds, in detection order.
///
/// # Errors
/// - [crate::AudioSampleError::Parameter] – if `tempo_bpm` is ≤ 0 or if
///   the inter-beat interval in frames is ≤ 0.
///
/// # Examples
/// ```
/// use audio_samples::operations::beat::track_beats_core;
/// use non_empty_slice::NonEmptyVec;
/// use std::num::NonZeroUsize;
///
/// let onset = NonEmptyVec::new(vec![0.0f64; 100]).unwrap();
/// let beats = track_beats_core(
///     &onset,
///     120.0,
///     44100.0,
///     NonZeroUsize::new(512).unwrap(),
///     None,
/// ).unwrap();
/// assert!(!beats.is_empty());
/// ```
#[inline]
pub fn track_beats_core(
    onset: &NonEmptySlice<f64>,
    tempo_bpm: f64,
    sample_rate: f64,
    hop_size: NonZeroUsize,
    tolerance_seconds: Option<f64>,
) -> AudioSampleResult<Vec<f64>> {
    if tempo_bpm <= 0.0 {
        return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
            "tempo_bpm",
            tempo_bpm,
        )));
    }

    // --- Timing model ---
    let hop_time = hop_size.get() as f64 / sample_rate;
    let ibi_seconds = 60.0 / tempo_bpm;

    let ibi_frames = (ibi_seconds / hop_time).round() as isize;
    if ibi_frames <= 0 {
        return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
            "tempo_bpm",
            "Inter-beat interval too small",
        )));
    }

    let tol_frames = tolerance_seconds
        .map_or_else(
            || (ibi_frames as f64 * 0.1).round() as isize,
            |t| (t / hop_time).round() as isize,
        )
        .max(1);

    let len = onset.len().get() as isize;

    // --- Starting peak ---
    let mut start = 0isize;
    let mut best_v = f64::NEG_INFINITY;
    for i in 0..onset.len().get() {
        // safety: i is in bounds of onset
        let v = unsafe { *onset.get_unchecked(i) };
        if v > best_v {
            best_v = v;
            start = i as isize;
        }
    }

    // Conservative capacity estimate
    let est_beats = (len / ibi_frames).max(1) as usize;
    let mut beat_frames = Vec::with_capacity(est_beats);
    beat_frames.push(start);

    // --- Forward tracking ---
    // `center` advances by exactly `ibi_frames` each step so the loop always
    // terminates, even when the detected peak falls behind the previous position.
    let mut center = start;
    while center + ibi_frames < len {
        center += ibi_frames;

        let lo = (center - tol_frames).max(0) as usize;
        let hi = (center + tol_frames).min(len) as usize;

        let rel = peak_index(&onset[lo..hi]) as isize;
        beat_frames.push(lo as isize + rel);
    }

    // --- Backward tracking ---
    let mut center = start;
    while center - ibi_frames >= 0 {
        center -= ibi_frames;

        let lo = (center - tol_frames).max(0) as usize;
        let hi = (center + tol_frames).min(len) as usize;

        let rel = peak_index(&onset[lo..hi]) as isize;
        beat_frames.push(lo as isize + rel);
    }

    // --- Convert to seconds ---
    let mut times = Vec::with_capacity(beat_frames.len());
    for f in beat_frames {
        times.push(f as f64 * hop_time);
    }

    Ok(times)
}

#[cfg(test)]
mod tests {
    use super::*;
    use non_empty_slice::NonEmptyVec;
    use proptest::prelude::*;

    fn synthetic_onset(len: usize) -> NonEmptyVec<f64> {
        // Simple periodic peaks with noise
        let mut v = vec![0.0; len];
        for i in (0..len).step_by(20.max(1)) {
            v[i] = 1.0;
        }
        let v = NonEmptyVec::new(v).unwrap();
        v
    }

    proptest! {
        #[test]
        fn beat_times_are_finite_and_non_negative(
            len in 64usize..2048,
            tempo in 40.0f64..240.0,
            sr in 8_000.0f64..96_000.0,
            hop in 1usize..2048,
        ) {
            let onset = synthetic_onset(len);
            let hop = NonZeroUsize::new(hop).unwrap();
            let beats = track_beats_core(
                &onset,
                tempo,
                sr,
                hop,
                None,
            ).unwrap();

            for &t in &beats {
                prop_assert!(t.is_finite());
                prop_assert!(t >= 0.0);
            }
        }

        #[test]
        fn beat_times_within_signal_bounds(
            len in 64usize..4096,
            tempo in 40.0f64..240.0,
            sr in 8_000.0f64..96_000.0,
            hop in 1usize..1024,
        ) {
            let onset = synthetic_onset(len);
            let duration = (len as f64 * hop as f64) / sr;
            let hop = NonZeroUsize::new(hop).unwrap();
            let beats = track_beats_core(
                &onset,
                tempo,
                sr,
                hop,
                None,
            ).unwrap();

            for &t in &beats {
                prop_assert!(t <= duration + 1e-6);
            }
        }

        #[test]
        fn first_beat_is_global_peak(
            len in 128usize..2048,
            tempo in 60.0f64..180.0,
            sr in 16_000.0f64..48_000.0,
            hop in 1usize..1024,
        ) {
            let hop = NonZeroUsize::new(hop).unwrap();
            let onset = vec![0.0; len];
            let mut onset = NonEmptyVec::new(onset).unwrap();
            let peak_idx = len / 3;
            onset[peak_idx] = 10.0;

            let beats = track_beats_core(
                &onset,
                tempo,
                sr,
                hop,
                None,
            ).unwrap();

            let first_frame = (beats[0] * sr / hop.get() as f64).round() as usize;
            prop_assert_eq!(first_frame, peak_idx);
        }

        #[test]
        fn insertion_order_preserves_forward_then_backward_structure(
            len in 256usize..4096,
            tempo in 60.0f64..180.0,
            sr in 16_000.0f64..48_000.0,
            hop in 1usize..512,
        ) {
            let onset = synthetic_onset(len);
            let hop = NonZeroUsize::new(hop).unwrap();
            let beats = track_beats_core(
                &onset,
                tempo,
                sr,
                hop,
                None,
            ).unwrap();

            if beats.len() >= 2 {
                // First forward step should move forward in time
                prop_assert!(beats[1] >= beats[0]);
            }

            // Check structure: forward beats come before backward beats
            // All beats after index 0 should be either all >= beats[0] (only forward)
            // or mixed (some forward >= beats[0], then backward < beats[0])
            if beats.len() >= 3 {
                let first = beats[0];
                let mut seen_forward = false;
                let mut seen_backward = false;

                for i in 1..beats.len() {
                    if beats[i] >= first {
                        // Forward beat
                        prop_assert!(!seen_backward, "Forward beat found after backward beat");
                        seen_forward = true;
                    } else {
                        // Backward beat
                        seen_backward = true;
                    }
                }

                if seen_forward && seen_backward {
                    // If both forward and backward beats are present, check the order
                    let first_backward_idx = beats.iter().position(|&t| t < first).unwrap();
                    for i in 1..first_backward_idx {
                        prop_assert!(beats[i] >= first, "Beat at index {} should be forward", i);
                    }
                    for i in first_backward_idx..beats.len() {
                        prop_assert!(beats[i] < first, "Beat at index {} should be backward", i);
                    }
                }
            }
        }
    }
}