msquic_async/
listener.rs

1use crate::connection::Connection;
2
3use std::collections::VecDeque;
4use std::future::Future;
5use std::net::SocketAddr;
6use std::sync::Mutex;
7use std::task::{Context, Poll, Waker};
8
9use libc::c_void;
10use thiserror::Error;
11use tracing::trace;
12
13/// Listener for incoming connections.
14pub struct Listener(Box<ListenerInner>);
15
16impl Listener {
17    /// Create a new listener.
18    pub fn new(
19        msquic_listener: msquic::Listener,
20        registration: &msquic::Registration,
21        configuration: msquic::Configuration,
22    ) -> Result<Self, ListenError> {
23        let inner = Box::new(ListenerInner::new(msquic_listener, configuration));
24        {
25            inner
26                .shared
27                .msquic_listener
28                .open(
29                    registration,
30                    ListenerInner::native_callback,
31                    &*inner as *const _ as *const c_void,
32                )
33                .map_err(ListenError::OtherError)?;
34        }
35        Ok(Self(inner))
36    }
37
38    /// Start the listener.
39    pub fn start<T: AsRef<[msquic::Buffer]>>(
40        &self,
41        alpn: &T,
42        local_address: Option<SocketAddr>,
43    ) -> Result<(), ListenError> {
44        let mut exclusive = self.0.exclusive.lock().unwrap();
45        match exclusive.state {
46            ListenerState::Open | ListenerState::ShutdownComplete => {}
47            ListenerState::StartComplete | ListenerState::Shutdown => {
48                return Err(ListenError::AlreadyStarted);
49            }
50        }
51        let local_address: Option<msquic::Addr> = local_address.map(|x| x.into());
52        self.0
53            .shared
54            .msquic_listener
55            .start(alpn.as_ref(), local_address.as_ref())
56            .map_err(ListenError::OtherError)?;
57        exclusive.state = ListenerState::StartComplete;
58        Ok(())
59    }
60
61    /// Accept a new connection.
62    pub fn accept(&self) -> Accept {
63        Accept(self)
64    }
65
66    /// Poll to accept a new connection.
67    pub fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<Connection, ListenError>> {
68        let mut exclusive = self.0.exclusive.lock().unwrap();
69
70        if !exclusive.new_connections.is_empty() {
71            return Poll::Ready(Ok(exclusive.new_connections.pop_front().unwrap()));
72        }
73
74        match exclusive.state {
75            ListenerState::Open => {
76                return Poll::Ready(Err(ListenError::NotStarted));
77            }
78            ListenerState::StartComplete | ListenerState::Shutdown => {}
79            ListenerState::ShutdownComplete => {
80                return Poll::Ready(Err(ListenError::Finished));
81            }
82        }
83        exclusive.new_connection_waiters.push(cx.waker().clone());
84        Poll::Pending
85    }
86
87    /// Stop the listener.
88    pub fn stop(&self) -> Stop {
89        Stop(self)
90    }
91
92    /// Poll to stop the listener.
93    pub fn poll_stop(&self, cx: &mut Context<'_>) -> Poll<Result<(), ListenError>> {
94        let mut call_stop = false;
95        {
96            let mut exclusive = self.0.exclusive.lock().unwrap();
97
98            match exclusive.state {
99                ListenerState::Open => {
100                    return Poll::Ready(Err(ListenError::NotStarted));
101                }
102                ListenerState::StartComplete => {
103                    call_stop = true;
104                    exclusive.state = ListenerState::Shutdown;
105                }
106                ListenerState::Shutdown => {}
107                ListenerState::ShutdownComplete => {
108                    return Poll::Ready(Ok(()));
109                }
110            }
111            exclusive.shutdown_complete_waiters.push(cx.waker().clone());
112        }
113        if call_stop {
114            self.0.shared.msquic_listener.stop();
115        }
116        Poll::Pending
117    }
118
119    /// Get the local address the listener is bound to.
120    pub fn local_addr(&self) -> Result<SocketAddr, ListenError> {
121        self.0
122            .shared
123            .msquic_listener
124            .get_local_addr()
125            .map(|addr| addr.as_socket().expect("not a socket address"))
126            .map_err(|_| ListenError::Failed)
127    }
128}
129
130struct ListenerInner {
131    exclusive: Mutex<ListenerInnerExclusive>,
132    shared: ListenerInnerShared,
133}
134
135struct ListenerInnerExclusive {
136    state: ListenerState,
137    new_connections: VecDeque<Connection>,
138    new_connection_waiters: Vec<Waker>,
139    shutdown_complete_waiters: Vec<Waker>,
140}
141unsafe impl Sync for ListenerInnerExclusive {}
142unsafe impl Send for ListenerInnerExclusive {}
143
144struct ListenerInnerShared {
145    msquic_listener: msquic::Listener,
146    configuration: msquic::Configuration,
147}
148unsafe impl Sync for ListenerInnerShared {}
149unsafe impl Send for ListenerInnerShared {}
150
151#[derive(Debug, Clone, PartialEq)]
152enum ListenerState {
153    Open,
154    StartComplete,
155    Shutdown,
156    ShutdownComplete,
157}
158
159impl ListenerInner {
160    fn new(msquic_listener: msquic::Listener, configuration: msquic::Configuration) -> Self {
161        Self {
162            exclusive: Mutex::new(ListenerInnerExclusive {
163                state: ListenerState::Open,
164                new_connections: VecDeque::new(),
165                new_connection_waiters: Vec::new(),
166                shutdown_complete_waiters: Vec::new(),
167            }),
168            shared: ListenerInnerShared {
169                msquic_listener,
170                configuration,
171            },
172        }
173    }
174
175    fn handle_event_new_connection(
176        inner: &Self,
177        payload: &msquic::ListenerEventNewConnection,
178    ) -> u32 {
179        trace!("Listener({:p}) new connection event", inner);
180
181        let new_conn = Connection::from_handle(payload.connection);
182        if let Err(status) = new_conn.set_configuration(&inner.shared.configuration) {
183            return status;
184        }
185
186        let mut exclusive = inner.exclusive.lock().unwrap();
187        exclusive.new_connections.push_back(new_conn);
188        exclusive
189            .new_connection_waiters
190            .drain(..)
191            .for_each(|waker| waker.wake());
192        msquic::QUIC_STATUS_SUCCESS
193    }
194
195    fn handle_event_stop_complete(
196        inner: &Self,
197        _payload: &msquic::ListenerEventStopComplete,
198    ) -> u32 {
199        trace!("Listener({:p}) stop complete", inner);
200        let mut exclusive = inner.exclusive.lock().unwrap();
201        exclusive.state = ListenerState::ShutdownComplete;
202
203        exclusive
204            .new_connection_waiters
205            .drain(..)
206            .for_each(|waker| waker.wake());
207
208        exclusive
209            .shutdown_complete_waiters
210            .drain(..)
211            .for_each(|waker| waker.wake());
212        msquic::QUIC_STATUS_SUCCESS
213    }
214
215    extern "C" fn native_callback(
216        _listener: msquic::Handle,
217        context: *mut c_void,
218        event: &msquic::ListenerEvent,
219    ) -> u32 {
220        let inner = unsafe { &mut *(context as *mut Self) };
221        match event.event_type {
222            msquic::LISTENER_EVENT_NEW_CONNECTION => {
223                Self::handle_event_new_connection(inner, unsafe { &event.payload.new_connection })
224            }
225            msquic::LISTENER_EVENT_STOP_COMPLETE => {
226                Self::handle_event_stop_complete(inner, unsafe { &event.payload.stop_complete })
227            }
228
229            _ => {
230                println!("Other callback {}", event.event_type);
231                0
232            }
233        }
234    }
235}
236
237/// Future generated by `[Listener::accept()]`.
238pub struct Accept<'a>(&'a Listener);
239
240impl Future for Accept<'_> {
241    type Output = Result<Connection, ListenError>;
242
243    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
244        self.0.poll_accept(cx)
245    }
246}
247
248/// Future generated by `[Listener::stop()]`.
249pub struct Stop<'a>(&'a Listener);
250
251impl Future for Stop<'_> {
252    type Output = Result<(), ListenError>;
253
254    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
255        self.0.poll_stop(cx)
256    }
257}
258
259#[derive(Debug, Error, Clone, PartialEq, Eq)]
260pub enum ListenError {
261    #[error("Not started yet")]
262    NotStarted,
263    #[error("already started")]
264    AlreadyStarted,
265    #[error("finished")]
266    Finished,
267    #[error("failed")]
268    Failed,
269    #[error("other error: status 0x{0:x}")]
270    OtherError(u32),
271}