1#[cfg(test)]
2mod endpoint_test;
3
4use std::{
5 collections::{HashMap, VecDeque},
6 fmt, iter,
7 net::{IpAddr, SocketAddr},
8 ops::{Index, IndexMut},
9 sync::Arc,
10 time::Instant,
11};
12
13use crate::association::Association;
14use crate::chunk::chunk_type::CT_INIT;
15use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
16use crate::packet::PartialDecode;
17use crate::shared::{
18 AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
19};
20use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
21use crate::{EcnCodepoint, Payload, Transmit};
22
23use bytes::Bytes;
24use log::{debug, trace};
25use rand::{rngs::StdRng, SeedableRng};
26use rustc_hash::FxHashMap;
27use slab::Slab;
28use thiserror::Error;
29
30pub struct Endpoint {
36 rng: StdRng,
37 transmits: VecDeque<Transmit>,
38 association_ids_init: HashMap<AssociationId, AssociationHandle>,
42 association_ids: FxHashMap<AssociationId, AssociationHandle>,
46
47 associations: Slab<AssociationMeta>,
48 local_cid_generator: Box<dyn AssociationIdGenerator>,
49 config: Arc<EndpointConfig>,
50 server_config: Option<Arc<ServerConfig>>,
51 reject_new_associations: bool,
55}
56
57impl fmt::Debug for Endpoint {
58 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
59 fmt.debug_struct("Endpoint<T>")
60 .field("rng", &self.rng)
61 .field("transmits", &self.transmits)
62 .field("association_ids_initial", &self.association_ids_init)
63 .field("association_ids", &self.association_ids)
64 .field("associations", &self.associations)
65 .field("config", &self.config)
66 .field("server_config", &self.server_config)
67 .field("reject_new_associations", &self.reject_new_associations)
68 .finish()
69 }
70}
71
72impl Endpoint {
73 pub fn new(config: Arc<EndpointConfig>, server_config: Option<Arc<ServerConfig>>) -> Self {
77 let rng = {
78 let mut base = rand::rng();
79 StdRng::from_rng(&mut base)
80 };
81 Self {
82 rng,
83 transmits: VecDeque::new(),
84 association_ids_init: HashMap::default(),
85 association_ids: FxHashMap::default(),
86 associations: Slab::new(),
87 local_cid_generator: (config.aid_generator_factory.as_ref())(),
88 reject_new_associations: false,
89 config,
90 server_config,
91 }
92 }
93
94 #[must_use]
96 pub fn poll_transmit(&mut self) -> Option<Transmit> {
97 self.transmits.pop_front()
98 }
99
100 pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
102 self.server_config = server_config;
103 }
104
105 pub fn handle_event(
109 &mut self,
110 ch: AssociationHandle,
111 event: EndpointEvent,
112 ) -> Option<AssociationEvent> {
113 match event.0 {
114 EndpointEventInner::Drained => {
115 let conn = self.associations.remove(ch.0);
116 self.association_ids_init.remove(&conn.init_cid);
117 for cid in conn.loc_cids.values() {
118 self.association_ids.remove(cid);
119 }
120 }
121 }
122 None
123 }
124
125 pub fn handle(
127 &mut self,
128 now: Instant,
129 remote: SocketAddr,
130 local_ip: Option<IpAddr>,
131 ecn: Option<EcnCodepoint>,
132 data: Bytes,
133 ) -> Option<(AssociationHandle, DatagramEvent)> {
134 let partial_decode = match PartialDecode::unmarshal(&data) {
135 Ok(x) => x,
136 Err(err) => {
137 trace!("malformed header: {}", err);
138 return None;
139 }
140 };
141
142 let dst_cid = partial_decode.common_header.verification_tag;
146 let known_ch = if dst_cid > 0 {
147 self.association_ids.get(&dst_cid).cloned()
148 } else {
149 if partial_decode.first_chunk_type == CT_INIT {
151 if let Some(dst_cid) = partial_decode.initiate_tag {
152 self.association_ids.get(&dst_cid).cloned()
153 } else {
154 None
155 }
156 } else {
157 None
158 }
159 };
160
161 if let Some(ch) = known_ch {
162 return Some((
163 ch,
164 DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
165 Transmit {
166 now,
167 remote,
168 ecn,
169 payload: Payload::PartialDecode(partial_decode),
170 local_ip,
171 },
172 ))),
173 ));
174 }
175
176 self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
180 .map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
181 }
182
183 pub fn connect(
185 &mut self,
186 config: ClientConfig,
187 remote: SocketAddr,
188 ) -> Result<(AssociationHandle, Association), ConnectError> {
189 if self.is_full() {
190 return Err(ConnectError::TooManyAssociations);
191 }
192 if remote.port() == 0 {
193 return Err(ConnectError::InvalidRemoteAddress(remote));
194 }
195
196 let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
197 let local_aid = self.new_aid();
198
199 let (ch, conn) = self.add_association(
200 remote_aid,
201 local_aid,
202 remote,
203 None,
204 Instant::now(),
205 None,
206 config.transport,
207 );
208 Ok((ch, conn))
209 }
210
211 fn new_aid(&mut self) -> AssociationId {
212 loop {
213 let aid = self.local_cid_generator.generate_aid();
214 if !self.association_ids.contains_key(&aid) {
215 break aid;
216 }
217 }
218 }
219
220 fn handle_first_packet(
221 &mut self,
222 now: Instant,
223 remote: SocketAddr,
224 local_ip: Option<IpAddr>,
225 ecn: Option<EcnCodepoint>,
226 partial_decode: PartialDecode,
227 ) -> Option<(AssociationHandle, Association)> {
228 if partial_decode.first_chunk_type != CT_INIT
229 || (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
230 {
231 debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
232 return None;
233 }
234
235 let server_config = self.server_config.as_ref().unwrap();
236
237 if self.associations.len() >= server_config.concurrent_associations as usize
238 || self.reject_new_associations
239 || self.is_full()
240 {
241 debug!("refusing association");
242 return None;
244 }
245
246 let server_config = server_config.clone();
247 let transport_config = server_config.transport.clone();
248
249 let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
250 let local_aid = self.new_aid();
251
252 let (ch, mut conn) = self.add_association(
253 remote_aid,
254 local_aid,
255 remote,
256 local_ip,
257 now,
258 Some(server_config),
259 transport_config,
260 );
261
262 conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
263 Transmit {
264 now,
265 remote,
266 ecn,
267 payload: Payload::PartialDecode(partial_decode),
268 local_ip,
269 },
270 )));
271
272 Some((ch, conn))
273 }
274
275 #[allow(clippy::too_many_arguments)]
276 fn add_association(
277 &mut self,
278 remote_aid: AssociationId,
279 local_aid: AssociationId,
280 remote_addr: SocketAddr,
281 local_ip: Option<IpAddr>,
282 now: Instant,
283 server_config: Option<Arc<ServerConfig>>,
284 transport_config: Arc<TransportConfig>,
285 ) -> (AssociationHandle, Association) {
286 let conn = Association::new(
287 server_config,
288 transport_config,
289 self.config.get_max_payload_size(),
290 local_aid,
291 remote_addr,
292 local_ip,
293 now,
294 );
295
296 let id = self.associations.insert(AssociationMeta {
297 init_cid: remote_aid,
298 cids_issued: 0,
299 loc_cids: iter::once((0, local_aid)).collect(),
300 initial_remote: remote_addr,
301 });
302
303 let ch = AssociationHandle(id);
304 self.association_ids.insert(local_aid, ch);
305
306 (ch, conn)
307 }
308
309 pub fn reject_new_associations(&mut self) {
311 self.reject_new_associations = true;
312 }
313
314 pub fn config(&self) -> &EndpointConfig {
316 &self.config
317 }
318
319 fn is_full(&self) -> bool {
321 (((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
322 }
323}
324
325#[derive(Debug)]
326pub(crate) struct AssociationMeta {
327 init_cid: AssociationId,
328 cids_issued: u64,
330 loc_cids: FxHashMap<u64, AssociationId>,
331 initial_remote: SocketAddr,
336}
337
338#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
340pub struct AssociationHandle(pub usize);
341
342impl From<AssociationHandle> for usize {
343 fn from(x: AssociationHandle) -> usize {
344 x.0
345 }
346}
347
348impl Index<AssociationHandle> for Slab<AssociationMeta> {
349 type Output = AssociationMeta;
350 fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
351 &self[ch.0]
352 }
353}
354
355impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
356 fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
357 &mut self[ch.0]
358 }
359}
360
361#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
364 AssociationEvent(AssociationEvent),
366 NewAssociation(Association),
368}
369
370#[derive(Debug, Error, Clone, PartialEq, Eq)]
374pub enum ConnectError {
375 #[error("endpoint stopping")]
379 EndpointStopping,
380 #[error("too many associations")]
384 TooManyAssociations,
385 #[error("invalid DNS name: {0}")]
387 InvalidDnsName(String),
388 #[error("invalid remote address: {0}")]
392 InvalidRemoteAddress(SocketAddr),
393 #[error("no default client config")]
397 NoDefaultClientConfig,
398}