1use std::ffi::{
4 OsStr,
5 OsString,
6 c_void,
7};
8use std::io;
9use std::marker::PhantomData;
10use std::os::windows::ffi::OsStringExt;
11
12use windows::Win32::Devices::FunctionDiscovery::PKEY_Device_FriendlyName;
13use windows::Win32::Graphics::Gdi::{
14 GetDC,
15 HDC,
16 ReleaseDC,
17};
18use windows::Win32::Media::Audio::{
19 DEVICE_STATE_ACTIVE,
20 IMMDevice,
21 IMMDeviceEnumerator,
22 MMDeviceEnumerator,
23 eConsole,
24 eRender,
25};
26use windows::Win32::System::Com::STGM_READ;
27use windows::Win32::System::Com::StructuredStorage::PROPVARIANT;
28use windows::Win32::UI::ColorSystem::{
29 GetDeviceGammaRamp,
30 SetDeviceGammaRamp,
31};
32use windows::core::GUID;
33
34use crate::com::{
35 ComInterfaceExt,
36 ComTaskMemory,
37};
38use crate::internal::{
39 ResultExt,
40 ReturnValue,
41};
42use crate::string::ZeroTerminatedWideString;
43
44#[derive(Debug)]
45pub(crate) struct ScreenDeviceContext {
46 raw_context: HDC,
47 phantom: PhantomData<*mut ()>,
48}
49
50impl ScreenDeviceContext {
51 #[expect(dead_code)]
52 pub(crate) fn get() -> io::Result<Self> {
53 let result = unsafe { GetDC(None).if_null_to_error(|| io::ErrorKind::Other.into())? };
54 Ok(Self {
55 raw_context: result,
56 phantom: PhantomData,
57 })
58 }
59
60 #[expect(dead_code)]
61 pub(crate) fn get_raw_gamma_ramp(&self) -> io::Result<[[u16; 256]; 3]> {
62 let mut rgbs: [[u16; 256]; 3] = [[0; 256]; 3];
63 let _ = unsafe {
64 GetDeviceGammaRamp(self.raw_context, rgbs.as_mut_ptr().cast::<c_void>())
65 .if_null_to_error(|| io::ErrorKind::Other.into())?
66 };
67 Ok(rgbs)
68 }
69
70 #[expect(dead_code)]
71 pub(crate) fn set_raw_gamma_ramp(&self, values: &[[u16; 256]; 3]) -> io::Result<()> {
72 let _ = unsafe {
73 SetDeviceGammaRamp(self.raw_context, values.as_ptr().cast::<c_void>())
74 .if_null_to_error(|| io::ErrorKind::Other.into())?
75 };
76 Ok(())
77 }
78}
79
80impl Drop for ScreenDeviceContext {
81 fn drop(&mut self) {
82 unsafe { ReleaseDC(None, self.raw_context) }
83 .if_null_to_error_else_drop(|| io::ErrorKind::Other.into())
84 .unwrap_or_default_and_print_error();
85 }
86}
87
88impl ReturnValue for HDC {
89 const NULL_VALUE: Self = HDC(std::ptr::null_mut());
90}
91
92impl ComInterfaceExt for IMMDeviceEnumerator {
93 const CLASS_GUID: GUID = MMDeviceEnumerator;
94}
95
96#[derive(Clone, Eq, Debug)]
98pub struct AudioOutputDevice {
99 id: OsString,
100 friendly_name: String,
101}
102
103impl AudioOutputDevice {
104 pub fn get_active_devices() -> io::Result<Vec<Self>> {
106 let enumerator = IMMDeviceEnumerator::new_instance()?;
107 let endpoints = unsafe { enumerator.EnumAudioEndpoints(eRender, DEVICE_STATE_ACTIVE) }?;
108 let num_endpoints = unsafe { endpoints.GetCount() }?;
109 (0..num_endpoints)
110 .map(|idx| {
111 let item = unsafe { endpoints.Item(idx)? };
112 item.try_into()
113 })
114 .collect()
115 }
116
117 pub fn get_id(&self) -> &OsStr {
119 &self.id
120 }
121
122 pub fn get_friendly_name(&self) -> &str {
124 &self.friendly_name
125 }
126
127 pub fn get_global_default() -> io::Result<Self> {
129 let enumerator = IMMDeviceEnumerator::new_instance()?;
130 let raw_device = unsafe { enumerator.GetDefaultAudioEndpoint(eRender, eConsole) }?;
131 raw_device.try_into()
132 }
133
134 pub fn set_global_default(&self) -> io::Result<()> {
136 let policy_config = policy_config::IPolicyConfig::new_instance()?;
137 let result = unsafe {
138 policy_config.SetDefaultEndpoint(
139 ZeroTerminatedWideString::from_os_str(self.get_id()).as_raw_pcwstr(),
140 eConsole,
141 )
142 };
143 result.map_err(Into::into)
144 }
145}
146
147impl TryFrom<IMMDevice> for AudioOutputDevice {
148 type Error = io::Error;
149
150 fn try_from(item: IMMDevice) -> Result<Self, Self::Error> {
151 let raw_id = unsafe { item.GetId()? };
152 let _raw_id_memory = ComTaskMemory(raw_id.as_ptr());
153 let property_store = unsafe { item.OpenPropertyStore(STGM_READ) }?;
154 let friendly_name_prop: PROPVARIANT =
155 unsafe { property_store.GetValue(&PKEY_Device_FriendlyName)? };
156 let friendly_name = friendly_name_prop.to_string();
157 let copy = AudioOutputDevice {
158 id: OsString::from_wide(unsafe { raw_id.as_wide() }),
159 friendly_name,
160 };
161 Ok(copy)
162 }
163}
164
165impl PartialEq for AudioOutputDevice {
166 fn eq(&self, other: &Self) -> bool {
167 self.id == other.id
168 }
169}
170
171mod policy_config {
172 #![allow(non_upper_case_globals, non_snake_case)]
173
174 use std::ffi::c_void;
175
176 use windows::Win32::Media::Audio::ERole;
177 use windows::core::{
178 GUID,
179 Interface,
180 PCWSTR,
181 };
182
183 use crate::com::ComInterfaceExt;
184
185 #[repr(transparent)]
186 pub struct IPolicyConfig(windows::core::IUnknown);
187
188 impl IPolicyConfig {
189 pub unsafe fn SetDefaultEndpoint<P0, P1>(
190 &self,
191 deviceId: P0,
192 eRole: P1,
193 ) -> windows::core::Result<()>
194 where
195 P0: Into<PCWSTR>,
196 P1: Into<ERole>,
197 {
198 unsafe {
199 (Interface::vtable(self).SetDefaultEndpoint)(
200 Interface::as_raw(self),
201 deviceId.into(),
202 eRole.into(),
203 )
204 .ok()
205 }
206 }
207 }
208
209 #[expect(clippy::transmute_ptr_to_ptr)]
210 mod interface_hierarchy {
211 use super::IPolicyConfig;
212
213 windows::core::imp::interface_hierarchy!(IPolicyConfig, windows::core::IUnknown);
214 }
215
216 impl Clone for IPolicyConfig {
217 fn clone(&self) -> Self {
218 Self(self.0.clone())
219 }
220 }
221 impl PartialEq for IPolicyConfig {
222 fn eq(&self, other: &Self) -> bool {
223 self.0 == other.0
224 }
225 }
226 impl Eq for IPolicyConfig {}
227 impl core::fmt::Debug for IPolicyConfig {
228 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
229 f.debug_tuple("IPolicyConfig").field(&self.0).finish()
230 }
231 }
232
233 unsafe impl Interface for IPolicyConfig {
234 type Vtable = IPolicyConfig_Vtbl;
235 const IID: GUID = GUID::from_u128(0xf8679f50_850a_41cf_9c72_430f290290c8);
236 }
237
238 #[repr(C)]
239 #[allow(non_camel_case_types)]
240 pub struct IPolicyConfig_Vtbl {
241 pub base__: windows::core::IUnknown_Vtbl,
242 padding: [*const c_void; 10], pub SetDefaultEndpoint: unsafe extern "system" fn(
244 this: *mut c_void,
245 wszDeviceId: PCWSTR,
246 eRole: ERole,
247 ) -> windows::core::HRESULT,
248 padding2: [*const c_void; 1], }
250
251 const CPolicyConfigClient: GUID = GUID::from_u128(0x870af99c_171d_4f9e_af0d_e63df40c2bc9);
252
253 impl ComInterfaceExt for IPolicyConfig {
254 const CLASS_GUID: GUID = CPolicyConfigClient;
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn check_audio_device_list() -> io::Result<()> {
264 let devices = AudioOutputDevice::get_active_devices()?;
265 if let Some(device) = devices.first() {
266 assert!(!device.id.is_empty());
267 }
268 Ok(())
269 }
270
271 #[test]
272 fn check_get_global_default() {
273 if let Ok(device) = AudioOutputDevice::get_global_default() {
275 std::hint::black_box(&device);
276 }
277 }
278}