1use cpal::{
38 BufferSize, FromSample, SampleFormat, SizedSample,
39 traits::{DeviceTrait, HostTrait, StreamTrait},
40};
41use crossbeam_channel::{Sender, bounded};
42use fundsp::prelude32::*;
43use osclet::{BorderMode, DaubechiesFamily, Osclet};
44use resampler::{Attenuation, Latency, ResamplerFir, SampleRate};
45use ringbuffer::{AllocRingBuffer, RingBuffer};
46
47pub mod dsp;
48
49const QUEUE_SIZE: usize = 4096;
52const DWT_LEVELS: usize = 4;
54const AUDIO_BUFFER_SIZE: u32 = 256;
56const DWT_WINDOW_SIZE: usize = 65536;
58const TARGET_SAMPLING_RATE: f64 = 22050.0;
60
61const MIN_BPM: f32 = 40.0;
63const MAX_BPM: f32 = 240.0;
65
66#[allow(clippy::large_enum_variant)]
71enum TransientBuffer<'a> {
72 Full(BufferRef<'a>),
74 Partial {
76 buffer: BufferArray<U1>,
78 length: usize,
80 },
81}
82
83#[derive(Debug, thiserror::Error)]
85pub enum Error {
86 #[error("CPAL error: {0}")]
87 DevicesError(#[from] cpal::DevicesError),
88 #[error("Unsupported sample format: {0}")]
89 UnsupportedSampleFormat(SampleFormat),
90 #[error("Failed to get device name: {0}")]
91 DeviceNameError(#[from] cpal::DeviceNameError),
92 #[error("No input device found")]
93 NoDeviceFound,
94 #[error("Failed to get stream configuration: {0}")]
95 StreamConfigError(#[from] cpal::DefaultStreamConfigError),
96 #[error("Resampling error: {0}")]
97 ResampleError(resampler::ResampleError),
98 #[error("Osclet error: {0}")]
99 Osclet(#[from] osclet::OscletError),
100}
101
102pub type Result<T, E = Error> = std::result::Result<T, E>;
103
104pub use crossbeam_channel::Receiver;
105
106#[derive(Clone, Debug, Copy, bon::Builder)]
122pub struct AnalyzerConfig {
123 #[builder(default = MIN_BPM)]
125 min_bpm: f32,
126 #[builder(default = MAX_BPM)]
128 max_bpm: f32,
129 #[builder(default = DWT_WINDOW_SIZE)]
131 window_size: usize,
132 #[builder(default = QUEUE_SIZE)]
134 queue_size: usize,
135 #[builder(default = AUDIO_BUFFER_SIZE)]
137 buffer_size: u32,
138}
139
140pub fn begin(config: AnalyzerConfig) -> Result<crossbeam_channel::Receiver<[(f32, f32); 5]>> {
179 let host = cpal::default_host();
180
181 let loopback_device = host
182 .input_devices()?
183 .find_map(|device| match device.description() {
184 #[cfg(target_os = "macos")]
185 Ok(desc) if desc.name().contains("BlackHole") => Some(Ok(device)),
186 Err(e) => Some(Err(e)),
187 Ok(_) => None,
188 })
189 .transpose()?
190 .or_else(|| host.default_input_device())
191 .ok_or(Error::NoDeviceFound)?;
192
193 let device_name = loopback_device.description()?.name().to_string();
194
195 tracing::info!("Using audio device: {}", device_name);
196
197 let supported_config = loopback_device.default_input_config()?;
198
199 let mut stream_config = supported_config.config();
200
201 stream_config.buffer_size = BufferSize::Fixed(config.buffer_size);
202
203 let sample_rate = stream_config.sample_rate as f64;
204
205 tracing::info!(
206 "Sampling with {:?} Hz on {} channels",
207 stream_config.sample_rate,
208 stream_config.channels
209 );
210
211 let (audio_sender, audio_receiver) = bounded(config.queue_size);
212 let (bmp_sender, bpm_receiver) = bounded(config.queue_size);
213
214 match supported_config.sample_format() {
215 SampleFormat::F32 => run::<f32>(&loopback_device, &stream_config, audio_sender)?,
216 SampleFormat::I16 => run::<i16>(&loopback_device, &stream_config, audio_sender)?,
217 SampleFormat::U16 => run::<u16>(&loopback_device, &stream_config, audio_sender)?,
218 other => {
219 return Err(Error::UnsupportedSampleFormat(other));
220 }
221 }
222
223 std::thread::spawn(move || run_analysis(sample_rate, audio_receiver, bmp_sender, config));
224
225 Ok(bpm_receiver)
226}
227
228impl<'a> TransientBuffer<'a> {
229 fn process<N: AudioUnit>(&'a self, node: &mut N) -> (BufferArray<U1>, usize) {
235 match self {
236 TransientBuffer::Full(buffer_ref) => {
237 let mut buffer = BufferArray::<U1>::new();
238
239 node.process(MAX_BUFFER_SIZE, buffer_ref, &mut buffer.buffer_mut());
240
241 (buffer, MAX_BUFFER_SIZE)
242 }
243 TransientBuffer::Partial { buffer, length } => {
244 let mut output_buffer = BufferArray::<U1>::new();
245
246 node.process(
247 *length,
248 &buffer.buffer_ref(),
249 &mut output_buffer.buffer_mut(),
250 );
251
252 (output_buffer, *length)
253 }
254 }
255 }
256}
257
258fn run_analysis(
275 sample_rate: f64,
276 audio_receiver: Receiver<(f32, f32)>,
277 bpm_sender: Sender<[(f32, f32); 5]>,
278 config: AnalyzerConfig,
279) -> Result<()> {
280 let now = std::time::Instant::now();
281
282 let dwt_executor = Osclet::make_daubechies_f32(DaubechiesFamily::Db4, BorderMode::Wrap);
283
284 let mut resampler = ResamplerFir::new(
285 1,
286 SampleRate::Hz48000,
287 SampleRate::Hz22050,
288 Latency::Sample64,
289 Attenuation::Db90,
290 );
291
292 tracing::info!("Resampling buffer: {}", resampler.buffer_size_output());
293
294 let resampling_factor = TARGET_SAMPLING_RATE / sample_rate;
295 tracing::info!(
296 "Resampling factor: {}, every {}th sample",
297 resampling_factor,
298 (sample_rate / TARGET_SAMPLING_RATE).round()
299 );
300
301 let mut ring_buffer = AllocRingBuffer::<f32>::new(config.window_size);
302
303 let once = std::sync::Once::new();
304
305 let mut filter_chain = dsp::alpha_lpf(0.99f32) >> dsp::fwr::<f32>();
306
307 let mut resampled_output = vec![0.0f32; resampler.buffer_size_output()];
308
309 loop {
310 let input = audio_receiver
312 .try_iter()
313 .map(|(l, r)| (l + r) * 0.5)
315 .collect::<Vec<_>>();
316
317 let mut input = &input[..];
318
319 while !input.is_empty() {
320 let (consumed, produced) = resampler
321 .resample(input, &mut resampled_output)
322 .map_err(Error::ResampleError)?;
323 ring_buffer.extend(resampled_output[..produced].iter().copied());
324
325 input = &input[consumed..];
326 }
327
328 if ring_buffer.is_full() {
329 once.call_once(|| {
330 let time = now.elapsed();
331 tracing::info!(
332 "Initial audio buffer filled with {} samples in {:.2?}",
333 ring_buffer.len(),
334 time
335 );
336 });
337
338 let signal = ring_buffer.to_vec();
339
340 let dwt = dwt_executor.multi_dwt(&signal, DWT_LEVELS)?;
341
342 let bands = dwt
343 .levels
344 .into_iter()
345 .enumerate()
346 .map(|(i, band)| {
347 filter_chain.reset();
348
349 let band = band
350 .approximations
351 .chunks(MAX_BUFFER_SIZE)
352 .map(|chunk| {
353 if chunk.len() == MAX_BUFFER_SIZE {
354 let buffer = unsafe {
355 std::slice::from_raw_parts::<'_, F32x>(
356 chunk.as_ptr() as *const _,
357 MAX_BUFFER_SIZE / SIMD_LEN,
358 )
359 };
360
361 TransientBuffer::Full(BufferRef::new(buffer))
362 } else {
363 let mut buffer = BufferArray::<U1>::new();
364 buffer.channel_f32_mut(0)[..chunk.len()].copy_from_slice(chunk);
365
366 TransientBuffer::Partial {
367 buffer,
368 length: chunk.len(),
369 }
370 }
371 })
372 .map(|buffer| buffer.process(&mut filter_chain))
373 .flat_map(|(mut band, length)| {
374 band.channel_f32(0)
375 .iter()
376 .copied()
377 .take(length)
378 .collect::<Vec<_>>()
379 })
380 .collect::<Vec<_>>();
381
382 let downsampling_factor = 1 << (3 - i);
383
384 let mut band = (0..band.len())
385 .step_by(downsampling_factor)
386 .map(|i| band[i])
387 .collect::<Vec<_>>();
388
389 let mean = band.iter().copied().sum::<f32>() / band.len() as f32;
390
391 band.iter_mut().for_each(|sample| *sample -= mean);
393
394 band.resize(4096, 0.0);
395
396 band
397 })
398 .collect::<Vec<_>>();
399
400 let summed_bands = (0..4096)
401 .map(|i| bands.iter().map(|band| band[i]).sum::<f32>())
402 .collect::<Vec<_>>();
403
404 let min_lag = ((4096.0 / 3.0) * 60.0 / config.max_bpm) as usize;
405 let max_lag = ((4096.0 / 3.0) * 60.0 / config.min_bpm) as usize;
406
407 let ac = autocorrelation(&summed_bands, max_lag);
408
409 let mut peaks = ac
410 .into_iter()
411 .enumerate()
412 .skip(min_lag)
413 .take(max_lag - min_lag)
414 .collect::<Vec<_>>();
415 peaks.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
416
417 let peaks = &peaks[1..6];
418
419 let peaks = peaks
420 .iter()
421 .copied()
422 .map(|(lag, v)| ((60.0 * (4096.0 / 3.0)) / (lag as f32), v))
423 .collect::<Vec<_>>();
424
425 if let Ok(()) = bpm_sender.try_send(peaks.try_into().unwrap()) {}
426 }
427 }
428}
429
430fn run<T>(
433 device: &cpal::Device,
434 config: &cpal::StreamConfig,
435 sender: Sender<(f32, f32)>,
436) -> Result<()>
437where
438 T: SizedSample,
439 f32: FromSample<T>,
440{
441 let channels = config.channels as usize;
442 let err_fn = |err| tracing::error!("an error occurred on stream: {}", err);
443 let stream = device.build_input_stream(
444 config,
445 move |data: &[T], _: &cpal::InputCallbackInfo| read_data(data, channels, sender.clone()),
446 err_fn,
447 None,
448 );
449 if let Ok(stream) = stream
450 && let Ok(()) = stream.play()
451 {
452 std::mem::forget(stream);
453 }
454
455 tracing::info!("Input stream built.");
456
457 Ok(())
458}
459
460fn read_data<T>(input: &[T], channels: usize, sender: Sender<(f32, f32)>)
463where
464 T: SizedSample,
465 f32: FromSample<T>,
466{
467 for frame in input.chunks(channels) {
468 let mut left = 0.0;
469 let mut right = 0.0;
470 for (channel, sample) in frame.iter().enumerate() {
471 if channel & 1 == 0 {
472 left = sample.to_sample::<f32>();
473 } else {
474 right = sample.to_sample::<f32>();
475 }
476 }
477
478 if let Ok(()) = sender.try_send((left, right)) {}
479 }
480}
481
482pub fn autocorrelation(signal: &[f32], max_lag: usize) -> Vec<f32> {
488 let n = signal.len();
489
490 let max_lag = std::cmp::Ord::min(max_lag, n);
491
492 let mut ac = vec![0.0f32; max_lag];
493
494 for lag in 0..max_lag {
495 let mut sum = 0.0f32;
496
497 for i in 0..(n - lag) {
498 sum += signal[i] * signal[i + lag];
499 }
500
501 ac[lag] = sum / n as f32;
502 }
503
504 ac
505}