use std::cell::Cell;
use std::iter;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{self, Receiver, Sender};
use std::thread;
use atomic::{Atomic, Ordering};
use bytemuck::NoUninit;
use cpal::{BufferSize, Device, SampleFormat, SampleRate, Stream, StreamConfig};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
const CHANNELS: u16 = 2;
const MIN_SAMPLE_RATE: u32 = 44100;
const LATENCY: f32 = 0.01;
pub trait AudioInput {
type Source: AudioSource + Send;
fn build(self, channels: u16, sample_rate: u32) -> Self::Source;
}
pub trait AudioSource {
fn get_samples(&mut self, buf: &mut [f32]) -> AudioSourceState;
}
pub enum AudioSourceState {
Paused,
Playing,
Drop,
}
pub type AudioEngineRc = Rc<AudioEngine>;
pub struct AudioEngine {
config: StreamConfig,
stream: Rc<Stream>,
worker_tx: Sender<WorkerMessage>,
playing: Cell<bool>,
}
type WorkerMessage = (Box<dyn FnOnce() -> Box<dyn AudioSource + Send> + Send>, AudioPosAtomic);
type InnerMutex = Arc<Mutex<Inner>>;
struct Inner { source_infos: Vec<SourceInfo>,
}
struct SourceInfo {
source: Box<dyn AudioSource + Send>,
pos_atomic: AudioPosAtomic,
}
impl AudioEngine {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let host = cpal::default_host();
let device = host.default_output_device().expect("Unable to determine default audio device");
let all_ranges: Vec<_> = device.supported_output_configs().expect("Unable to determine supported formats").filter(|range| range.channels() == CHANNELS && range.sample_format() == SampleFormat::F32).collect();
let mut ranges: Vec<_> = all_ranges.iter().filter(|range| range.min_sample_rate().0 <= MIN_SAMPLE_RATE && range.max_sample_rate().0 >= MIN_SAMPLE_RATE).collect();
let (range, sample_rate) = if !ranges.is_empty() {
(ranges[0], MIN_SAMPLE_RATE)
} else {
ranges = all_ranges.iter().filter(|range| range.min_sample_rate().0 >= MIN_SAMPLE_RATE).collect();
ranges.sort_by_key(|range| range.min_sample_rate().0);
let range = ranges.first().expect("No supported format");
(*range, range.min_sample_rate().0)
};
let mut config: StreamConfig = range.with_sample_rate(SampleRate(sample_rate)).config();
config.buffer_size = BufferSize::Fixed((config.sample_rate.0 as f32 * LATENCY) as u32);
let inner = Inner {
source_infos: Vec::new(),
};
let inner_mutex = Arc::new(Mutex::new(inner));
let stream = Self::build_stream(&device, &config, Arc::clone(&inner_mutex));
let (worker_tx, worker_rx) = mpsc::channel();
thread::spawn(move || Self::worker_impl(inner_mutex, worker_rx));
Self {
config,
stream: Rc::new(stream),
worker_tx,
playing: Cell::new(false),
}
}
fn build_stream(device: &Device, config: &StreamConfig, inner_mutex: InnerMutex) -> Stream {
let sample_rate = config.sample_rate.0;
let mut source_buf = Vec::from_iter(iter::repeat_n(0.0, CHANNELS as usize * sample_rate as usize));
let mut frame_count = 0_u64;
device.build_output_stream(config, move |buf: &mut [f32], _| { let buf_len = buf.len();
assert!(buf_len <= source_buf.len());
let source_buf_sl = &mut source_buf.as_mut_slice()[..buf_len];
buf.fill(0.0);
let mut inner = inner_mutex.lock().unwrap(); let source_infos = &mut inner.source_infos;
let mut i = 0;
let frame_count_pause = frame_count + sample_rate as u64;
while i < source_infos.len() {
let source_info = &mut source_infos[i];
let mut pos = source_info.pos_atomic.load(Ordering::Relaxed);
if pos.end != u64::MAX && pos.end > frame_count_pause {
i += 1;
} else {
match source_info.source.get_samples(source_buf_sl) {
AudioSourceState::Paused => {
if pos.end == u64::MAX {
pos.end = frame_count;
source_info.pos_atomic.store(pos, Ordering::Relaxed);
}
i += 1;
},
AudioSourceState::Playing => {
if pos.end != u64::MAX {
pos.offset += pos.end - pos.start;
pos.start = frame_count;
pos.end = u64::MAX;
source_info.pos_atomic.store(pos, Ordering::Relaxed);
}
for (src_sample, dst_sample) in source_buf_sl.iter().zip(buf.iter_mut()) {
*dst_sample += *src_sample;
}
i += 1;
},
AudioSourceState::Drop => {
source_infos.swap_remove(i); },
}
}
}
frame_count += buf_len as u64 / CHANNELS as u64;
},
|_| {
},
None).expect("Unable to build stream")
}
pub fn add<T: AudioInput + Send + 'static>(&self, input: T) -> AudioTimestamp {
let sample_rate = self.config.sample_rate.0;
let pos = AudioPos {
start: 0,
end: 0,
offset: 0,
};
let pos_atomic = Arc::new(Atomic::new(pos));
self.worker_tx.send((Box::new(move || Box::new(input.build(CHANNELS, sample_rate))), Arc::clone(&pos_atomic))).unwrap();
AudioTimestamp::new(sample_rate, Rc::clone(&self.stream), pos_atomic)
}
pub fn start(&self) {
if !self.playing.get() {
self.stream.play().expect("Unable to start stream");
self.playing.set(true);
}
}
pub fn pause(&self) {
if self.playing.get() {
self.stream.pause().expect("Unable to pause stream");
self.playing.set(false);
}
}
fn worker_impl(inner_mutex: InnerMutex, worker_rx: Receiver<WorkerMessage>) {
loop {
let (build_func, pos_atomic) = match worker_rx.recv() {
Ok(msg) => msg,
Err(_) => break,
};
let source_info = SourceInfo {
source: build_func(),
pos_atomic,
};
let mut inner = inner_mutex.lock().unwrap();
inner.source_infos.push(source_info);
}
}
}
pub struct AudioTimestamp {
sample_rate: f64,
stream: Rc<Stream>,
pos_atomic: AudioPosAtomic,
}
impl AudioTimestamp {
fn new(sample_rate: u32, stream: Rc<Stream>, pos_atomic: AudioPosAtomic) -> Self {
Self {
sample_rate: sample_rate.into(),
stream,
pos_atomic,
}
}
pub fn get_timestamp(&self) -> Option<f64> {
let stream_ts = self.stream.get_timestamp()?;
let pos = self.pos_atomic.load(Ordering::Relaxed);
let start_ts = pos.start as f64 / self.sample_rate;
let end_ts = pos.end as f64 / self.sample_rate;
let ts = if stream_ts < end_ts {
0.0_f64.max(stream_ts - start_ts) } else {
end_ts - start_ts };
Some(pos.offset as f64 / self.sample_rate + ts)
}
}
type AudioPosAtomic = Arc<Atomic<AudioPos>>;
#[repr(C)]
#[derive(Clone, Copy, NoUninit)]
struct AudioPos {
start: u64,
end: u64,
offset: u64,
}