1use std::{
2 collections::VecDeque,
3 io,
4 mem::ManuallyDrop,
5 net::{SocketAddr, SocketAddrV6},
6 ops::Deref,
7 pin::pin,
8 sync::{Arc, Mutex},
9 task::{Context, Poll, Waker},
10 time::Instant,
11};
12
13use compio_buf::{BufResult, bytes::Bytes};
14use compio_log::{Instrument, error};
15#[cfg(rustls)]
16use compio_net::ToSocketAddrsAsync;
17use compio_net::UdpSocket;
18use compio_runtime::JoinHandle;
19use flume::{Receiver, Sender, unbounded};
20use futures_util::{
21 FutureExt, StreamExt,
22 future::{self},
23 select,
24 task::AtomicWaker,
25};
26use quinn_proto::{
27 ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointConfig,
28 EndpointEvent, ServerConfig, Transmit, VarInt,
29};
30use rustc_hash::FxHashMap as HashMap;
31
32use crate::{Connecting, ConnectionEvent, Incoming, RecvMeta, Socket};
33
34#[derive(Debug)]
35struct EndpointState {
36 endpoint: quinn_proto::Endpoint,
37 worker: Option<JoinHandle<()>>,
38 connections: HashMap<ConnectionHandle, Sender<ConnectionEvent>>,
39 close: Option<(VarInt, Bytes)>,
40 exit_on_idle: bool,
41 incoming: VecDeque<quinn_proto::Incoming>,
42 incoming_wakers: VecDeque<Waker>,
43}
44
45impl EndpointState {
46 fn handle_data(&mut self, meta: RecvMeta, buf: &[u8], respond_fn: impl Fn(Vec<u8>, Transmit)) {
47 let now = Instant::now();
48 for data in buf[..meta.len]
49 .chunks(meta.stride.min(meta.len))
50 .map(Into::into)
51 {
52 let mut resp_buf = Vec::new();
53 match self.endpoint.handle(
54 now,
55 meta.remote,
56 meta.local_ip,
57 meta.ecn,
58 data,
59 &mut resp_buf,
60 ) {
61 Some(DatagramEvent::NewConnection(incoming)) => {
62 if self.close.is_none() {
63 self.incoming.push_back(incoming);
64 } else {
65 let transmit = self.endpoint.refuse(incoming, &mut resp_buf);
66 respond_fn(resp_buf, transmit);
67 }
68 }
69 Some(DatagramEvent::ConnectionEvent(ch, event)) => {
70 let _ = self
71 .connections
72 .get(&ch)
73 .unwrap()
74 .send(ConnectionEvent::Proto(event));
75 }
76 Some(DatagramEvent::Response(transmit)) => respond_fn(resp_buf, transmit),
77 None => {}
78 }
79 }
80 }
81
82 fn handle_event(&mut self, ch: ConnectionHandle, event: EndpointEvent) {
83 if event.is_drained() {
84 self.connections.remove(&ch);
85 }
86 if let Some(event) = self.endpoint.handle_event(ch, event) {
87 let _ = self
88 .connections
89 .get(&ch)
90 .unwrap()
91 .send(ConnectionEvent::Proto(event));
92 }
93 }
94
95 fn is_idle(&self) -> bool {
96 self.connections.is_empty()
97 }
98
99 fn poll_incoming(&mut self, cx: &mut Context) -> Poll<Option<quinn_proto::Incoming>> {
100 if self.close.is_none() {
101 if let Some(incoming) = self.incoming.pop_front() {
102 Poll::Ready(Some(incoming))
103 } else {
104 self.incoming_wakers.push_back(cx.waker().clone());
105 Poll::Pending
106 }
107 } else {
108 Poll::Ready(None)
109 }
110 }
111
112 fn new_connection(
113 &mut self,
114 handle: ConnectionHandle,
115 conn: quinn_proto::Connection,
116 socket: Socket,
117 events_tx: Sender<(ConnectionHandle, EndpointEvent)>,
118 ) -> Connecting {
119 let (tx, rx) = unbounded();
120 if let Some((error_code, reason)) = &self.close {
121 tx.send(ConnectionEvent::Close(*error_code, reason.clone()))
122 .unwrap();
123 }
124 self.connections.insert(handle, tx);
125 Connecting::new(handle, conn, socket, events_tx, rx)
126 }
127}
128
129type ChannelPair<T> = (Sender<T>, Receiver<T>);
130
131#[derive(Debug)]
132pub(crate) struct EndpointInner {
133 state: Mutex<EndpointState>,
134 socket: Socket,
135 ipv6: bool,
136 events: ChannelPair<(ConnectionHandle, EndpointEvent)>,
137 done: AtomicWaker,
138}
139
140impl EndpointInner {
141 fn new(
142 socket: UdpSocket,
143 config: EndpointConfig,
144 server_config: Option<ServerConfig>,
145 ) -> io::Result<Self> {
146 let socket = Socket::new(socket)?;
147 let ipv6 = socket.local_addr()?.is_ipv6();
148 let allow_mtud = !socket.may_fragment();
149
150 Ok(Self {
151 state: Mutex::new(EndpointState {
152 endpoint: quinn_proto::Endpoint::new(
153 Arc::new(config),
154 server_config.map(Arc::new),
155 allow_mtud,
156 None,
157 ),
158 worker: None,
159 connections: HashMap::default(),
160 close: None,
161 exit_on_idle: false,
162 incoming: VecDeque::new(),
163 incoming_wakers: VecDeque::new(),
164 }),
165 socket,
166 ipv6,
167 events: unbounded(),
168 done: AtomicWaker::new(),
169 })
170 }
171
172 fn connect(
173 &self,
174 remote: SocketAddr,
175 server_name: &str,
176 config: ClientConfig,
177 ) -> Result<Connecting, ConnectError> {
178 let mut state = self.state.lock().unwrap();
179
180 if state.worker.is_none() {
181 return Err(ConnectError::EndpointStopping);
182 }
183 if remote.is_ipv6() && !self.ipv6 {
184 return Err(ConnectError::InvalidRemoteAddress(remote));
185 }
186 let remote = if self.ipv6 {
187 SocketAddr::V6(match remote {
188 SocketAddr::V4(addr) => {
189 SocketAddrV6::new(addr.ip().to_ipv6_mapped(), addr.port(), 0, 0)
190 }
191 SocketAddr::V6(addr) => addr,
192 })
193 } else {
194 remote
195 };
196
197 let (handle, conn) = state
198 .endpoint
199 .connect(Instant::now(), config, remote, server_name)?;
200
201 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
202 }
203
204 fn respond(&self, buf: Vec<u8>, transmit: Transmit) {
205 let socket = self.socket.clone();
206 compio_runtime::spawn(async move {
207 let _ = socket.send(buf, &transmit).await;
208 })
209 .detach();
210 }
211
212 pub(crate) fn accept(
213 &self,
214 incoming: quinn_proto::Incoming,
215 server_config: Option<ServerConfig>,
216 ) -> Result<Connecting, ConnectionError> {
217 let mut state = self.state.lock().unwrap();
218 let mut resp_buf = Vec::new();
219 let now = Instant::now();
220 match state
221 .endpoint
222 .accept(incoming, now, &mut resp_buf, server_config.map(Arc::new))
223 {
224 Ok((handle, conn)) => {
225 Ok(state.new_connection(handle, conn, self.socket.clone(), self.events.0.clone()))
226 }
227 Err(err) => {
228 if let Some(transmit) = err.response {
229 self.respond(resp_buf, transmit);
230 }
231 Err(err.cause)
232 }
233 }
234 }
235
236 pub(crate) fn refuse(&self, incoming: quinn_proto::Incoming) {
237 let mut state = self.state.lock().unwrap();
238 let mut resp_buf = Vec::new();
239 let transmit = state.endpoint.refuse(incoming, &mut resp_buf);
240 self.respond(resp_buf, transmit);
241 }
242
243 #[allow(clippy::result_large_err)]
244 pub(crate) fn retry(
245 &self,
246 incoming: quinn_proto::Incoming,
247 ) -> Result<(), quinn_proto::RetryError> {
248 let mut state = self.state.lock().unwrap();
249 let mut resp_buf = Vec::new();
250 let transmit = state.endpoint.retry(incoming, &mut resp_buf)?;
251 self.respond(resp_buf, transmit);
252 Ok(())
253 }
254
255 pub(crate) fn ignore(&self, incoming: quinn_proto::Incoming) {
256 let mut state = self.state.lock().unwrap();
257 state.endpoint.ignore(incoming);
258 }
259
260 async fn run(&self) -> io::Result<()> {
261 let respond_fn = |buf: Vec<u8>, transmit: Transmit| self.respond(buf, transmit);
262
263 let mut recv_fut = pin!(
264 self.socket
265 .recv(Vec::with_capacity(
266 self.state
267 .lock()
268 .unwrap()
269 .endpoint
270 .config()
271 .get_max_udp_payload_size()
272 .min(64 * 1024) as usize
273 * self.socket.max_gro_segments(),
274 ))
275 .fuse()
276 );
277
278 let mut event_stream = self.events.1.stream().ready_chunks(100);
279
280 loop {
281 let mut state = select! {
282 BufResult(res, recv_buf) = recv_fut => {
283 let mut state = self.state.lock().unwrap();
284 match res {
285 Ok(meta) => state.handle_data(meta, &recv_buf, respond_fn),
286 Err(e) if e.kind() == io::ErrorKind::ConnectionReset => {}
287 #[cfg(windows)]
288 Err(e) if e.raw_os_error() == Some(windows_sys::Win32::Foundation::ERROR_PORT_UNREACHABLE as _) => {}
289 Err(e) => break Err(e),
290 }
291 recv_fut.set(self.socket.recv(recv_buf).fuse());
292 state
293 },
294 events = event_stream.select_next_some() => {
295 let mut state = self.state.lock().unwrap();
296 for (ch, event) in events {
297 state.handle_event(ch, event);
298 }
299 state
300 },
301 };
302
303 if state.exit_on_idle && state.is_idle() {
304 break Ok(());
305 }
306 if !state.incoming.is_empty() {
307 let n = state.incoming.len().min(state.incoming_wakers.len());
308 state.incoming_wakers.drain(..n).for_each(Waker::wake);
309 }
310 }
311 }
312}
313
314#[derive(Debug, Clone)]
315pub(crate) struct EndpointRef(Arc<EndpointInner>);
316
317impl EndpointRef {
318 unsafe fn try_unwrap_inner(&self) -> Option<EndpointInner> {
320 let ptr = unsafe { std::ptr::read(&self.0) };
321 match Arc::try_unwrap(ptr) {
322 Ok(inner) => Some(inner),
323 Err(ptr) => {
324 std::mem::forget(ptr);
325 None
326 }
327 }
328 }
329
330 async fn shutdown(self) -> io::Result<()> {
331 let (worker, idle) = {
332 let mut state = self.0.state.lock().unwrap();
333 let idle = state.is_idle();
334 if !idle {
335 state.exit_on_idle = true;
336 }
337 (state.worker.take(), idle)
338 };
339 if let Some(worker) = worker {
340 if idle {
341 worker.cancel().await;
342 } else {
343 let _ = worker.await;
344 }
345 }
346
347 let this = ManuallyDrop::new(self);
348 let inner = future::poll_fn(move |cx| {
349 if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
350 return Poll::Ready(inner);
351 }
352
353 this.done.register(cx.waker());
354
355 if let Some(inner) = unsafe { Self::try_unwrap_inner(&this) } {
356 Poll::Ready(inner)
357 } else {
358 Poll::Pending
359 }
360 })
361 .await;
362
363 inner.socket.close().await
364 }
365}
366
367impl Drop for EndpointRef {
368 fn drop(&mut self) {
369 if Arc::strong_count(&self.0) == 2 {
370 self.0.done.wake();
373 self.0.state.lock().unwrap().exit_on_idle = true;
375 }
376 }
377}
378
379impl Deref for EndpointRef {
380 type Target = EndpointInner;
381
382 fn deref(&self) -> &Self::Target {
383 &self.0
384 }
385}
386
387#[derive(Debug, Clone)]
389pub struct Endpoint {
390 inner: EndpointRef,
391 pub default_client_config: Option<ClientConfig>,
393}
394
395impl Endpoint {
396 pub fn new(
398 socket: UdpSocket,
399 config: EndpointConfig,
400 server_config: Option<ServerConfig>,
401 default_client_config: Option<ClientConfig>,
402 ) -> io::Result<Self> {
403 let inner = EndpointRef(Arc::new(EndpointInner::new(socket, config, server_config)?));
404 let worker = compio_runtime::spawn({
405 let inner = inner.clone();
406 async move {
407 #[allow(unused)]
408 if let Err(e) = inner.run().await {
409 error!("I/O error: {}", e);
410 }
411 }
412 .in_current_span()
413 });
414 inner.state.lock().unwrap().worker = Some(worker);
415 Ok(Self {
416 inner,
417 default_client_config,
418 })
419 }
420
421 #[cfg(rustls)]
436 pub async fn client(addr: impl ToSocketAddrsAsync) -> io::Result<Endpoint> {
437 let socket = UdpSocket::bind(addr).await?;
439 Self::new(socket, EndpointConfig::default(), None, None)
440 }
441
442 #[cfg(rustls)]
451 pub async fn server(addr: impl ToSocketAddrsAsync, config: ServerConfig) -> io::Result<Self> {
452 let socket = UdpSocket::bind(addr).await?;
453 Self::new(socket, EndpointConfig::default(), Some(config), None)
454 }
455
456 pub fn connect(
458 &self,
459 remote: SocketAddr,
460 server_name: &str,
461 config: Option<ClientConfig>,
462 ) -> Result<Connecting, ConnectError> {
463 let config = config
464 .or_else(|| self.default_client_config.clone())
465 .ok_or(ConnectError::NoDefaultClientConfig)?;
466
467 self.inner.connect(remote, server_name, config)
468 }
469
470 pub async fn wait_incoming(&self) -> Option<Incoming> {
479 future::poll_fn(|cx| self.inner.state.lock().unwrap().poll_incoming(cx))
480 .await
481 .map(|incoming| Incoming::new(incoming, self.inner.clone()))
482 }
483
484 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
490 self.inner
491 .state
492 .lock()
493 .unwrap()
494 .endpoint
495 .set_server_config(server_config.map(Arc::new))
496 }
497
498 pub fn local_addr(&self) -> io::Result<SocketAddr> {
500 self.inner.socket.local_addr()
501 }
502
503 pub fn open_connections(&self) -> usize {
505 self.inner.state.lock().unwrap().endpoint.open_connections()
506 }
507
508 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
515 let reason = Bytes::copy_from_slice(reason);
516 let mut state = self.inner.state.lock().unwrap();
517 if state.close.is_some() {
518 return;
519 }
520 state.close = Some((error_code, reason.clone()));
521 for conn in state.connections.values() {
522 let _ = conn.send(ConnectionEvent::Close(error_code, reason.clone()));
523 }
524 state.incoming_wakers.drain(..).for_each(Waker::wake);
525 }
526
527 pub async fn shutdown(self) -> io::Result<()> {
544 self.inner.shutdown().await
545 }
546}