use std::ptr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::thread;
use windows::{
Win32::Foundation::*, Win32::Media::Audio::*, Win32::Media::KernelStreaming::*,
Win32::Media::Multimedia::*, Win32::System::Com::*, Win32::System::Threading::*,
};
use crate::device_windows::{AudioDeviceType, AudioFormat, get_device_by_id};
use crate::error::{Error, Result};
pub struct PlaybackFrame {
pub data: Vec<u8>,
pub frames: i32,
pub channels: i32,
pub sample_rate: i32,
pub format: AudioFormat,
}
impl PlaybackFrame {
pub fn from_s16(data: &[i16], channels: i32, sample_rate: i32) -> Result<Self> {
if channels <= 0 {
return Err(Error::InvalidChannels);
}
let frames = data.len() as i32 / channels;
let bytes: Vec<u8> = data
.iter()
.flat_map(|&sample| sample.to_le_bytes())
.collect();
Ok(Self {
data: bytes,
frames,
channels,
sample_rate,
format: AudioFormat::S16,
})
}
pub fn from_f32(data: &[f32], channels: i32, sample_rate: i32) -> Result<Self> {
if channels <= 0 {
return Err(Error::InvalidChannels);
}
let frames = data.len() as i32 / channels;
let bytes: Vec<u8> = data
.iter()
.flat_map(|&sample| sample.to_le_bytes())
.collect();
Ok(Self {
data: bytes,
frames,
channels,
sample_rate,
format: AudioFormat::F32,
})
}
}
pub struct AudioPlaybackConfig {
pub device_id: Option<String>,
pub sample_rate: i32,
pub channels: i32,
}
impl Default for AudioPlaybackConfig {
fn default() -> Self {
Self {
device_id: None,
sample_rate: 48000,
channels: 2,
}
}
}
struct PlaybackContext {
callback: Box<dyn Fn() -> Option<PlaybackFrame> + Send + Sync>,
running: AtomicBool,
}
struct SessionData {
audio_client: IAudioClient,
render_client: IAudioRenderClient,
event_handle: HANDLE,
format: AudioFormat,
sample_rate: i32,
channels: i32,
buffer_frames: u32,
}
struct SendHandle(HANDLE);
unsafe impl Send for SendHandle {}
impl SendHandle {
fn into_inner(self) -> HANDLE {
self.0
}
}
struct SendPtr<T>(T);
unsafe impl Send for SendPtr<IAudioRenderClient> {}
unsafe impl Send for SendPtr<IAudioClient> {}
impl<T> SendPtr<T> {
fn into_inner(self) -> T {
self.0
}
}
pub struct AudioPlayback {
session: Option<SessionData>,
context: Option<Arc<PlaybackContext>>,
playback_thread: Option<thread::JoinHandle<()>>,
config: AudioPlaybackConfig,
actual_sample_rate: i32,
actual_channels: i32,
}
impl AudioPlayback {
pub fn new<F>(config: AudioPlaybackConfig, callback: F) -> Result<Self>
where
F: Fn() -> Option<PlaybackFrame> + Send + Sync + 'static,
{
crate::device_windows::init_com_mta()?;
unsafe {
let device = get_device_by_id(config.device_id.as_deref(), AudioDeviceType::Output)?;
let audio_client: IAudioClient = device
.Activate(CLSCTX_ALL, None)
.map_err(|_| Error::SessionCreateFailed)?;
let mix_format = audio_client
.GetMixFormat()
.map_err(|_| Error::SessionCreateFailed)?;
let sample_rate = (*mix_format).nSamplesPerSec as i32;
let channels = (*mix_format).nChannels as i32;
let format = determine_playback_format(mix_format);
let event_handle =
CreateEventW(None, false, false, None).map_err(|_| Error::SessionCreateFailed)?;
let buffer_duration: i64 = 100_000; audio_client
.Initialize(
AUDCLNT_SHAREMODE_SHARED,
AUDCLNT_STREAMFLAGS_EVENTCALLBACK,
buffer_duration,
0,
mix_format,
None,
)
.map_err(|_| {
CoTaskMemFree(Some(mix_format as *const _));
let _ = CloseHandle(event_handle);
Error::SessionCreateFailed
})?;
CoTaskMemFree(Some(mix_format as *const _));
audio_client.SetEventHandle(event_handle).map_err(|_| {
let _ = CloseHandle(event_handle);
Error::SessionCreateFailed
})?;
let buffer_frames = audio_client.GetBufferSize().map_err(|_| {
let _ = CloseHandle(event_handle);
Error::SessionCreateFailed
})?;
let render_client: IAudioRenderClient = audio_client.GetService().map_err(|_| {
let _ = CloseHandle(event_handle);
Error::SessionCreateFailed
})?;
let context = Arc::new(PlaybackContext {
callback: Box::new(callback),
running: AtomicBool::new(false),
});
let session = SessionData {
audio_client,
render_client,
event_handle,
format,
sample_rate,
channels,
buffer_frames,
};
Ok(Self {
session: Some(session),
context: Some(context),
playback_thread: None,
config,
actual_sample_rate: sample_rate,
actual_channels: channels,
})
}
}
pub fn start(&mut self) -> Result<()> {
let session = self.session.as_ref().ok_or(Error::SessionStartFailed)?;
let context = self.context.as_ref().ok_or(Error::SessionStartFailed)?;
if context.running.load(Ordering::Acquire) {
return Ok(());
}
unsafe {
session
.audio_client
.Start()
.map_err(|_| Error::SessionStartFailed)?;
}
let render_client = SendPtr(session.render_client.clone());
let audio_client = SendPtr(session.audio_client.clone());
let event_handle = SendHandle(session.event_handle);
let format = session.format;
let sample_rate = session.sample_rate;
let channels = session.channels;
let buffer_frames = session.buffer_frames;
let context_clone = Arc::clone(context);
let handle = thread::Builder::new()
.name("audio-playback".into())
.spawn(move || {
playback_thread_func(
render_client.into_inner(),
audio_client.into_inner(),
event_handle.into_inner(),
format,
sample_rate,
channels,
buffer_frames,
context_clone,
);
})
.map_err(|_| Error::SessionStartFailed)?;
context.running.store(true, Ordering::Release);
self.playback_thread = Some(handle);
Ok(())
}
pub fn stop(&mut self) {
if let Some(context) = &self.context
&& context.running.load(Ordering::Acquire)
{
context.running.store(false, Ordering::Release);
if let Some(session) = &self.session {
unsafe {
let _ = SetEvent(session.event_handle);
}
}
if let Some(handle) = self.playback_thread.take() {
let _ = handle.join();
}
if let Some(session) = &self.session {
unsafe {
let _ = session.audio_client.Stop();
}
}
}
}
pub fn config(&self) -> &AudioPlaybackConfig {
&self.config
}
pub fn sample_rate(&self) -> i32 {
self.actual_sample_rate
}
pub fn channels(&self) -> i32 {
self.actual_channels
}
}
impl Drop for AudioPlayback {
fn drop(&mut self) {
self.stop();
if let Some(session) = self.session.take() {
unsafe {
let _ = CloseHandle(session.event_handle);
}
}
}
}
unsafe impl Send for AudioPlayback {}
unsafe impl Sync for AudioPlayback {}
unsafe fn determine_playback_format(wave_format: *const WAVEFORMATEX) -> AudioFormat {
let format_tag = unsafe { (*wave_format).wFormatTag };
if format_tag == WAVE_FORMAT_IEEE_FLOAT as u16 {
return AudioFormat::F32;
}
if format_tag == WAVE_FORMAT_EXTENSIBLE as u16 {
let ext = wave_format as *const WAVEFORMATEXTENSIBLE;
let sub_format = unsafe { std::ptr::addr_of!((*ext).SubFormat).read_unaligned() };
if sub_format == KSDATAFORMAT_SUBTYPE_IEEE_FLOAT {
return AudioFormat::F32;
}
}
AudioFormat::S16
}
#[allow(clippy::too_many_arguments)]
fn playback_thread_func(
render_client: IAudioRenderClient,
audio_client: IAudioClient,
event_handle: HANDLE,
format: AudioFormat,
_sample_rate: i32,
channels: i32,
buffer_frames: u32,
context: Arc<PlaybackContext>,
) {
unsafe {
let _ = CoInitializeEx(None, COINIT_MULTITHREADED);
while context.running.load(Ordering::Acquire) {
let wait_result = WaitForSingleObject(event_handle, 10);
if !context.running.load(Ordering::Acquire) {
break;
}
if wait_result != WAIT_OBJECT_0 && wait_result != WAIT_TIMEOUT {
continue;
}
let padding = match audio_client.GetCurrentPadding() {
Ok(p) => p,
Err(_) => continue,
};
let frames_available = buffer_frames.saturating_sub(padding);
if frames_available == 0 {
continue;
}
let frame_opt =
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (context.callback)()))
.unwrap_or(None);
let data_ptr = match render_client.GetBuffer(frames_available) {
Ok(p) => p,
Err(_) => continue,
};
if let Some(frame) = frame_opt {
let bytes_per_sample: usize = match format {
AudioFormat::S16 => 2,
AudioFormat::F32 => 4,
};
let buffer_size = match (frames_available as usize)
.checked_mul(channels as usize)
.and_then(|n| n.checked_mul(bytes_per_sample))
{
Some(size) => size,
None => {
let _ = render_client.ReleaseBuffer(frames_available, 0);
continue;
}
};
if frame.format == AudioFormat::F32 && format == AudioFormat::S16 {
let src_count = frame.data.len() / 4;
let dst_s16 =
std::slice::from_raw_parts_mut(data_ptr as *mut i16, buffer_size / 2);
let copy_len = src_count.min(dst_s16.len());
let src_ptr = frame.data.as_ptr() as *const f32;
for (i, dst) in dst_s16.iter_mut().enumerate().take(copy_len) {
let sample = src_ptr.add(i).read_unaligned();
*dst = (sample * 32767.0).clamp(-32768.0, 32767.0) as i16;
}
dst_s16[copy_len..].fill(0);
} else if frame.format == AudioFormat::S16 && format == AudioFormat::F32 {
let src_count = frame.data.len() / 2;
let dst_f32 =
std::slice::from_raw_parts_mut(data_ptr as *mut f32, buffer_size / 4);
let copy_len = src_count.min(dst_f32.len());
let src_ptr = frame.data.as_ptr() as *const i16;
for (i, dst) in dst_f32.iter_mut().enumerate().take(copy_len) {
let sample = src_ptr.add(i).read_unaligned();
*dst = sample as f32 / 32768.0;
}
dst_f32[copy_len..].fill(0.0);
} else {
let copy_len = frame.data.len().min(buffer_size);
ptr::copy_nonoverlapping(frame.data.as_ptr(), data_ptr, copy_len);
if copy_len < buffer_size {
ptr::write_bytes(data_ptr.add(copy_len), 0, buffer_size - copy_len);
}
}
let _ = render_client.ReleaseBuffer(frames_available, 0);
} else {
let _ = render_client
.ReleaseBuffer(frames_available, AUDCLNT_BUFFERFLAGS_SILENT.0 as u32);
}
}
CoUninitialize();
}
}