msquic_async/
listener.rs

1use crate::connection::Connection;
2
3use std::collections::VecDeque;
4use std::future::Future;
5use std::net::SocketAddr;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll, Waker};
8
9use thiserror::Error;
10use tracing::trace;
11
12/// Listener for incoming connections.
13pub struct Listener {
14    inner: Arc<ListenerInner>,
15    msquic_listener: msquic::Listener,
16}
17
18impl Listener {
19    /// Create a new listener.
20    pub fn new(
21        registration: &msquic::Registration,
22        configuration: msquic::Configuration,
23    ) -> Result<Self, ListenError> {
24        let inner = Arc::new(ListenerInner::new(configuration));
25        let inner_in_ev = inner.clone();
26        let msquic_listener = msquic::Listener::open(registration, move |_, ev| match ev {
27            msquic::ListenerEvent::NewConnection { info, connection } => {
28                inner_in_ev.handle_event_new_connection(info, connection)
29            }
30            msquic::ListenerEvent::StopComplete {
31                app_close_in_progress,
32            } => inner_in_ev.handle_event_stop_complete(app_close_in_progress),
33        })
34        .map_err(ListenError::OtherError)?;
35        trace!("Listener({:p}) new", inner);
36        Ok(Self {
37            inner,
38            msquic_listener,
39        })
40    }
41
42    /// Start the listener.
43    pub fn start<T: AsRef<[msquic::BufferRef]>>(
44        &self,
45        alpn: &T,
46        local_address: Option<SocketAddr>,
47    ) -> Result<(), ListenError> {
48        let mut exclusive = self.inner.exclusive.lock().unwrap();
49        match exclusive.state {
50            ListenerState::Open | ListenerState::ShutdownComplete => {}
51            ListenerState::StartComplete | ListenerState::Shutdown => {
52                return Err(ListenError::AlreadyStarted);
53            }
54        }
55        let local_address: Option<msquic::Addr> = local_address.map(|x| x.into());
56        self.msquic_listener
57            .start(alpn.as_ref(), local_address.as_ref())
58            .map_err(ListenError::OtherError)?;
59        exclusive.state = ListenerState::StartComplete;
60        Ok(())
61    }
62
63    /// Accept a new connection.
64    pub fn accept(&self) -> Accept {
65        Accept(self)
66    }
67
68    /// Poll to accept a new connection.
69    pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<Connection, ListenError>> {
70        trace!("Listener({:p}) poll_accept", self);
71        let mut exclusive = self.inner.exclusive.lock().unwrap();
72
73        if !exclusive.new_connections.is_empty() {
74            return Poll::Ready(Ok(exclusive.new_connections.pop_front().unwrap()));
75        }
76
77        match exclusive.state {
78            ListenerState::Open => {
79                return Poll::Ready(Err(ListenError::NotStarted));
80            }
81            ListenerState::StartComplete | ListenerState::Shutdown => {}
82            ListenerState::ShutdownComplete => {
83                return Poll::Ready(Err(ListenError::Finished));
84            }
85        }
86        exclusive.new_connection_waiters.push(cx.waker().clone());
87        Poll::Pending
88    }
89
90    /// Stop the listener.
91    pub fn stop(&self) -> Stop {
92        Stop(self)
93    }
94
95    /// Poll to stop the listener.
96    pub fn poll_stop(&self, cx: &mut Context<'_>) -> Poll<Result<(), ListenError>> {
97        trace!("Listener({:p}) poll_stop", self);
98        let mut call_stop = false;
99        {
100            let mut exclusive = self.inner.exclusive.lock().unwrap();
101
102            match exclusive.state {
103                ListenerState::Open => {
104                    return Poll::Ready(Err(ListenError::NotStarted));
105                }
106                ListenerState::StartComplete => {
107                    call_stop = true;
108                    exclusive.state = ListenerState::Shutdown;
109                }
110                ListenerState::Shutdown => {}
111                ListenerState::ShutdownComplete => {
112                    return Poll::Ready(Ok(()));
113                }
114            }
115            exclusive.shutdown_complete_waiters.push(cx.waker().clone());
116        }
117        if call_stop {
118            self.msquic_listener.stop();
119        }
120        Poll::Pending
121    }
122
123    /// Get the local address the listener is bound to.
124    pub fn local_addr(&self) -> Result<SocketAddr, ListenError> {
125        self.msquic_listener
126            .get_local_addr()
127            .map(|addr| addr.as_socket().expect("not a socket address"))
128            .map_err(|_| ListenError::Failed)
129    }
130}
131
132impl Drop for Listener {
133    fn drop(&mut self) {
134        trace!("Listener(Inner: {:p}) dropping", self.inner);
135    }
136}
137
138struct ListenerInner {
139    exclusive: Mutex<ListenerInnerExclusive>,
140    shared: ListenerInnerShared,
141}
142
143struct ListenerInnerExclusive {
144    state: ListenerState,
145    new_connections: VecDeque<Connection>,
146    new_connection_waiters: Vec<Waker>,
147    shutdown_complete_waiters: Vec<Waker>,
148}
149unsafe impl Sync for ListenerInnerExclusive {}
150unsafe impl Send for ListenerInnerExclusive {}
151
152struct ListenerInnerShared {
153    configuration: msquic::Configuration,
154}
155unsafe impl Sync for ListenerInnerShared {}
156unsafe impl Send for ListenerInnerShared {}
157
158#[derive(Debug, Clone, PartialEq)]
159enum ListenerState {
160    Open,
161    StartComplete,
162    Shutdown,
163    ShutdownComplete,
164}
165
166impl ListenerInner {
167    fn new(configuration: msquic::Configuration) -> Self {
168        Self {
169            exclusive: Mutex::new(ListenerInnerExclusive {
170                state: ListenerState::Open,
171                new_connections: VecDeque::new(),
172                new_connection_waiters: Vec::new(),
173                shutdown_complete_waiters: Vec::new(),
174            }),
175            shared: ListenerInnerShared { configuration },
176        }
177    }
178
179    fn handle_event_new_connection(
180        &self,
181        _info: msquic::NewConnectionInfo<'_>,
182        connection: msquic::ConnectionRef,
183    ) -> Result<(), msquic::Status> {
184        trace!("Listener({:p}) New connection", self);
185
186        connection.set_configuration(&self.shared.configuration)?;
187        let new_conn = Connection::from_raw(unsafe { connection.as_raw() });
188
189        let mut exclusive = self.exclusive.lock().unwrap();
190        exclusive.new_connections.push_back(new_conn);
191        exclusive
192            .new_connection_waiters
193            .drain(..)
194            .for_each(|waker| waker.wake());
195        Ok(())
196    }
197
198    fn handle_event_stop_complete(
199        &self,
200        app_close_in_progress: bool,
201    ) -> Result<(), msquic::Status> {
202        trace!(
203            "Listener({:p}) Stop complete: app_close_in_progress={}",
204            self,
205            app_close_in_progress
206        );
207        {
208            let mut exclusive = self.exclusive.lock().unwrap();
209            exclusive.state = ListenerState::ShutdownComplete;
210
211            exclusive
212                .new_connection_waiters
213                .drain(..)
214                .for_each(|waker| waker.wake());
215
216            exclusive
217                .shutdown_complete_waiters
218                .drain(..)
219                .for_each(|waker| waker.wake());
220            trace!(
221                "Listener({:p}) new_connections's len={}",
222                self,
223                exclusive.new_connections.len()
224            );
225        }
226        // unsafe {
227        //     Arc::from_raw(self as *const _);
228        // }
229        Ok(())
230    }
231}
232
233impl Drop for ListenerInner {
234    fn drop(&mut self) {
235        trace!("ListenerInner({:p}) dropping", self);
236    }
237}
238
239/// Future generated by `[Listener::accept()]`.
240pub struct Accept<'a>(&'a Listener);
241
242impl Future for Accept<'_> {
243    type Output = Result<Connection, ListenError>;
244
245    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
246        self.0.poll_accept(cx)
247    }
248}
249
250/// Future generated by `[Listener::stop()]`.
251pub struct Stop<'a>(&'a Listener);
252
253impl Future for Stop<'_> {
254    type Output = Result<(), ListenError>;
255
256    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
257        self.0.poll_stop(cx)
258    }
259}
260
261#[derive(Debug, Error, Clone)]
262pub enum ListenError {
263    #[error("Not started yet")]
264    NotStarted,
265    #[error("already started")]
266    AlreadyStarted,
267    #[error("finished")]
268    Finished,
269    #[error("failed")]
270    Failed,
271    #[error("other error: status {0:?}")]
272    OtherError(msquic::Status),
273}