bpm_analyzer/
lib.rs

1//! # BPM Analyzer
2//!
3//! A real-time BPM (beats per minute) detection library that analyzes audio input
4//! using wavelet decomposition and autocorrelation techniques.
5//!
6//! ## Features
7//!
8//! - Real-time audio capture from system audio devices
9//! - Wavelet-based onset detection using Discrete Wavelet Transform (DWT)
10//! - Multi-band envelope analysis
11//! - Autocorrelation-based tempo estimation
12//! - Configurable BPM range and analysis parameters
13//!
14//! ## Example
15//!
16//! ```no_run
17//! use bpm_analyzer::{AnalyzerConfig, begin};
18//!
19//! // Configure the analyzer with default settings
20//! let config = AnalyzerConfig::builder()
21//!     .min_bpm(60.0)
22//!     .max_bpm(180.0)
23//!     .build();
24//!
25//! // Start the analyzer and receive BPM candidates
26//! let bpm_receiver = begin(config).expect("Failed to start analyzer");
27//!
28//! // Process BPM candidates
29//! for peaks in bpm_receiver.iter() {
30//!     // Each entry contains (bpm, confidence) pairs
31//!     if let Some((bpm, confidence)) = peaks.first() {
32//!         println!("Detected BPM: {} (confidence: {})", bpm, confidence);
33//!     }
34//! }
35//! ```
36
37use cpal::{
38    BufferSize, FromSample, SampleFormat, SizedSample,
39    traits::{DeviceTrait, HostTrait, StreamTrait},
40};
41use crossbeam_channel::{Sender, bounded};
42use fundsp::prelude32::*;
43use osclet::{BorderMode, DaubechiesFamily, Osclet};
44use resampler::{Attenuation, Latency, ResamplerFir, SampleRate};
45use ringbuffer::{AllocRingBuffer, RingBuffer};
46
47pub mod dsp;
48
49/// Size of the cross-thread audio queue
50/// Size of the cross-thread audio queue
51const QUEUE_SIZE: usize = 4096;
52/// Number of levels for discrete wavelet transform decomposition
53const DWT_LEVELS: usize = 4;
54/// Size of the audio buffer for capture
55const AUDIO_BUFFER_SIZE: u32 = 256;
56/// Window size for DWT analysis (must be power of 2)
57const DWT_WINDOW_SIZE: usize = 65536;
58/// Target sampling rate for analysis (Hz)
59const TARGET_SAMPLING_RATE: f64 = 22050.0;
60
61/// Default minimum BPM for detection range
62const MIN_BPM: f32 = 40.0;
63/// Default maximum BPM for detection range
64const MAX_BPM: f32 = 240.0;
65
66/// Internal buffer wrapper for processing audio chunks of varying sizes.
67///
68/// This enum handles both full-sized buffers (aligned for SIMD processing)
69/// and partial buffers (for the last chunk that may be smaller than the max buffer size).
70#[allow(clippy::large_enum_variant)]
71enum TransientBuffer<'a> {
72    /// A full buffer that can be processed using SIMD operations
73    Full(BufferRef<'a>),
74    /// A partial buffer for the remaining samples
75    Partial {
76        /// The buffer containing the samples
77        buffer: BufferArray<U1>,
78        /// The actual number of valid samples in the buffer
79        length: usize,
80    },
81}
82
83/// Errors that can occur during BPM analysis
84#[derive(Debug, thiserror::Error)]
85pub enum Error {
86    #[error("CPAL error: {0}")]
87    DevicesError(#[from] cpal::DevicesError),
88    #[error("Unsupported sample format: {0}")]
89    UnsupportedSampleFormat(SampleFormat),
90    #[error("Failed to get device name: {0}")]
91    DeviceNameError(#[from] cpal::DeviceNameError),
92    #[error("No input device found")]
93    NoDeviceFound,
94    #[error("Failed to get stream configuration: {0}")]
95    StreamConfigError(#[from] cpal::DefaultStreamConfigError),
96    #[error("Resampling error: {0}")]
97    ResampleError(resampler::ResampleError),
98    #[error("Osclet error: {0}")]
99    Osclet(#[from] osclet::OscletError),
100}
101
102pub type Result<T, E = Error> = std::result::Result<T, E>;
103
104pub use crossbeam_channel::Receiver;
105
106/// Configuration for the BPM analyzer.
107///
108/// Use the builder pattern to customize analyzer parameters:
109///
110/// # Example
111///
112/// ```
113/// use bpm_analyzer::AnalyzerConfig;
114///
115/// let config = AnalyzerConfig::builder()
116///     .min_bpm(60.0)
117///     .max_bpm(180.0)
118///     .window_size(32768)
119///     .build();
120/// ```
121#[derive(Clone, Debug, Copy, bon::Builder)]
122pub struct AnalyzerConfig {
123    /// Minimum BPM to detect (default: 40.0)
124    #[builder(default = MIN_BPM)]
125    min_bpm: f32,
126    /// Maximum BPM to detect (default: 240.0)
127    #[builder(default = MAX_BPM)]
128    max_bpm: f32,
129    /// Size of the analysis window in samples (default: 65536)
130    #[builder(default = DWT_WINDOW_SIZE)]
131    window_size: usize,
132    /// Size of the audio queue between threads (default: 4096)
133    #[builder(default = QUEUE_SIZE)]
134    queue_size: usize,
135    /// Size of the audio capture buffer (default: 256)
136    #[builder(default = AUDIO_BUFFER_SIZE)]
137    buffer_size: u32,
138}
139
140/// Starts the BPM analyzer with the given configuration.
141///
142/// This function initializes the audio input stream and spawns a background thread
143/// that performs real-time BPM analysis. It returns a receiver that yields arrays of
144/// BPM candidates with their confidence values.
145///
146/// # Arguments
147///
148/// * `config` - The analyzer configuration
149///
150/// # Returns
151///
152/// A `Receiver` that yields arrays of 5 `(bpm, confidence)` tuples, where:
153/// - `bpm`: The detected tempo in beats per minute
154/// - `confidence`: The autocorrelation value indicating detection confidence
155///
156/// The results are sorted by confidence (highest first).
157///
158/// # Errors
159///
160/// Returns an error if:
161/// - No suitable audio input device is found
162/// - The audio stream cannot be configured or started
163/// - The device's sample format is unsupported
164///
165/// # Example
166///
167/// ```no_run
168/// use bpm_analyzer::{AnalyzerConfig, begin};
169///
170/// let config = AnalyzerConfig::builder().build();
171/// let receiver = begin(config)?;
172///
173/// for peaks in receiver.iter() {
174///     println!("Top BPM candidate: {:.1}", peaks[0].0);
175/// }
176/// # Ok::<(), bpm_analyzer::Error>(())
177/// ```
178pub fn begin(config: AnalyzerConfig) -> Result<crossbeam_channel::Receiver<[(f32, f32); 5]>> {
179    let host = cpal::default_host();
180
181    let loopback_device = host
182        .input_devices()?
183        .find_map(|device| match device.description() {
184            #[cfg(target_os = "macos")]
185            Ok(desc) if desc.name().contains("BlackHole") => Some(Ok(device)),
186            Err(e) => Some(Err(e)),
187            Ok(_) => None,
188        })
189        .transpose()?
190        .or_else(|| host.default_input_device())
191        .ok_or(Error::NoDeviceFound)?;
192
193    let device_name = loopback_device.description()?.name().to_string();
194
195    tracing::info!("Using audio device: {}", device_name);
196
197    let supported_config = loopback_device.default_input_config()?;
198
199    let mut stream_config = supported_config.config();
200
201    stream_config.buffer_size = BufferSize::Fixed(config.buffer_size);
202
203    let sample_rate = stream_config.sample_rate as f64;
204
205    tracing::info!(
206        "Sampling with {:?} Hz on {} channels",
207        stream_config.sample_rate,
208        stream_config.channels
209    );
210
211    let (audio_sender, audio_receiver) = bounded(config.queue_size);
212    let (bmp_sender, bpm_receiver) = bounded(config.queue_size);
213
214    match supported_config.sample_format() {
215        SampleFormat::F32 => run::<f32>(&loopback_device, &stream_config, audio_sender)?,
216        SampleFormat::I16 => run::<i16>(&loopback_device, &stream_config, audio_sender)?,
217        SampleFormat::U16 => run::<u16>(&loopback_device, &stream_config, audio_sender)?,
218        other => {
219            return Err(Error::UnsupportedSampleFormat(other));
220        }
221    }
222
223    std::thread::spawn(move || run_analysis(sample_rate, audio_receiver, bmp_sender, config));
224
225    Ok(bpm_receiver)
226}
227
228impl<'a> TransientBuffer<'a> {
229    /// Processes this buffer through an audio processing node.
230    ///
231    /// # Returns
232    ///
233    /// A tuple containing the processed buffer and the number of valid samples.
234    fn process<N: AudioUnit>(&'a self, node: &mut N) -> (BufferArray<U1>, usize) {
235        match self {
236            TransientBuffer::Full(buffer_ref) => {
237                let mut buffer = BufferArray::<U1>::new();
238
239                node.process(MAX_BUFFER_SIZE, buffer_ref, &mut buffer.buffer_mut());
240
241                (buffer, MAX_BUFFER_SIZE)
242            }
243            TransientBuffer::Partial { buffer, length } => {
244                let mut output_buffer = BufferArray::<U1>::new();
245
246                node.process(
247                    *length,
248                    &buffer.buffer_ref(),
249                    &mut output_buffer.buffer_mut(),
250                );
251
252                (output_buffer, *length)
253            }
254        }
255    }
256}
257
258/// Main analysis loop that processes audio samples and detects BPM.
259///
260/// This function:
261/// 1. Resamples audio to 22.05 kHz
262/// 2. Accumulates samples in a ring buffer
263/// 3. Performs multi-level discrete wavelet transform
264/// 4. Extracts onset envelopes from each frequency band
265/// 5. Computes autocorrelation to find periodic patterns
266/// 6. Identifies BPM candidates based on peak autocorrelation values
267///
268/// # Arguments
269///
270/// * `sample_rate` - Original audio sample rate
271/// * `audio_receiver` - Channel receiving stereo audio samples
272/// * `bpm_sender` - Channel to send BPM detection results
273/// * `config` - Analyzer configuration
274fn run_analysis(
275    sample_rate: f64,
276    audio_receiver: Receiver<(f32, f32)>,
277    bpm_sender: Sender<[(f32, f32); 5]>,
278    config: AnalyzerConfig,
279) -> Result<()> {
280    let now = std::time::Instant::now();
281
282    let dwt_executor = Osclet::make_daubechies_f32(DaubechiesFamily::Db4, BorderMode::Wrap);
283
284    let mut resampler = ResamplerFir::new(
285        1,
286        SampleRate::Hz48000,
287        SampleRate::Hz22050,
288        Latency::Sample64,
289        Attenuation::Db90,
290    );
291
292    tracing::info!("Resampling buffer: {}", resampler.buffer_size_output());
293
294    let resampling_factor = TARGET_SAMPLING_RATE / sample_rate;
295    tracing::info!(
296        "Resampling factor: {}, every {}th sample",
297        resampling_factor,
298        (sample_rate / TARGET_SAMPLING_RATE).round()
299    );
300
301    let mut ring_buffer = AllocRingBuffer::<f32>::new(config.window_size);
302
303    let once = std::sync::Once::new();
304
305    let mut filter_chain = dsp::alpha_lpf(0.99f32) >> dsp::fwr::<f32>();
306
307    let mut resampled_output = vec![0.0f32; resampler.buffer_size_output()];
308
309    loop {
310        // Read all available audio samples from the channel
311        let input = audio_receiver
312            .try_iter()
313            // Mix to mono
314            .map(|(l, r)| (l + r) * 0.5)
315            .collect::<Vec<_>>();
316
317        let mut input = &input[..];
318
319        while !input.is_empty() {
320            let (consumed, produced) = resampler
321                .resample(input, &mut resampled_output)
322                .map_err(Error::ResampleError)?;
323            ring_buffer.extend(resampled_output[..produced].iter().copied());
324
325            input = &input[consumed..];
326        }
327
328        if ring_buffer.is_full() {
329            once.call_once(|| {
330                let time = now.elapsed();
331                tracing::info!(
332                    "Initial audio buffer filled with {} samples in {:.2?}",
333                    ring_buffer.len(),
334                    time
335                );
336            });
337
338            let signal = ring_buffer.to_vec();
339
340            let dwt = dwt_executor.multi_dwt(&signal, DWT_LEVELS)?;
341
342            let bands = dwt
343                .levels
344                .into_iter()
345                .enumerate()
346                .map(|(i, band)| {
347                    filter_chain.reset();
348
349                    let band = band
350                        .approximations
351                        .chunks(MAX_BUFFER_SIZE)
352                        .map(|chunk| {
353                            if chunk.len() == MAX_BUFFER_SIZE {
354                                let buffer = unsafe {
355                                    std::slice::from_raw_parts::<'_, F32x>(
356                                        chunk.as_ptr() as *const _,
357                                        MAX_BUFFER_SIZE / SIMD_LEN,
358                                    )
359                                };
360
361                                TransientBuffer::Full(BufferRef::new(buffer))
362                            } else {
363                                let mut buffer = BufferArray::<U1>::new();
364                                buffer.channel_f32_mut(0)[..chunk.len()].copy_from_slice(chunk);
365
366                                TransientBuffer::Partial {
367                                    buffer,
368                                    length: chunk.len(),
369                                }
370                            }
371                        })
372                        .map(|buffer| buffer.process(&mut filter_chain))
373                        .flat_map(|(mut band, length)| {
374                            band.channel_f32(0)
375                                .iter()
376                                .copied()
377                                .take(length)
378                                .collect::<Vec<_>>()
379                        })
380                        .collect::<Vec<_>>();
381
382                    let downsampling_factor = 1 << (3 - i);
383
384                    let mut band = (0..band.len())
385                        .step_by(downsampling_factor)
386                        .map(|i| band[i])
387                        .collect::<Vec<_>>();
388
389                    let mean = band.iter().copied().sum::<f32>() / band.len() as f32;
390
391                    // Normalization in each band (mean removal)
392                    band.iter_mut().for_each(|sample| *sample -= mean);
393
394                    band.resize(4096, 0.0);
395
396                    band
397                })
398                .collect::<Vec<_>>();
399
400            let summed_bands = (0..4096)
401                .map(|i| bands.iter().map(|band| band[i]).sum::<f32>())
402                .collect::<Vec<_>>();
403
404            let min_lag = ((4096.0 / 3.0) * 60.0 / config.max_bpm) as usize;
405            let max_lag = ((4096.0 / 3.0) * 60.0 / config.min_bpm) as usize;
406
407            let ac = autocorrelation(&summed_bands, max_lag);
408
409            let mut peaks = ac
410                .into_iter()
411                .enumerate()
412                .skip(min_lag)
413                .take(max_lag - min_lag)
414                .collect::<Vec<_>>();
415            peaks.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
416
417            let peaks = &peaks[1..6];
418
419            let peaks = peaks
420                .iter()
421                .copied()
422                .map(|(lag, v)| ((60.0 * (4096.0 / 3.0)) / (lag as f32), v))
423                .collect::<Vec<_>>();
424
425            if let Ok(()) = bpm_sender.try_send(peaks.try_into().unwrap()) {}
426        }
427    }
428}
429
430/// Starts the audio input stream on the given device with the given configuration,
431/// sending audio samples to the provided channel sender.
432fn run<T>(
433    device: &cpal::Device,
434    config: &cpal::StreamConfig,
435    sender: Sender<(f32, f32)>,
436) -> Result<()>
437where
438    T: SizedSample,
439    f32: FromSample<T>,
440{
441    let channels = config.channels as usize;
442    let err_fn = |err| tracing::error!("an error occurred on stream: {}", err);
443    let stream = device.build_input_stream(
444        config,
445        move |data: &[T], _: &cpal::InputCallbackInfo| read_data(data, channels, sender.clone()),
446        err_fn,
447        None,
448    );
449    if let Ok(stream) = stream
450        && let Ok(()) = stream.play()
451    {
452        std::mem::forget(stream);
453    }
454
455    tracing::info!("Input stream built.");
456
457    Ok(())
458}
459
460/// Callback function to read audio data from the input device.
461/// and sends it to the provided channel sender.
462fn read_data<T>(input: &[T], channels: usize, sender: Sender<(f32, f32)>)
463where
464    T: SizedSample,
465    f32: FromSample<T>,
466{
467    for frame in input.chunks(channels) {
468        let mut left = 0.0;
469        let mut right = 0.0;
470        for (channel, sample) in frame.iter().enumerate() {
471            if channel & 1 == 0 {
472                left = sample.to_sample::<f32>();
473            } else {
474                right = sample.to_sample::<f32>();
475            }
476        }
477
478        if let Ok(()) = sender.try_send((left, right)) {}
479    }
480}
481
482/// Computes the autocorrelation of a signal for lags [0, max_lag).
483/// Returns a Vec<f32> of length max_lag.
484///
485/// signal: input signal (e.g., summed band envelopes)
486/// max_lag: maximum lag in samples to compute
487pub fn autocorrelation(signal: &[f32], max_lag: usize) -> Vec<f32> {
488    let n = signal.len();
489
490    let max_lag = std::cmp::Ord::min(max_lag, n);
491
492    let mut ac = vec![0.0f32; max_lag];
493
494    for lag in 0..max_lag {
495        let mut sum = 0.0f32;
496
497        for i in 0..(n - lag) {
498            sum += signal[i] * signal[i + lag];
499        }
500
501        ac[lag] = sum / n as f32;
502    }
503
504    ac
505}