use std::{
collections::{HashMap, VecDeque},
fmt, iter,
net::SocketAddr,
ops::{Index, IndexMut},
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use bytes::{BufMut, BytesMut};
use err_derive::Error;
use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
use slab::Slab;
use tracing::{debug, trace, warn};
use crate::{
coding::BufMutExt,
connection::{initial_close, Connection, ConnectionError},
crypto::{
self, ClientConfig as ClientCryptoConfig, HmacKey, Keys, ServerConfig as ServerCryptoConfig,
},
packet::{Header, Packet, PacketDecodeError, PartialDecode},
shared::{
ClientConfig, ConfigError, ConnectionEvent, ConnectionEventInner, ConnectionId,
EcnCodepoint, EndpointConfig, EndpointEvent, EndpointEventInner, IssuedCid, ResetToken,
ServerConfig,
},
transport_parameters::TransportParameters,
Side, Transmit, TransportError, LOC_CID_COUNT, MAX_CID_SIZE, MIN_INITIAL_SIZE,
RESET_TOKEN_SIZE, VERSION,
};
pub struct Endpoint<S>
where
S: crypto::Session,
{
rng: StdRng,
transmits: VecDeque<Transmit>,
connection_ids_initial: HashMap<ConnectionId, ConnectionHandle>,
connection_ids: HashMap<ConnectionId, ConnectionHandle>,
connection_remotes: HashMap<SocketAddr, ConnectionHandle>,
connection_reset_tokens: ResetTokenTable,
connections: Slab<ConnectionMeta>,
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig<S>>>,
incoming_handshakes: usize,
reject_new_connections: bool,
reset_key: S::HmacKey,
token_key: Option<S::HmacKey>,
}
impl<S> Endpoint<S>
where
S: crypto::Session,
{
pub fn new(
config: Arc<EndpointConfig>,
server_config: Option<Arc<ServerConfig<S>>>,
) -> Result<Self, ConfigError> {
config.validate()?;
Ok(Self {
rng: StdRng::from_entropy(),
transmits: VecDeque::new(),
connection_ids_initial: HashMap::new(),
connection_ids: HashMap::new(),
connection_remotes: HashMap::new(),
connection_reset_tokens: ResetTokenTable::default(),
connections: Slab::new(),
incoming_handshakes: 0,
reject_new_connections: false,
reset_key: S::HmacKey::new(&config.reset_key)?,
token_key: server_config
.as_ref()
.map(|c| S::HmacKey::new(&c.token_key))
.transpose()?,
config,
server_config,
})
}
fn is_server(&self) -> bool {
self.server_config.is_some()
}
pub fn poll_transmit(&mut self) -> Option<Transmit> {
self.transmits.pop_front()
}
pub fn handle_event(
&mut self,
ch: ConnectionHandle,
event: EndpointEvent,
) -> Option<ConnectionEvent> {
use EndpointEventInner::*;
match event.0 {
NeedIdentifiers(max) => {
if self.config.local_cid_len != 0 {
return Some(
self.send_new_identifiers(ch, max.min(LOC_CID_COUNT - 1) as usize),
);
}
}
ResetToken(remote, token) => {
if let Some(old) = self.connections[ch].reset_token.replace((remote, token)) {
self.connection_reset_tokens.remove(old.0, old.1);
}
if self.connection_reset_tokens.insert(remote, token, ch) {
warn!("duplicate reset token");
}
}
RetireConnectionId(seq) => {
if let Some(cid) = self.connections[ch].loc_cids.remove(&seq) {
trace!("peer retired CID {}: {}", seq, cid);
self.connection_ids.remove(&cid);
return Some(self.send_new_identifiers(ch, 1));
}
}
Drained => {
let conn = self.connections.remove(ch.0);
if conn.init_cid.len() > 0 {
self.connection_ids_initial.remove(&conn.init_cid);
}
for cid in conn.loc_cids.values() {
self.connection_ids.remove(&cid);
}
self.connection_remotes.remove(&conn.initial_remote);
if let Some((remote, token)) = conn.reset_token {
self.connection_reset_tokens.remove(remote, token);
}
}
}
None
}
pub fn handle(
&mut self,
now: Instant,
remote: SocketAddr,
ecn: Option<EcnCodepoint>,
data: BytesMut,
) -> Option<(ConnectionHandle, DatagramEvent<S>)> {
let datagram_len = data.len();
let (first_decode, remaining) = match PartialDecode::new(data, self.config.local_cid_len) {
Ok(x) => x,
Err(PacketDecodeError::UnsupportedVersion {
source,
destination,
}) => {
if !self.is_server() {
debug!("dropping packet with unsupported version");
return None;
}
trace!("sending version negotiation");
let mut buf = Vec::<u8>::new();
Header::VersionNegotiate {
random: self.rng.gen::<u8>() | 0x40,
src_cid: destination,
dst_cid: source,
}
.encode(&mut buf);
buf.write::<u32>(0x0a1a_2a3a);
buf.write(VERSION);
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
contents: buf.into(),
});
return None;
}
Err(e) => {
trace!("malformed header: {}", e);
return None;
}
};
let dst_cid = first_decode.dst_cid();
let known_ch = {
let ch = if self.config.local_cid_len > 0 {
self.connection_ids.get(&dst_cid)
} else {
None
};
ch.or_else(|| {
if first_decode.is_initial() {
self.connection_ids_initial.get(&dst_cid)
} else {
None
}
})
.or_else(|| {
if self.config.local_cid_len == 0 {
self.connection_remotes.get(&remote)
} else {
None
}
})
.or_else(|| {
let data = first_decode.data();
if data.len() < RESET_TOKEN_SIZE {
return None;
}
self.connection_reset_tokens
.get(remote, &data[data.len() - RESET_TOKEN_SIZE..])
})
.cloned()
};
if let Some(ch) = known_ch {
return Some((
ch,
DatagramEvent::ConnectionEvent(ConnectionEvent(ConnectionEventInner::Datagram {
now,
remote,
ecn,
first_decode,
remaining,
})),
));
}
if !self.is_server() {
debug!("packet for unrecognized connection {}", dst_cid);
self.stateless_reset(datagram_len, remote, &dst_cid);
return None;
}
if first_decode.has_long_header() {
if !first_decode.is_initial() {
debug!(
"ignoring non-initial packet for unknown connection {}",
dst_cid
);
return None;
}
if datagram_len < MIN_INITIAL_SIZE {
debug!("ignoring short initial for connection {}", dst_cid);
return None;
}
let crypto = S::Keys::new_initial(&dst_cid, Side::Server);
let header_crypto = crypto.header_keys();
return match first_decode.finish(Some(&header_crypto)) {
Ok(packet) => self
.handle_first_packet(
now,
remote,
ecn,
packet,
remaining,
&crypto,
&header_crypto,
)
.map(|(ch, conn)| (ch, DatagramEvent::NewConnection(conn))),
Err(e) => {
trace!("unable to decode initial packet: {}", e);
None
}
};
}
if !dst_cid.is_empty() {
self.stateless_reset(datagram_len, remote, &dst_cid);
} else {
trace!("dropping unrecognized short packet without ID");
}
None
}
fn stateless_reset(
&mut self,
inciting_dgram_len: usize,
remote: SocketAddr,
dst_cid: &ConnectionId,
) {
const MIN_PADDING_LEN: usize = 5;
let max_padding_len = match inciting_dgram_len.checked_sub(RESET_TOKEN_SIZE) {
Some(headroom) if headroom > MIN_PADDING_LEN => headroom - 1,
_ => {
debug!("ignoring unexpected {} byte packet: not larger than minimum stateless reset size", inciting_dgram_len);
return;
}
};
debug!("sending stateless reset for {} to {}", dst_cid, remote);
let mut buf = Vec::<u8>::new();
const IDEAL_MIN_PADDING_LEN: usize = MIN_PADDING_LEN + MAX_CID_SIZE;
let padding_len = if max_padding_len <= IDEAL_MIN_PADDING_LEN {
max_padding_len
} else {
self.rng.gen_range(IDEAL_MIN_PADDING_LEN, max_padding_len)
};
buf.reserve_exact(padding_len + RESET_TOKEN_SIZE);
buf.resize(padding_len, 0);
self.rng.fill_bytes(&mut buf[0..padding_len]);
buf[0] = 0b0100_0000 | buf[0] >> 2;
buf.extend_from_slice(&reset_token_for(&self.reset_key, dst_cid));
debug_assert!(buf.len() < inciting_dgram_len);
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
contents: buf.into(),
});
}
pub fn connect(
&mut self,
config: ClientConfig<S>,
remote: SocketAddr,
server_name: &str,
) -> Result<(ConnectionHandle, Connection<S>), ConnectError> {
if self.is_full() {
return Err(ConnectError::TooManyConnections);
}
config.transport.validate()?;
let remote_id = ConnectionId::random(&mut self.rng, MAX_CID_SIZE);
trace!(initial_dcid = %remote_id);
let (ch, conn) = self.add_connection(
remote_id,
remote_id,
remote,
ConnectionOpts::Client {
config,
server_name: server_name.into(),
},
Instant::now(),
)?;
Ok((ch, conn))
}
fn send_new_identifiers(&mut self, ch: ConnectionHandle, num: usize) -> ConnectionEvent {
let mut ids = vec![];
for _ in 0..num {
let id = self.new_cid();
self.connection_ids.insert(id, ch);
let meta = &mut self.connections[ch];
meta.cids_issued += 1;
let sequence = meta.cids_issued;
meta.loc_cids.insert(sequence, id);
ids.push(IssuedCid {
sequence,
id,
reset_token: reset_token_for(&self.reset_key, &id),
});
}
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids))
}
fn new_cid(&mut self) -> ConnectionId {
loop {
let cid = ConnectionId::random(&mut self.rng, self.config.local_cid_len);
if !self.connection_ids.contains_key(&cid) {
break cid;
}
assert!(self.config.local_cid_len > 0);
}
}
fn add_connection(
&mut self,
init_cid: ConnectionId,
rem_cid: ConnectionId,
remote: SocketAddr,
opts: ConnectionOpts<S>,
now: Instant,
) -> Result<(ConnectionHandle, Connection<S>), ConnectError> {
let loc_cid = self.new_cid();
let (server_config, tls, transport_config) = match opts {
ConnectionOpts::Client {
config,
server_name,
} => {
let params = TransportParameters::new::<S>(&config.transport, None);
(
None,
config.crypto.start_session(&server_name, ¶ms)?,
config.transport,
)
}
ConnectionOpts::Server { orig_dst_cid } => {
let config = self.server_config.as_ref().unwrap();
let params = TransportParameters::new(&config.transport, Some(config));
let server_params = TransportParameters {
stateless_reset_token: Some(reset_token_for(&self.reset_key, &loc_cid)),
original_connection_id: orig_dst_cid,
..params
};
(
Some(config.clone()),
config.crypto.start_session(&server_params),
config.transport.clone(),
)
}
};
let conn = Connection::new(
Arc::clone(&self.config),
server_config,
transport_config,
init_cid,
loc_cid,
rem_cid,
remote,
tls,
now,
);
let id = self.connections.insert(ConnectionMeta {
init_cid,
cids_issued: 0,
loc_cids: iter::once((0, loc_cid)).collect(),
initial_remote: remote,
reset_token: None,
});
let ch = ConnectionHandle(id);
if self.config.local_cid_len > 0 {
self.connection_ids.insert(loc_cid, ch);
} else {
self.connection_remotes.insert(remote, ch);
}
Ok((ch, conn))
}
fn handle_first_packet(
&mut self,
now: Instant,
remote: SocketAddr,
ecn: Option<EcnCodepoint>,
mut packet: Packet,
rest: Option<BytesMut>,
crypto: &S::Keys,
header_crypto: &<S::Keys as Keys>::HeaderKeys,
) -> Option<(ConnectionHandle, Connection<S>)> {
let (src_cid, dst_cid, token, packet_number) = match packet.header {
Header::Initial {
src_cid,
dst_cid,
ref token,
number,
} => (src_cid, dst_cid, token.clone(), number),
_ => panic!("non-initial packet in handle_initial()"),
};
let packet_number = packet_number.expand(0);
if crypto
.decrypt(
packet_number as u64,
&packet.header_data,
&mut packet.payload,
)
.is_err()
{
debug!(packet_number, "failed to authenticate initial packet");
return None;
};
if !packet.reserved_bits_valid() {
debug!("dropping connection attempt with invalid reserved bits");
return None;
}
let temp_loc_cid = self.new_cid();
let server_config = self.server_config.as_ref().unwrap();
if self.incoming_handshakes == server_config.accept_buffer as usize
|| self.reject_new_connections
|| self.is_full()
{
debug!("rejecting connection due to full accept buffer");
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
contents: initial_close(
crypto,
header_crypto,
&src_cid,
&temp_loc_cid,
0,
TransportError::SERVER_BUSY(""),
),
});
return None;
}
if dst_cid.len() < 8
&& (!server_config.use_stateless_retry || dst_cid.len() != self.config.local_cid_len)
{
debug!(
"rejecting connection due to invalid DCID length {}",
dst_cid.len()
);
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
contents: initial_close(
crypto,
header_crypto,
&src_cid,
&temp_loc_cid,
0,
TransportError::PROTOCOL_VIOLATION("invalid destination CID length"),
),
});
return None;
}
let mut retry_cid = None;
if server_config.use_stateless_retry {
if let Some((token_dst_cid, token_issued)) =
token::check(self.token_key.as_ref().unwrap(), &remote, &token)
{
let expires = token_issued
+ Duration::from_micros(
self.server_config.as_ref().unwrap().retry_token_lifetime,
);
if expires > SystemTime::now() {
retry_cid = Some(token_dst_cid);
} else {
trace!("sending stateless retry due to expired token");
}
} else {
trace!("sending stateless retry due to invalid token");
}
if retry_cid.is_none() {
let token = token::generate(
self.token_key.as_ref().unwrap(),
&remote,
&dst_cid,
SystemTime::now(),
);
let mut buf = Vec::new();
let header = Header::Retry {
src_cid: temp_loc_cid,
dst_cid: src_cid,
orig_dst_cid: dst_cid,
};
let encode = header.encode(&mut buf);
encode.finish::<S::Keys, <S::Keys as Keys>::HeaderKeys>(
&mut buf,
header_crypto,
None,
);
buf.put_slice(&token);
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
contents: buf.into(),
});
return None;
}
}
let (ch, mut conn) = self
.add_connection(
dst_cid,
src_cid,
remote,
ConnectionOpts::Server {
orig_dst_cid: retry_cid,
},
now,
)
.unwrap();
if dst_cid.len() != 0 {
self.connection_ids_initial.insert(dst_cid, ch);
}
match conn.handle_first_packet(now, remote, ecn, packet_number as u64, packet, rest) {
Ok(()) => {
trace!(id = ch.0, icid = %dst_cid, "connection incoming");
self.incoming_handshakes += 1;
Some((ch, conn))
}
Err(e) => {
debug!("handshake failed: {}", e);
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
if let ConnectionError::TransportError(e) = e {
self.transmits.push_back(Transmit {
destination: remote,
ecn: None,
contents: initial_close(
crypto,
header_crypto,
&src_cid,
&temp_loc_cid,
0,
e,
),
});
}
None
}
}
}
pub fn accept(&mut self) {
self.incoming_handshakes -= 1;
}
pub fn reject_new_connections(&mut self) {
self.reject_new_connections = true;
}
#[cfg(test)]
pub(crate) fn known_connections(&self) -> usize {
let x = self.connections.len();
debug_assert_eq!(x, self.connection_ids_initial.len());
debug_assert!(x >= self.connection_reset_tokens.0.len());
debug_assert!(x >= self.connection_remotes.len());
x
}
#[cfg(test)]
pub(crate) fn known_cids(&self) -> usize {
self.connection_ids.len()
}
fn is_full(&self) -> bool {
self.config.local_cid_len <= 4
&& self.config.local_cid_len != 0
&& (2usize.pow(self.config.local_cid_len as u32 * 8) - self.connection_ids.len())
< 2usize.pow(self.config.local_cid_len as u32 * 8 - 2)
}
}
impl<S> fmt::Debug for Endpoint<S>
where
S: crypto::Session,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Endpoint<T>")
.field("rng", &self.rng)
.field("transmits", &self.transmits)
.field("connection_ids_initial", &self.connection_ids_initial)
.field("connection_ids", &self.connection_ids)
.field("connection_remotes", &self.connection_remotes)
.field("connection_reset_tokens", &self.connection_reset_tokens)
.field("connections", &self.connections)
.field("config", &self.config)
.field("server_config", &self.server_config)
.field("incoming_handshakes", &self.incoming_handshakes)
.field("reject_new_connections", &self.reject_new_connections)
.finish()
}
}
#[derive(Debug)]
pub(crate) struct ConnectionMeta {
init_cid: ConnectionId,
cids_issued: u64,
loc_cids: HashMap<u64, ConnectionId>,
initial_remote: SocketAddr,
reset_token: Option<(SocketAddr, ResetToken)>,
}
fn reset_token_for<H>(key: &H, id: &ConnectionId) -> ResetToken
where
H: crypto::HmacKey,
{
let signature = key.sign(id);
let mut result = [0; RESET_TOKEN_SIZE];
result.copy_from_slice(&signature.as_ref()[..RESET_TOKEN_SIZE]);
result.into()
}
mod token {
use std::{
io,
net::{IpAddr, SocketAddr},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use bytes::{Buf, BufMut};
use crate::{
coding::{BufExt, BufMutExt},
crypto::HmacKey,
shared::ConnectionId,
MAX_CID_SIZE,
};
pub fn generate<K>(
key: &K,
address: &SocketAddr,
dst_cid: &ConnectionId,
issued: SystemTime,
) -> Vec<u8>
where
K: HmacKey,
{
let mut buf = Vec::new();
buf.write(dst_cid.len() as u8);
buf.put_slice(dst_cid);
buf.write::<u64>(
issued
.duration_since(UNIX_EPOCH)
.map(|x| x.as_secs())
.unwrap_or(0),
);
let signature_pos = buf.len();
match address.ip() {
IpAddr::V4(x) => buf.put_slice(&x.octets()),
IpAddr::V6(x) => buf.put_slice(&x.octets()),
}
buf.write(address.port());
let signature = key.sign(&buf);
buf.truncate(signature_pos);
buf.extend_from_slice(signature.as_ref());
buf
}
pub fn check<K>(
key: &K,
address: &SocketAddr,
data: &[u8],
) -> Option<(ConnectionId, SystemTime)>
where
K: HmacKey,
{
let mut reader = io::Cursor::new(data);
let dst_cid_len = reader.get::<u8>().ok()? as usize;
if dst_cid_len > reader.remaining() || dst_cid_len > MAX_CID_SIZE {
return None;
}
let dst_cid = ConnectionId::new(&data[1..=dst_cid_len]);
reader.advance(dst_cid_len);
let issued = UNIX_EPOCH + Duration::new(reader.get::<u64>().ok()?, 0);
let signature_start = reader.position() as usize;
let mut buf = Vec::new();
buf.put_slice(&data[0..signature_start]);
match address.ip() {
IpAddr::V4(x) => buf.put_slice(&x.octets()),
IpAddr::V6(x) => buf.put_slice(&x.octets()),
}
buf.write(address.port());
key.verify(&buf, &data[signature_start..]).ok()?;
Some((dst_cid, issued))
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct ConnectionHandle(pub usize);
impl From<ConnectionHandle> for usize {
fn from(x: ConnectionHandle) -> usize {
x.0
}
}
impl Index<ConnectionHandle> for Slab<ConnectionMeta> {
type Output = ConnectionMeta;
fn index(&self, ch: ConnectionHandle) -> &ConnectionMeta {
&self[ch.0]
}
}
impl IndexMut<ConnectionHandle> for Slab<ConnectionMeta> {
fn index_mut(&mut self, ch: ConnectionHandle) -> &mut ConnectionMeta {
&mut self[ch.0]
}
}
pub enum DatagramEvent<S>
where
S: crypto::Session,
{
ConnectionEvent(ConnectionEvent),
NewConnection(Connection<S>),
}
enum ConnectionOpts<S: crypto::Session> {
Client {
config: ClientConfig<S>,
server_name: String,
},
Server {
orig_dst_cid: Option<ConnectionId>,
},
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectError {
#[error(display = "endpoint stopping")]
EndpointStopping,
#[error(display = "too many connections")]
TooManyConnections,
#[error(display = "invalid DNS name: {}", _0)]
InvalidDnsName(String),
#[error(display = "transport configuration error: {}", _0)]
Config(#[source] ConfigError),
}
#[derive(Default, Debug)]
struct ResetTokenTable(HashMap<SocketAddr, HashMap<ResetToken, ConnectionHandle>>);
impl ResetTokenTable {
fn insert(&mut self, remote: SocketAddr, token: ResetToken, ch: ConnectionHandle) -> bool {
self.0
.entry(remote)
.or_default()
.insert(token, ch)
.is_some()
}
fn remove(&mut self, remote: SocketAddr, token: ResetToken) {
use std::collections::hash_map::Entry;
match self.0.entry(remote) {
Entry::Vacant(_) => {}
Entry::Occupied(mut e) => {
e.get_mut().remove(&token);
if e.get().is_empty() {
e.remove_entry();
}
}
}
}
fn get(&self, remote: SocketAddr, token: &[u8]) -> Option<&ConnectionHandle> {
self.0.get(&remote).and_then(|x| x.get(token))
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "ring")]
#[test]
fn token_sanity() {
use super::*;
use crate::crypto::HmacKey;
use ring::hmac;
use std::{
net::Ipv6Addr,
time::{Duration, UNIX_EPOCH},
};
let mut key = [0; 64];
rand::thread_rng().fill_bytes(&mut key);
let key = <hmac::Key as HmacKey>::new(&key).unwrap();
let addr = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 4433);
let dst_cid = ConnectionId::random(&mut rand::thread_rng(), MAX_CID_SIZE);
let issued = UNIX_EPOCH + Duration::new(42, 0);
let token = token::generate(&key, &addr, &dst_cid, issued);
let (dst_cid2, issued2) = token::check(&key, &addr, &token).expect("token didn't validate");
assert_eq!(dst_cid, dst_cid2);
assert_eq!(issued, issued2);
}
}