1#[cfg(test)]
2mod endpoint_test;
3
4use std::{
5 collections::HashMap,
6 fmt, iter,
7 net::SocketAddr,
8 ops::{Index, IndexMut},
9 sync::Arc,
10 time::Instant,
11};
12
13use crate::Payload;
14use crate::association::Association;
15use crate::chunk::chunk_type::CT_INIT;
16use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
17use crate::packet::PartialDecode;
18use crate::shared::{
19 AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
20};
21use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
22use shared::{EcnCodepoint, TransportContext, TransportMessage, TransportProtocol};
23
24use bytes::Bytes;
25use log::{debug, trace, warn};
26use slab::Slab;
27use thiserror::Error;
28
29pub struct Endpoint {
35 local_addr: SocketAddr,
36 transport_protocol: TransportProtocol,
37 association_ids_init: HashMap<AssociationId, AssociationHandle>,
41 association_ids: HashMap<AssociationId, AssociationHandle>,
45
46 associations: Slab<AssociationMeta>,
47 local_cid_generator: Box<dyn AssociationIdGenerator>,
48 endpoint_config: Arc<EndpointConfig>,
49 server_config: Option<Arc<ServerConfig>>,
50 reject_new_associations: bool,
54}
55
56impl fmt::Debug for Endpoint {
57 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
58 fmt.debug_struct("Endpoint<T>")
59 .field("association_ids_initial", &self.association_ids_init)
60 .field("association_ids", &self.association_ids)
61 .field("associations", &self.associations)
62 .field("config", &self.endpoint_config)
63 .field("server_config", &self.server_config)
64 .field("reject_new_associations", &self.reject_new_associations)
65 .finish()
66 }
67}
68
69impl Endpoint {
70 pub fn new(
74 local_addr: SocketAddr,
75 transport_protocol: TransportProtocol,
76 endpoint_config: Arc<EndpointConfig>,
77 server_config: Option<Arc<ServerConfig>>,
78 ) -> Self {
79 Self {
80 local_addr,
81 transport_protocol,
82 association_ids_init: HashMap::default(),
83 association_ids: HashMap::default(),
84 associations: Slab::new(),
85 local_cid_generator: (endpoint_config.aid_generator_factory.as_ref())(),
86 reject_new_associations: false,
87 endpoint_config,
88 server_config,
89 }
90 }
91
92 pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
94 self.server_config = server_config;
95 }
96
97 pub fn handle_event(&mut self, ch: AssociationHandle, event: EndpointEvent) {
99 match event.0 {
100 EndpointEventInner::Drained => {
101 let conn = self.associations.remove(ch.0);
102 self.association_ids_init.remove(&conn.init_cid);
103 for cid in conn.loc_cids.values() {
104 self.association_ids.remove(cid);
105 }
106 }
107 }
108 }
109
110 pub fn handle(
112 &mut self,
113 now: Instant,
114 remote: SocketAddr,
115 ecn: Option<EcnCodepoint>,
116 data: Bytes,
117 ) -> Option<(AssociationHandle, DatagramEvent)> {
118 let partial_decode = match PartialDecode::unmarshal(&data) {
119 Ok(x) => x,
120 Err(err) => {
121 trace!("malformed header: {}", err);
122 return None;
123 }
124 };
125
126 let dst_cid = partial_decode.common_header.verification_tag;
130 let known_ch = if dst_cid > 0 {
131 self.association_ids.get(&dst_cid).cloned()
132 } else {
133 if partial_decode.first_chunk_type == CT_INIT {
135 if let Some(dst_cid) = partial_decode.initiate_tag {
136 self.association_ids.get(&dst_cid).cloned()
137 } else {
138 None
139 }
140 } else {
141 None
142 }
143 };
144
145 if let Some(ch) = known_ch {
146 return Some((
147 ch,
148 DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
149 TransportMessage {
150 now,
151 transport: TransportContext {
152 local_addr: self.local_addr,
153 peer_addr: remote,
154 ecn,
155 transport_protocol: self.transport_protocol,
156 },
157 message: Payload::PartialDecode(partial_decode),
158 },
159 ))),
160 ));
161 }
162
163 self.handle_first_packet(now, remote, ecn, partial_decode)
167 .map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
168 }
169
170 pub fn connect(
172 &mut self,
173 config: ClientConfig,
174 remote: SocketAddr,
175 ) -> Result<(AssociationHandle, Association), ConnectError> {
176 if self.is_full() {
177 return Err(ConnectError::TooManyAssociations);
178 }
179 if remote.port() == 0 {
180 return Err(ConnectError::InvalidRemoteAddress(remote));
181 }
182
183 let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
184 let local_aid = self.new_aid();
185
186 let (ch, conn) = self.add_association(
187 remote_aid,
188 local_aid,
189 remote,
190 Instant::now(),
191 None,
192 config.transport,
193 );
194 Ok((ch, conn))
195 }
196
197 fn new_aid(&mut self) -> AssociationId {
198 loop {
199 let aid = self.local_cid_generator.generate_aid();
200 if !self.association_ids.contains_key(&aid) {
201 break aid;
202 }
203 }
204 }
205
206 fn handle_first_packet(
207 &mut self,
208 now: Instant,
209 remote: SocketAddr,
210 ecn: Option<EcnCodepoint>,
211 partial_decode: PartialDecode,
212 ) -> Option<(AssociationHandle, Association)> {
213 if partial_decode.first_chunk_type != CT_INIT
214 || (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
215 {
216 debug!("refusing first packet with Non-INIT or empty initial_tag INIT");
217 return None;
218 }
219
220 let server_config = if let Some(server_config) = self.server_config.as_ref() {
221 server_config
222 } else {
223 warn!("refusing first packet due to empty server_config");
224 return None;
225 };
226
227 if self.associations.len() >= server_config.concurrent_associations as usize
228 || self.reject_new_associations
229 || self.is_full()
230 {
231 debug!("refusing association");
232 return None;
234 }
235
236 let server_config = server_config.clone();
237 let transport_config = server_config.transport.clone();
238
239 let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
240 let local_aid = self.new_aid();
241
242 let (ch, mut conn) = self.add_association(
243 remote_aid,
244 local_aid,
245 remote,
246 now,
247 Some(server_config),
248 transport_config,
249 );
250
251 conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
252 TransportMessage {
253 now,
254 transport: TransportContext {
255 local_addr: self.local_addr,
256 peer_addr: remote,
257 ecn,
258 transport_protocol: self.transport_protocol,
259 },
260 message: Payload::PartialDecode(partial_decode),
261 },
262 )));
263
264 Some((ch, conn))
265 }
266
267 #[allow(clippy::too_many_arguments)]
268 fn add_association(
269 &mut self,
270 remote_aid: AssociationId,
271 local_aid: AssociationId,
272 remote_addr: SocketAddr,
273 now: Instant,
274 server_config: Option<Arc<ServerConfig>>,
275 transport_config: Arc<TransportConfig>,
276 ) -> (AssociationHandle, Association) {
277 let conn = Association::new(
278 server_config,
279 transport_config,
280 self.endpoint_config.get_max_payload_size(),
281 local_aid,
282 remote_addr,
283 self.local_addr,
284 self.transport_protocol,
285 now,
286 );
287
288 let id = self.associations.insert(AssociationMeta {
289 init_cid: remote_aid,
290 cids_issued: 0,
291 loc_cids: iter::once((0, local_aid)).collect(),
292 initial_remote: remote_addr,
293 });
294
295 let ch = AssociationHandle(id);
296 self.association_ids.insert(local_aid, ch);
297
298 (ch, conn)
299 }
300
301 pub fn reject_new_associations(&mut self) {
303 self.reject_new_associations = true;
304 }
305
306 pub fn endpoint_config(&self) -> &EndpointConfig {
308 &self.endpoint_config
309 }
310
311 fn is_full(&self) -> bool {
313 (((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
314 }
315}
316
317#[derive(Debug)]
318pub(crate) struct AssociationMeta {
319 init_cid: AssociationId,
320 cids_issued: u64,
322 loc_cids: HashMap<u64, AssociationId>,
323 initial_remote: SocketAddr,
328}
329
330#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
332pub struct AssociationHandle(pub usize);
333
334impl From<AssociationHandle> for usize {
335 fn from(x: AssociationHandle) -> usize {
336 x.0
337 }
338}
339
340impl Index<AssociationHandle> for Slab<AssociationMeta> {
341 type Output = AssociationMeta;
342 fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
343 &self[ch.0]
344 }
345}
346
347impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
348 fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
349 &mut self[ch.0]
350 }
351}
352
353#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
356 AssociationEvent(AssociationEvent),
358 NewAssociation(Association),
360}
361
362#[derive(Debug, Error, Clone, PartialEq, Eq)]
366pub enum ConnectError {
367 #[error("endpoint stopping")]
371 EndpointStopping,
372 #[error("too many associations")]
376 TooManyAssociations,
377 #[error("invalid DNS name: {0}")]
379 InvalidDnsName(String),
380 #[error("invalid remote address: {0}")]
384 InvalidRemoteAddress(SocketAddr),
385 #[error("no default client config")]
389 NoDefaultClientConfig,
390}