use windows::{
Win32::Devices::FunctionDiscovery::*, Win32::Media::Audio::*,
Win32::System::Com::StructuredStorage::*, Win32::System::Com::*, Win32::System::Variant::*,
Win32::UI::Shell::PropertiesSystem::*, core::*,
};
use crate::error::{Error, Result};
pub(crate) fn init_com_mta() -> Result<()> {
unsafe {
let hr = CoInitializeEx(None, COINIT_MULTITHREADED);
if hr.is_ok() {
Ok(())
} else {
Err(Error::ComInitFailed)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AudioDeviceType {
Input,
Output,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AudioFormat {
S16,
F32,
}
pub struct AudioDevice {
name: String,
unique_id: String,
channels: i32,
sample_rate: i32,
device_type: AudioDeviceType,
}
impl AudioDevice {
pub fn name(&self) -> Result<String> {
Ok(self.name.clone())
}
pub fn unique_id(&self) -> Result<String> {
Ok(self.unique_id.clone())
}
pub fn channels(&self) -> i32 {
self.channels
}
pub fn sample_rate(&self) -> i32 {
self.sample_rate
}
pub fn device_type(&self) -> AudioDeviceType {
self.device_type
}
}
unsafe impl Send for AudioDevice {}
unsafe impl Sync for AudioDevice {}
pub struct AudioDeviceList {
devices: Vec<AudioDevice>,
}
impl AudioDeviceList {
pub fn enumerate_input() -> Result<Self> {
let devices = enumerate_devices_by_type(AudioDeviceType::Input)?;
Ok(Self { devices })
}
pub fn enumerate_output() -> Result<Self> {
let devices = enumerate_devices_by_type(AudioDeviceType::Output)?;
Ok(Self { devices })
}
pub fn enumerate() -> Result<Self> {
let mut devices = enumerate_devices_by_type(AudioDeviceType::Input)?;
devices.extend(enumerate_devices_by_type(AudioDeviceType::Output)?);
Ok(Self { devices })
}
pub fn devices(&self) -> &[AudioDevice] {
&self.devices
}
pub fn len(&self) -> usize {
self.devices.len()
}
pub fn is_empty(&self) -> bool {
self.devices.is_empty()
}
}
unsafe impl Send for AudioDeviceList {}
unsafe impl Sync for AudioDeviceList {}
fn enumerate_devices_by_type(device_type: AudioDeviceType) -> Result<Vec<AudioDevice>> {
init_com_mta()?;
unsafe {
let enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
.map_err(|_| Error::DeviceAccessDenied)?;
let data_flow = match device_type {
AudioDeviceType::Input => eCapture,
AudioDeviceType::Output => eRender,
};
let collection: IMMDeviceCollection = enumerator
.EnumAudioEndpoints(data_flow, DEVICE_STATE_ACTIVE)
.map_err(|_| Error::DeviceAccessDenied)?;
let count = collection
.GetCount()
.map_err(|_| Error::DeviceAccessDenied)?;
let mut devices = Vec::new();
for i in 0..count {
let device = match collection.Item(i) {
Ok(d) => d,
Err(_) => continue,
};
let device_id = match device.GetId() {
Ok(id) => id.to_string().unwrap_or_default(),
Err(_) => continue,
};
let props = device.OpenPropertyStore(STGM_READ);
let name = if let Ok(props) = props {
get_device_name(&props).unwrap_or_else(|| "Unknown Device".to_string())
} else {
"Unknown Device".to_string()
};
let (channels, sample_rate) = get_device_format(&device).unwrap_or((2, 48000));
devices.push(AudioDevice {
name,
unique_id: device_id,
channels,
sample_rate,
device_type,
});
}
Ok(devices)
}
}
fn get_device_name(props: &IPropertyStore) -> Option<String> {
unsafe {
let mut prop_value = props.GetValue(&PKEY_Device_FriendlyName).ok()?;
if prop_value.Anonymous.Anonymous.vt == VT_LPWSTR {
let pwsz = prop_value.Anonymous.Anonymous.Anonymous.pwszVal;
if !pwsz.is_null() {
let len = (0..).take_while(|&i| *pwsz.0.add(i) != 0).count();
let slice = std::slice::from_raw_parts(pwsz.0, len);
let name = String::from_utf16(slice).ok();
PropVariantClear(&mut prop_value).ok();
return name;
}
}
PropVariantClear(&mut prop_value).ok();
None
}
}
fn get_device_format(device: &IMMDevice) -> Option<(i32, i32)> {
unsafe {
let audio_client: IAudioClient = device.Activate(CLSCTX_ALL, None).ok()?;
let mix_format = audio_client.GetMixFormat().ok()?;
let channels = (*mix_format).nChannels as i32;
let sample_rate = (*mix_format).nSamplesPerSec as i32;
CoTaskMemFree(Some(mix_format as *const _));
Some((channels, sample_rate))
}
}
pub(crate) fn get_device_by_id(
device_id: Option<&str>,
device_type: AudioDeviceType,
) -> Result<IMMDevice> {
init_com_mta()?;
unsafe {
let enumerator: IMMDeviceEnumerator =
CoCreateInstance(&MMDeviceEnumerator, None, CLSCTX_ALL)
.map_err(|_| Error::DeviceAccessDenied)?;
if let Some(id) = device_id {
let wide_id: Vec<u16> = id.encode_utf16().chain(std::iter::once(0)).collect();
let pcwstr = PCWSTR::from_raw(wide_id.as_ptr());
enumerator
.GetDevice(pcwstr)
.map_err(|_| Error::DeviceNotFound)
} else {
let data_flow = match device_type {
AudioDeviceType::Input => eCapture,
AudioDeviceType::Output => eRender,
};
enumerator
.GetDefaultAudioEndpoint(data_flow, eConsole)
.map_err(|_| Error::DeviceNotFound)
}
}
}