sctp_proto/endpoint/
mod.rs

1#[cfg(test)]
2mod endpoint_test;
3
4use std::{
5    collections::{HashMap, VecDeque},
6    fmt, iter,
7    net::{IpAddr, SocketAddr},
8    ops::{Index, IndexMut},
9    sync::Arc,
10    time::Instant,
11};
12
13use crate::association::Association;
14use crate::chunk::chunk_type::CT_INIT;
15use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
16use crate::packet::PartialDecode;
17use crate::shared::{
18    AssociationEvent, AssociationEventInner, AssociationId, EndpointEvent, EndpointEventInner,
19};
20use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
21use crate::{EcnCodepoint, Payload, Transmit};
22
23use bytes::Bytes;
24use log::{debug, trace};
25use rand::{rngs::StdRng, SeedableRng};
26use rustc_hash::FxHashMap;
27use slab::Slab;
28use thiserror::Error;
29
30/// The main entry point to the library
31///
32/// This object performs no I/O whatsoever. Instead, it generates a stream of packets to send via
33/// `poll_transmit`, and consumes incoming packets and association-generated events via `handle` and
34/// `handle_event`.
35pub struct Endpoint {
36    rng: StdRng,
37    transmits: VecDeque<Transmit>,
38    /// Identifies associations based on the INIT Dst AID the peer utilized
39    ///
40    /// Uses a standard `HashMap` to protect against hash collision attacks.
41    association_ids_init: HashMap<AssociationId, AssociationHandle>,
42    /// Identifies associations based on locally created CIDs
43    ///
44    /// Uses a cheaper hash function since keys are locally created
45    association_ids: FxHashMap<AssociationId, AssociationHandle>,
46
47    associations: Slab<AssociationMeta>,
48    local_cid_generator: Box<dyn AssociationIdGenerator>,
49    config: Arc<EndpointConfig>,
50    server_config: Option<Arc<ServerConfig>>,
51    /// Whether incoming associations should be unconditionally rejected by a server
52    ///
53    /// Equivalent to a `ServerConfig.accept_buffer` of `0`, but can be changed after the endpoint is constructed.
54    reject_new_associations: bool,
55}
56
57impl fmt::Debug for Endpoint {
58    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
59        fmt.debug_struct("Endpoint<T>")
60            .field("rng", &self.rng)
61            .field("transmits", &self.transmits)
62            .field("association_ids_initial", &self.association_ids_init)
63            .field("association_ids", &self.association_ids)
64            .field("associations", &self.associations)
65            .field("config", &self.config)
66            .field("server_config", &self.server_config)
67            .field("reject_new_associations", &self.reject_new_associations)
68            .finish()
69    }
70}
71
72impl Endpoint {
73    /// Create a new endpoint
74    ///
75    /// Returns `Err` if the configuration is invalid.
76    pub fn new(config: Arc<EndpointConfig>, server_config: Option<Arc<ServerConfig>>) -> Self {
77        let rng = {
78            let mut base = rand::rng();
79            StdRng::from_rng(&mut base)
80        };
81        Self {
82            rng,
83            transmits: VecDeque::new(),
84            association_ids_init: HashMap::default(),
85            association_ids: FxHashMap::default(),
86            associations: Slab::new(),
87            local_cid_generator: (config.aid_generator_factory.as_ref())(),
88            reject_new_associations: false,
89            config,
90            server_config,
91        }
92    }
93
94    /// Get the next packet to transmit
95    #[must_use]
96    pub fn poll_transmit(&mut self) -> Option<Transmit> {
97        self.transmits.pop_front()
98    }
99
100    /// Replace the server configuration, affecting new incoming associations only
101    pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
102        self.server_config = server_config;
103    }
104
105    /// Process `EndpointEvent`s emitted from related `Association`s
106    ///
107    /// In turn, processing this event may return a `AssociationEvent` for the same `Association`.
108    pub fn handle_event(
109        &mut self,
110        ch: AssociationHandle,
111        event: EndpointEvent,
112    ) -> Option<AssociationEvent> {
113        match event.0 {
114            EndpointEventInner::Drained => {
115                let conn = self.associations.remove(ch.0);
116                self.association_ids_init.remove(&conn.init_cid);
117                for cid in conn.loc_cids.values() {
118                    self.association_ids.remove(cid);
119                }
120            }
121        }
122        None
123    }
124
125    /// Process an incoming UDP datagram
126    pub fn handle(
127        &mut self,
128        now: Instant,
129        remote: SocketAddr,
130        local_ip: Option<IpAddr>,
131        ecn: Option<EcnCodepoint>,
132        data: Bytes,
133    ) -> Option<(AssociationHandle, DatagramEvent)> {
134        let partial_decode = match PartialDecode::unmarshal(&data) {
135            Ok(x) => x,
136            Err(err) => {
137                trace!("malformed header: {}", err);
138                return None;
139            }
140        };
141
142        //
143        // Handle packet on existing association, if any
144        //
145        let dst_cid = partial_decode.common_header.verification_tag;
146        let known_ch = if dst_cid > 0 {
147            self.association_ids.get(&dst_cid).cloned()
148        } else {
149            //TODO: improve INIT handling for DoS attack
150            if partial_decode.first_chunk_type == CT_INIT {
151                if let Some(dst_cid) = partial_decode.initiate_tag {
152                    self.association_ids.get(&dst_cid).cloned()
153                } else {
154                    None
155                }
156            } else {
157                None
158            }
159        };
160
161        if let Some(ch) = known_ch {
162            return Some((
163                ch,
164                DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
165                    Transmit {
166                        now,
167                        remote,
168                        ecn,
169                        payload: Payload::PartialDecode(partial_decode),
170                        local_ip,
171                    },
172                ))),
173            ));
174        }
175
176        //
177        // Potentially create a new association
178        //
179        self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
180            .map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
181    }
182
183    /// Initiate an Association
184    pub fn connect(
185        &mut self,
186        config: ClientConfig,
187        remote: SocketAddr,
188    ) -> Result<(AssociationHandle, Association), ConnectError> {
189        if self.is_full() {
190            return Err(ConnectError::TooManyAssociations);
191        }
192        if remote.port() == 0 {
193            return Err(ConnectError::InvalidRemoteAddress(remote));
194        }
195
196        let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
197        let local_aid = self.new_aid();
198
199        let (ch, conn) = self.add_association(
200            remote_aid,
201            local_aid,
202            remote,
203            None,
204            Instant::now(),
205            None,
206            config.transport,
207        );
208        Ok((ch, conn))
209    }
210
211    fn new_aid(&mut self) -> AssociationId {
212        loop {
213            let aid = self.local_cid_generator.generate_aid();
214            if !self.association_ids.contains_key(&aid) {
215                break aid;
216            }
217        }
218    }
219
220    fn handle_first_packet(
221        &mut self,
222        now: Instant,
223        remote: SocketAddr,
224        local_ip: Option<IpAddr>,
225        ecn: Option<EcnCodepoint>,
226        partial_decode: PartialDecode,
227    ) -> Option<(AssociationHandle, Association)> {
228        if partial_decode.first_chunk_type != CT_INIT
229            || (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
230        {
231            debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
232            return None;
233        }
234
235        let server_config = self.server_config.as_ref().unwrap();
236
237        if self.associations.len() >= server_config.concurrent_associations as usize
238            || self.reject_new_associations
239            || self.is_full()
240        {
241            debug!("refusing association");
242            //TODO: self.initial_close();
243            return None;
244        }
245
246        let server_config = server_config.clone();
247        let transport_config = server_config.transport.clone();
248
249        let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
250        let local_aid = self.new_aid();
251
252        let (ch, mut conn) = self.add_association(
253            remote_aid,
254            local_aid,
255            remote,
256            local_ip,
257            now,
258            Some(server_config),
259            transport_config,
260        );
261
262        conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
263            Transmit {
264                now,
265                remote,
266                ecn,
267                payload: Payload::PartialDecode(partial_decode),
268                local_ip,
269            },
270        )));
271
272        Some((ch, conn))
273    }
274
275    #[allow(clippy::too_many_arguments)]
276    fn add_association(
277        &mut self,
278        remote_aid: AssociationId,
279        local_aid: AssociationId,
280        remote_addr: SocketAddr,
281        local_ip: Option<IpAddr>,
282        now: Instant,
283        server_config: Option<Arc<ServerConfig>>,
284        transport_config: Arc<TransportConfig>,
285    ) -> (AssociationHandle, Association) {
286        let conn = Association::new(
287            server_config,
288            transport_config,
289            self.config.get_max_payload_size(),
290            local_aid,
291            remote_addr,
292            local_ip,
293            now,
294        );
295
296        let id = self.associations.insert(AssociationMeta {
297            init_cid: remote_aid,
298            cids_issued: 0,
299            loc_cids: iter::once((0, local_aid)).collect(),
300            initial_remote: remote_addr,
301        });
302
303        let ch = AssociationHandle(id);
304        self.association_ids.insert(local_aid, ch);
305
306        (ch, conn)
307    }
308
309    /// Unconditionally reject future incoming associations
310    pub fn reject_new_associations(&mut self) {
311        self.reject_new_associations = true;
312    }
313
314    /// Access the configuration used by this endpoint
315    pub fn config(&self) -> &EndpointConfig {
316        &self.config
317    }
318
319    /// Whether we've used up 3/4 of the available AID space
320    fn is_full(&self) -> bool {
321        (((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
322    }
323}
324
325#[derive(Debug)]
326pub(crate) struct AssociationMeta {
327    init_cid: AssociationId,
328    /// Number of local association IDs.
329    cids_issued: u64,
330    loc_cids: FxHashMap<u64, AssociationId>,
331    /// Remote address the association began with
332    ///
333    /// Only needed to support associations with zero-length AIDs, which cannot migrate, so we don't
334    /// bother keeping it up to date.
335    initial_remote: SocketAddr,
336}
337
338/// Internal identifier for an `Association` currently associated with an endpoint
339#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
340pub struct AssociationHandle(pub usize);
341
342impl From<AssociationHandle> for usize {
343    fn from(x: AssociationHandle) -> usize {
344        x.0
345    }
346}
347
348impl Index<AssociationHandle> for Slab<AssociationMeta> {
349    type Output = AssociationMeta;
350    fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
351        &self[ch.0]
352    }
353}
354
355impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
356    fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
357        &mut self[ch.0]
358    }
359}
360
361/// Event resulting from processing a single datagram
362#[allow(clippy::large_enum_variant)] // Not passed around extensively
363pub enum DatagramEvent {
364    /// The datagram is redirected to its `Association`
365    AssociationEvent(AssociationEvent),
366    /// The datagram has resulted in starting a new `Association`
367    NewAssociation(Association),
368}
369
370/// Errors in the parameters being used to create a new association
371///
372/// These arise before any I/O has been performed.
373#[derive(Debug, Error, Clone, PartialEq, Eq)]
374pub enum ConnectError {
375    /// The endpoint can no longer create new associations
376    ///
377    /// Indicates that a necessary component of the endpoint has been dropped or otherwise disabled.
378    #[error("endpoint stopping")]
379    EndpointStopping,
380    /// The number of active associations on the local endpoint is at the limit
381    ///
382    /// Try using longer association IDs.
383    #[error("too many associations")]
384    TooManyAssociations,
385    /// The domain name supplied was malformed
386    #[error("invalid DNS name: {0}")]
387    InvalidDnsName(String),
388    /// The remote [`SocketAddr`] supplied was malformed
389    ///
390    /// Examples include attempting to connect to port 0, or using an inappropriate address family.
391    #[error("invalid remote address: {0}")]
392    InvalidRemoteAddress(SocketAddr),
393    /// No default client configuration was set up
394    ///
395    /// Use `Endpoint::connect_with` to specify a client configuration.
396    #[error("no default client config")]
397    NoDefaultClientConfig,
398}