1use cpal::{
4 BufferSize, FromSample, SampleFormat, SizedSample,
5 traits::{DeviceTrait, HostTrait, StreamTrait},
6};
7use crossbeam_channel::{Receiver, Sender, bounded};
8use fundsp::prelude32::*;
9use osclet::{BorderMode, DaubechiesFamily, Osclet};
10use resampler::{Attenuation, Latency, ResamplerFir, SampleRate};
11use ringbuffer::{AllocRingBuffer, RingBuffer};
12
13use crate::{
14 config::{AnalyzerConfig, DWT_LEVELS, TARGET_SAMPLING_RATE},
15 dsp,
16 error::{Error, Result},
17 types::{BeatTiming, BpmDetection},
18};
19
20#[allow(clippy::large_enum_variant)]
25enum TransientBuffer<'a> {
26 Full(BufferRef<'a>),
28 Partial {
30 buffer: BufferArray<U1>,
32 length: usize,
34 },
35}
36
37impl<'a> TransientBuffer<'a> {
38 fn process<N: AudioUnit>(&'a self, node: &mut N) -> (BufferArray<U1>, usize) {
44 match self {
45 TransientBuffer::Full(buffer_ref) => {
46 let mut buffer = BufferArray::<U1>::new();
47
48 node.process(MAX_BUFFER_SIZE, buffer_ref, &mut buffer.buffer_mut());
49
50 (buffer, MAX_BUFFER_SIZE)
51 }
52 TransientBuffer::Partial { buffer, length } => {
53 let mut output_buffer = BufferArray::<U1>::new();
54
55 node.process(
56 *length,
57 &buffer.buffer_ref(),
58 &mut output_buffer.buffer_mut(),
59 );
60
61 (output_buffer, *length)
62 }
63 }
64 }
65}
66
67pub fn begin(config: AnalyzerConfig) -> Result<Receiver<BpmDetection>> {
105 config.validate()?;
107 let host = cpal::default_host();
108
109 let device = host
110 .input_devices()?
111 .find_map(|device| match device.description() {
112 #[cfg(target_os = "macos")]
113 Ok(desc) if desc.name().contains("BlackHole") => Some(Ok(device)),
114 Err(e) => Some(Err(e)),
115 Ok(_) => None,
116 })
117 .transpose()?
118 .or_else(|| host.default_input_device())
119 .ok_or(Error::NoDeviceFound)?;
120
121 begin_with_device(config, &device)
122}
123
124pub fn begin_with_device(
164 config: AnalyzerConfig,
165 device: &cpal::Device,
166) -> Result<Receiver<BpmDetection>> {
167 config.validate()?;
169
170 let device_name = device.description()?.name().to_string();
171
172 tracing::info!("Using audio device: {}", device_name);
173
174 let supported_config = device.default_input_config()?;
175
176 let mut stream_config = supported_config.config();
177
178 stream_config.buffer_size = BufferSize::Fixed(config.buffer_size());
179
180 let sample_rate = stream_config.sample_rate as f64;
181
182 tracing::info!(
183 "Sampling with {:?} Hz on {} channels",
184 stream_config.sample_rate,
185 stream_config.channels
186 );
187
188 let (audio_sender, audio_receiver) = bounded(config.queue_size());
189 let (bpm_sender, bpm_receiver) = bounded(config.queue_size());
190
191 match supported_config.sample_format() {
192 SampleFormat::F32 => run::<f32>(device, &stream_config, audio_sender)?,
193 SampleFormat::I16 => run::<i16>(device, &stream_config, audio_sender)?,
194 SampleFormat::U16 => run::<u16>(device, &stream_config, audio_sender)?,
195 other => {
196 return Err(Error::UnsupportedSampleFormat(other));
197 }
198 }
199
200 std::thread::spawn(move || run_analysis(sample_rate, audio_receiver, bpm_sender, config));
201
202 Ok(bpm_receiver)
203}
204
205fn run_analysis(
222 sample_rate: f64,
223 audio_receiver: Receiver<(f32, f32)>,
224 bpm_sender: Sender<BpmDetection>,
225 config: AnalyzerConfig,
226) -> Result<()> {
227 let now = std::time::Instant::now();
228
229 let dwt_executor = Osclet::make_daubechies_f32(DaubechiesFamily::Db4, BorderMode::Wrap);
230
231 let input_sample_rate = match sample_rate as u32 {
233 16000 => SampleRate::Hz16000,
234 22050 => SampleRate::Hz22050,
235 32000 => SampleRate::Hz32000,
236 44100 => SampleRate::Hz44100,
237 48000 => SampleRate::Hz48000,
238 88200 => SampleRate::Hz88200,
239 96000 => SampleRate::Hz96000,
240 176400 => SampleRate::Hz176400,
241 192000 => SampleRate::Hz192000,
242 _ => return Err(Error::UnsupportedSampleRate(sample_rate as u32)),
243 };
244
245 let mut resampler = ResamplerFir::new(
246 1,
247 input_sample_rate,
248 SampleRate::Hz22050,
249 Latency::Sample64,
250 Attenuation::Db90,
251 );
252
253 tracing::info!("Resampling buffer: {}", resampler.buffer_size_output());
254
255 let resampling_factor = TARGET_SAMPLING_RATE / sample_rate;
256 let window_length = config.window_size() as f64 / TARGET_SAMPLING_RATE;
257
258 tracing::info!(
259 "Analysis window: {} samples ({:.2} seconds)",
260 config.window_size(),
261 window_length
262 );
263
264 tracing::info!(
265 "Resampling factor: {}, every {}th sample",
266 resampling_factor,
267 (sample_rate / TARGET_SAMPLING_RATE).round()
268 );
269
270 let mut ring_buffer = AllocRingBuffer::<f32>::new(config.window_size());
271
272 let once = std::sync::Once::new();
273
274 let mut filter_chain = dsp::alpha_lpf(0.99f32) >> dsp::fwr::<f32>();
275
276 let mut resampled_output = vec![0.0f32; resampler.buffer_size_output()];
277
278 let mut input_buffer = Vec::with_capacity(4096);
280 let mut signal = vec![0.0f32; config.window_size()];
281 let mut bands = vec![vec![0.0f32; 4096]; DWT_LEVELS];
282 let mut summed_bands = vec![0.0f32; 4096];
283 let mut peaks_buffer = Vec::with_capacity(1024);
284
285 let mut beat_timings: Vec<BeatTiming> = Vec::with_capacity(8);
287 let mut prev_summed_bands = vec![0.0f32; 4096];
288 let mut samples_processed = 0usize;
289 let mut current_bpm: Option<f32> = None;
290
291 loop {
292 input_buffer.clear();
294 input_buffer.extend(
295 audio_receiver
296 .try_iter()
297 .map(|(l, r)| (l + r) * 0.5),
299 );
300
301 let mut input_slice = &input_buffer[..];
302
303 while !input_slice.is_empty() {
304 let (consumed, produced) = resampler
305 .resample(input_slice, &mut resampled_output)
306 .map_err(Error::ResampleError)?;
307 ring_buffer.extend(resampled_output[..produced].iter().copied());
308 samples_processed += produced;
309
310 input_slice = &input_slice[consumed..];
311 }
312
313 if ring_buffer.is_full() {
314 once.call_once(|| {
315 let time = now.elapsed();
316 tracing::info!(
317 "Initial audio buffer filled with {} samples in {:.2?}",
318 ring_buffer.len(),
319 time
320 );
321 });
322
323 signal = ring_buffer.to_vec();
325
326 let dwt = dwt_executor.multi_dwt(&signal, DWT_LEVELS)?;
327
328 for (band_idx, level) in dwt.levels.into_iter().enumerate() {
330 filter_chain.reset();
331
332 let mut processed_samples = Vec::with_capacity(level.approximations.len());
334
335 for chunk in level.approximations.chunks(MAX_BUFFER_SIZE) {
336 let transient_buffer = if chunk.len() == MAX_BUFFER_SIZE {
337 let buffer = unsafe {
338 std::slice::from_raw_parts::<'_, F32x>(
339 chunk.as_ptr() as *const _,
340 MAX_BUFFER_SIZE / SIMD_LEN,
341 )
342 };
343 TransientBuffer::Full(BufferRef::new(buffer))
344 } else {
345 let mut buffer = BufferArray::<U1>::new();
346 buffer.channel_f32_mut(0)[..chunk.len()].copy_from_slice(chunk);
347 TransientBuffer::Partial {
348 buffer,
349 length: chunk.len(),
350 }
351 };
352
353 let (mut output, length) = transient_buffer.process(&mut filter_chain);
354 processed_samples.extend_from_slice(&output.channel_f32(0)[..length]);
355 }
356
357 let downsampling_factor = 1 << (DWT_LEVELS - 1 - band_idx);
358
359 let band_buffer = &mut bands[band_idx];
361 band_buffer.fill(0.0);
362
363 let downsampled_len = processed_samples.len() / downsampling_factor;
364 let samples_to_copy = std::cmp::min(downsampled_len, 4096);
365
366 for (i, sample_idx) in (0..processed_samples.len())
367 .step_by(downsampling_factor)
368 .take(samples_to_copy)
369 .enumerate()
370 {
371 band_buffer[i] = processed_samples[sample_idx];
372 }
373
374 if samples_to_copy > 0 {
376 let mean: f32 =
377 band_buffer[..samples_to_copy].iter().sum::<f32>() / samples_to_copy as f32;
378 band_buffer[..samples_to_copy]
379 .iter_mut()
380 .for_each(|sample| *sample -= mean);
381 }
382 }
383
384 summed_bands.fill(0.0);
391 for i in 0..4096 {
392 summed_bands[i] = bands[0][i] * 2.0 + bands[1][i] * 1.5 + bands[2][i] * 1.0 + bands[3][i] * 0.5; }
397
398 let mut onset_strengths = Vec::with_capacity(4096);
401
402 for i in 0..summed_bands.len() {
403 let onset = (summed_bands[i] - prev_summed_bands[i]).max(0.0);
405 onset_strengths.push(onset);
406 }
407
408 let mut sorted_onsets = onset_strengths.clone();
410 sorted_onsets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
411 let percentile_90 = sorted_onsets[(sorted_onsets.len() * 9) / 10];
412
413 let threshold = (percentile_90 * 1.5).max(0.05);
415
416 for i in 10..(onset_strengths.len() - 10) {
418 let current = onset_strengths[i];
419
420 let is_local_max = (i.saturating_sub(10)..i).all(|j| current > onset_strengths[j])
423 && ((i + 1)..=std::cmp::min(i + 10, onset_strengths.len() - 1)).all(|j| current >= onset_strengths[j]);
424
425 if is_local_max && current > threshold {
426 let beat_sample = samples_processed - config.window_size() + (i * (config.window_size() / 4096));
428 let beat_time = beat_sample as f64 / TARGET_SAMPLING_RATE;
429
430 let mut tempo_valid = beat_timings.len() < 3; if !tempo_valid {
435 if let Some(bpm) = current_bpm {
436 if let Some(last_beat) = beat_timings.last() {
437 let interval = beat_time - last_beat.time_seconds;
438 let expected_interval = 60.0 / bpm as f64;
439
440 let deviation = (interval / expected_interval).abs();
443 tempo_valid = (0.7..=1.3).contains(&deviation) || (1.7..=2.3).contains(&deviation) || (0.4..=0.6).contains(&deviation); } else {
447 tempo_valid = true; }
449 } else {
450 tempo_valid = true; }
452 }
453
454 let should_add = beat_timings.last()
456 .map(|last: &BeatTiming| (beat_time - last.time_seconds) > 0.15)
457 .unwrap_or(true);
458
459 if should_add && tempo_valid {
460 let normalized_strength = (current / threshold).min(2.0);
461 beat_timings.push(BeatTiming::new(beat_time, normalized_strength));
462
463 if beat_timings.len() > 8 {
465 beat_timings.remove(0);
466 }
467 }
468 }
469 }
470
471 prev_summed_bands.copy_from_slice(&summed_bands);
473
474 let min_lag = ((4096.0 / window_length) * 60.0 / config.max_bpm() as f64) as usize;
475 let max_lag = ((4096.0 / window_length) * 60.0 / config.min_bpm() as f64) as usize;
476
477 let ac = autocorrelation(&summed_bands, max_lag);
478
479 peaks_buffer.clear();
481 peaks_buffer.extend(
482 ac.iter()
483 .enumerate()
484 .skip(min_lag)
485 .take(max_lag - min_lag)
486 .map(|(idx, &val)| (idx, val)),
487 );
488 peaks_buffer
489 .sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
490
491 let peak_count = std::cmp::Ord::min(peaks_buffer.len().saturating_sub(1), 5);
493 if peak_count > 0 {
494 let mut result = [(0.0f32, 0.0f32); 5];
495 for (i, &(lag, v)) in peaks_buffer[1..=peak_count].iter().enumerate() {
496 let bpm = (60.0 * (4096.0 / window_length as f32)) / (lag as f32);
497 result[i] = (bpm, v);
498 }
499
500 if result[0].0 > 0.0 {
503 if let Some(old_bpm) = current_bpm {
504 let bpm_change = ((result[0].0 - old_bpm) / old_bpm).abs();
505 if bpm_change > 0.1 {
506 if !beat_timings.is_empty() {
509 let last = beat_timings.last().cloned().unwrap();
510 beat_timings.clear();
511 beat_timings.push(last);
512 }
513 }
514 }
515 current_bpm = Some(result[0].0);
516 }
517
518 let _ = bpm_sender.try_send(BpmDetection::with_beats(result, beat_timings.clone()));
519 }
520 }
521 }
522}
523
524fn run<T>(
527 device: &cpal::Device,
528 config: &cpal::StreamConfig,
529 sender: Sender<(f32, f32)>,
530) -> Result<()>
531where
532 T: SizedSample,
533 f32: FromSample<T>,
534{
535 let channels = config.channels as usize;
536 let err_fn = |err| tracing::error!("an error occurred on stream: {}", err);
537 let stream = device.build_input_stream(
538 config,
539 move |data: &[T], _: &cpal::InputCallbackInfo| read_data(data, channels, sender.clone()),
540 err_fn,
541 None,
542 );
543 if let Ok(stream) = stream
544 && let Ok(()) = stream.play()
545 {
546 std::mem::forget(stream);
547 }
548
549 tracing::info!("Input stream built.");
550
551 Ok(())
552}
553
554fn read_data<T>(input: &[T], channels: usize, sender: Sender<(f32, f32)>)
557where
558 T: SizedSample,
559 f32: FromSample<T>,
560{
561 for frame in input.chunks(channels) {
562 let left = if !frame.is_empty() {
563 frame[0].to_sample::<f32>()
564 } else {
565 0.0
566 };
567
568 let right = if channels > 1 && frame.len() > 1 {
569 frame[1].to_sample::<f32>()
570 } else {
571 left };
573
574 let _ = sender.try_send((left, right));
575 }
576}
577
578fn autocorrelation(signal: &[f32], max_lag: usize) -> Vec<f32> {
589 let n = signal.len();
590
591 let max_lag = std::cmp::Ord::min(max_lag, n);
592
593 let mut ac = vec![0.0f32; max_lag];
594
595 for lag in 0..max_lag {
596 let mut sum = 0.0f32;
597
598 for i in 0..(n - lag) {
599 sum += signal[i] * signal[i + lag];
600 }
601
602 ac[lag] = sum / n as f32;
603 }
604
605 ac
606}