#[cfg(test)]
mod endpoint_test;
use alloc::boxed::Box;
use alloc::collections::VecDeque;
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use core::fmt;
use core::iter;
use core::net::{IpAddr, SocketAddr};
use core::ops::{Index, IndexMut};
use std::collections::HashMap;
use std::time::Instant;
use crate::chunk::{chunk_init::ChunkInit, chunk_type::CT_INIT};
use crate::config::MAX_SNAP_INIT_BYTES;
use crate::config::{ClientConfig, EndpointConfig, ServerConfig, TransportConfig};
use crate::packet::PartialDecode;
use crate::shared::AssociationEvent;
use crate::shared::{AssociationEventInner, AssociationId};
use crate::shared::{EndpointEvent, EndpointEventInner};
use crate::util::{AssociationIdGenerator, RandomAssociationIdGenerator};
use crate::{EcnCodepoint, Payload, Transmit};
use crate::{association::Association, chunk::Chunk};
use bytes::Bytes;
use log::{debug, trace, warn};
use rand::{SeedableRng, rngs::StdRng};
use rustc_hash::FxHashMap;
use slab::Slab;
use thiserror::Error;
pub struct Endpoint {
rng: StdRng,
transmits: VecDeque<Transmit>,
association_ids_init: HashMap<AssociationId, AssociationHandle>,
association_ids: FxHashMap<AssociationId, AssociationHandle>,
associations: Slab<AssociationMeta>,
local_cid_generator: Box<dyn AssociationIdGenerator>,
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
reject_new_associations: bool,
}
impl fmt::Debug for Endpoint {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Endpoint<T>")
.field("rng", &self.rng)
.field("transmits", &self.transmits)
.field("association_ids_initial", &self.association_ids_init)
.field("association_ids", &self.association_ids)
.field("associations", &self.associations)
.field("config", &self.config)
.field("server_config", &self.server_config)
.field("reject_new_associations", &self.reject_new_associations)
.finish()
}
}
impl Endpoint {
pub fn new(config: Arc<EndpointConfig>, server_config: Option<Arc<ServerConfig>>) -> Self {
let rng = {
let mut base = rand::rng();
StdRng::from_rng(&mut base)
};
Self {
rng,
transmits: VecDeque::new(),
association_ids_init: HashMap::default(),
association_ids: FxHashMap::default(),
associations: Slab::new(),
local_cid_generator: (config.aid_generator_factory.as_ref())(),
reject_new_associations: false,
config,
server_config,
}
}
#[must_use]
pub fn poll_transmit(&mut self) -> Option<Transmit> {
self.transmits.pop_front()
}
pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
self.server_config = server_config;
}
pub fn handle_event(
&mut self,
ch: AssociationHandle,
event: EndpointEvent,
) -> Option<AssociationEvent> {
match event.0 {
EndpointEventInner::Drained => {
let conn = self.associations.remove(ch.0);
self.association_ids_init.remove(&conn.init_cid);
for cid in conn.loc_cids.values() {
self.association_ids.remove(cid);
}
}
}
None
}
pub fn handle(
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
data: Bytes,
) -> Option<(AssociationHandle, DatagramEvent)> {
let partial_decode = match PartialDecode::unmarshal(&data) {
Ok(x) => x,
Err(err) => {
trace!("malformed header: {}", err);
return None;
}
};
let dst_cid = partial_decode.common_header.verification_tag;
let known_ch = if dst_cid > 0 {
self.association_ids.get(&dst_cid).cloned()
} else {
if partial_decode.first_chunk_type == CT_INIT {
if let Some(dst_cid) = partial_decode.initiate_tag {
self.association_ids_init.get(&dst_cid).cloned()
} else {
None
}
} else {
None
}
};
if let Some(ch) = known_ch {
return Some((
ch,
DatagramEvent::AssociationEvent(AssociationEvent(AssociationEventInner::Datagram(
Transmit {
now,
remote,
ecn,
payload: Payload::PartialDecode(partial_decode),
local_ip,
},
))),
));
}
self.handle_first_packet(now, remote, local_ip, ecn, partial_decode)
.map(|(ch, a)| (ch, DatagramEvent::NewAssociation(a)))
}
pub fn connect(
&mut self,
config: ClientConfig,
remote: SocketAddr,
) -> Result<(AssociationHandle, Association), ConnectError> {
if self.is_full() {
return Err(ConnectError::TooManyAssociations);
}
if remote.port() == 0 {
return Err(ConnectError::InvalidRemoteAddress(remote));
}
match (config.local_sctp_init, config.remote_sctp_init) {
(Some(local_init), Some(remote_init)) => {
self.connect_with_snap(config.transport, remote, local_init, remote_init)
}
(partial_local, partial_remote) => {
if partial_local.is_some() || partial_remote.is_some() {
warn!(
"partial SNAP config: both local_sctp_init and \
remote_sctp_init must be set; falling back to normal handshake"
);
}
let remote_aid = RandomAssociationIdGenerator::new().generate_aid();
let local_aid = self.new_aid();
Ok(self.add_association(
remote_aid,
local_aid,
remote,
None,
Instant::now(),
None,
config.transport,
))
}
}
}
fn connect_with_snap(
&mut self,
transport: Arc<TransportConfig>,
remote: SocketAddr,
local_snap_bytes: Bytes,
remote_snap_bytes: Bytes,
) -> Result<(AssociationHandle, Association), ConnectError> {
if local_snap_bytes.len() > MAX_SNAP_INIT_BYTES {
return Err(ConnectError::Snap(SnapError::OversizedInit {
side: SnapSide::Local,
len: local_snap_bytes.len(),
}));
}
if remote_snap_bytes.len() > MAX_SNAP_INIT_BYTES {
return Err(ConnectError::Snap(SnapError::OversizedInit {
side: SnapSide::Remote,
len: remote_snap_bytes.len(),
}));
}
let local_init = ChunkInit::unmarshal(&local_snap_bytes).map_err(|err| {
ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Local,
reason: err.to_string(),
})
})?;
let remote_init = ChunkInit::unmarshal(&remote_snap_bytes).map_err(|err| {
ConnectError::Snap(SnapError::ParseFailed {
side: SnapSide::Remote,
reason: err.to_string(),
})
})?;
if local_init.is_ack {
return Err(ConnectError::Snap(SnapError::InvalidInitAck {
side: SnapSide::Local,
}));
}
if remote_init.is_ack {
return Err(ConnectError::Snap(SnapError::InvalidInitAck {
side: SnapSide::Remote,
}));
}
let local_aid = local_init.initiate_tag;
let remote_aid = remote_init.initiate_tag;
if local_aid == 0 {
return Err(ConnectError::Snap(SnapError::ZeroInitiateTag {
side: SnapSide::Local,
}));
}
if remote_aid == 0 {
return Err(ConnectError::Snap(SnapError::ZeroInitiateTag {
side: SnapSide::Remote,
}));
}
if local_init.initial_tsn == 0 {
return Err(ConnectError::Snap(SnapError::ZeroInitialTsn {
side: SnapSide::Local,
}));
}
if remote_init.initial_tsn == 0 {
return Err(ConnectError::Snap(SnapError::ZeroInitialTsn {
side: SnapSide::Remote,
}));
}
local_init.check().map_err(|err| {
ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Local,
reason: err.to_string(),
})
})?;
remote_init.check().map_err(|err| {
ConnectError::Snap(SnapError::InvalidInit {
side: SnapSide::Remote,
reason: err.to_string(),
})
})?;
if local_aid == remote_aid {
return Err(ConnectError::Snap(SnapError::AidCollision {
kind: AidCollisionKind::LocalEqualsRemote,
aid: local_aid,
}));
}
if self.association_ids.contains_key(&local_aid) {
return Err(ConnectError::Snap(SnapError::AidCollision {
kind: AidCollisionKind::LocalInAssociationIds,
aid: local_aid,
}));
}
if self.association_ids.contains_key(&remote_aid) {
return Err(ConnectError::Snap(SnapError::AidCollision {
kind: AidCollisionKind::RemoteInAssociationIds,
aid: remote_aid,
}));
}
if self.association_ids_init.contains_key(&remote_aid) {
return Err(ConnectError::Snap(SnapError::AidCollision {
kind: AidCollisionKind::RemoteInAssociationIdsInit,
aid: remote_aid,
}));
}
if self.association_ids_init.contains_key(&local_aid) {
return Err(ConnectError::Snap(SnapError::AidCollision {
kind: AidCollisionKind::LocalInAssociationIdsInit,
aid: local_aid,
}));
}
let conn = Association::new_with_out_of_band_init(
transport,
self.config.get_max_payload_size(),
remote,
None,
local_init,
remote_init,
)
.map_err(|err| ConnectError::Snap(SnapError::AssociationFailed(err.to_string())))?;
let id = self.associations.insert(AssociationMeta {
init_cid: remote_aid,
cids_issued: 1,
loc_cids: iter::once((0, local_aid)).collect(),
initial_remote: remote,
});
let ch = AssociationHandle(id);
self.association_ids.insert(local_aid, ch);
self.association_ids_init.insert(remote_aid, ch);
debug!(
"Created SNAP association: local_aid={:#x} remote_aid={:#x}",
local_aid, remote_aid
);
Ok((ch, conn))
}
fn new_aid(&mut self) -> AssociationId {
loop {
let aid = self.local_cid_generator.generate_aid();
if !self.association_ids.contains_key(&aid) {
break aid;
}
}
}
fn handle_first_packet(
&mut self,
now: Instant,
remote: SocketAddr,
local_ip: Option<IpAddr>,
ecn: Option<EcnCodepoint>,
partial_decode: PartialDecode,
) -> Option<(AssociationHandle, Association)> {
if partial_decode.first_chunk_type != CT_INIT
|| (partial_decode.first_chunk_type == CT_INIT && partial_decode.initiate_tag.is_none())
{
debug!("refusing first packet with Non-INIT or emtpy initial_tag INIT");
return None;
}
let server_config = self.server_config.as_ref().unwrap();
if self.associations.len() >= server_config.concurrent_associations as usize
|| self.reject_new_associations
|| self.is_full()
{
debug!("refusing association");
return None;
}
let server_config = server_config.clone();
let transport_config = server_config.transport.clone();
let remote_aid = *partial_decode.initiate_tag.as_ref().unwrap();
let local_aid = self.new_aid();
let (ch, mut conn) = self.add_association(
remote_aid,
local_aid,
remote,
local_ip,
now,
Some(server_config),
transport_config,
);
self.association_ids_init.insert(remote_aid, ch);
conn.handle_event(AssociationEvent(AssociationEventInner::Datagram(
Transmit {
now,
remote,
ecn,
payload: Payload::PartialDecode(partial_decode),
local_ip,
},
)));
Some((ch, conn))
}
#[allow(clippy::too_many_arguments)]
fn add_association(
&mut self,
remote_aid: AssociationId,
local_aid: AssociationId,
remote_addr: SocketAddr,
local_ip: Option<IpAddr>,
now: Instant,
server_config: Option<Arc<ServerConfig>>,
transport_config: Arc<TransportConfig>,
) -> (AssociationHandle, Association) {
let conn = Association::new(
server_config,
transport_config,
self.config.get_max_payload_size(),
local_aid,
remote_addr,
local_ip,
now,
);
let id = self.associations.insert(AssociationMeta {
init_cid: remote_aid,
cids_issued: 1,
loc_cids: iter::once((0, local_aid)).collect(),
initial_remote: remote_addr,
});
let ch = AssociationHandle(id);
self.association_ids.insert(local_aid, ch);
(ch, conn)
}
pub fn reject_new_associations(&mut self) {
self.reject_new_associations = true;
}
pub fn config(&self) -> &EndpointConfig {
&self.config
}
fn is_full(&self) -> bool {
(((u32::MAX >> 1) + (u32::MAX >> 2)) as usize) < self.association_ids.len()
}
}
#[derive(Debug)]
pub(crate) struct AssociationMeta {
init_cid: AssociationId,
cids_issued: u64,
loc_cids: FxHashMap<u64, AssociationId>,
initial_remote: SocketAddr,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct AssociationHandle(pub usize);
impl From<AssociationHandle> for usize {
fn from(x: AssociationHandle) -> usize {
x.0
}
}
impl Index<AssociationHandle> for Slab<AssociationMeta> {
type Output = AssociationMeta;
fn index(&self, ch: AssociationHandle) -> &AssociationMeta {
&self[ch.0]
}
}
impl IndexMut<AssociationHandle> for Slab<AssociationMeta> {
fn index_mut(&mut self, ch: AssociationHandle) -> &mut AssociationMeta {
&mut self[ch.0]
}
}
#[allow(clippy::large_enum_variant)] pub enum DatagramEvent {
AssociationEvent(AssociationEvent),
NewAssociation(Association),
}
#[non_exhaustive]
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectError {
#[error("endpoint stopping")]
EndpointStopping,
#[error("too many associations")]
TooManyAssociations,
#[error("invalid DNS name: {0}")]
InvalidDnsName(String),
#[error("invalid remote address: {0}")]
InvalidRemoteAddress(SocketAddr),
#[error("no default client config")]
NoDefaultClientConfig,
#[error("SNAP error: {0}")]
Snap(#[from] SnapError),
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SnapSide {
Local,
Remote,
}
impl fmt::Display for SnapSide {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SnapSide::Local => f.write_str("local"),
SnapSide::Remote => f.write_str("remote"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AidCollisionKind {
LocalEqualsRemote,
LocalInAssociationIds,
RemoteInAssociationIds,
RemoteInAssociationIdsInit,
LocalInAssociationIdsInit,
}
impl fmt::Display for AidCollisionKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AidCollisionKind::LocalEqualsRemote => f.write_str("local_aid equals remote_aid"),
AidCollisionKind::LocalInAssociationIds => f.write_str("local_aid in association_ids"),
AidCollisionKind::RemoteInAssociationIds => {
f.write_str("remote_aid in association_ids")
}
AidCollisionKind::RemoteInAssociationIdsInit => {
f.write_str("remote_aid in association_ids_init")
}
AidCollisionKind::LocalInAssociationIdsInit => {
f.write_str("local_aid in association_ids_init")
}
}
}
}
#[non_exhaustive]
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum SnapError {
#[error("failed to parse {side} SNAP token: {reason}")]
ParseFailed {
side: SnapSide,
reason: String,
},
#[error("invalid {side} SNAP token: expected INIT, got INIT-ACK")]
InvalidInitAck {
side: SnapSide,
},
#[error("{side} SNAP token has zero initiate_tag")]
ZeroInitiateTag {
side: SnapSide,
},
#[error("{side} SNAP token has zero initial_tsn")]
ZeroInitialTsn {
side: SnapSide,
},
#[error("SNAP collision: {kind} {aid:#x} already in use")]
AidCollision {
kind: AidCollisionKind,
aid: u32,
},
#[error("{side} SNAP token too large: {len} bytes (max {max})", max = MAX_SNAP_INIT_BYTES)]
OversizedInit {
side: SnapSide,
len: usize,
},
#[error("invalid {side} SNAP token: {reason}")]
InvalidInit {
side: SnapSide,
reason: String,
},
#[error("failed to create SNAP association: {0}")]
AssociationFailed(String),
}