use std::{
collections::{HashMap, hash_map},
convert::TryFrom,
fmt, mem,
net::{IpAddr, SocketAddr},
ops::{Index, IndexMut},
sync::Arc,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use rand::{Rng, RngExt, SeedableRng, rngs::StdRng};
use rustc_hash::FxHashMap;
use slab::Slab;
use thiserror::Error;
use tracing::{debug, error, trace, warn};
use crate::{
Duration, FourTuple, INITIAL_MTU, Instant, MAX_CID_SIZE, MIN_INITIAL_SIZE, PathId,
RESET_TOKEN_SIZE, ResetToken, Side, Transmit, TransportConfig, TransportError,
cid_generator::ConnectionIdGenerator,
coding::{BufMutExt, Decodable, Encodable, UnexpectedEnd},
config::{ClientConfig, EndpointConfig, ServerConfig},
connection::{Connection, ConnectionError, SideArgs},
crypto::{self, Keys, UnsupportedVersion},
frame,
packet::{
FixedLengthConnectionIdParser, Header, InitialHeader, InitialPacket, PacketDecodeError,
PacketNumber, PartialDecode, ProtectedInitialHeader,
},
shared::{
ConnectionEvent, ConnectionEventInner, ConnectionId, DatagramConnectionEvent, EcnCodepoint,
EndpointEvent, EndpointEventInner, IssuedCid,
},
token::{IncomingToken, InvalidRetryTokenError, Token, TokenPayload},
transport_parameters::{PreferredAddress, TransportParameters},
};
pub struct Endpoint {
rng: StdRng,
index: ConnectionIndex,
connections: Slab<ConnectionMeta>,
local_cid_generator: Box<dyn ConnectionIdGenerator>,
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
last_stateless_reset: Option<Instant>,
incoming_buffers: Slab<IncomingBuffer>,
all_incoming_buffers_total_bytes: u64,
}
impl Endpoint {
pub fn new(
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig>>,
allow_mtud: bool,
) -> Self {
Self {
rng: config
.rng_seed
.map_or_else(|| StdRng::from_rng(&mut rand::rng()), StdRng::from_seed),
index: ConnectionIndex::default(),
connections: Slab::new(),
local_cid_generator: (config.connection_id_generator_factory.as_ref())(),
config,
server_config,
allow_mtud,
last_stateless_reset: None,
incoming_buffers: Slab::new(),
all_incoming_buffers_total_bytes: 0,
}
}
pub fn set_server_config(&mut self, server_config: Option<Arc<ServerConfig>>) {
self.server_config = server_config;
}
pub fn handle_event(
&mut self,
ch: ConnectionHandle,
event: EndpointEvent,
) -> Option<ConnectionEvent> {
use EndpointEventInner::*;
match event.0 {
NeedIdentifiers(path_id, now, n) => {
return Some(self.send_new_identifiers(path_id, now, ch, n));
}
ResetToken(path_id, remote, token) => {
if let Some(old) = self.connections[ch]
.reset_token
.insert(path_id, (remote, token))
{
self.index.connection_reset_tokens.remove(old.0, old.1);
}
if self.index.connection_reset_tokens.insert(remote, token, ch) {
warn!("duplicate reset token");
}
}
RetireResetToken(path_id) => {
if let Some(old) = self.connections[ch].reset_token.remove(&path_id) {
self.index.connection_reset_tokens.remove(old.0, old.1);
}
}
RetireConnectionId(now, path_id, seq, allow_more_cids) => {
if let Some(cid) = self.connections[ch]
.local_cids
.get_mut(&path_id)
.and_then(|pcid| pcid.cids.remove(&seq))
{
trace!(%path_id, "local CID retired {}: {}", seq, cid);
self.index.retire(cid);
if allow_more_cids {
return Some(self.send_new_identifiers(path_id, now, ch, 1));
}
}
}
Drained => {
if let Some(conn) = self.connections.try_remove(ch.0) {
self.index.remove(&conn);
} else {
error!(id = ch.0, "unknown connection drained");
}
}
}
None
}
pub fn handle(
&mut self,
now: Instant,
network_path: FourTuple,
ecn: Option<EcnCodepoint>,
data: BytesMut,
buf: &mut Vec<u8>,
) -> Option<DatagramEvent> {
let datagram_len = data.len();
let mut event = match PartialDecode::new(
data,
&FixedLengthConnectionIdParser::new(self.local_cid_generator.cid_len()),
&self.config.supported_versions,
self.config.grease_quic_bit,
) {
Ok((first_decode, remaining)) => DatagramConnectionEvent {
now,
network_path,
path_id: PathId::ZERO, ecn,
first_decode,
remaining,
},
Err(PacketDecodeError::UnsupportedVersion {
src_cid,
dst_cid,
version,
}) => {
if self.server_config.is_none() {
debug!("dropping packet with unsupported version");
return None;
}
trace!("sending version negotiation");
Header::VersionNegotiate {
random: self.rng.random::<u8>() | 0x40,
src_cid: dst_cid,
dst_cid: src_cid,
}
.encode(buf);
buf.write::<u32>(match version {
0x0a1a_2a3a => 0x0a1a_2a4a,
_ => 0x0a1a_2a3a,
});
for &version in &self.config.supported_versions {
buf.write(version);
}
return Some(DatagramEvent::Response(Transmit {
destination: network_path.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: network_path.local_ip,
}));
}
Err(e) => {
trace!("malformed header: {}", e);
return None;
}
};
let dst_cid = event.first_decode.dst_cid();
if let Some(route_to) = self.index.get(&network_path, &event.first_decode) {
event.path_id = match route_to {
RouteDatagramTo::Incoming(_) => PathId::ZERO,
RouteDatagramTo::Connection(_, path_id) => path_id,
};
match route_to {
RouteDatagramTo::Incoming(incoming_idx) => {
let incoming_buffer = &mut self.incoming_buffers[incoming_idx];
let config = &self.server_config.as_ref().unwrap();
if incoming_buffer
.total_bytes
.checked_add(datagram_len as u64)
.is_some_and(|n| n <= config.incoming_buffer_size)
&& self
.all_incoming_buffers_total_bytes
.checked_add(datagram_len as u64)
.is_some_and(|n| n <= config.incoming_buffer_size_total)
{
incoming_buffer.datagrams.push(event);
incoming_buffer.total_bytes += datagram_len as u64;
self.all_incoming_buffers_total_bytes += datagram_len as u64;
}
None
}
RouteDatagramTo::Connection(ch, _path_id) => Some(DatagramEvent::ConnectionEvent(
ch,
ConnectionEvent(ConnectionEventInner::Datagram(event)),
)),
}
} else if event.first_decode.initial_header().is_some() {
self.handle_first_packet(datagram_len, event, network_path, buf)
} else if event.first_decode.has_long_header() {
debug!(
"ignoring non-initial packet for unknown connection {}",
dst_cid
);
None
} else if !event.first_decode.is_initial()
&& self.local_cid_generator.validate(dst_cid).is_err()
{
debug!("dropping packet with invalid CID");
None
} else if dst_cid.is_empty() {
trace!("dropping unrecognized short packet without ID");
None
} else {
self.stateless_reset(now, datagram_len, network_path, dst_cid, buf)
.map(DatagramEvent::Response)
}
}
fn stateless_reset(
&mut self,
now: Instant,
inciting_dgram_len: usize,
network_path: FourTuple,
dst_cid: ConnectionId,
buf: &mut Vec<u8>,
) -> Option<Transmit> {
if self
.last_stateless_reset
.is_some_and(|last| last + self.config.min_reset_interval > now)
{
debug!("ignoring unexpected packet within minimum stateless reset interval");
return None;
}
const MIN_PADDING_LEN: usize = 5;
let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
_ => {
debug!(
"ignoring unexpected {} byte packet: not larger than minimum stateless reset size",
inciting_dgram_len
);
return None;
}
};
debug!(%dst_cid, %network_path.remote, "sending stateless reset");
self.last_stateless_reset = Some(now);
const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
max_padding_len
} else {
self.rng
.random_range(IDEAL_MIN_PADDING_LEN..max_padding_len)
};
buf.reserve(padding_len + RESET_TOKEN_SIZE);
buf.resize(padding_len, 0);
self.rng.fill_bytes(&mut buf[0..padding_len]);
buf[0] = 0b0100_0000 | (buf[0] >> 2);
buf.extend_from_slice(&ResetToken::new(&*self.config.reset_key, dst_cid));
debug_assert!(buf.len() < inciting_dgram_len);
Some(Transmit {
destination: network_path.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: network_path.local_ip,
})
}
pub fn connect(
&mut self,
now: Instant,
config: ClientConfig,
remote: SocketAddr,
server_name: &str,
) -> Result<(ConnectionHandle, Connection), ConnectError> {
if self.cids_exhausted() {
return Err(ConnectError::CidsExhausted);
}
if remote.port() == 0 || remote.ip().is_unspecified() {
return Err(ConnectError::InvalidRemoteAddress(remote));
}
if !self.config.supported_versions.contains(&config.version) {
return Err(ConnectError::UnsupportedVersion);
}
let remote_id = (config.initial_dst_cid_provider)();
trace!(initial_dcid = %remote_id);
let ch = ConnectionHandle(self.connections.vacant_key());
let local_cid = self.new_cid(ch, PathId::ZERO);
let params = TransportParameters::new(
&config.transport,
&self.config,
self.local_cid_generator.as_ref(),
local_cid,
None,
&mut self.rng,
);
let tls = config
.crypto
.start_session(config.version, server_name, ¶ms)?;
let conn = self.add_connection(
ch,
config.version,
remote_id,
local_cid,
remote_id,
FourTuple {
remote,
local_ip: None,
},
now,
tls,
config.transport,
SideArgs::Client {
token_store: config.token_store,
server_name: server_name.into(),
},
¶ms,
);
Ok((ch, conn))
}
fn send_new_identifiers(
&mut self,
path_id: PathId,
now: Instant,
ch: ConnectionHandle,
num: u64,
) -> ConnectionEvent {
let mut ids = vec![];
for _ in 0..num {
let id = self.new_cid(ch, path_id);
let cid_meta = self.connections[ch].local_cids.entry(path_id).or_default();
let sequence = cid_meta.issued;
cid_meta.issued += 1;
cid_meta.cids.insert(sequence, id);
ids.push(IssuedCid {
path_id,
sequence,
id,
reset_token: ResetToken::new(&*self.config.reset_key, id),
});
}
ConnectionEvent(ConnectionEventInner::NewIdentifiers(
ids,
now,
self.local_cid_generator.cid_len(),
self.local_cid_generator.cid_lifetime(),
))
}
fn new_cid(&mut self, ch: ConnectionHandle, path_id: PathId) -> ConnectionId {
loop {
let cid = self.local_cid_generator.generate_cid();
if cid.is_empty() {
debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
return cid;
}
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
e.insert((ch, path_id));
break cid;
}
}
}
fn handle_first_packet(
&mut self,
datagram_len: usize,
event: DatagramConnectionEvent,
network_path: FourTuple,
buf: &mut Vec<u8>,
) -> Option<DatagramEvent> {
let dst_cid = event.first_decode.dst_cid();
let header = event.first_decode.initial_header().unwrap();
let Some(server_config) = &self.server_config else {
debug!("packet for unrecognized connection {}", dst_cid);
return self
.stateless_reset(event.now, datagram_len, network_path, dst_cid, buf)
.map(DatagramEvent::Response);
};
if datagram_len < MIN_INITIAL_SIZE as usize {
debug!("ignoring short initial for connection {}", dst_cid);
return None;
}
let crypto = match server_config.crypto.initial_keys(header.version, dst_cid) {
Ok(keys) => keys,
Err(UnsupportedVersion) => {
debug!(
"ignoring initial packet version {:#x} unsupported by cryptographic layer",
header.version
);
return None;
}
};
if let Err(reason) = self.early_validate_first_packet(header) {
return Some(DatagramEvent::Response(self.initial_close(
header.version,
network_path,
&crypto,
header.src_cid,
reason,
buf,
)));
}
let packet = match event.first_decode.finish(Some(&*crypto.header.remote)) {
Ok(packet) => packet,
Err(e) => {
trace!("unable to decode initial packet: {}", e);
return None;
}
};
if !packet.reserved_bits_valid() {
debug!("dropping connection attempt with invalid reserved bits");
return None;
}
let Header::Initial(header) = packet.header else {
panic!("non-initial packet in handle_first_packet()");
};
let server_config = self.server_config.as_ref().unwrap().clone();
let token = match IncomingToken::from_header(&header, &server_config, network_path.remote) {
Ok(token) => token,
Err(InvalidRetryTokenError) => {
debug!("rejecting invalid retry token");
return Some(DatagramEvent::Response(self.initial_close(
header.version,
network_path,
&crypto,
header.src_cid,
TransportError::INVALID_TOKEN(""),
buf,
)));
}
};
let incoming_idx = self.incoming_buffers.insert(IncomingBuffer::default());
self.index
.insert_initial_incoming(header.dst_cid, incoming_idx);
Some(DatagramEvent::NewConnection(Incoming {
received_at: event.now,
network_path,
ecn: event.ecn,
packet: InitialPacket {
header,
header_data: packet.header_data,
payload: packet.payload,
},
rest: event.remaining,
crypto,
token,
incoming_idx,
improper_drop_warner: IncomingImproperDropWarner,
}))
}
pub fn accept(
&mut self,
mut incoming: Incoming,
now: Instant,
buf: &mut Vec<u8>,
server_config: Option<Arc<ServerConfig>>,
) -> Result<(ConnectionHandle, Connection), Box<AcceptError>> {
let remote_address_validated = incoming.remote_address_validated();
incoming.improper_drop_warner.dismiss();
let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
let packet_number = incoming.packet.header.number.expand(0);
let InitialHeader {
src_cid,
dst_cid,
version,
..
} = incoming.packet.header;
let server_config =
server_config.unwrap_or_else(|| self.server_config.as_ref().unwrap().clone());
if server_config
.transport
.max_idle_timeout
.is_some_and(|timeout| {
incoming.received_at + Duration::from_millis(timeout.into()) <= now
})
{
debug!("abandoning accept of stale initial");
self.index.remove_initial(dst_cid);
return Err(Box::new(AcceptError {
cause: ConnectionError::TimedOut,
response: None,
}));
}
if self.cids_exhausted() {
debug!("refusing connection");
self.index.remove_initial(dst_cid);
return Err(Box::new(AcceptError {
cause: ConnectionError::CidsExhausted,
response: Some(self.initial_close(
version,
incoming.network_path,
&incoming.crypto,
src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)),
}));
}
if incoming
.crypto
.packet
.remote
.decrypt(
PathId::ZERO,
packet_number,
&incoming.packet.header_data,
&mut incoming.packet.payload,
)
.is_err()
{
debug!(packet_number, "failed to authenticate initial packet");
self.index.remove_initial(dst_cid);
return Err(Box::new(AcceptError {
cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
response: None,
}));
};
let ch = ConnectionHandle(self.connections.vacant_key());
let local_cid = self.new_cid(ch, PathId::ZERO);
let mut params = TransportParameters::new(
&server_config.transport,
&self.config,
self.local_cid_generator.as_ref(),
local_cid,
Some(&server_config),
&mut self.rng,
);
params.stateless_reset_token = Some(ResetToken::new(&*self.config.reset_key, local_cid));
params.original_dst_cid = Some(incoming.token.orig_dst_cid);
params.retry_src_cid = incoming.token.retry_src_cid;
let mut pref_addr_cid = None;
if server_config.has_preferred_address() {
let cid = self.new_cid(ch, PathId::ZERO);
pref_addr_cid = Some(cid);
params.preferred_address = Some(PreferredAddress {
address_v4: server_config.preferred_address_v4,
address_v6: server_config.preferred_address_v6,
connection_id: cid,
stateless_reset_token: ResetToken::new(&*self.config.reset_key, cid),
});
}
let tls = server_config.crypto.start_session(version, ¶ms);
let transport_config = server_config.transport.clone();
let mut conn = self.add_connection(
ch,
version,
dst_cid,
local_cid,
src_cid,
incoming.network_path,
incoming.received_at,
tls,
transport_config,
SideArgs::Server {
server_config,
pref_addr_cid,
path_validated: remote_address_validated,
},
¶ms,
);
self.index.insert_initial(dst_cid, ch);
match conn.handle_first_packet(
incoming.received_at,
incoming.network_path,
incoming.ecn,
packet_number,
incoming.packet,
incoming.rest,
) {
Ok(()) => {
trace!(id = ch.0, icid = %dst_cid, "new connection");
for event in incoming_buffer.datagrams {
conn.handle_event(ConnectionEvent(ConnectionEventInner::Datagram(event)))
}
Ok((ch, conn))
}
Err(e) => {
debug!("handshake failed: {}", e);
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
let response = match e {
ConnectionError::TransportError(ref e) => Some(self.initial_close(
version,
incoming.network_path,
&incoming.crypto,
src_cid,
e.clone(),
buf,
)),
_ => None,
};
Err(Box::new(AcceptError { cause: e, response }))
}
}
}
fn early_validate_first_packet(
&mut self,
header: &ProtectedInitialHeader,
) -> Result<(), TransportError> {
let config = &self.server_config.as_ref().unwrap();
if self.cids_exhausted() || self.incoming_buffers.len() >= config.max_incoming {
return Err(TransportError::CONNECTION_REFUSED(""));
}
if header.dst_cid.len() < 8
&& (header.token_pos.is_empty()
|| header.dst_cid.len() != self.local_cid_generator.cid_len())
{
debug!(
"rejecting connection due to invalid DCID length {}",
header.dst_cid.len()
);
return Err(TransportError::PROTOCOL_VIOLATION(
"invalid destination CID length",
));
}
Ok(())
}
pub fn refuse(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Transmit {
self.clean_up_incoming(&incoming);
incoming.improper_drop_warner.dismiss();
self.initial_close(
incoming.packet.header.version,
incoming.network_path,
&incoming.crypto,
incoming.packet.header.src_cid,
TransportError::CONNECTION_REFUSED(""),
buf,
)
}
pub fn retry(&mut self, incoming: Incoming, buf: &mut Vec<u8>) -> Result<Transmit, RetryError> {
if !incoming.may_retry() {
return Err(RetryError(Box::new(incoming)));
}
self.clean_up_incoming(&incoming);
incoming.improper_drop_warner.dismiss();
let server_config = self.server_config.as_ref().unwrap();
let local_cid = self.local_cid_generator.generate_cid();
let payload = TokenPayload::Retry {
address: incoming.network_path.remote,
orig_dst_cid: incoming.packet.header.dst_cid,
issued: server_config.time_source.now(),
};
let token = Token::new(payload, &mut self.rng).encode(&*server_config.token_key);
let header = Header::Retry {
src_cid: local_cid,
dst_cid: incoming.packet.header.src_cid,
version: incoming.packet.header.version,
};
let encode = header.encode(buf);
buf.put_slice(&token);
buf.extend_from_slice(&server_config.crypto.retry_tag(
incoming.packet.header.version,
incoming.packet.header.dst_cid,
buf,
));
encode.finish(buf, &*incoming.crypto.header.local, None);
Ok(Transmit {
destination: incoming.network_path.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: incoming.network_path.local_ip,
})
}
pub fn ignore(&mut self, incoming: Incoming) {
self.clean_up_incoming(&incoming);
incoming.improper_drop_warner.dismiss();
}
fn clean_up_incoming(&mut self, incoming: &Incoming) {
self.index.remove_initial(incoming.packet.header.dst_cid);
let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
}
fn add_connection(
&mut self,
ch: ConnectionHandle,
version: u32,
init_cid: ConnectionId,
local_cid: ConnectionId,
remote_cid: ConnectionId,
network_path: FourTuple,
now: Instant,
tls: Box<dyn crypto::Session>,
transport_config: Arc<TransportConfig>,
side_args: SideArgs,
params: &TransportParameters,
) -> Connection {
let mut rng_seed = [0; 32];
self.rng.fill_bytes(&mut rng_seed);
let side = side_args.side();
let pref_addr_cid = side_args.pref_addr_cid();
let qlog =
transport_config.create_qlog_sink(side_args.side(), network_path.remote, init_cid, now);
qlog.emit_connection_started(
now,
local_cid,
remote_cid,
network_path.remote,
network_path.local_ip,
params,
);
let conn = Connection::new(
self.config.clone(),
transport_config,
init_cid,
local_cid,
remote_cid,
network_path,
tls,
self.local_cid_generator.as_ref(),
now,
version,
self.allow_mtud,
rng_seed,
side_args,
qlog,
);
let mut path_cids = PathLocalCids::default();
path_cids.cids.insert(path_cids.issued, local_cid);
path_cids.issued += 1;
if let Some(cid) = pref_addr_cid {
debug_assert_eq!(path_cids.issued, 1, "preferred address cid seq must be 1");
path_cids.cids.insert(path_cids.issued, cid);
path_cids.issued += 1;
}
let id = self.connections.insert(ConnectionMeta {
init_cid,
local_cids: FxHashMap::from_iter([(PathId::ZERO, path_cids)]),
network_path,
side,
reset_token: Default::default(),
});
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
self.index.insert_conn(network_path, local_cid, ch, side);
conn
}
fn initial_close(
&mut self,
version: u32,
network_path: FourTuple,
crypto: &Keys,
remote_id: ConnectionId,
reason: TransportError,
buf: &mut Vec<u8>,
) -> Transmit {
let local_id = self.local_cid_generator.generate_cid();
let number = PacketNumber::U8(0);
let header = Header::Initial(InitialHeader {
dst_cid: remote_id,
src_cid: local_id,
number,
token: Bytes::new(),
version,
});
let partial_encode = header.encode(buf);
let max_len =
INITIAL_MTU as usize - partial_encode.header_len - crypto.packet.local.tag_len();
frame::Close::from(reason).encoder(max_len).encode(buf);
buf.resize(buf.len() + crypto.packet.local.tag_len(), 0);
partial_encode.finish(
buf,
&*crypto.header.local,
Some((0, Default::default(), &*crypto.packet.local)),
);
Transmit {
destination: network_path.remote,
ecn: None,
size: buf.len(),
segment_size: None,
src_ip: network_path.local_ip,
}
}
pub fn config(&self) -> &EndpointConfig {
&self.config
}
pub fn open_connections(&self) -> usize {
self.connections.len()
}
pub fn incoming_buffer_bytes(&self) -> u64 {
self.all_incoming_buffers_total_bytes
}
#[cfg(test)]
pub(crate) fn known_connections(&self) -> usize {
let x = self.connections.len();
debug_assert_eq!(x, self.index.connection_ids_initial.len());
debug_assert!(x >= self.index.connection_reset_tokens.0.len());
debug_assert!(x >= self.index.incoming_connection_remotes.len());
debug_assert!(x >= self.index.outgoing_connection_remotes.len());
x
}
#[cfg(test)]
pub(crate) fn known_cids(&self) -> usize {
self.index.connection_ids.len()
}
fn cids_exhausted(&self) -> bool {
let cid_len = self.local_cid_generator.cid_len();
if cid_len == 0 || cid_len > 4 {
return false;
}
let bits = (cid_len * 8) as u32;
let space = 1u64 << bits;
let reserve = 1u64 << (bits - 2);
let len = self.index.connection_ids.len() as u64;
len > (space - reserve)
}
}
impl fmt::Debug for Endpoint {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Endpoint")
.field("rng", &self.rng)
.field("index", &self.index)
.field("connections", &self.connections)
.field("config", &self.config)
.field("server_config", &self.server_config)
.field("incoming_buffers.len", &self.incoming_buffers.len())
.field(
"all_incoming_buffers_total_bytes",
&self.all_incoming_buffers_total_bytes,
)
.finish()
}
}
#[derive(Default)]
struct IncomingBuffer {
datagrams: Vec<DatagramConnectionEvent>,
total_bytes: u64,
}
#[derive(Copy, Clone, Debug)]
enum RouteDatagramTo {
Incoming(usize),
Connection(ConnectionHandle, PathId),
}
#[derive(Default, Debug)]
struct ConnectionIndex {
connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
connection_ids: FxHashMap<ConnectionId, (ConnectionHandle, PathId)>,
incoming_connection_remotes: HashMap<FourTuple, ConnectionHandle>,
outgoing_connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
connection_reset_tokens: ResetTokenTable,
}
impl ConnectionIndex {
fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
if dst_cid.is_empty() {
return;
}
self.connection_ids_initial
.insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
}
fn remove_initial(&mut self, dst_cid: ConnectionId) {
if dst_cid.is_empty() {
return;
}
let removed = self.connection_ids_initial.remove(&dst_cid);
debug_assert!(removed.is_some());
}
fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
if dst_cid.is_empty() {
return;
}
self.connection_ids_initial.insert(
dst_cid,
RouteDatagramTo::Connection(connection, PathId::ZERO),
);
}
fn insert_conn(
&mut self,
network_path: FourTuple,
dst_cid: ConnectionId,
connection: ConnectionHandle,
side: Side,
) {
match dst_cid.len() {
0 => match side {
Side::Server => {
self.incoming_connection_remotes
.insert(network_path, connection);
}
Side::Client => {
self.outgoing_connection_remotes
.insert(network_path.remote, connection);
}
},
_ => {
self.connection_ids
.insert(dst_cid, (connection, PathId::ZERO));
}
}
}
fn retire(&mut self, dst_cid: ConnectionId) {
self.connection_ids.remove(&dst_cid);
}
fn remove(&mut self, conn: &ConnectionMeta) {
if conn.side.is_server() {
self.remove_initial(conn.init_cid);
}
for cid in conn
.local_cids
.values()
.flat_map(|pcids| pcids.cids.values())
{
self.connection_ids.remove(cid);
}
self.incoming_connection_remotes.remove(&conn.network_path);
self.outgoing_connection_remotes
.remove(&conn.network_path.remote);
for (remote, token) in conn.reset_token.values() {
self.connection_reset_tokens.remove(*remote, *token);
}
}
fn get(&self, network_path: &FourTuple, datagram: &PartialDecode) -> Option<RouteDatagramTo> {
if !datagram.dst_cid().is_empty()
&& let Some(&(ch, path_id)) = self.connection_ids.get(&datagram.dst_cid())
{
return Some(RouteDatagramTo::Connection(ch, path_id));
}
if (datagram.is_initial() || datagram.is_0rtt())
&& let Some(&ch) = self.connection_ids_initial.get(&datagram.dst_cid())
{
return Some(ch);
}
if datagram.dst_cid().is_empty() {
if let Some(&ch) = self.incoming_connection_remotes.get(network_path) {
return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
}
if let Some(&ch) = self.outgoing_connection_remotes.get(&network_path.remote) {
return Some(RouteDatagramTo::Connection(ch, PathId::ZERO));
}
}
let data = datagram.data();
if data.len() < RESET_TOKEN_SIZE {
return None;
}
self.connection_reset_tokens
.get(network_path.remote, &data[data.len() - RESET_TOKEN_SIZE..])
.cloned()
.map(|ch| RouteDatagramTo::Connection(ch, PathId::ZERO))
}
}
#[derive(Debug)]
pub(crate) struct ConnectionMeta {
init_cid: ConnectionId,
local_cids: FxHashMap<PathId, PathLocalCids>,
network_path: FourTuple,
side: Side,
reset_token: FxHashMap<PathId, (SocketAddr, ResetToken)>,
}
#[derive(Debug, Default)]
struct PathLocalCids {
issued: u64,
cids: FxHashMap<u64, ConnectionId>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct ConnectionHandle(pub usize);
impl From<ConnectionHandle> for usize {
fn from(x: ConnectionHandle) -> Self {
x.0
}
}
impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
type Output = ConnectionMeta;
fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
&self[ch.0]
}
}
impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
&mut self[ch.0]
}
}
pub enum DatagramEvent {
ConnectionEvent(ConnectionHandle, ConnectionEvent),
NewConnection(Incoming),
Response(Transmit),
}
#[derive(derive_more::Debug)]
pub struct Incoming {
#[debug(skip)]
received_at: Instant,
network_path: FourTuple,
ecn: Option<EcnCodepoint>,
#[debug(skip)]
packet: InitialPacket,
#[debug(skip)]
rest: Option<BytesMut>,
#[debug(skip)]
crypto: Keys,
token: IncomingToken,
incoming_idx: usize,
#[debug(skip)]
improper_drop_warner: IncomingImproperDropWarner,
}
impl Incoming {
pub fn local_ip(&self) -> Option<IpAddr> {
self.network_path.local_ip
}
pub fn remote_address(&self) -> SocketAddr {
self.network_path.remote
}
pub fn remote_address_validated(&self) -> bool {
self.token.validated
}
pub fn may_retry(&self) -> bool {
self.token.retry_src_cid.is_none()
}
pub fn orig_dst_cid(&self) -> ConnectionId {
self.token.orig_dst_cid
}
pub fn decrypt(&self) -> Option<DecryptedInitial> {
let packet_number = self.packet.header.number.expand(0);
let mut payload = self.packet.payload.clone();
self.crypto
.packet
.remote
.decrypt(
PathId::ZERO,
packet_number,
&self.packet.header_data,
&mut payload,
)
.ok()?;
Some(DecryptedInitial(payload.freeze()))
}
}
pub struct DecryptedInitial(Bytes);
impl DecryptedInitial {
pub fn alpns(&self) -> Option<IncomingAlpns> {
let frames = frame::Iter::new(self.0.clone()).ok()?;
let mut first = None;
let mut rest = Vec::new();
for frame in frames {
match frame {
Ok(frame::Frame::Crypto(crypto)) => match first {
None => first = Some(crypto),
Some(_) => rest.push(crypto),
},
Err(_) => return None,
_ => {}
}
}
let first = first?;
if rest.is_empty() && first.offset == 0 {
let data = find_alpn_data(&first.data).ok()?;
return Some(IncomingAlpns { data, pos: 0 });
}
rest.push(first);
let source = assemble_crypto_frames(&mut rest)?;
let data = find_alpn_data(&source).ok()?;
Some(IncomingAlpns { data, pos: 0 })
}
}
const TLS_HANDSHAKE_TYPE_CLIENT_HELLO: u8 = 0x01;
const TLS_EXTENSION_TYPE_ALPN: u16 = 0x0010;
const TLS_CLIENT_HELLO_FIXED_LEN: usize = 2 + 32;
pub struct IncomingAlpns {
data: Bytes,
pos: usize,
}
impl Iterator for IncomingAlpns {
type Item = Result<Bytes, UnexpectedEnd>;
fn next(&mut self) -> Option<Self::Item> {
if self.pos >= self.data.len() {
return None;
}
let len = self.data[self.pos] as usize;
self.pos += 1;
if self.pos + len > self.data.len() {
return Some(Err(UnexpectedEnd));
}
let proto = self.data.slice(self.pos..self.pos + len);
self.pos += len;
Some(Ok(proto))
}
}
fn assemble_crypto_frames(frames: &mut [frame::Crypto]) -> Option<Bytes> {
frames.sort_by_key(|f| f.offset);
let capacity = frames.iter().map(|f| f.data.len()).sum();
let mut buf = Vec::with_capacity(capacity);
for f in frames.iter() {
let start = f.offset as usize;
if start > buf.len() {
return None;
}
let end = start + f.data.len();
if end > buf.len() {
buf.extend_from_slice(&f.data[buf.len() - start..]);
}
}
Some(Bytes::from(buf))
}
fn find_alpn_data(source: &Bytes) -> Result<Bytes, UnexpectedEnd> {
let mut r = &**source;
if u8::decode(&mut r)? != TLS_HANDSHAKE_TYPE_CLIENT_HELLO {
return Err(UnexpectedEnd);
}
let len = decode_u24(&mut r)?;
let mut body = take(&mut r, len)?;
skip(&mut body, TLS_CLIENT_HELLO_FIXED_LEN)?;
skip_u8_prefixed(&mut body)?;
skip_u16_prefixed(&mut body)?;
skip_u8_prefixed(&mut body)?;
let mut exts = take_u16_prefixed(&mut body)?;
while exts.has_remaining() {
let ext_type = u16::decode(&mut exts)?;
let ext_data = take_u16_prefixed(&mut exts)?;
if ext_type == TLS_EXTENSION_TYPE_ALPN {
let list = take_u16_prefixed(&mut &*ext_data)?;
return Ok(source.slice_ref(list));
}
}
Err(UnexpectedEnd)
}
fn decode_u24(r: &mut &[u8]) -> Result<usize, UnexpectedEnd> {
let a = u8::decode(r)?;
let b = u8::decode(r)?;
let c = u8::decode(r)?;
Ok(u32::from_be_bytes([0, a, b, c]) as usize)
}
fn take<'a>(r: &mut &'a [u8], len: usize) -> Result<&'a [u8], UnexpectedEnd> {
if r.remaining() < len {
return Err(UnexpectedEnd);
}
let data = &r[..len];
r.advance(len);
Ok(data)
}
fn take_u16_prefixed<'a>(r: &mut &'a [u8]) -> Result<&'a [u8], UnexpectedEnd> {
let len = u16::decode(r)? as usize;
take(r, len)
}
fn skip(r: &mut &[u8], len: usize) -> Result<(), UnexpectedEnd> {
take(r, len)?;
Ok(())
}
fn skip_u8_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
let len = u8::decode(r)? as usize;
skip(r, len)
}
fn skip_u16_prefixed(r: &mut &[u8]) -> Result<(), UnexpectedEnd> {
let len = u16::decode(r)? as usize;
skip(r, len)
}
struct IncomingImproperDropWarner;
impl IncomingImproperDropWarner {
fn dismiss(self) {
mem::forget(self);
}
}
impl Drop for IncomingImproperDropWarner {
fn drop(&mut self) {
warn!(
"noq_proto::Incoming dropped without passing to Endpoint::accept/refuse/retry/ignore \
(may cause memory leak and eventual inability to accept new connections)"
);
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectError {
#[error("endpoint stopping")]
EndpointStopping,
#[error("CIDs exhausted")]
CidsExhausted,
#[error("invalid server name: {0}")]
InvalidServerName(String),
#[error("invalid remote address: {0}")]
InvalidRemoteAddress(SocketAddr),
#[error("no default client config")]
NoDefaultClientConfig,
#[error("unsupported QUIC version")]
UnsupportedVersion,
}
#[derive(Debug)]
pub struct AcceptError {
pub cause: ConnectionError,
pub response: Option<Transmit>,
}
#[derive(Debug, Error)]
#[error("retry() with validated Incoming")]
pub struct RetryError(Box<Incoming>);
impl RetryError {
pub fn into_incoming(self) -> Incoming {
*self.0
}
}
#[derive(Default, Debug)]
struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
impl ResetTokenTable {
fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
self.0
.entry(remote)
.or_default()
.insert(token, ch)
.is_some()
}
fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
use std::collections::hash_map::Entry;
match self.0.entry(remote) {
Entry::Vacant(_) => {}
Entry::Occupied(mut e) => {
e.get_mut().remove(&token);
if e.get().is_empty() {
e.remove_entry();
}
}
}
}
fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
let token = ResetToken::from(<[u8; RESET_TOKEN_SIZE]>::try_from(token).ok()?);
self.0.get(&remote)?.get(&token)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn assemble_contiguous() {
let data = b"hello world";
let mut frames = vec![
frame::Crypto {
offset: 0,
data: Bytes::from_static(&data[..5]),
},
frame::Crypto {
offset: 5,
data: Bytes::from_static(&data[5..]),
},
];
assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
}
#[test]
fn assemble_out_of_order() {
let data = b"hello world";
let mut frames = vec![
frame::Crypto {
offset: 5,
data: Bytes::from_static(&data[5..]),
},
frame::Crypto {
offset: 0,
data: Bytes::from_static(&data[..5]),
},
];
assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
}
#[test]
fn assemble_with_overlap() {
let data = b"hello world";
let mut frames = vec![
frame::Crypto {
offset: 0,
data: Bytes::from_static(&data[..7]),
},
frame::Crypto {
offset: 5,
data: Bytes::from_static(&data[5..]),
},
];
assert_eq!(&assemble_crypto_frames(&mut frames).unwrap()[..], &data[..]);
}
#[test]
fn assemble_with_gap() {
let mut frames = vec![frame::Crypto {
offset: 10,
data: Bytes::from_static(b"world"),
}];
assert!(assemble_crypto_frames(&mut frames).is_none());
}
}