use cpal::{
BufferSize, FromSample, SampleFormat, SizedSample,
traits::{DeviceTrait, HostTrait, StreamTrait},
};
use crossbeam_channel::{Receiver, Sender, bounded};
use fundsp::prelude32::*;
use osclet::{BorderMode, DaubechiesFamily, Osclet};
use resampler::{Attenuation, Latency, ResamplerFir, SampleRate};
use ringbuffer::{AllocRingBuffer, RingBuffer};
use crate::{
config::{AnalyzerConfig, DWT_LEVELS, TARGET_SAMPLING_RATE},
dsp,
error::{Error, Result},
types::{BeatTiming, BpmDetection},
};
#[allow(clippy::large_enum_variant)]
enum TransientBuffer<'a> {
Full(BufferRef<'a>),
Partial {
buffer: BufferArray<U1>,
length: usize,
},
}
impl<'a> TransientBuffer<'a> {
fn process<N: AudioUnit>(&'a self, node: &mut N) -> (BufferArray<U1>, usize) {
match self {
TransientBuffer::Full(buffer_ref) => {
let mut buffer = BufferArray::<U1>::new();
node.process(MAX_BUFFER_SIZE, buffer_ref, &mut buffer.buffer_mut());
(buffer, MAX_BUFFER_SIZE)
}
TransientBuffer::Partial { buffer, length } => {
let mut output_buffer = BufferArray::<U1>::new();
node.process(
*length,
&buffer.buffer_ref(),
&mut output_buffer.buffer_mut(),
);
(output_buffer, *length)
}
}
}
}
pub fn begin(config: AnalyzerConfig) -> Result<Receiver<BpmDetection>> {
config.validate()?;
let host = cpal::default_host();
let device = host
.input_devices()?
.find_map(|device| match device.description() {
#[cfg(target_os = "macos")]
Ok(desc) if desc.name().contains("BlackHole") => Some(Ok(device)),
Err(e) => Some(Err(e)),
Ok(_) => None,
})
.transpose()?
.or_else(|| host.default_input_device())
.ok_or(Error::NoDeviceFound)?;
begin_with_device(config, &device)
}
pub fn begin_with_device(
config: AnalyzerConfig,
device: &cpal::Device,
) -> Result<Receiver<BpmDetection>> {
config.validate()?;
let device_name = device.description()?.name().to_string();
tracing::info!("Using audio device: {}", device_name);
let supported_config = device.default_input_config()?;
let mut stream_config = supported_config.config();
stream_config.buffer_size = BufferSize::Fixed(config.buffer_size());
let sample_rate = stream_config.sample_rate as f64;
tracing::info!(
"Sampling with {:?} Hz on {} channels",
stream_config.sample_rate,
stream_config.channels
);
let (audio_sender, audio_receiver) = bounded(config.queue_size());
let (bpm_sender, bpm_receiver) = bounded(config.queue_size());
match supported_config.sample_format() {
SampleFormat::F32 => run::<f32>(device, &stream_config, audio_sender)?,
SampleFormat::I16 => run::<i16>(device, &stream_config, audio_sender)?,
SampleFormat::U16 => run::<u16>(device, &stream_config, audio_sender)?,
other => {
return Err(Error::UnsupportedSampleFormat(other));
}
}
std::thread::spawn(move || run_analysis(sample_rate, audio_receiver, bpm_sender, config));
Ok(bpm_receiver)
}
fn run_analysis(
sample_rate: f64,
audio_receiver: Receiver<(f32, f32)>,
bpm_sender: Sender<BpmDetection>,
config: AnalyzerConfig,
) -> Result<()> {
let now = std::time::Instant::now();
let dwt_executor = Osclet::make_daubechies_f32(DaubechiesFamily::Db4, BorderMode::Wrap);
let input_sample_rate = match sample_rate as u32 {
16000 => SampleRate::Hz16000,
22050 => SampleRate::Hz22050,
32000 => SampleRate::Hz32000,
44100 => SampleRate::Hz44100,
48000 => SampleRate::Hz48000,
88200 => SampleRate::Hz88200,
96000 => SampleRate::Hz96000,
176400 => SampleRate::Hz176400,
192000 => SampleRate::Hz192000,
_ => return Err(Error::UnsupportedSampleRate(sample_rate as u32)),
};
let mut resampler = ResamplerFir::new(
1,
input_sample_rate,
SampleRate::Hz22050,
Latency::Sample64,
Attenuation::Db90,
);
tracing::info!("Resampling buffer: {}", resampler.buffer_size_output());
let resampling_factor = TARGET_SAMPLING_RATE / sample_rate;
let window_length = config.window_size() as f64 / TARGET_SAMPLING_RATE;
tracing::info!(
"Analysis window: {} samples ({:.2} seconds)",
config.window_size(),
window_length
);
tracing::info!(
"Resampling factor: {}, every {}th sample",
resampling_factor,
(sample_rate / TARGET_SAMPLING_RATE).round()
);
let mut ring_buffer = AllocRingBuffer::<f32>::new(config.window_size());
let once = std::sync::Once::new();
let mut filter_chain = dsp::alpha_lpf(0.99f32) >> dsp::fwr::<f32>();
let mut resampled_output = vec![0.0f32; resampler.buffer_size_output()];
let mut input_buffer = Vec::with_capacity(4096);
let mut signal = vec![0.0f32; config.window_size()];
let mut bands = vec![vec![0.0f32; 4096]; DWT_LEVELS];
let mut summed_bands = vec![0.0f32; 4096];
let mut peaks_buffer = Vec::with_capacity(1024);
let mut beat_timings: Vec<BeatTiming> = Vec::with_capacity(8);
let mut prev_summed_bands = vec![0.0f32; 4096];
let mut samples_processed = 0usize;
let mut current_bpm: Option<f32> = None;
loop {
input_buffer.clear();
input_buffer.extend(
audio_receiver
.try_iter()
.map(|(l, r)| (l + r) * 0.5),
);
let mut input_slice = &input_buffer[..];
while !input_slice.is_empty() {
let (consumed, produced) = resampler
.resample(input_slice, &mut resampled_output)
.map_err(Error::ResampleError)?;
ring_buffer.extend(resampled_output[..produced].iter().copied());
samples_processed += produced;
input_slice = &input_slice[consumed..];
}
if ring_buffer.is_full() {
once.call_once(|| {
let time = now.elapsed();
tracing::info!(
"Initial audio buffer filled with {} samples in {:.2?}",
ring_buffer.len(),
time
);
});
signal = ring_buffer.to_vec();
let dwt = dwt_executor.multi_dwt(&signal, DWT_LEVELS)?;
for (band_idx, level) in dwt.levels.into_iter().enumerate() {
filter_chain.reset();
let mut processed_samples = Vec::with_capacity(level.approximations.len());
for chunk in level.approximations.chunks(MAX_BUFFER_SIZE) {
let transient_buffer = if chunk.len() == MAX_BUFFER_SIZE {
let buffer = unsafe {
std::slice::from_raw_parts::<'_, F32x>(
chunk.as_ptr() as *const _,
MAX_BUFFER_SIZE / SIMD_LEN,
)
};
TransientBuffer::Full(BufferRef::new(buffer))
} else {
let mut buffer = BufferArray::<U1>::new();
buffer.channel_f32_mut(0)[..chunk.len()].copy_from_slice(chunk);
TransientBuffer::Partial {
buffer,
length: chunk.len(),
}
};
let (mut output, length) = transient_buffer.process(&mut filter_chain);
processed_samples.extend_from_slice(&output.channel_f32(0)[..length]);
}
let downsampling_factor = 1 << (DWT_LEVELS - 1 - band_idx);
let band_buffer = &mut bands[band_idx];
band_buffer.fill(0.0);
let downsampled_len = processed_samples.len() / downsampling_factor;
let samples_to_copy = std::cmp::min(downsampled_len, 4096);
for (i, sample_idx) in (0..processed_samples.len())
.step_by(downsampling_factor)
.take(samples_to_copy)
.enumerate()
{
band_buffer[i] = processed_samples[sample_idx];
}
if samples_to_copy > 0 {
let mean: f32 =
band_buffer[..samples_to_copy].iter().sum::<f32>() / samples_to_copy as f32;
band_buffer[..samples_to_copy]
.iter_mut()
.for_each(|sample| *sample -= mean);
}
}
summed_bands.fill(0.0);
for i in 0..4096 {
summed_bands[i] = bands[0][i] * 2.0 + bands[1][i] * 1.5 + bands[2][i] * 1.0 + bands[3][i] * 0.5; }
let mut onset_strengths = Vec::with_capacity(4096);
for i in 0..summed_bands.len() {
let onset = (summed_bands[i] - prev_summed_bands[i]).max(0.0);
onset_strengths.push(onset);
}
let mut sorted_onsets = onset_strengths.clone();
sorted_onsets.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let percentile_90 = sorted_onsets[(sorted_onsets.len() * 9) / 10];
let threshold = (percentile_90 * 1.5).max(0.05);
for i in 10..(onset_strengths.len() - 10) {
let current = onset_strengths[i];
let is_local_max = (i.saturating_sub(10)..i).all(|j| current > onset_strengths[j])
&& ((i + 1)..=std::cmp::min(i + 10, onset_strengths.len() - 1)).all(|j| current >= onset_strengths[j]);
if is_local_max && current > threshold {
let beat_sample = samples_processed - config.window_size() + (i * (config.window_size() / 4096));
let beat_time = beat_sample as f64 / TARGET_SAMPLING_RATE;
let mut tempo_valid = beat_timings.len() < 3;
if !tempo_valid {
if let Some(bpm) = current_bpm {
if let Some(last_beat) = beat_timings.last() {
let interval = beat_time - last_beat.time_seconds;
let expected_interval = 60.0 / bpm as f64;
let deviation = (interval / expected_interval).abs();
tempo_valid = (0.7..=1.3).contains(&deviation) || (1.7..=2.3).contains(&deviation) || (0.4..=0.6).contains(&deviation); } else {
tempo_valid = true; }
} else {
tempo_valid = true; }
}
let should_add = beat_timings.last()
.map(|last: &BeatTiming| (beat_time - last.time_seconds) > 0.15)
.unwrap_or(true);
if should_add && tempo_valid {
let normalized_strength = (current / threshold).min(2.0);
beat_timings.push(BeatTiming::new(beat_time, normalized_strength));
if beat_timings.len() > 8 {
beat_timings.remove(0);
}
}
}
}
prev_summed_bands.copy_from_slice(&summed_bands);
let min_lag = ((4096.0 / window_length) * 60.0 / config.max_bpm() as f64) as usize;
let max_lag = ((4096.0 / window_length) * 60.0 / config.min_bpm() as f64) as usize;
let ac = autocorrelation(&summed_bands, max_lag);
peaks_buffer.clear();
peaks_buffer.extend(
ac.iter()
.enumerate()
.skip(min_lag)
.take(max_lag - min_lag)
.map(|(idx, &val)| (idx, val)),
);
peaks_buffer
.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let peak_count = std::cmp::Ord::min(peaks_buffer.len().saturating_sub(1), 5);
if peak_count > 0 {
let mut result = [(0.0f32, 0.0f32); 5];
for (i, &(lag, v)) in peaks_buffer[1..=peak_count].iter().enumerate() {
let bpm = (60.0 * (4096.0 / window_length as f32)) / (lag as f32);
result[i] = (bpm, v);
}
if result[0].0 > 0.0 {
if let Some(old_bpm) = current_bpm {
let bpm_change = ((result[0].0 - old_bpm) / old_bpm).abs();
if bpm_change > 0.1 {
if !beat_timings.is_empty() {
let last = beat_timings.last().cloned().unwrap();
beat_timings.clear();
beat_timings.push(last);
}
}
}
current_bpm = Some(result[0].0);
}
let _ = bpm_sender.try_send(BpmDetection::with_beats(result, beat_timings.clone()));
}
}
}
}
fn run<T>(
device: &cpal::Device,
config: &cpal::StreamConfig,
sender: Sender<(f32, f32)>,
) -> Result<()>
where
T: SizedSample,
f32: FromSample<T>,
{
let channels = config.channels as usize;
let err_fn = |err| tracing::error!("an error occurred on stream: {}", err);
let stream = device.build_input_stream(
config,
move |data: &[T], _: &cpal::InputCallbackInfo| read_data(data, channels, sender.clone()),
err_fn,
None,
);
if let Ok(stream) = stream
&& let Ok(()) = stream.play()
{
std::mem::forget(stream);
}
tracing::info!("Input stream built.");
Ok(())
}
fn read_data<T>(input: &[T], channels: usize, sender: Sender<(f32, f32)>)
where
T: SizedSample,
f32: FromSample<T>,
{
for frame in input.chunks(channels) {
let left = if !frame.is_empty() {
frame[0].to_sample::<f32>()
} else {
0.0
};
let right = if channels > 1 && frame.len() > 1 {
frame[1].to_sample::<f32>()
} else {
left };
let _ = sender.try_send((left, right));
}
}
fn autocorrelation(signal: &[f32], max_lag: usize) -> Vec<f32> {
let n = signal.len();
let max_lag = std::cmp::Ord::min(max_lag, n);
let mut ac = vec![0.0f32; max_lag];
for lag in 0..max_lag {
let mut sum = 0.0f32;
for i in 0..(n - lag) {
sum += signal[i] * signal[i + lag];
}
ac[lag] = sum / n as f32;
}
ac
}