1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io,
6 io::IoSliceMut,
7 mem,
8 net::{SocketAddr, SocketAddrV6},
9 pin::Pin,
10 str,
11 sync::{Arc, Mutex},
12 task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use super::runtime::default_runtime;
17use super::{
18 runtime::{AsyncUdpSocket, Runtime},
19 udp_transmit,
20};
21use crate::Instant;
22use bytes::{Bytes, BytesMut};
23use pin_project_lite::pin_project;
24use crate::{
25 ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
26 EndpointEvent, ServerConfig,
27};
28use rustc_hash::FxHashMap;
29#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"), feature = "production-ready"))]
30use socket2::{Domain, Protocol, Socket, Type};
31use tokio::sync::{Notify, futures::Notified, mpsc};
32use tracing::{Instrument, Span};
33use quinn_udp::{BATCH_SIZE, RecvMeta};
34
35use super::{
36 ConnectionEvent, IO_LOOP_BOUND, RECV_TIME_BOUND,
37 connection::Connecting, work_limiter::WorkLimiter,
38};
39use crate::{EndpointConfig, VarInt};
40
41#[derive(Debug, Clone)]
48pub struct Endpoint {
49 pub(crate) inner: EndpointRef,
50 pub(crate) default_client_config: Option<ClientConfig>,
51 runtime: Arc<dyn Runtime>,
52}
53
54impl Endpoint {
55 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"), feature = "production-ready"))] pub fn client(addr: SocketAddr) -> io::Result<Self> {
81 let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
82 if addr.is_ipv6() {
83 if let Err(e) = socket.set_only_v6(false) {
84 tracing::debug!(%e, "unable to make socket dual-stack");
85 }
86 }
87 socket.bind(&addr.into())?;
88 let runtime =
89 default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
90 Self::new_with_abstract_socket(
91 EndpointConfig::default(),
92 None,
93 runtime.wrap_udp_socket(socket.into())?,
94 runtime,
95 )
96 }
97
98 pub fn stats(&self) -> EndpointStats {
100 self.inner.state.lock().unwrap().stats
101 }
102
103 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
111 let socket = std::net::UdpSocket::bind(addr)?;
112 let runtime =
113 default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
114 Self::new_with_abstract_socket(
115 EndpointConfig::default(),
116 Some(config),
117 runtime.wrap_udp_socket(socket)?,
118 runtime,
119 )
120 }
121
122 #[cfg(not(wasm_browser))]
124 pub fn new(
125 config: EndpointConfig,
126 server_config: Option<ServerConfig>,
127 socket: std::net::UdpSocket,
128 runtime: Arc<dyn Runtime>,
129 ) -> io::Result<Self> {
130 let socket = runtime.wrap_udp_socket(socket)?;
131 Self::new_with_abstract_socket(config, server_config, socket, runtime)
132 }
133
134 pub fn new_with_abstract_socket(
139 config: EndpointConfig,
140 server_config: Option<ServerConfig>,
141 socket: Arc<dyn AsyncUdpSocket>,
142 runtime: Arc<dyn Runtime>,
143 ) -> io::Result<Self> {
144 let addr = socket.local_addr()?;
145 let allow_mtud = !socket.may_fragment();
146 let rc = EndpointRef::new(
147 socket,
148 crate::endpoint::Endpoint::new(
149 Arc::new(config),
150 server_config.map(Arc::new),
151 allow_mtud,
152 None,
153 ),
154 addr.is_ipv6(),
155 runtime.clone(),
156 );
157 let driver = EndpointDriver(rc.clone());
158 runtime.spawn(Box::pin(
159 async {
160 if let Err(e) = driver.await {
161 tracing::error!("I/O error: {}", e);
162 }
163 }
164 .instrument(Span::current()),
165 ));
166 Ok(Self {
167 inner: rc,
168 default_client_config: None,
169 runtime,
170 })
171 }
172
173 pub fn accept(&self) -> Accept<'_> {
180 Accept {
181 endpoint: self,
182 notify: self.inner.shared.incoming.notified(),
183 }
184 }
185
186 pub fn set_default_client_config(&mut self, config: ClientConfig) {
188 self.default_client_config = Some(config);
189 }
190
191 pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
200 let config = match &self.default_client_config {
201 Some(config) => config.clone(),
202 None => return Err(ConnectError::NoDefaultClientConfig),
203 };
204
205 self.connect_with(config, addr, server_name)
206 }
207
208 pub fn connect_with(
214 &self,
215 config: ClientConfig,
216 addr: SocketAddr,
217 server_name: &str,
218 ) -> Result<Connecting, ConnectError> {
219 let mut endpoint = self.inner.state.lock().unwrap();
220 if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
221 return Err(ConnectError::EndpointStopping);
222 }
223 if addr.is_ipv6() && !endpoint.ipv6 {
224 return Err(ConnectError::InvalidRemoteAddress(addr));
225 }
226 let addr = if endpoint.ipv6 {
227 SocketAddr::V6(ensure_ipv6(addr))
228 } else {
229 addr
230 };
231
232 let (ch, conn) = endpoint
233 .inner
234 .connect(self.runtime.now(), config, addr, server_name)?;
235
236 let socket = endpoint.socket.clone();
237 endpoint.stats.outgoing_handshakes += 1;
238 Ok(endpoint
239 .recv_state
240 .connections
241 .insert(ch, conn, socket, self.runtime.clone()))
242 }
243
244 #[cfg(not(wasm_browser))]
248 pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
249 self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
250 }
251
252 pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
259 let addr = socket.local_addr()?;
260 let mut inner = self.inner.state.lock().unwrap();
261 inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
262 inner.ipv6 = addr.is_ipv6();
263
264 for sender in inner.recv_state.connections.senders.values() {
266 let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
268 }
269
270 Ok(())
271 }
272
273 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
277 self.inner
278 .state
279 .lock()
280 .unwrap()
281 .inner
282 .set_server_config(server_config.map(Arc::new))
283 }
284
285 pub fn local_addr(&self) -> io::Result<SocketAddr> {
287 self.inner.state.lock().unwrap().socket.local_addr()
288 }
289
290 pub fn open_connections(&self) -> usize {
292 self.inner.state.lock().unwrap().inner.open_connections()
293 }
294
295 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
301 let reason = Bytes::copy_from_slice(reason);
302 let mut endpoint = self.inner.state.lock().unwrap();
303 endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
304 for sender in endpoint.recv_state.connections.senders.values() {
305 let _ = sender.send(ConnectionEvent::Close {
307 error_code,
308 reason: reason.clone(),
309 });
310 }
311 self.inner.shared.incoming.notify_waiters();
312 }
313
314 pub async fn wait_idle(&self) {
325 loop {
326 {
327 let endpoint = &mut *self.inner.state.lock().unwrap();
328 if endpoint.recv_state.connections.is_empty() {
329 break;
330 }
331 self.inner.shared.idle.notified()
333 }
334 .await;
335 }
336 }
337}
338
339#[non_exhaustive]
341#[derive(Debug, Default, Copy, Clone)]
342pub struct EndpointStats {
343 pub accepted_handshakes: u64,
345 pub outgoing_handshakes: u64,
347 pub refused_handshakes: u64,
349 pub ignored_handshakes: u64,
351}
352
353#[must_use = "endpoint drivers must be spawned for I/O to occur"]
364#[derive(Debug)]
365pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
366
367impl Future for EndpointDriver {
368 type Output = Result<(), io::Error>;
369
370 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
371 let mut endpoint = self.0.state.lock().unwrap();
372 if endpoint.driver.is_none() {
373 endpoint.driver = Some(cx.waker().clone());
374 }
375
376 let now = endpoint.runtime.now();
377 let mut keep_going = false;
378 keep_going |= endpoint.drive_recv(cx, now)?;
379 keep_going |= endpoint.handle_events(cx, &self.0.shared);
380
381 if !endpoint.recv_state.incoming.is_empty() {
382 self.0.shared.incoming.notify_waiters();
383 }
384
385 if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
386 Poll::Ready(Ok(()))
387 } else {
388 drop(endpoint);
389 if keep_going {
393 cx.waker().wake_by_ref();
394 }
395 Poll::Pending
396 }
397 }
398}
399
400impl Drop for EndpointDriver {
401 fn drop(&mut self) {
402 let mut endpoint = self.0.state.lock().unwrap();
403 endpoint.driver_lost = true;
404 self.0.shared.incoming.notify_waiters();
405 endpoint.recv_state.connections.senders.clear();
408 }
409}
410
411#[derive(Debug)]
412pub(crate) struct EndpointInner {
413 pub(crate) state: Mutex<State>,
414 pub(crate) shared: Shared,
415}
416
417impl EndpointInner {
418 pub(crate) fn accept(
419 &self,
420 incoming: crate::Incoming,
421 server_config: Option<Arc<ServerConfig>>,
422 ) -> Result<Connecting, ConnectionError> {
423 let mut state = self.state.lock().unwrap();
424 let mut response_buffer = Vec::new();
425 let now = state.runtime.now();
426 match state
427 .inner
428 .accept(incoming, now, &mut response_buffer, server_config)
429 {
430 Ok((handle, conn)) => {
431 state.stats.accepted_handshakes += 1;
432 let socket = state.socket.clone();
433 let runtime = state.runtime.clone();
434 Ok(state
435 .recv_state
436 .connections
437 .insert(handle, conn, socket, runtime))
438 }
439 Err(error) => {
440 if let Some(transmit) = error.response {
441 respond(transmit, &response_buffer, &*state.socket);
442 }
443 Err(error.cause)
444 }
445 }
446 }
447
448 pub(crate) fn refuse(&self, incoming: crate::Incoming) {
449 let mut state = self.state.lock().unwrap();
450 state.stats.refused_handshakes += 1;
451 let mut response_buffer = Vec::new();
452 let transmit = state.inner.refuse(incoming, &mut response_buffer);
453 respond(transmit, &response_buffer, &*state.socket);
454 }
455
456 pub(crate) fn retry(&self, incoming: crate::Incoming) -> Result<(), crate::endpoint::RetryError> {
457 let mut state = self.state.lock().unwrap();
458 let mut response_buffer = Vec::new();
459 let transmit = state.inner.retry(incoming, &mut response_buffer)?;
460 respond(transmit, &response_buffer, &*state.socket);
461 Ok(())
462 }
463
464 pub(crate) fn ignore(&self, incoming: crate::Incoming) {
465 let mut state = self.state.lock().unwrap();
466 state.stats.ignored_handshakes += 1;
467 state.inner.ignore(incoming);
468 }
469}
470
471#[derive(Debug)]
472pub(crate) struct State {
473 socket: Arc<dyn AsyncUdpSocket>,
474 prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
477 inner: crate::endpoint::Endpoint,
478 recv_state: RecvState,
479 driver: Option<Waker>,
480 ipv6: bool,
481 events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
482 ref_count: usize,
484 driver_lost: bool,
485 runtime: Arc<dyn Runtime>,
486 stats: EndpointStats,
487}
488
489#[derive(Debug)]
490pub(crate) struct Shared {
491 incoming: Notify,
492 idle: Notify,
493}
494
495impl State {
496 fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
497 let get_time = || self.runtime.now();
498 self.recv_state.recv_limiter.start_cycle(get_time);
499 if let Some(socket) = &self.prev_socket {
500 let poll_res =
502 self.recv_state
503 .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
504 if poll_res.is_err() {
505 self.prev_socket = None;
506 }
507 };
508 let poll_res =
509 self.recv_state
510 .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
511 self.recv_state.recv_limiter.finish_cycle(get_time);
512 let poll_res = poll_res?;
513 if poll_res.received_connection_packet {
514 self.prev_socket = None;
517 }
518 Ok(poll_res.keep_going)
519 }
520
521 fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
522 for _ in 0..IO_LOOP_BOUND {
523 let (ch, event) = match self.events.poll_recv(cx) {
524 Poll::Ready(Some(x)) => x,
525 Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
526 Poll::Pending => {
527 return false;
528 }
529 };
530
531 if event.is_drained() {
532 self.recv_state.connections.senders.remove(&ch);
533 if self.recv_state.connections.is_empty() {
534 shared.idle.notify_waiters();
535 }
536 }
537 let Some(event) = self.inner.handle_event(ch, event) else {
538 continue;
539 };
540 let _ = self
542 .recv_state
543 .connections
544 .senders
545 .get_mut(&ch)
546 .unwrap()
547 .send(ConnectionEvent::Proto(event));
548 }
549
550 true
551 }
552}
553
554impl Drop for State {
555 fn drop(&mut self) {
556 for incoming in self.recv_state.incoming.drain(..) {
557 self.inner.ignore(incoming);
558 }
559 }
560}
561
562fn respond(transmit: crate::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
563 _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
584}
585
586#[inline]
587fn proto_ecn(ecn: quinn_udp::EcnCodepoint) -> crate::EcnCodepoint {
588 match ecn {
589 quinn_udp::EcnCodepoint::Ect0 => crate::EcnCodepoint::Ect0,
590 quinn_udp::EcnCodepoint::Ect1 => crate::EcnCodepoint::Ect1,
591 quinn_udp::EcnCodepoint::Ce => crate::EcnCodepoint::Ce,
592 }
593}
594
595#[derive(Debug)]
596struct ConnectionSet {
597 senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
599 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
601 close: Option<(VarInt, Bytes)>,
603}
604
605impl ConnectionSet {
606 fn insert(
607 &mut self,
608 handle: ConnectionHandle,
609 conn: crate::Connection,
610 socket: Arc<dyn AsyncUdpSocket>,
611 runtime: Arc<dyn Runtime>,
612 ) -> Connecting {
613 let (send, recv) = mpsc::unbounded_channel();
614 if let Some((error_code, ref reason)) = self.close {
615 send.send(ConnectionEvent::Close {
616 error_code,
617 reason: reason.clone(),
618 })
619 .unwrap();
620 }
621 self.senders.insert(handle, send);
622 Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
623 }
624
625 fn is_empty(&self) -> bool {
626 self.senders.is_empty()
627 }
628}
629
630fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
631 match x {
632 SocketAddr::V6(x) => x,
633 SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
634 }
635}
636
637pin_project! {
638 pub struct Accept<'a> {
640 endpoint: &'a Endpoint,
641 #[pin]
642 notify: Notified<'a>,
643 }
644}
645
646impl Future for Accept<'_> {
647 type Output = Option<super::incoming::Incoming>;
648 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
649 let mut this = self.project();
650 let mut endpoint = this.endpoint.inner.state.lock().unwrap();
651 if endpoint.driver_lost {
652 return Poll::Ready(None);
653 }
654 if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
655 drop(endpoint);
657 let incoming = super::incoming::Incoming::new(incoming, this.endpoint.inner.clone());
658 return Poll::Ready(Some(incoming));
659 }
660 if endpoint.recv_state.connections.close.is_some() {
661 return Poll::Ready(None);
662 }
663 loop {
664 match this.notify.as_mut().poll(ctx) {
665 Poll::Pending => return Poll::Pending,
667 Poll::Ready(()) => this
669 .notify
670 .set(this.endpoint.inner.shared.incoming.notified()),
671 }
672 }
673 }
674}
675
676#[derive(Debug)]
677pub(crate) struct EndpointRef(Arc<EndpointInner>);
678
679impl EndpointRef {
680 pub(crate) fn new(
681 socket: Arc<dyn AsyncUdpSocket>,
682 inner: crate::endpoint::Endpoint,
683 ipv6: bool,
684 runtime: Arc<dyn Runtime>,
685 ) -> Self {
686 let (sender, events) = mpsc::unbounded_channel();
687 let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
688 Self(Arc::new(EndpointInner {
689 shared: Shared {
690 incoming: Notify::new(),
691 idle: Notify::new(),
692 },
693 state: Mutex::new(State {
694 socket,
695 prev_socket: None,
696 inner,
697 ipv6,
698 events,
699 driver: None,
700 ref_count: 0,
701 driver_lost: false,
702 recv_state,
703 runtime,
704 stats: EndpointStats::default(),
705 }),
706 }))
707 }
708}
709
710impl Clone for EndpointRef {
711 fn clone(&self) -> Self {
712 self.0.state.lock().unwrap().ref_count += 1;
713 Self(self.0.clone())
714 }
715}
716
717impl Drop for EndpointRef {
718 fn drop(&mut self) {
719 let endpoint = &mut *self.0.state.lock().unwrap();
720 if let Some(x) = endpoint.ref_count.checked_sub(1) {
721 endpoint.ref_count = x;
722 if x == 0 {
723 if let Some(task) = endpoint.driver.take() {
726 task.wake();
727 }
728 }
729 }
730 }
731}
732
733impl std::ops::Deref for EndpointRef {
734 type Target = EndpointInner;
735 fn deref(&self) -> &Self::Target {
736 &self.0
737 }
738}
739
740struct RecvState {
742 incoming: VecDeque<crate::Incoming>,
743 connections: ConnectionSet,
744 recv_buf: Box<[u8]>,
745 recv_limiter: WorkLimiter,
746}
747
748impl RecvState {
749 fn new(
750 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
751 max_receive_segments: usize,
752 endpoint: &crate::endpoint::Endpoint,
753 ) -> Self {
754 let recv_buf = vec![
755 0;
756 endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
757 * max_receive_segments
758 * BATCH_SIZE
759 ];
760 Self {
761 connections: ConnectionSet {
762 senders: FxHashMap::default(),
763 sender,
764 close: None,
765 },
766 incoming: VecDeque::new(),
767 recv_buf: recv_buf.into(),
768 recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
769 }
770 }
771
772 fn poll_socket(
773 &mut self,
774 cx: &mut Context,
775 endpoint: &mut crate::endpoint::Endpoint,
776 socket: &dyn AsyncUdpSocket,
777 runtime: &dyn Runtime,
778 now: Instant,
779 ) -> Result<PollProgress, io::Error> {
780 let mut received_connection_packet = false;
781 let mut metas = [RecvMeta::default(); BATCH_SIZE];
782 let mut iovs: [IoSliceMut; BATCH_SIZE] = {
783 let mut bufs = self
784 .recv_buf
785 .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
786 .map(IoSliceMut::new);
787
788 std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
792 };
793 loop {
794 match socket.poll_recv(cx, &mut iovs, &mut metas) {
795 Poll::Ready(Ok(msgs)) => {
796 self.recv_limiter.record_work(msgs);
797 for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
798 let mut data: BytesMut = buf[0..meta.len].into();
799 while !data.is_empty() {
800 let buf = data.split_to(meta.stride.min(data.len()));
801 let mut response_buffer = Vec::new();
802 match endpoint.handle(
803 now,
804 meta.addr,
805 meta.dst_ip,
806 meta.ecn.map(proto_ecn),
807 buf,
808 &mut response_buffer,
809 ) {
810 Some(DatagramEvent::NewConnection(incoming)) => {
811 if self.connections.close.is_none() {
812 self.incoming.push_back(incoming);
813 } else {
814 let transmit =
815 endpoint.refuse(incoming, &mut response_buffer);
816 respond(transmit, &response_buffer, socket);
817 }
818 }
819 Some(DatagramEvent::ConnectionEvent(handle, event)) => {
820 received_connection_packet = true;
822 let _ = self
823 .connections
824 .senders
825 .get_mut(&handle)
826 .unwrap()
827 .send(ConnectionEvent::Proto(event));
828 }
829 Some(DatagramEvent::Response(transmit)) => {
830 respond(transmit, &response_buffer, socket);
831 }
832 None => {}
833 }
834 }
835 }
836 }
837 Poll::Pending => {
838 return Ok(PollProgress {
839 received_connection_packet,
840 keep_going: false,
841 });
842 }
843 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
846 continue;
847 }
848 Poll::Ready(Err(e)) => {
849 return Err(e);
850 }
851 }
852 if !self.recv_limiter.allow_work(|| runtime.now()) {
853 return Ok(PollProgress {
854 received_connection_packet,
855 keep_going: true,
856 });
857 }
858 }
859 }
860}
861
862impl fmt::Debug for RecvState {
863 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
864 f.debug_struct("RecvState")
865 .field("incoming", &self.incoming)
866 .field("connections", &self.connections)
867 .field("recv_limiter", &self.recv_limiter)
869 .finish_non_exhaustive()
870 }
871}
872
873#[derive(Default)]
874struct PollProgress {
875 received_connection_packet: bool,
877 keep_going: bool,
879}