1use std::{
2 collections::{hash_map, HashMap},
3 convert::TryFrom,
4 fmt, mem,
5 net::{IpAddr, SocketAddr},
6 ops::{Index, IndexMut},
7 sync::Arc,
8};
9
10use bytes::{BufMut, Bytes, BytesMut};
11use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
12use rustc_hash::FxHashMap;
13use slab::Slab;
14use thiserror::Error;
15use tracing::{debug, error, trace, warn};
16
17use crate::{
18 cid_generator::ConnectionIdGenerator,
19 coding::BufMutExt,
20 config::{ClientConfig, EndpointConfig, ServerConfig},
21 connection::{Connection, ConnectionError, SideArgs},
22 crypto::{self, Keys, UnsupportedVersion},
23 frame,
24 packet::{
25 FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, PacketDecodeError,
26 PacketNumber, PartialDecode, ProtectedInitialHeader,
27 },
28 shared::{
29 ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
30 EndpointEvent, EndpointEventInner, IssuedCid,
31 },
32 token::{IncomingToken, InvalidRetryTokenError, Token, TokenPayload},
33 transport_parameters::{PreferredAddress, TransportParameters},
34 Duration, Instant, ResetToken, Side, Transmit, TransportConfig, TransportError, INITIAL_MTU,
35 MAX_CID_SIZE, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE,
36};
37
38pub struct Endpoint {
43 rng: StdRng,
44 index: ConnectionIndex,
45 connections: Slab<ConnectionMeta>,
46 local_cid_generator: Box<dyn ConnectionIdGenerator>,
47 config: Arc<EndpointConfig>,
48 server_config: Option<Arc<ServerConfig>>,
49 allow_mtud: bool,
51 last_stateless_reset: Option<Instant>,
53 incoming_buffers: Slab<IncomingBuffer>,
55 all_incoming_buffers_total_bytes: u64,
56}
57
58impl Endpoint {
59 pub fn new(
70 config: Arc<EndpointConfig>,
71 server_config: Option<Arc<ServerConfig>>,
72 allow_mtud: bool,
73 rng_seed: Option<[u8; 32]>,
74 ) -> Self {
75 let rng_seed = rng_seed.or(config.rng_seed);
76 Self {
77 rng: rng_seed.map_or(StdRng::from_entropy(), StdRng::from_seed),
78 index: ConnectionIndex::default(),
79 connections: Slab::new(),
80 local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
81 config,
82 server_config,
83 allow_mtud,
84 last_stateless_reset: None,
85 incoming_buffers: Slab::new(),
86 all_incoming_buffers_total_bytes: 0,
87 }
88 }
89
90 pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
92 self.server_config = server_config;
93 }
94
95 pub fn handle_event(
99 &mut self,
100 ch: ConnectionHandle,
101 event: EndpointEvent,
102 ) -> Option<ConnectionEvent> {
103 use EndpointEventInner::*;
104 match event.0 {
105 NeedIdentifiers(now, n) => {
106 return Some(self.send_new_identifiers(now, ch, n));
107 }
108 ResetToken(remote, token) => {
109 if let Some(old) = self.connections[ch].reset_token.replace((remote, token)) {
110 self.index.connection_reset_tokens.remove(old.0, old.1);
111 }
112 if self.index.connection_reset_tokens.insert(remote, token, ch) {
113 warn!("duplicate reset token");
114 }
115 }
116 RetireConnectionId(now, seq, allow_more_cids) => {
117 if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) {
118 trace!("peer retired CID {}: {}", seq, cid);
119 self.index.retire(cid);
120 if allow_more_cids {
121 return Some(self.send_new_identifiers(now, ch, 1));
122 }
123 }
124 }
125 Drained => {
126 if let Some(conn) = self.connections.try_remove(ch.0) {
127 self.index.remove(&conn);
128 } else {
129 error!(id = ch.0, "unknown connection drained");
133 }
134 }
135 }
136 None
137 }
138
139 pub fn handle(
141 &mut self,
142 now: Instant,
143 remote: SocketAddr,
144 local_ip: Option<IpAddr>,
145 ecn: Option<EcnCodepoint>,
146 data: BytesMut,
147 buf: &mut Vec<u8>,
148 ) -> Option<DatagramEvent> {
149 let datagram_len = data.len();
151 let event = match PartialDecode::new(
152 data,
153 &FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
154 &self.config.supported_versions,
155 self.config.grease_quic_bit,
156 ) {
157 Ok((first_decode, remaining)) => DatagramConnectionEvent {
158 now,
159 remote,
160 ecn,
161 first_decode,
162 remaining,
163 },
164 Err(PacketDecodeError::UnsupportedVersion {
165 src_cid,
166 dst_cid,
167 version,
168 }) => {
169 if self.server_config.is_none() {
170 debug!("dropping packet with unsupported version");
171 return None;
172 }
173 trace!("sending version negotiation");
174 Header::VersionNegotiate {
176 random: self.rng.gen::<u8>() | 0x40,
177 src_cid: dst_cid,
178 dst_cid: src_cid,
179 }
180 .encode(buf);
181 buf.write::<u32>(match version {
183 0x0a1a_2a3a => 0x0a1a_2a4a,
184 _ => 0x0a1a_2a3a,
185 });
186 for &version in &self.config.supported_versions {
187 buf.write(version);
188 }
189 return Some(DatagramEvent::Response(Transmit {
190 destination: remote,
191 ecn: None,
192 size: buf.len(),
193 segment_size: None,
194 src_ip: local_ip,
195 }));
196 }
197 Err(e) => {
198 trace!("malformed header: {}", e);
199 return None;
200 }
201 };
202
203 let addresses = FourTuple { remote, local_ip };
204 let dst_cid = event.first_decode.dst_cid();
205
206 if let Some(route_to) = self.index.get(&addresses, &event.first_decode) {
207 match route_to {
209 RouteDatagramTo::Incoming(incoming_idx) => {
210 let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
211 let config = &self.server_config.as_ref().unwrap();
212
213 if incoming_buffer
214 .total_bytes
215 .checked_add(datagram_len as u64)
216 .is_some_and(|n| n <= config.incoming_buffer_size)
217 && self
218 .all_incoming_buffers_total_bytes
219 .checked_add(datagram_len as u64)
220 .is_some_and(|n| n <= config.incoming_buffer_size_total)
221 {
222 incoming_buffer.datagrams.push(event);
223 incoming_buffer.total_bytes += datagram_len as u64;
224 self.all_incoming_buffers_total_bytes += datagram_len as u64;
225 }
226
227 None
228 }
229 RouteDatagramTo::Connection(ch) => Some(DatagramEvent::ConnectionEvent(
230 ch,
231 ConnectionEvent(ConnectionEventInner::Datagram(event)),
232 )),
233 }
234 } else if event.first_decode.initial_header().is_some() {
235 self.handle_first_packet(datagram_len, event, addresses, buf)
238 } else if event.first_decode.has_long_header() {
239 debug!(
240 "ignoring non-initial packet for unknown connection {}",
241 dst_cid
242 );
243 None
244 } else if !event.first_decode.is_initial()
245 && self.local_cid_generator.validate(dst_cid).is_err()
246 {
247 debug!("dropping packet with invalid CID");
251 None
252 } else if dst_cid.is_empty() {
253 trace!("dropping unrecognized short packet without ID");
254 None
255 } else {
256 self.stateless_reset(now, datagram_len, addresses, *dst_cid, buf)
257 .map(DatagramEvent::Response)
258 }
259 }
260
261 fn stateless_reset(
262 &mut self,
263 now: Instant,
264 inciting_dgram_len: usize,
265 addresses: FourTuple,
266 dst_cid: ConnectionId,
267 buf: &mut Vec<u8>,
268 ) -> Option<Transmit> {
269 if self
270 .last_stateless_reset
271 .is_some_and(|last| last + self.config.min_reset_interval > now)
272 {
273 debug!("ignoring unexpected packet within minimum stateless reset interval");
274 return None;
275 }
276
277 const MIN_PADDING_LEN: usize = 5;
279
280 let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
283 Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
284 _ => {
285 debug!("ignoring unexpected {} byte packet: not larger than minimum stateless reset size", inciting_dgram_len);
286 return None;
287 }
288 };
289
290 debug!(
291 "sending stateless reset for {} to {}",
292 dst_cid, addresses.remote
293 );
294 self.last_stateless_reset = Some(now);
295 const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
297 let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
298 max_padding_len
299 } else {
300 self.rng.gen_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
301 };
302 buf.reserve(padding_len + RESET_TOKEN_SIZE);
303 buf.resize(padding_len, 0);
304 self.rng.fill_bytes(&mut buf[0..padding_len]);
305 buf[0] = 0b0100_0000 | buf[0] >> 2;
306 buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
307
308 debug_assert!(buf.len() < inciting_dgram_len);
309
310 Some(Transmit {
311 destination: addresses.remote,
312 ecn: None,
313 size: buf.len(),
314 segment_size: None,
315 src_ip: addresses.local_ip,
316 })
317 }
318
319 pub fn connect(
321 &mut self,
322 now: Instant,
323 config: ClientConfig,
324 remote: SocketAddr,
325 server_name: &str,
326 ) -> Result<(ConnectionHandle, Connection), ConnectError> {
327 if self.cids_exhausted() {
328 return Err(ConnectError::CidsExhausted);
329 }
330 if remote.port() == 0 || remote.ip().is_unspecified() {
331 return Err(ConnectError::InvalidRemoteAddress(remote));
332 }
333 if !self.config.supported_versions.contains(&config.version) {
334 return Err(ConnectError::UnsupportedVersion);
335 }
336
337 let remote_id = (config.initial_dst_cid_provider)();
338 trace!(initial_dcid = %remote_id);
339
340 let ch = ConnectionHandle(self.connections.vacant_key());
341 let loc_cid = self.new_cid(ch);
342 let params = TransportParameters::new(
343 &config.transport,
344 &self.config,
345 self.local_cid_generator.as_ref(),
346 loc_cid,
347 None,
348 &mut self.rng,
349 );
350 let tls = config
351 .crypto
352 .start_session(config.version, server_name, ¶ms)?;
353
354 let conn = self.add_connection(
355 ch,
356 config.version,
357 remote_id,
358 loc_cid,
359 remote_id,
360 FourTuple {
361 remote,
362 local_ip: None,
363 },
364 now,
365 tls,
366 config.transport,
367 SideArgs::Client {
368 token_store: config.token_store,
369 server_name: server_name.into(),
370 },
371 );
372 Ok((ch, conn))
373 }
374
375 fn send_new_identifiers(
376 &mut self,
377 now: Instant,
378 ch: ConnectionHandle,
379 num: u64,
380 ) -> ConnectionEvent {
381 let mut ids = vec![];
382 for _ in 0..num {
383 let id = self.new_cid(ch);
384 let meta = &mut self.connections[ch];
385 let sequence = meta.cids_issued;
386 meta.cids_issued += 1;
387 meta.loc_cids.insert(sequence, id);
388 ids.push(IssuedCid {
389 sequence,
390 id,
391 reset_token: ResetToken::new(&*self.config.reset_key, id),
392 });
393 }
394 ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
395 }
396
397 fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
399 loop {
400 let cid = self.local_cid_generator.generate_cid();
401 if cid.len() == 0 {
402 debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
404 return cid;
405 }
406 if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
407 e.insert(ch);
408 break cid;
409 }
410 }
411 }
412
413 fn handle_first_packet(
414 &mut self,
415 datagram_len: usize,
416 event: DatagramConnectionEvent,
417 addresses: FourTuple,
418 buf: &mut Vec<u8>,
419 ) -> Option<DatagramEvent> {
420 let dst_cid = event.first_decode.dst_cid();
421 let header = event.first_decode.initial_header().unwrap();
422
423 let Some(server_config) = &self.server_config else {
424 debug!("packet for unrecognized connection {}", dst_cid);
425 return self
426 .stateless_reset(event.now, datagram_len, addresses, *dst_cid, buf)
427 .map(DatagramEvent::Response);
428 };
429
430 if datagram_len < MIN_INITIAL_SIZE as usize {
431 debug!("ignoring short initial for connection {}", dst_cid);
432 return None;
433 }
434
435 let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
436 Ok(keys) => keys,
437 Err(UnsupportedVersion) => {
438 debug!(
441 "ignoring initial packet version {:#x} unsupported by cryptographic layer",
442 header.version
443 );
444 return None;
445 }
446 };
447
448 if let Err(reason) = self.early_validate_first_packet(header) {
449 return Some(DatagramEvent::Response(self.initial_close(
450 header.version,
451 addresses,
452 &crypto,
453 &header.src_cid,
454 reason,
455 buf,
456 )));
457 }
458
459 let packet = match event.first_decode.finish(Some(&*crypto.header.remote)) {
460 Ok(packet) => packet,
461 Err(e) => {
462 trace!("unable to decode initial packet: {}", e);
463 return None;
464 }
465 };
466
467 if !packet.reserved_bits_valid() {
468 debug!("dropping connection attempt with invalid reserved bits");
469 return None;
470 }
471
472 let Header::Initial(header) = packet.header else {
473 panic!("non-initial packet in handle_first_packet()");
474 };
475
476 let server_config = self.server_config.as_ref().unwrap().clone();
477
478 let token = match IncomingToken::from_header(&header, &server_config, addresses.remote) {
479 Ok(token) => token,
480 Err(InvalidRetryTokenError) => {
481 debug!("rejecting invalid retry token");
482 return Some(DatagramEvent::Response(self.initial_close(
483 header.version,
484 addresses,
485 &crypto,
486 &header.src_cid,
487 TransportError::INVALID_TOKEN(""),
488 buf,
489 )));
490 }
491 };
492
493 let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
494 self.index
495 .insert_initial_incoming(header.dst_cid, incoming_idx);
496
497 Some(DatagramEvent::NewConnection(Incoming {
498 received_at: event.now,
499 addresses,
500 ecn: event.ecn,
501 packet: InitialPacket {
502 header,
503 header_data: packet.header_data,
504 payload: packet.payload,
505 },
506 rest: event.remaining,
507 crypto,
508 token,
509 incoming_idx,
510 improper_drop_warner: IncomingImproperDropWarner,
511 }))
512 }
513
514 pub fn accept(
516 &mut self,
517 mut incoming: Incoming,
518 now: Instant,
519 buf: &mut Vec<u8>,
520 server_config: Option<Arc<ServerConfig>>,
521 ) -> Result<(ConnectionHandle, Connection), AcceptError> {
522 let remote_address_validated = incoming.remote_address_validated();
523 incoming.improper_drop_warner.dismiss();
524 let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
525 self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
526
527 let packet_number = incoming.packet.header.number.expand(0);
528 let InitialHeader {
529 src_cid,
530 dst_cid,
531 version,
532 ..
533 } = incoming.packet.header;
534 let server_config =
535 server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
536
537 if server_config
538 .transport
539 .max_idle_timeout
540 .is_some_and(|timeout| {
541 incoming.received_at + Duration::from_millis(timeout.into()) <= now
542 })
543 {
544 debug!("abandoning accept of stale initial");
545 self.index.remove_initial(dst_cid);
546 return Err(AcceptError {
547 cause: ConnectionError::TimedOut,
548 response: None,
549 });
550 }
551
552 if self.cids_exhausted() {
553 debug!("refusing connection");
554 self.index.remove_initial(dst_cid);
555 return Err(AcceptError {
556 cause: ConnectionError::CidsExhausted,
557 response: Some(self.initial_close(
558 version,
559 incoming.addresses,
560 &incoming.crypto,
561 &src_cid,
562 TransportError::CONNECTION_REFUSED(""),
563 buf,
564 )),
565 });
566 }
567
568 if incoming
569 .crypto
570 .packet
571 .remote
572 .decrypt(
573 packet_number,
574 &incoming.packet.header_data,
575 &mut incoming.packet.payload,
576 )
577 .is_err()
578 {
579 debug!(packet_number, "failed to authenticate initial packet");
580 self.index.remove_initial(dst_cid);
581 return Err(AcceptError {
582 cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
583 response: None,
584 });
585 };
586
587 let ch = ConnectionHandle(self.connections.vacant_key());
588 let loc_cid = self.new_cid(ch);
589 let mut params = TransportParameters::new(
590 &server_config.transport,
591 &self.config,
592 self.local_cid_generator.as_ref(),
593 loc_cid,
594 Some(&server_config),
595 &mut self.rng,
596 );
597 params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, loc_cid));
598 params.original_dst_cid = Some(incoming.token.orig_dst_cid);
599 params.retry_src_cid = incoming.token.retry_src_cid;
600 let mut pref_addr_cid = None;
601 if server_config.preferred_address_v4.is_some()
602 || server_config.preferred_address_v6.is_some()
603 {
604 let cid = self.new_cid(ch);
605 pref_addr_cid = Some(cid);
606 params.preferred_address = Some(PreferredAddress {
607 address_v4: server_config.preferred_address_v4,
608 address_v6: server_config.preferred_address_v6,
609 connection_id: cid,
610 stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
611 });
612 }
613
614 let tls = server_config.crypto.clone().start_session(version, ¶ms);
615 let transport_config = server_config.transport.clone();
616 let mut conn = self.add_connection(
617 ch,
618 version,
619 dst_cid,
620 loc_cid,
621 src_cid,
622 incoming.addresses,
623 incoming.received_at,
624 tls,
625 transport_config,
626 SideArgs::Server {
627 server_config,
628 pref_addr_cid,
629 path_validated: remote_address_validated,
630 },
631 );
632 self.index.insert_initial(dst_cid, ch);
633
634 match conn.handle_first_packet(
635 incoming.received_at,
636 incoming.addresses.remote,
637 incoming.ecn,
638 packet_number,
639 incoming.packet,
640 incoming.rest,
641 ) {
642 Ok(()) => {
643 trace!(id = ch.0, icid = %dst_cid, "new connection");
644
645 for event in incoming_buffer.datagrams {
646 conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
647 }
648
649 Ok((ch, conn))
650 }
651 Err(e) => {
652 debug!("handshake failed: {}", e);
653 self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
654 let response = match e {
655 ConnectionError::TransportError(ref e) => Some(self.initial_close(
656 version,
657 incoming.addresses,
658 &incoming.crypto,
659 &src_cid,
660 e.clone(),
661 buf,
662 )),
663 _ => None,
664 };
665 Err(AcceptError { cause: e, response })
666 }
667 }
668 }
669
670 fn early_validate_first_packet(
672 &mut self,
673 header: &ProtectedInitialHeader,
674 ) -> Result<(), TransportError> {
675 let config = &self.server_config.as_ref().unwrap();
676 if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
677 return Err(TransportError::CONNECTION_REFUSED(""));
678 }
679
680 if header.dst_cid.len() < 8
685 && (header.token_pos.is_empty()
686 || header.dst_cid.len() != self.local_cid_generator.cid_len())
687 {
688 debug!(
689 "rejecting connection due to invalid DCID length {}",
690 header.dst_cid.len()
691 );
692 return Err(TransportError::PROTOCOL_VIOLATION(
693 "invalid destination CID length",
694 ));
695 }
696
697 Ok(())
698 }
699
700 pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
702 self.clean_up_incoming(&incoming);
703 incoming.improper_drop_warner.dismiss();
704
705 self.initial_close(
706 incoming.packet.header.version,
707 incoming.addresses,
708 &incoming.crypto,
709 &incoming.packet.header.src_cid,
710 TransportError::CONNECTION_REFUSED(""),
711 buf,
712 )
713 }
714
715 pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
719 if !incoming.may_retry() {
720 return Err(RetryError(incoming));
721 }
722
723 self.clean_up_incoming(&incoming);
724 incoming.improper_drop_warner.dismiss();
725
726 let server_config = self.server_config.as_ref().unwrap();
727
728 let loc_cid = self.local_cid_generator.generate_cid();
735
736 let payload = TokenPayload::Retry {
737 address: incoming.addresses.remote,
738 orig_dst_cid: incoming.packet.header.dst_cid,
739 issued: server_config.time_source.now(),
740 };
741 let token = Token::new(payload, &mut self.rng).encode(&*server_config.token_key);
742
743 let header = Header::Retry {
744 src_cid: loc_cid,
745 dst_cid: incoming.packet.header.src_cid,
746 version: incoming.packet.header.version,
747 };
748
749 let encode = header.encode(buf);
750 buf.put_slice(&token);
751 buf.extend_from_slice(&server_config.crypto.retry_tag(
752 incoming.packet.header.version,
753 &incoming.packet.header.dst_cid,
754 buf,
755 ));
756 encode.finish(buf, &*incoming.crypto.header.local, None);
757
758 Ok(Transmit {
759 destination: incoming.addresses.remote,
760 ecn: None,
761 size: buf.len(),
762 segment_size: None,
763 src_ip: incoming.addresses.local_ip,
764 })
765 }
766
767 pub fn ignore(&mut self, incoming: Incoming) {
772 self.clean_up_incoming(&incoming);
773 incoming.improper_drop_warner.dismiss();
774 }
775
776 fn clean_up_incoming(&mut self, incoming: &Incoming) {
778 self.index.remove_initial(incoming.packet.header.dst_cid);
779 let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
780 self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
781 }
782
783 fn add_connection(
784 &mut self,
785 ch: ConnectionHandle,
786 version: u32,
787 init_cid: ConnectionId,
788 loc_cid: ConnectionId,
789 rem_cid: ConnectionId,
790 addresses: FourTuple,
791 now: Instant,
792 tls: Box<dyn crypto::Session>,
793 transport_config: Arc<TransportConfig>,
794 side_args: SideArgs,
795 ) -> Connection {
796 let mut rng_seed = [0; 32];
797 self.rng.fill_bytes(&mut rng_seed);
798 let side = side_args.side();
799 let pref_addr_cid = side_args.pref_addr_cid();
800 let conn = Connection::new(
801 self.config.clone(),
802 transport_config,
803 init_cid,
804 loc_cid,
805 rem_cid,
806 addresses.remote,
807 addresses.local_ip,
808 tls,
809 self.local_cid_generator.as_ref(),
810 now,
811 version,
812 self.allow_mtud,
813 rng_seed,
814 side_args,
815 );
816
817 let mut cids_issued = 0;
818 let mut loc_cids = FxHashMap::default();
819
820 loc_cids.insert(cids_issued, loc_cid);
821 cids_issued += 1;
822
823 if let Some(cid) = pref_addr_cid {
824 debug_assert_eq!(cids_issued, 1, "preferred address cid seq must be 1");
825 loc_cids.insert(cids_issued, cid);
826 cids_issued += 1;
827 }
828
829 let id = self.connections.insert(ConnectionMeta {
830 init_cid,
831 cids_issued,
832 loc_cids,
833 addresses,
834 side,
835 reset_token: None,
836 });
837 debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
838
839 self.index.insert_conn(addresses, loc_cid, ch, side);
840
841 conn
842 }
843
844 fn initial_close(
845 &mut self,
846 version: u32,
847 addresses: FourTuple,
848 crypto: &Keys,
849 remote_id: &ConnectionId,
850 reason: TransportError,
851 buf: &mut Vec<u8>,
852 ) -> Transmit {
853 let local_id = self.local_cid_generator.generate_cid();
857 let number = PacketNumber::U8(0);
858 let header = Header::Initial(InitialHeader {
859 dst_cid: *remote_id,
860 src_cid: local_id,
861 number,
862 token: Bytes::new(),
863 version,
864 });
865
866 let partial_encode = header.encode(buf);
867 let max_len =
868 INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
869 frame::Close::from(reason).encode(buf, max_len);
870 buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
871 partial_encode.finish(buf, &*crypto.header.local, Some((0, &*crypto.packet.local)));
872 Transmit {
873 destination: addresses.remote,
874 ecn: None,
875 size: buf.len(),
876 segment_size: None,
877 src_ip: addresses.local_ip,
878 }
879 }
880
881 pub fn config(&self) -> &EndpointConfig {
883 &self.config
884 }
885
886 pub fn open_connections(&self) -> usize {
888 self.connections.len()
889 }
890
891 pub fn incoming_buffer_bytes(&self) -> u64 {
894 self.all_incoming_buffers_total_bytes
895 }
896
897 #[cfg(test)]
898 pub(crate) fn known_connections(&self) -> usize {
899 let x = self.connections.len();
900 debug_assert_eq!(x, self.index.connection_ids_initial.len());
901 debug_assert!(x >= self.index.connection_reset_tokens.0.len());
903 debug_assert!(x >= self.index.incoming_connection_remotes.len());
905 debug_assert!(x >= self.index.outgoing_connection_remotes.len());
906 x
907 }
908
909 #[cfg(test)]
910 pub(crate) fn known_cids(&self) -> usize {
911 self.index.connection_ids.len()
912 }
913
914 fn cids_exhausted(&self) -> bool {
919 self.local_cid_generator.cid_len() <= 4
920 && self.local_cid_generator.cid_len() != 0
921 && (2usize.pow(self.local_cid_generator.cid_len() as u32 * 8)
922 - self.index.connection_ids.len())
923 < 2usize.pow(self.local_cid_generator.cid_len() as u32 * 8 - 2)
924 }
925}
926
927impl fmt::Debug for Endpoint {
928 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
929 fmt.debug_struct("Endpoint")
930 .field("rng", &self.rng)
931 .field("index", &self.index)
932 .field("connections", &self.connections)
933 .field("config", &self.config)
934 .field("server_config", &self.server_config)
935 .field("incoming_buffers.len", &self.incoming_buffers.len())
937 .field(
938 "all_incoming_buffers_total_bytes",
939 &self.all_incoming_buffers_total_bytes,
940 )
941 .finish()
942 }
943}
944
945#[derive(Default)]
947struct IncomingBuffer {
948 datagrams: Vec<DatagramConnectionEvent>,
949 total_bytes: u64,
950}
951
952#[derive(Copy, Clone, Debug)]
954enum RouteDatagramTo {
955 Incoming(usize),
956 Connection(ConnectionHandle),
957}
958
959#[derive(Default, Debug)]
961struct ConnectionIndex {
962 connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
968 connection_ids: FxHashMap<ConnectionId, ConnectionHandle>,
972 incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
976 outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
985 connection_reset_tokens: ResetTokenTable,
990}
991
992impl ConnectionIndex {
993 fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
995 if dst_cid.len() == 0 {
996 return;
997 }
998 self.connection_ids_initial
999 .insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
1000 }
1001
1002 fn remove_initial(&mut self, dst_cid: ConnectionId) {
1004 if dst_cid.len() == 0 {
1005 return;
1006 }
1007 let removed = self.connection_ids_initial.remove(&dst_cid);
1008 debug_assert!(removed.is_some());
1009 }
1010
1011 fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
1013 if dst_cid.len() == 0 {
1014 return;
1015 }
1016 self.connection_ids_initial
1017 .insert(dst_cid, RouteDatagramTo::Connection(connection));
1018 }
1019
1020 fn insert_conn(
1023 &mut self,
1024 addresses: FourTuple,
1025 dst_cid: ConnectionId,
1026 connection: ConnectionHandle,
1027 side: Side,
1028 ) {
1029 match dst_cid.len() {
1030 0 => match side {
1031 Side::Server => {
1032 self.incoming_connection_remotes
1033 .insert(addresses, connection);
1034 }
1035 Side::Client => {
1036 self.outgoing_connection_remotes
1037 .insert(addresses.remote, connection);
1038 }
1039 },
1040 _ => {
1041 self.connection_ids.insert(dst_cid, connection);
1042 }
1043 }
1044 }
1045
1046 fn retire(&mut self, dst_cid: ConnectionId) {
1048 self.connection_ids.remove(&dst_cid);
1049 }
1050
1051 fn remove(&mut self, conn: &ConnectionMeta) {
1053 if conn.side.is_server() {
1054 self.remove_initial(conn.init_cid);
1055 }
1056 for cid in conn.loc_cids.values() {
1057 self.connection_ids.remove(cid);
1058 }
1059 self.incoming_connection_remotes.remove(&conn.addresses);
1060 self.outgoing_connection_remotes
1061 .remove(&conn.addresses.remote);
1062 if let Some((remote, token)) = conn.reset_token {
1063 self.connection_reset_tokens.remove(remote, token);
1064 }
1065 }
1066
1067 fn get(&self, addresses: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
1069 if datagram.dst_cid().len() != 0 {
1070 if let Some(&ch) = self.connection_ids.get(datagram.dst_cid()) {
1071 return Some(RouteDatagramTo::Connection(ch));
1072 }
1073 }
1074 if datagram.is_initial() || datagram.is_0rtt() {
1075 if let Some(&ch) = self.connection_ids_initial.get(datagram.dst_cid()) {
1076 return Some(ch);
1077 }
1078 }
1079 if datagram.dst_cid().len() == 0 {
1080 if let Some(&ch) = self.incoming_connection_remotes.get(addresses) {
1081 return Some(RouteDatagramTo::Connection(ch));
1082 }
1083 if let Some(&ch) = self.outgoing_connection_remotes.get(&addresses.remote) {
1084 return Some(RouteDatagramTo::Connection(ch));
1085 }
1086 }
1087 let data = datagram.data();
1088 if data.len() < RESET_TOKEN_SIZE {
1089 return None;
1090 }
1091 self.connection_reset_tokens
1092 .get(addresses.remote, &data[data.len() - RESET_TOKEN_SIZE..])
1093 .cloned()
1094 .map(RouteDatagramTo::Connection)
1095 }
1096}
1097
1098#[derive(Debug)]
1099pub(crate) struct ConnectionMeta {
1100 init_cid: ConnectionId,
1101 cids_issued: u64,
1103 loc_cids: FxHashMap<u64, ConnectionId>,
1104 addresses: FourTuple,
1109 side: Side,
1110 reset_token: Option<(SocketAddr, ResetToken)>,
1113}
1114
1115#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
1117pub struct ConnectionHandle(pub usize);
1118
1119impl From<ConnectionHandle> for usize {
1120 fn from(x: ConnectionHandle) -> Self {
1121 x.0
1122 }
1123}
1124
1125impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
1126 type Output = ConnectionMeta;
1127 fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
1128 &self[ch.0]
1129 }
1130}
1131
1132impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
1133 fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
1134 &mut self[ch.0]
1135 }
1136}
1137
1138pub enum DatagramEvent {
1140 ConnectionEvent(ConnectionHandle, ConnectionEvent),
1142 NewConnection(Incoming),
1144 Response(Transmit),
1146}
1147
1148pub struct Incoming {
1150 received_at: Instant,
1151 addresses: FourTuple,
1152 ecn: Option<EcnCodepoint>,
1153 packet: InitialPacket,
1154 rest: Option<BytesMut>,
1155 crypto: Keys,
1156 token: IncomingToken,
1157 incoming_idx: usize,
1158 improper_drop_warner: IncomingImproperDropWarner,
1159}
1160
1161impl Incoming {
1162 pub fn local_ip(&self) -> Option<IpAddr> {
1166 self.addresses.local_ip
1167 }
1168
1169 pub fn remote_address(&self) -> SocketAddr {
1171 self.addresses.remote
1172 }
1173
1174 pub fn remote_address_validated(&self) -> bool {
1182 self.token.validated
1183 }
1184
1185 pub fn may_retry(&self) -> bool {
1190 self.token.retry_src_cid.is_none()
1191 }
1192
1193 pub fn orig_dst_cid(&self) -> &ConnectionId {
1195 &self.token.orig_dst_cid
1196 }
1197}
1198
1199impl fmt::Debug for Incoming {
1200 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
1201 f.debug_struct("Incoming")
1202 .field("addresses", &self.addresses)
1203 .field("ecn", &self.ecn)
1204 .field("token", &self.token)
1207 .field("incoming_idx", &self.incoming_idx)
1208 .finish_non_exhaustive()
1210 }
1211}
1212
1213struct IncomingImproperDropWarner;
1214
1215impl IncomingImproperDropWarner {
1216 fn dismiss(self) {
1217 mem::forget(self);
1218 }
1219}
1220
1221impl Drop for IncomingImproperDropWarner {
1222 fn drop(&mut self) {
1223 warn!("quinn_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
1224 (may cause memory leak and eventual inability to accept new connections)");
1225 }
1226}
1227
1228#[derive(Debug, Error, Clone, PartialEq, Eq)]
1232pub enum ConnectError {
1233 #[error("endpoint stopping")]
1237 EndpointStopping,
1238 #[error("CIDs exhausted")]
1242 CidsExhausted,
1243 #[error("invalid server name: {0}")]
1245 InvalidServerName(String),
1246 #[error("invalid remote address: {0}")]
1250 InvalidRemoteAddress(SocketAddr),
1251 #[error("no default client config")]
1255 NoDefaultClientConfig,
1256 #[error("unsupported QUIC version")]
1258 UnsupportedVersion,
1259}
1260
1261#[derive(Debug)]
1263pub struct AcceptError {
1264 pub cause: ConnectionError,
1266 pub response: Option<Transmit>,
1268}
1269
1270#[derive(Debug, Error)]
1272#[error("retry() with validated Incoming")]
1273pub struct RetryError(Incoming);
1274
1275impl RetryError {
1276 pub fn into_incoming(self) -> Incoming {
1278 self.0
1279 }
1280}
1281
1282#[derive(Default, Debug)]
1287struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
1288
1289impl ResetTokenTable {
1290 fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
1291 self.0
1292 .entry(remote)
1293 .or_default()
1294 .insert(token, ch)
1295 .is_some()
1296 }
1297
1298 fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
1299 use std::collections::hash_map::Entry;
1300 match self.0.entry(remote) {
1301 Entry::Vacant(_) => {}
1302 Entry::Occupied(mut e) => {
1303 e.get_mut().remove(&token);
1304 if e.get().is_empty() {
1305 e.remove_entry();
1306 }
1307 }
1308 }
1309 }
1310
1311 fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
1312 let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
1313 self.0.get(&remote)?.get(&token)
1314 }
1315}
1316
1317#[derive(Hash, Eq, PartialEq, Debug, Copy, Clone)]
1322struct FourTuple {
1323 remote: SocketAddr,
1324 local_ip: Option<IpAddr>,
1326}