1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io,
6 io::IoSliceMut,
7 mem,
8 net::{SocketAddr, SocketAddrV6},
9 pin::Pin,
10 str,
11 sync::{Arc, Mutex},
12 task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use super::runtime::default_runtime;
17use super::{
18 runtime::{AsyncUdpSocket, Runtime},
19 udp_transmit,
20};
21use crate::Instant;
22use crate::{
23 ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent, EndpointEvent,
24 ServerConfig,
25};
26use bytes::{Bytes, BytesMut};
27use pin_project_lite::pin_project;
28use quinn_udp::{BATCH_SIZE, RecvMeta};
29use rustc_hash::FxHashMap;
30#[cfg(all(not(wasm_browser), feature = "network-discovery"))]
31use socket2::{Domain, Protocol, Socket, Type};
32use tokio::sync::{Notify, futures::Notified, mpsc};
33use tracing::{Instrument, Span};
34
35use super::{
36 ConnectionEvent, IO_LOOP_BOUND, RECV_TIME_BOUND, connection::Connecting,
37 work_limiter::WorkLimiter,
38};
39use crate::{EndpointConfig, VarInt};
40
41#[derive(Debug, Clone)]
48pub struct Endpoint {
49 pub(crate) inner: EndpointRef,
50 pub(crate) default_client_config: Option<ClientConfig>,
51 runtime: Arc<dyn Runtime>,
52}
53
54impl Endpoint {
55 #[cfg(all(not(wasm_browser), feature = "network-discovery"))]
80 pub fn client(addr: SocketAddr) -> io::Result<Self> {
81 let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
82 if addr.is_ipv6() {
83 if let Err(e) = socket.set_only_v6(false) {
84 tracing::debug!(%e, "unable to make socket dual-stack");
85 }
86 }
87 socket.bind(&addr.into())?;
88 let runtime =
89 default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
90 Self::new_with_abstract_socket(
91 EndpointConfig::default(),
92 None,
93 runtime.wrap_udp_socket(socket.into())?,
94 runtime,
95 )
96 }
97
98 pub fn stats(&self) -> EndpointStats {
100 self.inner.state.lock().unwrap().stats
101 }
102
103 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
111 let socket = std::net::UdpSocket::bind(addr)?;
112 let runtime =
113 default_runtime().ok_or_else(|| io::Error::other("no async runtime found"))?;
114 Self::new_with_abstract_socket(
115 EndpointConfig::default(),
116 Some(config),
117 runtime.wrap_udp_socket(socket)?,
118 runtime,
119 )
120 }
121
122 #[cfg(not(wasm_browser))]
124 pub fn new(
125 config: EndpointConfig,
126 server_config: Option<ServerConfig>,
127 socket: std::net::UdpSocket,
128 runtime: Arc<dyn Runtime>,
129 ) -> io::Result<Self> {
130 let socket = runtime.wrap_udp_socket(socket)?;
131 Self::new_with_abstract_socket(config, server_config, socket, runtime)
132 }
133
134 pub fn new_with_abstract_socket(
139 config: EndpointConfig,
140 server_config: Option<ServerConfig>,
141 socket: Arc<dyn AsyncUdpSocket>,
142 runtime: Arc<dyn Runtime>,
143 ) -> io::Result<Self> {
144 let addr = socket.local_addr()?;
145 let allow_mtud = !socket.may_fragment();
146 let rc = EndpointRef::new(
147 socket,
148 crate::endpoint::Endpoint::new(
149 Arc::new(config),
150 server_config.map(Arc::new),
151 allow_mtud,
152 None,
153 ),
154 addr.is_ipv6(),
155 runtime.clone(),
156 );
157 let driver = EndpointDriver(rc.clone());
158 runtime.spawn(Box::pin(
159 async {
160 if let Err(e) = driver.await {
161 tracing::error!("I/O error: {}", e);
162 }
163 }
164 .instrument(Span::current()),
165 ));
166 Ok(Self {
167 inner: rc,
168 default_client_config: None,
169 runtime,
170 })
171 }
172
173 pub fn accept(&self) -> Accept<'_> {
180 Accept {
181 endpoint: self,
182 notify: self.inner.shared.incoming.notified(),
183 }
184 }
185
186 pub fn set_default_client_config(&mut self, config: ClientConfig) {
188 self.default_client_config = Some(config);
189 }
190
191 pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
200 let config = match &self.default_client_config {
201 Some(config) => config.clone(),
202 None => return Err(ConnectError::NoDefaultClientConfig),
203 };
204
205 self.connect_with(config, addr, server_name)
206 }
207
208 pub fn connect_with(
214 &self,
215 config: ClientConfig,
216 addr: SocketAddr,
217 server_name: &str,
218 ) -> Result<Connecting, ConnectError> {
219 let mut endpoint = self.inner.state.lock().unwrap();
220 if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
221 return Err(ConnectError::EndpointStopping);
222 }
223 if addr.is_ipv6() && !endpoint.ipv6 {
224 return Err(ConnectError::InvalidRemoteAddress(addr));
225 }
226 let addr = if endpoint.ipv6 {
227 SocketAddr::V6(ensure_ipv6(addr))
228 } else {
229 addr
230 };
231
232 let (ch, conn) = endpoint
233 .inner
234 .connect(self.runtime.now(), config, addr, server_name)?;
235
236 let socket = endpoint.socket.clone();
237 endpoint.stats.outgoing_handshakes += 1;
238 Ok(endpoint
239 .recv_state
240 .connections
241 .insert(ch, conn, socket, self.runtime.clone()))
242 }
243
244 #[cfg(not(wasm_browser))]
248 pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
249 self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
250 }
251
252 pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
259 let addr = socket.local_addr()?;
260 let mut inner = self.inner.state.lock().unwrap();
261 inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
262 inner.ipv6 = addr.is_ipv6();
263
264 for sender in inner.recv_state.connections.senders.values() {
266 let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
268 }
269
270 Ok(())
271 }
272
273 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
277 self.inner
278 .state
279 .lock()
280 .unwrap()
281 .inner
282 .set_server_config(server_config.map(Arc::new))
283 }
284
285 pub fn local_addr(&self) -> io::Result<SocketAddr> {
287 self.inner.state.lock().unwrap().socket.local_addr()
288 }
289
290 pub fn open_connections(&self) -> usize {
292 self.inner.state.lock().unwrap().inner.open_connections()
293 }
294
295 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
301 let reason = Bytes::copy_from_slice(reason);
302 let mut endpoint = self.inner.state.lock().unwrap();
303 endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
304 for sender in endpoint.recv_state.connections.senders.values() {
305 let _ = sender.send(ConnectionEvent::Close {
307 error_code,
308 reason: reason.clone(),
309 });
310 }
311 self.inner.shared.incoming.notify_waiters();
312 }
313
314 pub async fn wait_idle(&self) {
325 loop {
326 {
327 let endpoint = &mut *self.inner.state.lock().unwrap();
328 if endpoint.recv_state.connections.is_empty() {
329 break;
330 }
331 self.inner.shared.idle.notified()
333 }
334 .await;
335 }
336 }
337}
338
339#[non_exhaustive]
341#[derive(Debug, Default, Copy, Clone)]
342pub struct EndpointStats {
343 pub accepted_handshakes: u64,
345 pub outgoing_handshakes: u64,
347 pub refused_handshakes: u64,
349 pub ignored_handshakes: u64,
351}
352
353#[must_use = "endpoint drivers must be spawned for I/O to occur"]
364#[derive(Debug)]
365pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
366
367impl Future for EndpointDriver {
368 type Output = Result<(), io::Error>;
369
370 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
371 let mut endpoint = self.0.state.lock().unwrap();
372 if endpoint.driver.is_none() {
373 endpoint.driver = Some(cx.waker().clone());
374 }
375
376 let now = endpoint.runtime.now();
377 let mut keep_going = false;
378 keep_going |= endpoint.drive_recv(cx, now)?;
379 keep_going |= endpoint.handle_events(cx, &self.0.shared);
380
381 if !endpoint.recv_state.incoming.is_empty() {
382 self.0.shared.incoming.notify_waiters();
383 }
384
385 if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
386 Poll::Ready(Ok(()))
387 } else {
388 drop(endpoint);
389 if keep_going {
393 cx.waker().wake_by_ref();
394 }
395 Poll::Pending
396 }
397 }
398}
399
400impl Drop for EndpointDriver {
401 fn drop(&mut self) {
402 let mut endpoint = self.0.state.lock().unwrap();
403 endpoint.driver_lost = true;
404 self.0.shared.incoming.notify_waiters();
405 endpoint.recv_state.connections.senders.clear();
408 }
409}
410
411#[derive(Debug)]
412pub(crate) struct EndpointInner {
413 pub(crate) state: Mutex<State>,
414 pub(crate) shared: Shared,
415}
416
417impl EndpointInner {
418 pub(crate) fn accept(
419 &self,
420 incoming: crate::Incoming,
421 server_config: Option<Arc<ServerConfig>>,
422 ) -> Result<Connecting, ConnectionError> {
423 let mut state = self.state.lock().unwrap();
424 let mut response_buffer = Vec::new();
425 let now = state.runtime.now();
426 match state
427 .inner
428 .accept(incoming, now, &mut response_buffer, server_config)
429 {
430 Ok((handle, conn)) => {
431 state.stats.accepted_handshakes += 1;
432 let socket = state.socket.clone();
433 let runtime = state.runtime.clone();
434 Ok(state
435 .recv_state
436 .connections
437 .insert(handle, conn, socket, runtime))
438 }
439 Err(error) => {
440 if let Some(transmit) = error.response {
441 respond(transmit, &response_buffer, &*state.socket);
442 }
443 Err(error.cause)
444 }
445 }
446 }
447
448 pub(crate) fn refuse(&self, incoming: crate::Incoming) {
449 let mut state = self.state.lock().unwrap();
450 state.stats.refused_handshakes += 1;
451 let mut response_buffer = Vec::new();
452 let transmit = state.inner.refuse(incoming, &mut response_buffer);
453 respond(transmit, &response_buffer, &*state.socket);
454 }
455
456 pub(crate) fn retry(
457 &self,
458 incoming: crate::Incoming,
459 ) -> Result<(), crate::endpoint::RetryError> {
460 let mut state = self.state.lock().unwrap();
461 let mut response_buffer = Vec::new();
462 let transmit = state.inner.retry(incoming, &mut response_buffer)?;
463 respond(transmit, &response_buffer, &*state.socket);
464 Ok(())
465 }
466
467 pub(crate) fn ignore(&self, incoming: crate::Incoming) {
468 let mut state = self.state.lock().unwrap();
469 state.stats.ignored_handshakes += 1;
470 state.inner.ignore(incoming);
471 }
472}
473
474#[derive(Debug)]
475pub(crate) struct State {
476 socket: Arc<dyn AsyncUdpSocket>,
477 prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
480 inner: crate::endpoint::Endpoint,
481 recv_state: RecvState,
482 driver: Option<Waker>,
483 ipv6: bool,
484 events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
485 ref_count: usize,
487 driver_lost: bool,
488 runtime: Arc<dyn Runtime>,
489 stats: EndpointStats,
490}
491
492#[derive(Debug)]
493pub(crate) struct Shared {
494 incoming: Notify,
495 idle: Notify,
496}
497
498impl State {
499 fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
500 let get_time = || self.runtime.now();
501 self.recv_state.recv_limiter.start_cycle(get_time);
502 if let Some(socket) = &self.prev_socket {
503 let poll_res =
505 self.recv_state
506 .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
507 if poll_res.is_err() {
508 self.prev_socket = None;
509 }
510 };
511 let poll_res =
512 self.recv_state
513 .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
514 self.recv_state.recv_limiter.finish_cycle(get_time);
515 let poll_res = poll_res?;
516 if poll_res.received_connection_packet {
517 self.prev_socket = None;
520 }
521 Ok(poll_res.keep_going)
522 }
523
524 fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
525 for _ in 0..IO_LOOP_BOUND {
526 let (ch, event) = match self.events.poll_recv(cx) {
527 Poll::Ready(Some(x)) => x,
528 Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
529 Poll::Pending => {
530 return false;
531 }
532 };
533
534 if event.is_drained() {
535 self.recv_state.connections.senders.remove(&ch);
536 if self.recv_state.connections.is_empty() {
537 shared.idle.notify_waiters();
538 }
539 }
540 let Some(event) = self.inner.handle_event(ch, event) else {
541 continue;
542 };
543 let _ = self
545 .recv_state
546 .connections
547 .senders
548 .get_mut(&ch)
549 .unwrap()
550 .send(ConnectionEvent::Proto(event));
551 }
552
553 true
554 }
555}
556
557impl Drop for State {
558 fn drop(&mut self) {
559 for incoming in self.recv_state.incoming.drain(..) {
560 self.inner.ignore(incoming);
561 }
562 }
563}
564
565fn respond(transmit: crate::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
566 _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
587}
588
589#[inline]
590fn proto_ecn(ecn: quinn_udp::EcnCodepoint) -> crate::EcnCodepoint {
591 match ecn {
592 quinn_udp::EcnCodepoint::Ect0 => crate::EcnCodepoint::Ect0,
593 quinn_udp::EcnCodepoint::Ect1 => crate::EcnCodepoint::Ect1,
594 quinn_udp::EcnCodepoint::Ce => crate::EcnCodepoint::Ce,
595 }
596}
597
598#[derive(Debug)]
599struct ConnectionSet {
600 senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
602 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
604 close: Option<(VarInt, Bytes)>,
606}
607
608impl ConnectionSet {
609 fn insert(
610 &mut self,
611 handle: ConnectionHandle,
612 conn: crate::Connection,
613 socket: Arc<dyn AsyncUdpSocket>,
614 runtime: Arc<dyn Runtime>,
615 ) -> Connecting {
616 let (send, recv) = mpsc::unbounded_channel();
617 if let Some((error_code, ref reason)) = self.close {
618 send.send(ConnectionEvent::Close {
619 error_code,
620 reason: reason.clone(),
621 })
622 .unwrap();
623 }
624 self.senders.insert(handle, send);
625 Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
626 }
627
628 fn is_empty(&self) -> bool {
629 self.senders.is_empty()
630 }
631}
632
633fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
634 match x {
635 SocketAddr::V6(x) => x,
636 SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
637 }
638}
639
640pin_project! {
641 pub struct Accept<'a> {
643 endpoint: &'a Endpoint,
644 #[pin]
645 notify: Notified<'a>,
646 }
647}
648
649impl Future for Accept<'_> {
650 type Output = Option<super::incoming::Incoming>;
651 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
652 let mut this = self.project();
653 let mut endpoint = this.endpoint.inner.state.lock().unwrap();
654 if endpoint.driver_lost {
655 return Poll::Ready(None);
656 }
657 if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
658 drop(endpoint);
660 let incoming = super::incoming::Incoming::new(incoming, this.endpoint.inner.clone());
661 return Poll::Ready(Some(incoming));
662 }
663 if endpoint.recv_state.connections.close.is_some() {
664 return Poll::Ready(None);
665 }
666 loop {
667 match this.notify.as_mut().poll(ctx) {
668 Poll::Pending => return Poll::Pending,
670 Poll::Ready(()) => this
672 .notify
673 .set(this.endpoint.inner.shared.incoming.notified()),
674 }
675 }
676 }
677}
678
679#[derive(Debug)]
680pub(crate) struct EndpointRef(Arc<EndpointInner>);
681
682impl EndpointRef {
683 pub(crate) fn new(
684 socket: Arc<dyn AsyncUdpSocket>,
685 inner: crate::endpoint::Endpoint,
686 ipv6: bool,
687 runtime: Arc<dyn Runtime>,
688 ) -> Self {
689 let (sender, events) = mpsc::unbounded_channel();
690 let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
691 Self(Arc::new(EndpointInner {
692 shared: Shared {
693 incoming: Notify::new(),
694 idle: Notify::new(),
695 },
696 state: Mutex::new(State {
697 socket,
698 prev_socket: None,
699 inner,
700 ipv6,
701 events,
702 driver: None,
703 ref_count: 0,
704 driver_lost: false,
705 recv_state,
706 runtime,
707 stats: EndpointStats::default(),
708 }),
709 }))
710 }
711}
712
713impl Clone for EndpointRef {
714 fn clone(&self) -> Self {
715 self.0.state.lock().unwrap().ref_count += 1;
716 Self(self.0.clone())
717 }
718}
719
720impl Drop for EndpointRef {
721 fn drop(&mut self) {
722 let endpoint = &mut *self.0.state.lock().unwrap();
723 if let Some(x) = endpoint.ref_count.checked_sub(1) {
724 endpoint.ref_count = x;
725 if x == 0 {
726 if let Some(task) = endpoint.driver.take() {
729 task.wake();
730 }
731 }
732 }
733 }
734}
735
736impl std::ops::Deref for EndpointRef {
737 type Target = EndpointInner;
738 fn deref(&self) -> &Self::Target {
739 &self.0
740 }
741}
742
743struct RecvState {
745 incoming: VecDeque<crate::Incoming>,
746 connections: ConnectionSet,
747 recv_buf: Box<[u8]>,
748 recv_limiter: WorkLimiter,
749}
750
751impl RecvState {
752 fn new(
753 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
754 max_receive_segments: usize,
755 endpoint: &crate::endpoint::Endpoint,
756 ) -> Self {
757 let recv_buf = vec![
758 0;
759 endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
760 * max_receive_segments
761 * BATCH_SIZE
762 ];
763 Self {
764 connections: ConnectionSet {
765 senders: FxHashMap::default(),
766 sender,
767 close: None,
768 },
769 incoming: VecDeque::new(),
770 recv_buf: recv_buf.into(),
771 recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
772 }
773 }
774
775 fn poll_socket(
776 &mut self,
777 cx: &mut Context,
778 endpoint: &mut crate::endpoint::Endpoint,
779 socket: &dyn AsyncUdpSocket,
780 runtime: &dyn Runtime,
781 now: Instant,
782 ) -> Result<PollProgress, io::Error> {
783 let mut received_connection_packet = false;
784 let mut metas = [RecvMeta::default(); BATCH_SIZE];
785 let mut iovs: [IoSliceMut; BATCH_SIZE] = {
786 let mut bufs = self
787 .recv_buf
788 .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
789 .map(IoSliceMut::new);
790
791 std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
795 };
796 loop {
797 match socket.poll_recv(cx, &mut iovs, &mut metas) {
798 Poll::Ready(Ok(msgs)) => {
799 self.recv_limiter.record_work(msgs);
800 for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
801 let mut data: BytesMut = buf[0..meta.len].into();
802 while !data.is_empty() {
803 let buf = data.split_to(meta.stride.min(data.len()));
804 let mut response_buffer = Vec::new();
805 match endpoint.handle(
806 now,
807 meta.addr,
808 meta.dst_ip,
809 meta.ecn.map(proto_ecn),
810 buf,
811 &mut response_buffer,
812 ) {
813 Some(DatagramEvent::NewConnection(incoming)) => {
814 if self.connections.close.is_none() {
815 self.incoming.push_back(incoming);
816 } else {
817 let transmit =
818 endpoint.refuse(incoming, &mut response_buffer);
819 respond(transmit, &response_buffer, socket);
820 }
821 }
822 Some(DatagramEvent::ConnectionEvent(handle, event)) => {
823 received_connection_packet = true;
825 let _ = self
826 .connections
827 .senders
828 .get_mut(&handle)
829 .unwrap()
830 .send(ConnectionEvent::Proto(event));
831 }
832 Some(DatagramEvent::Response(transmit)) => {
833 respond(transmit, &response_buffer, socket);
834 }
835 None => {}
836 }
837 }
838 }
839 }
840 Poll::Pending => {
841 return Ok(PollProgress {
842 received_connection_packet,
843 keep_going: false,
844 });
845 }
846 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
849 continue;
850 }
851 Poll::Ready(Err(e)) => {
852 return Err(e);
853 }
854 }
855 if !self.recv_limiter.allow_work(|| runtime.now()) {
856 return Ok(PollProgress {
857 received_connection_packet,
858 keep_going: true,
859 });
860 }
861 }
862 }
863}
864
865impl fmt::Debug for RecvState {
866 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
867 f.debug_struct("RecvState")
868 .field("incoming", &self.incoming)
869 .field("connections", &self.connections)
870 .field("recv_limiter", &self.recv_limiter)
872 .finish_non_exhaustive()
873 }
874}
875
876#[derive(Default)]
877struct PollProgress {
878 received_connection_packet: bool,
880 keep_going: bool,
882}