audio_device/wasapi/
client.rs

1use crate::loom::sync::Arc;
2use crate::wasapi::{ClientConfig, Error, InitializedClient, Sample, SampleFormat};
3use crate::windows::{AsyncEvent, Event, RawEvent};
4use std::marker;
5use std::mem;
6use std::ptr;
7use windows_sys::Windows::Win32::Com as com;
8use windows_sys::Windows::Win32::CoreAudio as core;
9use windows_sys::Windows::Win32::Multimedia as mm;
10use windows_sys::Windows::Win32::SystemServices as ss;
11
12/// An audio client.
13pub struct Client {
14    pub(super) tag: ste::Tag,
15    pub(super) audio_client: core::IAudioClient,
16}
17
18impl Client {
19    /// Get the default client configuration.
20    pub fn default_client_config(&self) -> Result<ClientConfig, Error> {
21        let tag = ste::Tag::current_thread();
22
23        unsafe {
24            let mut mix_format = mem::MaybeUninit::<*mut mm::WAVEFORMATEX>::zeroed();
25
26            self.audio_client
27                .GetMixFormat(mix_format.as_mut_ptr())
28                .ok()?;
29
30            let mix_format = mix_format.assume_init() as *const mm::WAVEFORMATEX;
31
32            let bits_per_sample = (*mix_format).wBitsPerSample;
33
34            let sample_format = match (*mix_format).wFormatTag as u32 {
35                core::WAVE_FORMAT_EXTENSIBLE => {
36                    debug_assert_eq! {
37                        (*mix_format).cbSize as usize,
38                        mem::size_of::<mm::WAVEFORMATEXTENSIBLE>() - mem::size_of::<mm::WAVEFORMATEX>()
39                    };
40
41                    let mix_format = mix_format as *const mm::WAVEFORMATEXTENSIBLE;
42
43                    if bits_per_sample == 32
44                        && matches!((*mix_format).SubFormat, mm::KSDATAFORMAT_SUBTYPE_IEEE_FLOAT)
45                    {
46                        SampleFormat::F32
47                    } else {
48                        return Err(Error::UnsupportedMixFormat);
49                    }
50                }
51                mm::WAVE_FORMAT_PCM => {
52                    assert!((*mix_format).cbSize == 0);
53
54                    if bits_per_sample == 16 {
55                        SampleFormat::I16
56                    } else {
57                        return Err(Error::UnsupportedMixFormat);
58                    }
59                }
60                _ => {
61                    return Err(Error::UnsupportedMixFormat);
62                }
63            };
64
65            let channels = (*mix_format).nChannels;
66            let sample_rate = (*mix_format).nSamplesPerSec;
67
68            Ok(ClientConfig {
69                tag,
70                channels,
71                sample_rate,
72                sample_format,
73            })
74        }
75    }
76
77    /// Try to initialize the client with the given configuration.
78    pub fn initialize<T>(&self, config: ClientConfig) -> Result<InitializedClient<T, Event>, Error>
79    where
80        T: Sample,
81    {
82        self.initialize_inner(config, || Event::new(false, false))
83    }
84
85    cfg_events_driver! {
86        /// Try to initialize the client with the given configuration.
87        ///
88        /// # Panics
89        ///
90        /// Panics if the audio runtime is not available.
91        ///
92        /// See [Runtime][crate::runtime::Runtime] for more.
93        pub fn initialize_async<T>(
94            &self,
95            config: ClientConfig,
96        ) -> Result<InitializedClient<T, AsyncEvent>, Error>
97        where
98            T: Sample,
99        {
100            self.initialize_inner(config, || AsyncEvent::new(false))
101        }
102    }
103
104    /// Try to initialize the client with the given configuration.
105    fn initialize_inner<T, F, E>(
106        &self,
107        mut config: ClientConfig,
108        event: F,
109    ) -> Result<InitializedClient<T, E>, Error>
110    where
111        T: Sample,
112        F: FnOnce() -> windows::Result<E>,
113        E: RawEvent,
114    {
115        unsafe {
116            let mut mix_format = T::mix_format(config);
117            let mut closest_match: *mut mm::WAVEFORMATEXTENSIBLE = ptr::null_mut();
118
119            let result: windows::HRESULT = self.audio_client.IsFormatSupported(
120                core::AUDCLNT_SHAREMODE::AUDCLNT_SHAREMODE_SHARED,
121                &mix_format as *const _ as *const mm::WAVEFORMATEX,
122                &mut closest_match as *mut _ as *mut *mut mm::WAVEFORMATEX,
123            );
124
125            if result == ss::S_FALSE {
126                if !T::is_compatible_with(closest_match as *const _) {
127                    return Err(Error::UnsupportedMixFormat);
128                }
129
130                mix_format = *closest_match;
131                config.sample_rate = mix_format.Format.nSamplesPerSec;
132                config.channels = mix_format.Format.nChannels;
133                com::CoTaskMemFree(closest_match as *mut _);
134            } else {
135                debug_assert!(closest_match.is_null());
136                result.ok()?;
137            };
138
139            self.audio_client
140                .Initialize(
141                    core::AUDCLNT_SHAREMODE::AUDCLNT_SHAREMODE_SHARED,
142                    core::AUDCLNT_STREAMFLAGS_EVENTCALLBACK,
143                    0,
144                    0,
145                    &mix_format as *const _ as *const mm::WAVEFORMATEX,
146                    ptr::null_mut(),
147                )
148                .ok()?;
149
150            let event = Arc::new(event()?);
151
152            self.audio_client.SetEventHandle(event.raw_event()).ok()?;
153
154            let mut buffer_size = mem::MaybeUninit::<u32>::uninit();
155            self.audio_client
156                .GetBufferSize(buffer_size.as_mut_ptr())
157                .ok()?;
158            let buffer_size = buffer_size.assume_init();
159
160            Ok(InitializedClient {
161                tag: self.tag,
162                audio_client: self.audio_client.clone(),
163                config,
164                buffer_size,
165                event,
166                _marker: marker::PhantomData,
167            })
168        }
169    }
170
171    /// Start playback on device.
172    pub fn start(&self) -> Result<(), Error> {
173        unsafe {
174            self.audio_client.Start().ok()?;
175        }
176
177        Ok(())
178    }
179
180    /// Stop playback on device.
181    pub fn stop(&self) -> Result<(), Error> {
182        unsafe {
183            self.audio_client.Stop().ok()?;
184        }
185
186        Ok(())
187    }
188}
189
190// Safety: thread safety is ensured through tagging with ste::Tag.
191unsafe impl Send for Client {}