Skip to main content

msquic_async/
listener.rs

1use crate::connection::Connection;
2
3#[cfg(feature = "msquic-2-5")]
4use msquic_v2_5 as msquic;
5#[cfg(feature = "msquic-seera")]
6use seera_msquic as msquic;
7
8use std::collections::VecDeque;
9use std::future::Future;
10use std::net::SocketAddr;
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll, Waker};
13
14use thiserror::Error;
15use tracing::{error, info, trace};
16
17/// Listener for incoming connections.
18pub struct Listener {
19    inner: Arc<ListenerInner>,
20    msquic_listener: msquic::Listener,
21}
22
23impl Listener {
24    /// Create a new listener.
25    pub fn new(
26        registration: &msquic::Registration,
27        configuration: msquic::Configuration,
28    ) -> Result<Self, ListenError> {
29        let inner = Arc::new(ListenerInner::new(configuration));
30        let inner_in_ev = inner.clone();
31        let msquic_listener = msquic::Listener::open(registration, move |_, ev| match ev {
32            msquic::ListenerEvent::NewConnection { info, connection } => {
33                inner_in_ev.handle_event_new_connection(info, connection)
34            }
35            msquic::ListenerEvent::StopComplete {
36                app_close_in_progress,
37            } => inner_in_ev.handle_event_stop_complete(app_close_in_progress),
38        })
39        .map_err(ListenError::OtherError)?;
40        trace!("Listener({:p}) new", inner);
41        Ok(Self {
42            inner,
43            msquic_listener,
44        })
45    }
46
47    /// Start the listener.
48    pub fn start<T: AsRef<[msquic::BufferRef]>>(
49        &self,
50        alpn: &T,
51        local_address: Option<SocketAddr>,
52    ) -> Result<(), ListenError> {
53        let mut exclusive = self.inner.exclusive.lock().unwrap();
54        match exclusive.state {
55            ListenerState::Open | ListenerState::ShutdownComplete => {}
56            ListenerState::StartComplete | ListenerState::Shutdown => {
57                return Err(ListenError::AlreadyStarted);
58            }
59        }
60        let local_address: Option<msquic::Addr> = local_address.map(|x| x.into());
61        self.msquic_listener
62            .start(alpn.as_ref(), local_address.as_ref())
63            .map_err(ListenError::OtherError)?;
64        exclusive.state = ListenerState::StartComplete;
65        Ok(())
66    }
67
68    /// Accept a new connection.
69    pub fn accept(&self) -> Accept<'_> {
70        Accept(self)
71    }
72
73    /// Poll to accept a new connection.
74    pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<Connection, ListenError>> {
75        trace!("Listener({:p}) poll_accept", self);
76        let mut exclusive = self.inner.exclusive.lock().unwrap();
77
78        if !exclusive.new_connections.is_empty() {
79            return Poll::Ready(Ok(exclusive.new_connections.pop_front().unwrap()));
80        }
81
82        match exclusive.state {
83            ListenerState::Open => {
84                return Poll::Ready(Err(ListenError::NotStarted));
85            }
86            ListenerState::StartComplete | ListenerState::Shutdown => {}
87            ListenerState::ShutdownComplete => {
88                return Poll::Ready(Err(ListenError::Finished));
89            }
90        }
91        exclusive.new_connection_waiters.push(cx.waker().clone());
92        Poll::Pending
93    }
94
95    /// Stop the listener.
96    pub fn stop(&self) -> Stop<'_> {
97        Stop(self)
98    }
99
100    /// Poll to stop the listener.
101    pub fn poll_stop(&self, cx: &mut Context<'_>) -> Poll<Result<(), ListenError>> {
102        trace!("Listener({:p}) poll_stop", self);
103        let mut call_stop = false;
104        {
105            let mut exclusive = self.inner.exclusive.lock().unwrap();
106
107            match exclusive.state {
108                ListenerState::Open => {
109                    return Poll::Ready(Err(ListenError::NotStarted));
110                }
111                ListenerState::StartComplete => {
112                    call_stop = true;
113                    exclusive.state = ListenerState::Shutdown;
114                }
115                ListenerState::Shutdown => {}
116                ListenerState::ShutdownComplete => {
117                    return Poll::Ready(Ok(()));
118                }
119            }
120            exclusive.shutdown_complete_waiters.push(cx.waker().clone());
121        }
122        if call_stop {
123            self.msquic_listener.stop();
124        }
125        Poll::Pending
126    }
127
128    /// Get the local address the listener is bound to.
129    pub fn local_addr(&self) -> Result<SocketAddr, ListenError> {
130        self.msquic_listener
131            .get_local_addr()
132            .map(|addr| addr.as_socket().expect("not a socket address"))
133            .map_err(|_| ListenError::Failed)
134    }
135
136    /// Set the SSL key log file for new connections.
137    pub fn set_sslkeylog_file(&self, file: std::fs::File) -> Result<(), ListenError> {
138        let mut exclusive = self.inner.exclusive.lock().unwrap();
139        if exclusive.sslkeylog_file.is_some() {
140            return Err(ListenError::SslKeyLogFileAlreadySet);
141        }
142        exclusive.sslkeylog_file = Some(file);
143        Ok(())
144    }
145}
146
147impl Drop for Listener {
148    fn drop(&mut self) {
149        trace!("Listener(Inner: {:p}) dropping", self.inner);
150    }
151}
152
153struct ListenerInner {
154    exclusive: Mutex<ListenerInnerExclusive>,
155    shared: ListenerInnerShared,
156}
157
158struct ListenerInnerExclusive {
159    state: ListenerState,
160    new_connections: VecDeque<Connection>,
161    new_connection_waiters: Vec<Waker>,
162    shutdown_complete_waiters: Vec<Waker>,
163    sslkeylog_file: Option<std::fs::File>,
164}
165unsafe impl Sync for ListenerInnerExclusive {}
166unsafe impl Send for ListenerInnerExclusive {}
167
168struct ListenerInnerShared {
169    configuration: msquic::Configuration,
170}
171unsafe impl Sync for ListenerInnerShared {}
172unsafe impl Send for ListenerInnerShared {}
173
174#[derive(Debug, Clone, PartialEq)]
175enum ListenerState {
176    Open,
177    StartComplete,
178    Shutdown,
179    ShutdownComplete,
180}
181
182impl ListenerInner {
183    fn new(configuration: msquic::Configuration) -> Self {
184        Self {
185            exclusive: Mutex::new(ListenerInnerExclusive {
186                state: ListenerState::Open,
187                new_connections: VecDeque::new(),
188                new_connection_waiters: Vec::new(),
189                shutdown_complete_waiters: Vec::new(),
190                sslkeylog_file: None,
191            }),
192            shared: ListenerInnerShared { configuration },
193        }
194    }
195
196    fn handle_event_new_connection(
197        &self,
198        _info: msquic::NewConnectionInfo<'_>,
199        #[cfg(feature = "msquic-2-5")] connection: msquic::ConnectionRef,
200        #[cfg(not(feature = "msquic-2-5"))] connection: msquic::Connection,
201    ) -> Result<(), msquic::Status> {
202        trace!("Listener({:p}) New connection", self);
203
204        let mut exclusive = self.exclusive.lock().unwrap();
205
206        let (sslkeylog_file, tls_secrets) = if let Some(file) = exclusive.sslkeylog_file.as_ref() {
207            let sslkeylog_file = match file.try_clone() {
208                Ok(f) => {
209                    info!(
210                        "Listener({:p}) SSL key log file set for new connection",
211                        self
212                    );
213                    Some(f)
214                }
215                Err(e) => {
216                    error!(
217                        "Listener({:p}) Failed to clone SSL key log file: {}",
218                        self, e
219                    );
220                    None
221                }
222            };
223
224            if sslkeylog_file.is_none() {
225                (None, None)
226            } else {
227                // Create a QUIC_TLS_SECRETS structure with zeroed fields
228                let tls_secrets = Box::new(msquic::ffi::QUIC_TLS_SECRETS {
229                    SecretLength: 0,
230                    ClientRandom: [0; 32],
231                    IsSet: msquic::ffi::QUIC_TLS_SECRETS__bindgen_ty_1 {
232                        _bitfield_align_1: [0; 0],
233                        _bitfield_1: msquic::ffi::QUIC_TLS_SECRETS__bindgen_ty_1::new_bitfield_1(
234                            0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
235                        ),
236                    },
237                    ClientEarlyTrafficSecret: [0; 64],
238                    ClientHandshakeTrafficSecret: [0; 64],
239                    ServerHandshakeTrafficSecret: [0; 64],
240                    ClientTrafficSecret0: [0; 64],
241                    ServerTrafficSecret0: [0; 64],
242                });
243                unsafe {
244                    msquic::Api::set_param(
245                        connection.as_raw(),
246                        msquic::ffi::QUIC_PARAM_CONN_TLS_SECRETS,
247                        std::mem::size_of::<msquic::ffi::QUIC_TLS_SECRETS>() as u32,
248                        tls_secrets.as_ref() as *const _ as *const _,
249                    )
250                }?;
251                (sslkeylog_file, Some(tls_secrets))
252            }
253        } else {
254            (None, None)
255        };
256        connection.set_configuration(&self.shared.configuration)?;
257        #[cfg(feature = "msquic-2-5")]
258        let new_conn =
259            Connection::from_raw(unsafe { connection.as_raw() }, tls_secrets, sslkeylog_file);
260        #[cfg(not(feature = "msquic-2-5"))]
261        let new_conn = Connection::from_raw(connection, tls_secrets, sslkeylog_file);
262
263        exclusive.new_connections.push_back(new_conn);
264        exclusive
265            .new_connection_waiters
266            .drain(..)
267            .for_each(|waker| waker.wake());
268        Ok(())
269    }
270
271    fn handle_event_stop_complete(
272        &self,
273        app_close_in_progress: bool,
274    ) -> Result<(), msquic::Status> {
275        trace!(
276            "Listener({:p}) Stop complete: app_close_in_progress={}",
277            self,
278            app_close_in_progress
279        );
280        {
281            let mut exclusive = self.exclusive.lock().unwrap();
282            exclusive.state = ListenerState::ShutdownComplete;
283
284            exclusive
285                .new_connection_waiters
286                .drain(..)
287                .for_each(|waker| waker.wake());
288
289            exclusive
290                .shutdown_complete_waiters
291                .drain(..)
292                .for_each(|waker| waker.wake());
293            trace!(
294                "Listener({:p}) new_connections's len={}",
295                self,
296                exclusive.new_connections.len()
297            );
298        }
299        // unsafe {
300        //     Arc::from_raw(self as *const _);
301        // }
302        Ok(())
303    }
304}
305
306impl Drop for ListenerInner {
307    fn drop(&mut self) {
308        trace!("ListenerInner({:p}) dropping", self);
309    }
310}
311
312/// Future generated by `[Listener::accept()]`.
313pub struct Accept<'a>(&'a Listener);
314
315impl Future for Accept<'_> {
316    type Output = Result<Connection, ListenError>;
317
318    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
319        self.0.poll_accept(cx)
320    }
321}
322
323/// Future generated by `[Listener::stop()]`.
324pub struct Stop<'a>(&'a Listener);
325
326impl Future for Stop<'_> {
327    type Output = Result<(), ListenError>;
328
329    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330        self.0.poll_stop(cx)
331    }
332}
333
334#[derive(Debug, Error, Clone)]
335pub enum ListenError {
336    #[error("Not started yet")]
337    NotStarted,
338    #[error("already started")]
339    AlreadyStarted,
340    #[error("finished")]
341    Finished,
342    #[error("failed")]
343    Failed,
344    #[error("SSL key log file already set")]
345    SslKeyLogFileAlreadySet,
346    #[error("other error: status {0:?}")]
347    OtherError(msquic::Status),
348}