1use std::{
2 collections::VecDeque,
3 io,
4 mem::ManuallyDrop,
5 net::{SocketAddr, SocketAddrV6},
6 pin::pin,
7 sync::{Arc, Mutex},
8 task::{Context, Poll, Waker},
9 time::Instant,
10};
11
12use compio_buf::{BufResult, bytes::Bytes};
13use compio_log::{Instrument, error};
14#[cfg(rustls)]
15use compio_net::ToSocketAddrsAsync;
16use compio_net::UdpSocket;
17use compio_runtime::JoinHandle;
18use flume::{Receiver, Sender, unbounded};
19use futures_util::{
20 FutureExt, StreamExt,
21 future::{self},
22 select,
23 task::AtomicWaker,
24};
25use quinn_proto::{
26 ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
27 EndpointEvent, ServerConfig, Transmit, VarInt,
28};
29use rustc_hash::FxHashMap as HashMap;
30
31use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket};
32
33#[derive(Debug)]
34struct EndpointState {
35 endpoint: quinn_proto::Endpoint,
36 worker: Option<JoinHandle<()>>,
37 connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
38 close: Option<(VarInt, Bytes)>,
39 exit_on_idle: bool,
40 incoming: VecDeque<quinn_proto::Incoming>,
41 incoming_wakers: VecDeque<Waker>,
42}
43
44impl EndpointState {
45 fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
46 let now = Instant::now();
47 for data in buf[..meta.len]
48 .chunks(meta.stride.min(meta.len))
49 .map(Into::into)
50 {
51 let mut resp_buf = Vec::new();
52 match self.endpoint.handle(
53 now,
54 meta.remote,
55 meta.local_ip,
56 meta.ecn,
57 data,
58 &mut resp_buf,
59 ) {
60 Some(DatagramEvent::NewConnection(incoming)) => {
61 if self.close.is_none() {
62 self.incoming.push_back(incoming);
63 } else {
64 let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
65 respond_fn(resp_buf, transmit);
66 }
67 }
68 Some(DatagramEvent::ConnectionEvent(ch, event)) => {
69 let _ = self
70 .connections
71 .get(&ch)
72 .unwrap()
73 .send(ConnectionEvent::Proto(event));
74 }
75 Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
76 None => {}
77 }
78 }
79 }
80
81 fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
82 if event.is_drained() {
83 self.connections.remove(&ch);
84 }
85 if let Some(event) = self.endpoint.handle_event(ch, event) {
86 let _ = self
87 .connections
88 .get(&ch)
89 .unwrap()
90 .send(ConnectionEvent::Proto(event));
91 }
92 }
93
94 fn is_idle(&self) -> bool {
95 self.connections.is_empty()
96 }
97
98 fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
99 if self.close.is_none() {
100 if let Some(incoming) = self.incoming.pop_front() {
101 Poll::Ready(Some(incoming))
102 } else {
103 self.incoming_wakers.push_back(cx.waker().clone());
104 Poll::Pending
105 }
106 } else {
107 Poll::Ready(None)
108 }
109 }
110
111 fn new_connection(
112 &mut self,
113 handle: ConnectionHandle,
114 conn: quinn_proto::Connection,
115 socket: Socket,
116 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
117 ) -> Connecting {
118 let (tx, rx) = unbounded();
119 if let Some((error_code, reason)) = &self.close {
120 tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
121 .unwrap();
122 }
123 self.connections.insert(handle, tx);
124 Connecting::new(handle, conn, socket, events_tx, rx)
125 }
126}
127
128type ChannelPair<T> = (Sender<T>, Receiver<T>);
129
130#[derive(Debug)]
131pub(crate) struct EndpointInner {
132 state: Mutex<EndpointState>,
133 socket: Socket,
134 ipv6: bool,
135 events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
136 done: AtomicWaker,
137}
138
139impl EndpointInner {
140 fn new(
141 socket: UdpSocket,
142 config: EndpointConfig,
143 server_config: Option<ServerConfig>,
144 ) -> io::Result<Self> {
145 let socket = Socket::new(socket)?;
146 let ipv6 = socket.local_addr()?.is_ipv6();
147 let allow_mtud = !socket.may_fragment();
148
149 Ok(Self {
150 state: Mutex::new(EndpointState {
151 endpoint: quinn_proto::Endpoint::new(
152 Arc::new(config),
153 server_config.map(Arc::new),
154 allow_mtud,
155 None,
156 ),
157 worker: None,
158 connections: HashMap::default(),
159 close: None,
160 exit_on_idle: false,
161 incoming: VecDeque::new(),
162 incoming_wakers: VecDeque::new(),
163 }),
164 socket,
165 ipv6,
166 events: unbounded(),
167 done: AtomicWaker::new(),
168 })
169 }
170
171 fn connect(
172 &self,
173 remote: SocketAddr,
174 server_name: &str,
175 config: ClientConfig,
176 ) -> Result<Connecting, ConnectError> {
177 let mut state = self.state.lock().unwrap();
178
179 if state.worker.is_none() {
180 return Err(ConnectError::EndpointStopping);
181 }
182 if remote.is_ipv6() && !self.ipv6 {
183 return Err(ConnectError::InvalidRemoteAddress(remote));
184 }
185 let remote = if self.ipv6 {
186 SocketAddr::V6(match remote {
187 SocketAddr::V4(addr) => {
188 SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
189 }
190 SocketAddr::V6(addr) => addr,
191 })
192 } else {
193 remote
194 };
195
196 let (handle, conn) = state
197 .endpoint
198 .connect(Instant::now(), config, remote, server_name)?;
199
200 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
201 }
202
203 fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
204 let socket = self.socket.clone();
205 compio_runtime::spawn(async move {
206 let _ = socket.send(buf, &transmit).await;
207 })
208 .detach();
209 }
210
211 pub(crate) fn accept(
212 &self,
213 incoming: quinn_proto::Incoming,
214 server_config: Option<ServerConfig>,
215 ) -> Result<Connecting, ConnectionError> {
216 let mut state = self.state.lock().unwrap();
217 let mut resp_buf = Vec::new();
218 let now = Instant::now();
219 match state
220 .endpoint
221 .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
222 {
223 Ok((handle, conn)) => {
224 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
225 }
226 Err(err) => {
227 if let Some(transmit) = err.response {
228 self.respond(resp_buf, transmit);
229 }
230 Err(err.cause)
231 }
232 }
233 }
234
235 pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
236 let mut state = self.state.lock().unwrap();
237 let mut resp_buf = Vec::new();
238 let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
239 self.respond(resp_buf, transmit);
240 }
241
242 #[allow(clippy::result_large_err)]
243 pub(crate) fn retry(
244 &self,
245 incoming: quinn_proto::Incoming,
246 ) -> Result<(), quinn_proto::RetryError> {
247 let mut state = self.state.lock().unwrap();
248 let mut resp_buf = Vec::new();
249 let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
250 self.respond(resp_buf, transmit);
251 Ok(())
252 }
253
254 pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
255 let mut state = self.state.lock().unwrap();
256 state.endpoint.ignore(incoming);
257 }
258
259 async fn run(&self) -> io::Result<()> {
260 let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
261
262 let mut recv_fut = pin!(
263 self.socket
264 .recv(Vec::with_capacity(
265 self.state
266 .lock()
267 .unwrap()
268 .endpoint
269 .config()
270 .get_max_udp_payload_size()
271 .min(64 * 1024) as usize
272 * self.socket.max_gro_segments(),
273 ))
274 .fuse()
275 );
276
277 let mut event_stream = self.events.1.stream().ready_chunks(100);
278
279 loop {
280 let mut state = select! {
281 BufResult(res, recv_buf) = recv_fut => {
282 let mut state = self.state.lock().unwrap();
283 match res {
284 Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
285 Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
286 #[cfg(windows)]
287 Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
288 Err(e) => break Err(e),
289 }
290 recv_fut.set(self.socket.recv(recv_buf).fuse());
291 state
292 },
293 events = event_stream.select_next_some() => {
294 let mut state = self.state.lock().unwrap();
295 for (ch, event) in events {
296 state.handle_event(ch, event);
297 }
298 state
299 },
300 };
301
302 if state.exit_on_idle && state.is_idle() {
303 break Ok(());
304 }
305 if !state.incoming.is_empty() {
306 let n = state.incoming.len().min(state.incoming_wakers.len());
307 state.incoming_wakers.drain(..n).for_each(Waker::wake);
308 }
309 }
310 }
311}
312
313#[derive(Debug, Clone)]
315pub struct Endpoint {
316 inner: Arc<EndpointInner>,
317 pub default_client_config: Option<ClientConfig>,
319}
320
321impl Endpoint {
322 pub fn new(
324 socket: UdpSocket,
325 config: EndpointConfig,
326 server_config: Option<ServerConfig>,
327 default_client_config: Option<ClientConfig>,
328 ) -> io::Result<Self> {
329 let inner = Arc::new(EndpointInner::new(socket, config, server_config)?);
330 let worker = compio_runtime::spawn({
331 let inner = inner.clone();
332 async move {
333 #[allow(unused)]
334 if let Err(e) = inner.run().await {
335 error!("I/O error: {}", e);
336 }
337 }
338 .in_current_span()
339 });
340 inner.state.lock().unwrap().worker = Some(worker);
341 Ok(Self {
342 inner,
343 default_client_config,
344 })
345 }
346
347 #[cfg(rustls)]
362 pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
363 let socket = UdpSocket::bind(addr).await?;
365 Self::new(socket, EndpointConfig::default(), None, None)
366 }
367
368 #[cfg(rustls)]
377 pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
378 let socket = UdpSocket::bind(addr).await?;
379 Self::new(socket, EndpointConfig::default(), Some(config), None)
380 }
381
382 pub fn connect(
384 &self,
385 remote: SocketAddr,
386 server_name: &str,
387 config: Option<ClientConfig>,
388 ) -> Result<Connecting, ConnectError> {
389 let config = config
390 .or_else(|| self.default_client_config.clone())
391 .ok_or(ConnectError::NoDefaultClientConfig)?;
392
393 self.inner.connect(remote, server_name, config)
394 }
395
396 pub async fn wait_incoming(&self) -> Option<Incoming> {
405 future::poll_fn(|cx| self.inner.state.lock().unwrap().poll_incoming(cx))
406 .await
407 .map(|incoming| Incoming::new(incoming, self.inner.clone()))
408 }
409
410 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
416 self.inner
417 .state
418 .lock()
419 .unwrap()
420 .endpoint
421 .set_server_config(server_config.map(Arc::new))
422 }
423
424 pub fn local_addr(&self) -> io::Result<SocketAddr> {
426 self.inner.socket.local_addr()
427 }
428
429 pub fn open_connections(&self) -> usize {
431 self.inner.state.lock().unwrap().endpoint.open_connections()
432 }
433
434 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
441 let reason = Bytes::copy_from_slice(reason);
442 let mut state = self.inner.state.lock().unwrap();
443 if state.close.is_some() {
444 return;
445 }
446 state.close = Some((error_code, reason.clone()));
447 for conn in state.connections.values() {
448 let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
449 }
450 state.incoming_wakers.drain(..).for_each(Waker::wake);
451 }
452
453 unsafe fn try_unwrap_inner(this: &ManuallyDrop<Self>) -> Option<EndpointInner> {
455 let ptr = ManuallyDrop::new(std::ptr::read(&this.inner));
456 match Arc::try_unwrap(ManuallyDrop::into_inner(ptr)) {
457 Ok(inner) => Some(inner),
458 Err(ptr) => {
459 std::mem::forget(ptr);
460 None
461 }
462 }
463 }
464
465 pub async fn shutdown(self) -> io::Result<()> {
482 let worker = self.inner.state.lock().unwrap().worker.take();
483 if let Some(worker) = worker {
484 if self.inner.state.lock().unwrap().is_idle() {
485 worker.cancel().await;
486 } else {
487 self.inner.state.lock().unwrap().exit_on_idle = true;
488 let _ = worker.await;
489 }
490 }
491
492 let this = ManuallyDrop::new(self);
493 let inner = future::poll_fn(move |cx| {
494 if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
495 return Poll::Ready(inner);
496 }
497
498 this.inner.done.register(cx.waker());
499
500 if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
501 Poll::Ready(inner)
502 } else {
503 Poll::Pending
504 }
505 })
506 .await;
507
508 inner.socket.close().await
509 }
510}
511
512impl Drop for Endpoint {
513 fn drop(&mut self) {
514 if Arc::strong_count(&self.inner) == 2 {
515 self.inner.done.wake();
518 self.inner.state.lock().unwrap().exit_on_idle = true;
520 }
521 }
522}