1use std::{
2 collections::VecDeque,
3 fmt,
4 future::Future,
5 io,
6 io::IoSliceMut,
7 mem,
8 net::{SocketAddr, SocketAddrV6},
9 pin::Pin,
10 str,
11 sync::{Arc, Mutex},
12 task::{Context, Poll, Waker},
13};
14
15#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))]
16use crate::runtime::default_runtime;
17use crate::{
18 runtime::{AsyncUdpSocket, Runtime},
19 udp_transmit, Instant,
20};
21use bytes::{Bytes, BytesMut};
22use pin_project_lite::pin_project;
23use proto::{
24 self as proto, ClientConfig, ConnectError, ConnectionError, ConnectionHandle, DatagramEvent,
25 EndpointEvent, ServerConfig,
26};
27use rustc_hash::FxHashMap;
28#[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring"),))]
29use socket2::{Domain, Protocol, Socket, Type};
30use tokio::sync::{futures::Notified, mpsc, Notify};
31use tracing::{Instrument, Span};
32use udp::{RecvMeta, BATCH_SIZE};
33
34use crate::{
35 connection::Connecting, incoming::Incoming, work_limiter::WorkLimiter, ConnectionEvent,
36 EndpointConfig, VarInt, IO_LOOP_BOUND, RECV_TIME_BOUND,
37};
38
39#[derive(Debug, Clone)]
46pub struct Endpoint {
47 pub(crate) inner: EndpointRef,
48 pub(crate) default_client_config: Option<ClientConfig>,
49 runtime: Arc<dyn Runtime>,
50}
51
52impl Endpoint {
53 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn client(addr: SocketAddr) -> io::Result<Self> {
72 let socket = Socket::new(Domain::for_address(addr), Type::DGRAM, Some(Protocol::UDP))?;
73 if addr.is_ipv6() {
74 if let Err(e) = socket.set_only_v6(false) {
75 tracing::debug!(%e, "unable to make socket dual-stack");
76 }
77 }
78 socket.bind(&addr.into())?;
79 let runtime = default_runtime()
80 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
81 Self::new_with_abstract_socket(
82 EndpointConfig::default(),
83 None,
84 runtime.wrap_udp_socket(socket.into())?,
85 runtime,
86 )
87 }
88
89 pub fn stats(&self) -> EndpointStats {
91 self.inner.state.lock().unwrap().stats
92 }
93
94 #[cfg(all(not(wasm_browser), any(feature = "aws-lc-rs", feature = "ring")))] pub fn server(config: ServerConfig, addr: SocketAddr) -> io::Result<Self> {
102 let socket = std::net::UdpSocket::bind(addr)?;
103 let runtime = default_runtime()
104 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "no async runtime found"))?;
105 Self::new_with_abstract_socket(
106 EndpointConfig::default(),
107 Some(config),
108 runtime.wrap_udp_socket(socket)?,
109 runtime,
110 )
111 }
112
113 #[cfg(not(wasm_browser))]
115 pub fn new(
116 config: EndpointConfig,
117 server_config: Option<ServerConfig>,
118 socket: std::net::UdpSocket,
119 runtime: Arc<dyn Runtime>,
120 ) -> io::Result<Self> {
121 let socket = runtime.wrap_udp_socket(socket)?;
122 Self::new_with_abstract_socket(config, server_config, socket, runtime)
123 }
124
125 pub fn new_with_abstract_socket(
130 config: EndpointConfig,
131 server_config: Option<ServerConfig>,
132 socket: Arc<dyn AsyncUdpSocket>,
133 runtime: Arc<dyn Runtime>,
134 ) -> io::Result<Self> {
135 let addr = socket.local_addr()?;
136 let allow_mtud = !socket.may_fragment();
137 let rc = EndpointRef::new(
138 socket,
139 proto::Endpoint::new(
140 Arc::new(config),
141 server_config.map(Arc::new),
142 allow_mtud,
143 None,
144 ),
145 addr.is_ipv6(),
146 runtime.clone(),
147 );
148 let driver = EndpointDriver(rc.clone());
149 runtime.spawn(Box::pin(
150 async {
151 if let Err(e) = driver.await {
152 tracing::error!("I/O error: {}", e);
153 }
154 }
155 .instrument(Span::current()),
156 ));
157 Ok(Self {
158 inner: rc,
159 default_client_config: None,
160 runtime,
161 })
162 }
163
164 pub fn accept(&self) -> Accept<'_> {
171 Accept {
172 endpoint: self,
173 notify: self.inner.shared.incoming.notified(),
174 }
175 }
176
177 pub fn set_default_client_config(&mut self, config: ClientConfig) {
179 self.default_client_config = Some(config);
180 }
181
182 pub fn connect(&self, addr: SocketAddr, server_name: &str) -> Result<Connecting, ConnectError> {
191 let config = match &self.default_client_config {
192 Some(config) => config.clone(),
193 None => return Err(ConnectError::NoDefaultClientConfig),
194 };
195
196 self.connect_with(config, addr, server_name)
197 }
198
199 pub fn connect_with(
205 &self,
206 config: ClientConfig,
207 addr: SocketAddr,
208 server_name: &str,
209 ) -> Result<Connecting, ConnectError> {
210 let mut endpoint = self.inner.state.lock().unwrap();
211 if endpoint.driver_lost || endpoint.recv_state.connections.close.is_some() {
212 return Err(ConnectError::EndpointStopping);
213 }
214 if addr.is_ipv6() && !endpoint.ipv6 {
215 return Err(ConnectError::InvalidRemoteAddress(addr));
216 }
217 let addr = if endpoint.ipv6 {
218 SocketAddr::V6(ensure_ipv6(addr))
219 } else {
220 addr
221 };
222
223 let (ch, conn) = endpoint
224 .inner
225 .connect(self.runtime.now(), config, addr, server_name)?;
226
227 let socket = endpoint.socket.clone();
228 endpoint.stats.outgoing_handshakes += 1;
229 Ok(endpoint
230 .recv_state
231 .connections
232 .insert(ch, conn, socket, self.runtime.clone()))
233 }
234
235 #[cfg(not(wasm_browser))]
239 pub fn rebind(&self, socket: std::net::UdpSocket) -> io::Result<()> {
240 self.rebind_abstract(self.runtime.wrap_udp_socket(socket)?)
241 }
242
243 pub fn rebind_abstract(&self, socket: Arc<dyn AsyncUdpSocket>) -> io::Result<()> {
250 let addr = socket.local_addr()?;
251 let mut inner = self.inner.state.lock().unwrap();
252 inner.prev_socket = Some(mem::replace(&mut inner.socket, socket));
253 inner.ipv6 = addr.is_ipv6();
254
255 for sender in inner.recv_state.connections.senders.values() {
257 let _ = sender.send(ConnectionEvent::Rebind(inner.socket.clone()));
259 }
260
261 Ok(())
262 }
263
264 pub fn set_server_config(&self, server_config: Option<ServerConfig>) {
268 self.inner
269 .state
270 .lock()
271 .unwrap()
272 .inner
273 .set_server_config(server_config.map(Arc::new))
274 }
275
276 pub fn local_addr(&self) -> io::Result<SocketAddr> {
278 self.inner.state.lock().unwrap().socket.local_addr()
279 }
280
281 pub fn open_connections(&self) -> usize {
283 self.inner.state.lock().unwrap().inner.open_connections()
284 }
285
286 pub fn close(&self, error_code: VarInt, reason: &[u8]) {
292 let reason = Bytes::copy_from_slice(reason);
293 let mut endpoint = self.inner.state.lock().unwrap();
294 endpoint.recv_state.connections.close = Some((error_code, reason.clone()));
295 for sender in endpoint.recv_state.connections.senders.values() {
296 let _ = sender.send(ConnectionEvent::Close {
298 error_code,
299 reason: reason.clone(),
300 });
301 }
302 self.inner.shared.incoming.notify_waiters();
303 }
304
305 pub async fn wait_idle(&self) {
316 loop {
317 {
318 let endpoint = &mut *self.inner.state.lock().unwrap();
319 if endpoint.recv_state.connections.is_empty() {
320 break;
321 }
322 self.inner.shared.idle.notified()
324 }
325 .await;
326 }
327 }
328}
329
330#[non_exhaustive]
332#[derive(Debug, Default, Copy, Clone)]
333pub struct EndpointStats {
334 pub accepted_handshakes: u64,
336 pub outgoing_handshakes: u64,
338 pub refused_handshakes: u64,
340 pub ignored_handshakes: u64,
342}
343
344#[must_use = "endpoint drivers must be spawned for I/O to occur"]
355#[derive(Debug)]
356pub(crate) struct EndpointDriver(pub(crate) EndpointRef);
357
358impl Future for EndpointDriver {
359 type Output = Result<(), io::Error>;
360
361 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
362 let mut endpoint = self.0.state.lock().unwrap();
363 if endpoint.driver.is_none() {
364 endpoint.driver = Some(cx.waker().clone());
365 }
366
367 let now = endpoint.runtime.now();
368 let mut keep_going = false;
369 keep_going |= endpoint.drive_recv(cx, now)?;
370 keep_going |= endpoint.handle_events(cx, &self.0.shared);
371
372 if !endpoint.recv_state.incoming.is_empty() {
373 self.0.shared.incoming.notify_waiters();
374 }
375
376 if endpoint.ref_count == 0 && endpoint.recv_state.connections.is_empty() {
377 Poll::Ready(Ok(()))
378 } else {
379 drop(endpoint);
380 if keep_going {
384 cx.waker().wake_by_ref();
385 }
386 Poll::Pending
387 }
388 }
389}
390
391impl Drop for EndpointDriver {
392 fn drop(&mut self) {
393 let mut endpoint = self.0.state.lock().unwrap();
394 endpoint.driver_lost = true;
395 self.0.shared.incoming.notify_waiters();
396 endpoint.recv_state.connections.senders.clear();
399 }
400}
401
402#[derive(Debug)]
403pub(crate) struct EndpointInner {
404 pub(crate) state: Mutex<State>,
405 pub(crate) shared: Shared,
406}
407
408impl EndpointInner {
409 pub(crate) fn accept(
410 &self,
411 incoming: proto::Incoming,
412 server_config: Option<Arc<ServerConfig>>,
413 ) -> Result<Connecting, ConnectionError> {
414 let mut state = self.state.lock().unwrap();
415 let mut response_buffer = Vec::new();
416 let now = state.runtime.now();
417 match state
418 .inner
419 .accept(incoming, now, &mut response_buffer, server_config)
420 {
421 Ok((handle, conn)) => {
422 state.stats.accepted_handshakes += 1;
423 let socket = state.socket.clone();
424 let runtime = state.runtime.clone();
425 Ok(state
426 .recv_state
427 .connections
428 .insert(handle, conn, socket, runtime))
429 }
430 Err(error) => {
431 if let Some(transmit) = error.response {
432 respond(transmit, &response_buffer, &*state.socket);
433 }
434 Err(error.cause)
435 }
436 }
437 }
438
439 pub(crate) fn refuse(&self, incoming: proto::Incoming) {
440 let mut state = self.state.lock().unwrap();
441 state.stats.refused_handshakes += 1;
442 let mut response_buffer = Vec::new();
443 let transmit = state.inner.refuse(incoming, &mut response_buffer);
444 respond(transmit, &response_buffer, &*state.socket);
445 }
446
447 pub(crate) fn retry(&self, incoming: proto::Incoming) -> Result<(), proto::RetryError> {
448 let mut state = self.state.lock().unwrap();
449 let mut response_buffer = Vec::new();
450 let transmit = state.inner.retry(incoming, &mut response_buffer)?;
451 respond(transmit, &response_buffer, &*state.socket);
452 Ok(())
453 }
454
455 pub(crate) fn ignore(&self, incoming: proto::Incoming) {
456 let mut state = self.state.lock().unwrap();
457 state.stats.ignored_handshakes += 1;
458 state.inner.ignore(incoming);
459 }
460}
461
462#[derive(Debug)]
463pub(crate) struct State {
464 socket: Arc<dyn AsyncUdpSocket>,
465 prev_socket: Option<Arc<dyn AsyncUdpSocket>>,
468 inner: proto::Endpoint,
469 recv_state: RecvState,
470 driver: Option<Waker>,
471 ipv6: bool,
472 events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
473 ref_count: usize,
475 driver_lost: bool,
476 runtime: Arc<dyn Runtime>,
477 stats: EndpointStats,
478}
479
480#[derive(Debug)]
481pub(crate) struct Shared {
482 incoming: Notify,
483 idle: Notify,
484}
485
486impl State {
487 fn drive_recv(&mut self, cx: &mut Context, now: Instant) -> Result<bool, io::Error> {
488 let get_time = || self.runtime.now();
489 self.recv_state.recv_limiter.start_cycle(get_time);
490 if let Some(socket) = &self.prev_socket {
491 let poll_res =
493 self.recv_state
494 .poll_socket(cx, &mut self.inner, &**socket, &*self.runtime, now);
495 if poll_res.is_err() {
496 self.prev_socket = None;
497 }
498 };
499 let poll_res =
500 self.recv_state
501 .poll_socket(cx, &mut self.inner, &*self.socket, &*self.runtime, now);
502 self.recv_state.recv_limiter.finish_cycle(get_time);
503 let poll_res = poll_res?;
504 if poll_res.received_connection_packet {
505 self.prev_socket = None;
508 }
509 Ok(poll_res.keep_going)
510 }
511
512 fn handle_events(&mut self, cx: &mut Context, shared: &Shared) -> bool {
513 for _ in 0..IO_LOOP_BOUND {
514 let (ch, event) = match self.events.poll_recv(cx) {
515 Poll::Ready(Some(x)) => x,
516 Poll::Ready(None) => unreachable!("EndpointInner owns one sender"),
517 Poll::Pending => {
518 return false;
519 }
520 };
521
522 if event.is_drained() {
523 self.recv_state.connections.senders.remove(&ch);
524 if self.recv_state.connections.is_empty() {
525 shared.idle.notify_waiters();
526 }
527 }
528 let Some(event) = self.inner.handle_event(ch, event) else {
529 continue;
530 };
531 let _ = self
533 .recv_state
534 .connections
535 .senders
536 .get_mut(&ch)
537 .unwrap()
538 .send(ConnectionEvent::Proto(event));
539 }
540
541 true
542 }
543}
544
545impl Drop for State {
546 fn drop(&mut self) {
547 for incoming in self.recv_state.incoming.drain(..) {
548 self.inner.ignore(incoming);
549 }
550 }
551}
552
553fn respond(transmit: proto::Transmit, response_buffer: &[u8], socket: &dyn AsyncUdpSocket) {
554 _ = socket.try_send(&udp_transmit(&transmit, &response_buffer[..transmit.size]));
575}
576
577#[inline]
578fn proto_ecn(ecn: udp::EcnCodepoint) -> proto::EcnCodepoint {
579 match ecn {
580 udp::EcnCodepoint::Ect0 => proto::EcnCodepoint::Ect0,
581 udp::EcnCodepoint::Ect1 => proto::EcnCodepoint::Ect1,
582 udp::EcnCodepoint::Ce => proto::EcnCodepoint::Ce,
583 }
584}
585
586#[derive(Debug)]
587struct ConnectionSet {
588 senders: FxHashMap<ConnectionHandle, mpsc::UnboundedSender<ConnectionEvent>>,
590 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
592 close: Option<(VarInt, Bytes)>,
594}
595
596impl ConnectionSet {
597 fn insert(
598 &mut self,
599 handle: ConnectionHandle,
600 conn: proto::Connection,
601 socket: Arc<dyn AsyncUdpSocket>,
602 runtime: Arc<dyn Runtime>,
603 ) -> Connecting {
604 let (send, recv) = mpsc::unbounded_channel();
605 if let Some((error_code, ref reason)) = self.close {
606 send.send(ConnectionEvent::Close {
607 error_code,
608 reason: reason.clone(),
609 })
610 .unwrap();
611 }
612 self.senders.insert(handle, send);
613 Connecting::new(handle, conn, self.sender.clone(), recv, socket, runtime)
614 }
615
616 fn is_empty(&self) -> bool {
617 self.senders.is_empty()
618 }
619}
620
621fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
622 match x {
623 SocketAddr::V6(x) => x,
624 SocketAddr::V4(x) => SocketAddrV6::new(x.ip().to_ipv6_mapped(), x.port(), 0, 0),
625 }
626}
627
628pin_project! {
629 pub struct Accept<'a> {
631 endpoint: &'a Endpoint,
632 #[pin]
633 notify: Notified<'a>,
634 }
635}
636
637impl Future for Accept<'_> {
638 type Output = Option<Incoming>;
639 fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
640 let mut this = self.project();
641 let mut endpoint = this.endpoint.inner.state.lock().unwrap();
642 if endpoint.driver_lost {
643 return Poll::Ready(None);
644 }
645 if let Some(incoming) = endpoint.recv_state.incoming.pop_front() {
646 drop(endpoint);
648 let incoming = Incoming::new(incoming, this.endpoint.inner.clone());
649 return Poll::Ready(Some(incoming));
650 }
651 if endpoint.recv_state.connections.close.is_some() {
652 return Poll::Ready(None);
653 }
654 loop {
655 match this.notify.as_mut().poll(ctx) {
656 Poll::Pending => return Poll::Pending,
658 Poll::Ready(()) => this
660 .notify
661 .set(this.endpoint.inner.shared.incoming.notified()),
662 }
663 }
664 }
665}
666
667#[derive(Debug)]
668pub(crate) struct EndpointRef(Arc<EndpointInner>);
669
670impl EndpointRef {
671 pub(crate) fn new(
672 socket: Arc<dyn AsyncUdpSocket>,
673 inner: proto::Endpoint,
674 ipv6: bool,
675 runtime: Arc<dyn Runtime>,
676 ) -> Self {
677 let (sender, events) = mpsc::unbounded_channel();
678 let recv_state = RecvState::new(sender, socket.max_receive_segments(), &inner);
679 Self(Arc::new(EndpointInner {
680 shared: Shared {
681 incoming: Notify::new(),
682 idle: Notify::new(),
683 },
684 state: Mutex::new(State {
685 socket,
686 prev_socket: None,
687 inner,
688 ipv6,
689 events,
690 driver: None,
691 ref_count: 0,
692 driver_lost: false,
693 recv_state,
694 runtime,
695 stats: EndpointStats::default(),
696 }),
697 }))
698 }
699}
700
701impl Clone for EndpointRef {
702 fn clone(&self) -> Self {
703 self.0.state.lock().unwrap().ref_count += 1;
704 Self(self.0.clone())
705 }
706}
707
708impl Drop for EndpointRef {
709 fn drop(&mut self) {
710 let endpoint = &mut *self.0.state.lock().unwrap();
711 if let Some(x) = endpoint.ref_count.checked_sub(1) {
712 endpoint.ref_count = x;
713 if x == 0 {
714 if let Some(task) = endpoint.driver.take() {
717 task.wake();
718 }
719 }
720 }
721 }
722}
723
724impl std::ops::Deref for EndpointRef {
725 type Target = EndpointInner;
726 fn deref(&self) -> &Self::Target {
727 &self.0
728 }
729}
730
731struct RecvState {
733 incoming: VecDeque<proto::Incoming>,
734 connections: ConnectionSet,
735 recv_buf: Box<[u8]>,
736 recv_limiter: WorkLimiter,
737}
738
739impl RecvState {
740 fn new(
741 sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
742 max_receive_segments: usize,
743 endpoint: &proto::Endpoint,
744 ) -> Self {
745 let recv_buf = vec![
746 0;
747 endpoint.config().get_max_udp_payload_size().min(64 * 1024) as usize
748 * max_receive_segments
749 * BATCH_SIZE
750 ];
751 Self {
752 connections: ConnectionSet {
753 senders: FxHashMap::default(),
754 sender,
755 close: None,
756 },
757 incoming: VecDeque::new(),
758 recv_buf: recv_buf.into(),
759 recv_limiter: WorkLimiter::new(RECV_TIME_BOUND),
760 }
761 }
762
763 fn poll_socket(
764 &mut self,
765 cx: &mut Context,
766 endpoint: &mut proto::Endpoint,
767 socket: &dyn AsyncUdpSocket,
768 runtime: &dyn Runtime,
769 now: Instant,
770 ) -> Result<PollProgress, io::Error> {
771 let mut received_connection_packet = false;
772 let mut metas = [RecvMeta::default(); BATCH_SIZE];
773 let mut iovs: [IoSliceMut; BATCH_SIZE] = {
774 let mut bufs = self
775 .recv_buf
776 .chunks_mut(self.recv_buf.len() / BATCH_SIZE)
777 .map(IoSliceMut::new);
778
779 std::array::from_fn(|_| bufs.next().expect("BATCH_SIZE elements"))
783 };
784 loop {
785 match socket.poll_recv(cx, &mut iovs, &mut metas) {
786 Poll::Ready(Ok(msgs)) => {
787 self.recv_limiter.record_work(msgs);
788 for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) {
789 let mut data: BytesMut = buf[0..meta.len].into();
790 while !data.is_empty() {
791 let buf = data.split_to(meta.stride.min(data.len()));
792 let mut response_buffer = Vec::new();
793 match endpoint.handle(
794 now,
795 meta.addr,
796 meta.dst_ip,
797 meta.ecn.map(proto_ecn),
798 buf,
799 &mut response_buffer,
800 ) {
801 Some(DatagramEvent::NewConnection(incoming)) => {
802 if self.connections.close.is_none() {
803 self.incoming.push_back(incoming);
804 } else {
805 let transmit =
806 endpoint.refuse(incoming, &mut response_buffer);
807 respond(transmit, &response_buffer, socket);
808 }
809 }
810 Some(DatagramEvent::ConnectionEvent(handle, event)) => {
811 received_connection_packet = true;
813 let _ = self
814 .connections
815 .senders
816 .get_mut(&handle)
817 .unwrap()
818 .send(ConnectionEvent::Proto(event));
819 }
820 Some(DatagramEvent::Response(transmit)) => {
821 respond(transmit, &response_buffer, socket);
822 }
823 None => {}
824 }
825 }
826 }
827 }
828 Poll::Pending => {
829 return Ok(PollProgress {
830 received_connection_packet,
831 keep_going: false,
832 });
833 }
834 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionReset => {
837 continue;
838 }
839 Poll::Ready(Err(e)) => {
840 return Err(e);
841 }
842 }
843 if !self.recv_limiter.allow_work(|| runtime.now()) {
844 return Ok(PollProgress {
845 received_connection_packet,
846 keep_going: true,
847 });
848 }
849 }
850 }
851}
852
853impl fmt::Debug for RecvState {
854 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
855 f.debug_struct("RecvState")
856 .field("incoming", &self.incoming)
857 .field("connections", &self.connections)
858 .field("recv_limiter", &self.recv_limiter)
860 .finish_non_exhaustive()
861 }
862}
863
864#[derive(Default)]
865struct PollProgress {
866 received_connection_packet: bool,
868 keep_going: bool,
870}