bpm_analyzer/
analyzer.rs

1//! Core BPM analysis functionality.
2
3use cpal::{
4    BufferSize, FromSample, SampleFormat, SizedSample,
5    traits::{DeviceTrait, HostTrait, StreamTrait},
6};
7use crossbeam_channel::{Receiver, Sender, bounded};
8use fundsp::prelude32::*;
9use osclet::{BorderMode, DaubechiesFamily, Osclet};
10use resampler::{Attenuation, Latency, ResamplerFir, SampleRate};
11use ringbuffer::{AllocRingBuffer, RingBuffer};
12
13use crate::{
14    config::{AnalyzerConfig, DWT_LEVELS, TARGET_SAMPLING_RATE},
15    dsp,
16    error::{Error, Result},
17    types::{BeatTiming, BpmDetection},
18};
19
20/// Internal buffer wrapper for processing audio chunks of varying sizes.
21///
22/// This enum handles both full-sized buffers (aligned for SIMD processing)
23/// and partial buffers (for the last chunk that may be smaller than the max buffer size).
24#[allow(clippy::large_enum_variant)]
25enum TransientBuffer<'a> {
26    /// A full buffer that can be processed using SIMD operations
27    Full(BufferRef<'a>),
28    /// A partial buffer for the remaining samples
29    Partial {
30        /// The buffer containing the samples
31        buffer: BufferArray<U1>,
32        /// The actual number of valid samples in the buffer
33        length: usize,
34    },
35}
36
37impl<'a> TransientBuffer<'a> {
38    /// Processes this buffer through an audio processing node.
39    ///
40    /// # Returns
41    ///
42    /// A tuple containing the processed buffer and the number of valid samples.
43    fn process<N: AudioUnit>(&'a self, node: &mut N) -> (BufferArray<U1>, usize) {
44        match self {
45            TransientBuffer::Full(buffer_ref) => {
46                let mut buffer = BufferArray::<U1>::new();
47
48                node.process(MAX_BUFFER_SIZE, buffer_ref, &mut buffer.buffer_mut());
49
50                (buffer, MAX_BUFFER_SIZE)
51            }
52            TransientBuffer::Partial { buffer, length } => {
53                let mut output_buffer = BufferArray::<U1>::new();
54
55                node.process(
56                    *length,
57                    &buffer.buffer_ref(),
58                    &mut output_buffer.buffer_mut(),
59                );
60
61                (output_buffer, *length)
62            }
63        }
64    }
65}
66
67/// Starts the BPM analyzer with the given configuration using the default device.
68///
69/// This automatically selects an audio input device, preferring BlackHole on macOS,
70/// then falling back to the system default input device.
71///
72/// # Arguments
73///
74/// * `config` - The analyzer configuration
75///
76/// # Returns
77///
78/// A `Receiver` that yields `BpmDetection` results containing up to 5 BPM candidates
79/// sorted by confidence (highest first).
80///
81/// # Errors
82///
83/// Returns an error if:
84/// - The configuration is invalid
85/// - No suitable audio input device is found
86/// - The audio stream cannot be configured or started
87/// - The device's sample format is unsupported
88///
89/// # Example
90///
91/// ```no_run
92/// use bpm_analyzer::{AnalyzerConfig, begin};
93///
94/// let config = AnalyzerConfig::electronic();
95/// let receiver = begin(config)?;
96///
97/// for detection in receiver.iter() {
98///     if let Some(bpm) = detection.bpm() {
99///         println!("Detected BPM: {:.1}", bpm);
100///     }
101/// }
102/// # Ok::<(), bpm_analyzer::Error>(())
103/// ```
104pub fn begin(config: AnalyzerConfig) -> Result<Receiver<BpmDetection>> {
105    // Validate configuration
106    config.validate()?;
107    let host = cpal::default_host();
108
109    let device = host
110        .input_devices()?
111        .find_map(|device| match device.description() {
112            #[cfg(target_os = "macos")]
113            Ok(desc) if desc.name().contains("BlackHole") => Some(Ok(device)),
114            Err(e) => Some(Err(e)),
115            Ok(_) => None,
116        })
117        .transpose()?
118        .or_else(|| host.default_input_device())
119        .ok_or(Error::NoDeviceFound)?;
120
121    begin_with_device(config, &device)
122}
123
124/// Starts the BPM analyzer with the given configuration and a specific audio device.
125///
126/// Use this function when you need to specify which audio device to use.
127/// To list available devices, use [`list_audio_devices`](crate::list_audio_devices).
128/// To get a device by name, use [`get_device_by_name`](crate::get_device_by_name).
129///
130/// # Arguments
131///
132/// * `config` - The analyzer configuration
133/// * `device` - The audio input device to use
134///
135/// # Returns
136///
137/// A `Receiver` that yields `BpmDetection` results containing up to 5 BPM candidates
138/// sorted by confidence (highest first).
139///
140/// # Errors
141///
142/// Returns an error if:
143/// - The configuration is invalid
144/// - The audio stream cannot be configured or started
145/// - The device's sample format is unsupported
146///
147/// # Example
148///
149/// ```no_run
150/// use bpm_analyzer::{AnalyzerConfig, begin_with_device, get_device_by_name};
151///
152/// let config = AnalyzerConfig::electronic();
153/// let device = get_device_by_name("BlackHole 2ch")?;
154/// let receiver = begin_with_device(config, &device)?;
155///
156/// for detection in receiver.iter() {
157///     if let Some(bpm) = detection.bpm() {
158///         println!("Detected BPM: {:.1}", bpm);
159///     }
160/// }
161/// # Ok::<(), bpm_analyzer::Error>(())
162/// ```
163pub fn begin_with_device(
164    config: AnalyzerConfig,
165    device: &cpal::Device,
166) -> Result<Receiver<BpmDetection>> {
167    // Validate configuration
168    config.validate()?;
169
170    let device_name = device.description()?.name().to_string();
171
172    tracing::info!("Using audio device: {}", device_name);
173
174    let supported_config = device.default_input_config()?;
175
176    let mut stream_config = supported_config.config();
177
178    stream_config.buffer_size = BufferSize::Fixed(config.buffer_size());
179
180    let sample_rate = stream_config.sample_rate as f64;
181
182    tracing::info!(
183        "Sampling with {:?} Hz on {} channels",
184        stream_config.sample_rate,
185        stream_config.channels
186    );
187
188    let (audio_sender, audio_receiver) = bounded(config.queue_size());
189    let (bpm_sender, bpm_receiver) = bounded(config.queue_size());
190
191    match supported_config.sample_format() {
192        SampleFormat::F32 => run::<f32>(device, &stream_config, audio_sender)?,
193        SampleFormat::I16 => run::<i16>(device, &stream_config, audio_sender)?,
194        SampleFormat::U16 => run::<u16>(device, &stream_config, audio_sender)?,
195        other => {
196            return Err(Error::UnsupportedSampleFormat(other));
197        }
198    }
199
200    std::thread::spawn(move || run_analysis(sample_rate, audio_receiver, bpm_sender, config));
201
202    Ok(bpm_receiver)
203}
204
205/// Main analysis loop that processes audio samples and detects BPM.
206///
207/// This function:
208/// 1. Resamples audio to 22.05 kHz
209/// 2. Accumulates samples in a ring buffer
210/// 3. Performs multi-level discrete wavelet transform
211/// 4. Extracts onset envelopes from each frequency band
212/// 5. Computes autocorrelation to find periodic patterns
213/// 6. Identifies BPM candidates based on peak autocorrelation values
214///
215/// # Arguments
216///
217/// * `sample_rate` - Original audio sample rate
218/// * `audio_receiver` - Channel receiving stereo audio samples
219/// * `bpm_sender` - Channel to send BPM detection results
220/// * `config` - Analyzer configuration
221fn run_analysis(
222    sample_rate: f64,
223    audio_receiver: Receiver<(f32, f32)>,
224    bpm_sender: Sender<BpmDetection>,
225    config: AnalyzerConfig,
226) -> Result<()> {
227    let now = std::time::Instant::now();
228
229    let dwt_executor = Osclet::make_daubechies_f32(DaubechiesFamily::Db4, BorderMode::Wrap);
230
231    // Create resampler based on actual sample rate
232    let input_sample_rate = match sample_rate as u32 {
233        16000 => SampleRate::Hz16000,
234        22050 => SampleRate::Hz22050,
235        32000 => SampleRate::Hz32000,
236        44100 => SampleRate::Hz44100,
237        48000 => SampleRate::Hz48000,
238        88200 => SampleRate::Hz88200,
239        96000 => SampleRate::Hz96000,
240        176400 => SampleRate::Hz176400,
241        192000 => SampleRate::Hz192000,
242        _ => return Err(Error::UnsupportedSampleRate(sample_rate as u32)),
243    };
244
245    let mut resampler = ResamplerFir::new(
246        1,
247        input_sample_rate,
248        SampleRate::Hz22050,
249        Latency::Sample64,
250        Attenuation::Db90,
251    );
252
253    tracing::info!("Resampling buffer: {}", resampler.buffer_size_output());
254
255    let resampling_factor = TARGET_SAMPLING_RATE / sample_rate;
256    let window_length = config.window_size() as f64 / TARGET_SAMPLING_RATE;
257
258    tracing::info!(
259        "Analysis window: {} samples ({:.2} seconds)",
260        config.window_size(),
261        window_length
262    );
263
264    tracing::info!(
265        "Resampling factor: {}, every {}th sample",
266        resampling_factor,
267        (sample_rate / TARGET_SAMPLING_RATE).round()
268    );
269
270    let mut ring_buffer = AllocRingBuffer::<f32>::new(config.window_size());
271
272    let once = std::sync::Once::new();
273
274    let mut filter_chain = dsp::alpha_lpf(0.99f32) >> dsp::fwr::<f32>();
275
276    let mut resampled_output = vec![0.0f32; resampler.buffer_size_output()];
277
278    // Pre-allocate buffers to reduce allocations in hot loop
279    let mut input_buffer = Vec::with_capacity(4096);
280    let mut signal = vec![0.0f32; config.window_size()];
281    let mut bands = vec![vec![0.0f32; 4096]; DWT_LEVELS];
282    let mut summed_bands = vec![0.0f32; 4096];
283    let mut peaks_buffer = Vec::with_capacity(1024);
284    
285    // Beat detection state
286    let mut beat_timings: Vec<BeatTiming> = Vec::with_capacity(8);
287    let mut prev_summed_bands = vec![0.0f32; 4096];
288    let mut samples_processed = 0usize;
289    let mut current_bpm: Option<f32> = None;
290
291    loop {
292        // Read all available audio samples from the channel
293        input_buffer.clear();
294        input_buffer.extend(
295            audio_receiver
296                .try_iter()
297                // Mix to mono
298                .map(|(l, r)| (l + r) * 0.5),
299        );
300
301        let mut input_slice = &input_buffer[..];
302
303        while !input_slice.is_empty() {
304            let (consumed, produced) = resampler
305                .resample(input_slice, &mut resampled_output)
306                .map_err(Error::ResampleError)?;
307            ring_buffer.extend(resampled_output[..produced].iter().copied());
308            samples_processed += produced;
309
310            input_slice = &input_slice[consumed..];
311        }
312
313        if ring_buffer.is_full() {
314            once.call_once(|| {
315                let time = now.elapsed();
316                tracing::info!(
317                    "Initial audio buffer filled with {} samples in {:.2?}",
318                    ring_buffer.len(),
319                    time
320                );
321            });
322
323            // Copy ring buffer into pre-allocated signal buffer
324            signal = ring_buffer.to_vec();
325
326            let dwt = dwt_executor.multi_dwt(&signal, DWT_LEVELS)?;
327
328            // Process each band and store in pre-allocated buffers
329            for (band_idx, level) in dwt.levels.into_iter().enumerate() {
330                filter_chain.reset();
331
332                // Process approximations through filter chain
333                let mut processed_samples = Vec::with_capacity(level.approximations.len());
334
335                for chunk in level.approximations.chunks(MAX_BUFFER_SIZE) {
336                    let transient_buffer = if chunk.len() == MAX_BUFFER_SIZE {
337                        let buffer = unsafe {
338                            std::slice::from_raw_parts::<'_, F32x>(
339                                chunk.as_ptr() as *const _,
340                                MAX_BUFFER_SIZE / SIMD_LEN,
341                            )
342                        };
343                        TransientBuffer::Full(BufferRef::new(buffer))
344                    } else {
345                        let mut buffer = BufferArray::<U1>::new();
346                        buffer.channel_f32_mut(0)[..chunk.len()].copy_from_slice(chunk);
347                        TransientBuffer::Partial {
348                            buffer,
349                            length: chunk.len(),
350                        }
351                    };
352
353                    let (mut output, length) = transient_buffer.process(&mut filter_chain);
354                    processed_samples.extend_from_slice(&output.channel_f32(0)[..length]);
355                }
356
357                let downsampling_factor = 1 << (DWT_LEVELS - 1 - band_idx);
358
359                // Downsample and store in pre-allocated band buffer
360                let band_buffer = &mut bands[band_idx];
361                band_buffer.fill(0.0);
362
363                let downsampled_len = processed_samples.len() / downsampling_factor;
364                let samples_to_copy = std::cmp::min(downsampled_len, 4096);
365
366                for (i, sample_idx) in (0..processed_samples.len())
367                    .step_by(downsampling_factor)
368                    .take(samples_to_copy)
369                    .enumerate()
370                {
371                    band_buffer[i] = processed_samples[sample_idx];
372                }
373
374                // Normalization (mean removal)
375                if samples_to_copy > 0 {
376                    let mean: f32 =
377                        band_buffer[..samples_to_copy].iter().sum::<f32>() / samples_to_copy as f32;
378                    band_buffer[..samples_to_copy]
379                        .iter_mut()
380                        .for_each(|sample| *sample -= mean);
381                }
382            }
383
384            // Sum bands into pre-allocated buffer with frequency weighting
385            // Weight higher frequencies more for better beat detection across the spectrum
386            // Band 0 (highest): 2.0x weight for high-frequency transients (hi-hats, cymbals)
387            // Band 1 (high):    1.5x weight for mid-high transients (snares)
388            // Band 2 (mid):     1.0x weight for mid-range (claps, vocals)
389            // Band 3 (low):     0.5x weight for low-end (kicks) - reduced to balance energy
390            summed_bands.fill(0.0);
391            for i in 0..4096 {
392                summed_bands[i] = bands[0][i] * 2.0  // Highest frequencies
393                                + bands[1][i] * 1.5  // High frequencies
394                                + bands[2][i] * 1.0  // Mid frequencies
395                                + bands[3][i] * 0.5; // Low frequencies
396            }
397            
398            // Simple but effective beat detection using energy differences
399            // Calculate onset strength for each point and find peaks
400            let mut onset_strengths = Vec::with_capacity(4096);
401            
402            for i in 0..summed_bands.len() {
403                // Calculate increase in energy from previous frame
404                let onset = (summed_bands[i] - prev_summed_bands[i]).max(0.0);
405                onset_strengths.push(onset);
406            }
407            
408            // Use 90th percentile instead of median for stricter threshold
409            let mut sorted_onsets = onset_strengths.clone();
410            sorted_onsets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
411            let percentile_90 = sorted_onsets[(sorted_onsets.len() * 9) / 10];
412            
413            // Threshold is 1.5x the 90th percentile - captures strong onsets only
414            let threshold = (percentile_90 * 1.5).max(0.05);
415            
416            // Peak picking: find prominent local maxima
417            for i in 10..(onset_strengths.len() - 10) {
418                let current = onset_strengths[i];
419                
420                // Stricter local maximum check over a larger 21-sample window
421                // Must be stronger than all neighbors within 10 samples
422                let is_local_max = (i.saturating_sub(10)..i).all(|j| current > onset_strengths[j])
423                    && ((i + 1)..=std::cmp::min(i + 10, onset_strengths.len() - 1)).all(|j| current >= onset_strengths[j]);
424                
425                if is_local_max && current > threshold {
426                    // Calculate the time of this beat
427                    let beat_sample = samples_processed - config.window_size() + (i * (config.window_size() / 4096));
428                    let beat_time = beat_sample as f64 / TARGET_SAMPLING_RATE;
429                    
430                    // Tempo-based validation: only enforce after we have at least 3 beats
431                    // This allows the system to bootstrap and adapt to tempo changes
432                    let mut tempo_valid = beat_timings.len() < 3; // Always allow first 3 beats
433                    
434                    if !tempo_valid {
435                        if let Some(bpm) = current_bpm {
436                            if let Some(last_beat) = beat_timings.last() {
437                                let interval = beat_time - last_beat.time_seconds;
438                                let expected_interval = 60.0 / bpm as f64;
439                                
440                                // Allow ±30% deviation from expected interval
441                                // Also accept half-time (2x interval) and double-time (0.5x interval)
442                                let deviation = (interval / expected_interval).abs();
443                                tempo_valid = (0.7..=1.3).contains(&deviation)  // Normal time (±30%)
444                                    || (1.7..=2.3).contains(&deviation)         // Half time
445                                    || (0.4..=0.6).contains(&deviation);        // Double time
446                            } else {
447                                tempo_valid = true; // No previous beat to compare
448                            }
449                        } else {
450                            tempo_valid = true; // No BPM reference yet
451                        }
452                    }
453                    
454                    // Add beat if it's not too close to the previous one (0.15s min spacing)
455                    let should_add = beat_timings.last()
456                        .map(|last: &BeatTiming| (beat_time - last.time_seconds) > 0.15)
457                        .unwrap_or(true);
458                    
459                    if should_add && tempo_valid {
460                        let normalized_strength = (current / threshold).min(2.0);
461                        beat_timings.push(BeatTiming::new(beat_time, normalized_strength));
462                        
463                        // Keep only the last 8 beats
464                        if beat_timings.len() > 8 {
465                            beat_timings.remove(0);
466                        }
467                    }
468                }
469            }
470            
471            // Store current bands for next iteration
472            prev_summed_bands.copy_from_slice(&summed_bands);
473
474            let min_lag = ((4096.0 / window_length) * 60.0 / config.max_bpm() as f64) as usize;
475            let max_lag = ((4096.0 / window_length) * 60.0 / config.min_bpm() as f64) as usize;
476
477            let ac = autocorrelation(&summed_bands, max_lag);
478
479            // Reuse peaks buffer
480            peaks_buffer.clear();
481            peaks_buffer.extend(
482                ac.iter()
483                    .enumerate()
484                    .skip(min_lag)
485                    .take(max_lag - min_lag)
486                    .map(|(idx, &val)| (idx, val)),
487            );
488            peaks_buffer
489                .sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
490
491            // Skip the first peak (which is usually the zero-lag) and take up to 5 peaks
492            let peak_count = std::cmp::Ord::min(peaks_buffer.len().saturating_sub(1), 5);
493            if peak_count > 0 {
494                let mut result = [(0.0f32, 0.0f32); 5];
495                for (i, &(lag, v)) in peaks_buffer[1..=peak_count].iter().enumerate() {
496                    let bpm = (60.0 * (4096.0 / window_length as f32)) / (lag as f32);
497                    result[i] = (bpm, v);
498                }
499                
500                // Update current BPM for beat validation (use the highest confidence BPM)
501                // If BPM changes significantly (>10%), reset beat timings to allow re-syncing
502                if result[0].0 > 0.0 {
503                    if let Some(old_bpm) = current_bpm {
504                        let bpm_change = ((result[0].0 - old_bpm) / old_bpm).abs();
505                        if bpm_change > 0.1 {
506                            // Significant BPM change detected (song change or tempo shift)
507                            // Keep only the last beat to maintain continuity
508                            if !beat_timings.is_empty() {
509                                let last = beat_timings.last().cloned().unwrap();
510                                beat_timings.clear();
511                                beat_timings.push(last);
512                            }
513                        }
514                    }
515                    current_bpm = Some(result[0].0);
516                }
517
518                let _ = bpm_sender.try_send(BpmDetection::with_beats(result, beat_timings.clone()));
519            }
520        }
521    }
522}
523
524/// Starts the audio input stream on the given device with the given configuration,
525/// sending audio samples to the provided channel sender.
526fn run<T>(
527    device: &cpal::Device,
528    config: &cpal::StreamConfig,
529    sender: Sender<(f32, f32)>,
530) -> Result<()>
531where
532    T: SizedSample,
533    f32: FromSample<T>,
534{
535    let channels = config.channels as usize;
536    let err_fn = |err| tracing::error!("an error occurred on stream: {}", err);
537    let stream = device.build_input_stream(
538        config,
539        move |data: &[T], _: &cpal::InputCallbackInfo| read_data(data, channels, sender.clone()),
540        err_fn,
541        None,
542    );
543    if let Ok(stream) = stream
544        && let Ok(()) = stream.play()
545    {
546        std::mem::forget(stream);
547    }
548
549    tracing::info!("Input stream built.");
550
551    Ok(())
552}
553
554/// Callback function to read audio data from the input device
555/// and sends it to the provided channel sender.
556fn read_data<T>(input: &[T], channels: usize, sender: Sender<(f32, f32)>)
557where
558    T: SizedSample,
559    f32: FromSample<T>,
560{
561    for frame in input.chunks(channels) {
562        let left = if !frame.is_empty() {
563            frame[0].to_sample::<f32>()
564        } else {
565            0.0
566        };
567
568        let right = if channels > 1 && frame.len() > 1 {
569            frame[1].to_sample::<f32>()
570        } else {
571            left // For mono, duplicate to both channels
572        };
573
574        let _ = sender.try_send((left, right));
575    }
576}
577
578/// Computes the autocorrelation of a signal for lags [0, max_lag).
579///
580/// # Arguments
581///
582/// * `signal` - Input signal (e.g., summed band envelopes)
583/// * `max_lag` - Maximum lag in samples to compute
584///
585/// # Returns
586///
587/// A Vec<f32> of autocorrelation values for each lag
588fn autocorrelation(signal: &[f32], max_lag: usize) -> Vec<f32> {
589    let n = signal.len();
590
591    let max_lag = std::cmp::Ord::min(max_lag, n);
592
593    let mut ac = vec![0.0f32; max_lag];
594
595    for lag in 0..max_lag {
596        let mut sum = 0.0f32;
597
598        for i in 0..(n - lag) {
599            sum += signal[i] * signal[i + lag];
600        }
601
602        ac[lag] = sum / n as f32;
603    }
604
605    ac
606}