use std::{
cmp,
collections::{BTreeMap, VecDeque},
convert::TryFrom,
fmt, io, mem,
net::{IpAddr, SocketAddr},
sync::Arc,
time::{Duration, Instant},
};
use bytes::{Bytes, BytesMut};
use frame::StreamMetaVec;
use rand::{rngs::StdRng, Rng, SeedableRng};
use thiserror::Error;
use tracing::{debug, error, trace, trace_span, warn};
use crate::{
cid_generator::ConnectionIdGenerator,
cid_queue::CidQueue,
coding::BufMutExt,
config::{ServerConfig, TransportConfig},
crypto::{self, KeyPair, Keys, PacketKey},
frame,
frame::{Close, Datagram, FrameStruct},
packet::{Header, LongType, Packet, PartialDecode, SpaceId},
range_set::RangeSet,
shared::{
ConnectionEvent, ConnectionEventInner, ConnectionId, EcnCodepoint, EndpointEvent,
EndpointEventInner, IssuedCid,
},
transport_parameters::TransportParameters,
Dir, Frame, Side, StreamId, Transmit, TransportError, TransportErrorCode, VarInt,
MAX_STREAM_COUNT, MIN_INITIAL_SIZE, RESET_TOKEN_SIZE, TIMER_GRANULARITY,
};
mod assembler;
pub use assembler::Chunk;
mod cid_state;
use cid_state::CidState;
mod datagrams;
use datagrams::DatagramState;
pub use datagrams::{Datagrams, SendDatagramError};
mod pacing;
mod packet_builder;
use packet_builder::PacketBuilder;
mod paths;
use paths::PathData;
mod send_buffer;
mod spaces;
#[cfg(fuzzing)]
pub use spaces::Retransmits;
#[cfg(not(fuzzing))]
use spaces::Retransmits;
use spaces::{PacketSpace, SentPacket, ThinRetransmits};
mod stats;
pub use stats::ConnectionStats;
mod streams;
#[cfg(fuzzing)]
pub use streams::StreamsState;
#[cfg(not(fuzzing))]
use streams::StreamsState;
pub use streams::{
ByteSlice, BytesArray, BytesSource, Chunks, FinishError, ReadError, ReadableError, RecvStream,
SendStream, ShouldTransmit, StreamEvent, Streams, UnknownStream, WriteError, Written,
};
mod timer;
use timer::{Timer, TimerTable};
pub struct Connection<S>
where
S: crypto::Session,
{
server_config: Option<Arc<ServerConfig<S>>>,
config: Arc<TransportConfig>,
rng: StdRng,
crypto: S,
handshake_cid: ConnectionId,
rem_handshake_cid: ConnectionId,
local_ip: Option<IpAddr>,
path: PathData,
prev_path: Option<PathData>,
state: State,
side: Side,
zero_rtt_enabled: bool,
zero_rtt_crypto: Option<ZeroRttCrypto<S>>,
key_phase: bool,
peer_params: TransportParameters,
orig_rem_cid: ConnectionId,
initial_dst_cid: ConnectionId,
retry_src_cid: Option<ConnectionId>,
lost_packets: u64,
events: VecDeque<Event>,
endpoint_events: VecDeque<EndpointEventInner>,
spin_enabled: bool,
spin: bool,
spaces: [PacketSpace<S>; 3],
highest_space: SpaceId,
prev_crypto: Option<PrevCrypto<S::PacketKey>>,
next_crypto: Option<KeyPair<S::PacketKey>>,
accepted_0rtt: bool,
permit_idle_reset: bool,
idle_timeout: Option<Duration>,
timers: TimerTable,
authentication_failures: u64,
path_response: Option<PathResponse>,
close: bool,
pto_count: u32,
in_flight: InFlight,
receiving_ecn: bool,
total_authed_packets: u64,
app_limited: bool,
streams: StreamsState,
rem_cids: CidQueue,
local_cid_state: CidState,
datagrams: DatagramState,
stats: ConnectionStats,
version: u32,
}
impl<S> Connection<S>
where
S: crypto::Session,
{
pub(crate) fn new(
server_config: Option<Arc<ServerConfig<S>>>,
config: Arc<TransportConfig>,
init_cid: ConnectionId,
loc_cid: ConnectionId,
rem_cid: ConnectionId,
remote: SocketAddr,
local_ip: Option<IpAddr>,
crypto: S,
cid_gen: &dyn ConnectionIdGenerator,
now: Instant,
version: u32,
) -> Self {
let side = if server_config.is_some() {
Side::Server
} else {
Side::Client
};
let initial_space = PacketSpace {
crypto: Some(S::initial_keys(&init_cid, side)),
..PacketSpace::new(now)
};
let state = State::Handshake(state::Handshake {
rem_cid_set: side.is_server(),
token: None,
client_hello: None,
});
let mut rng = StdRng::from_entropy();
let path_validated = server_config
.as_ref()
.map_or(true, |c| c.use_stateless_retry);
let mut this = Self {
server_config,
crypto,
handshake_cid: loc_cid,
rem_handshake_cid: rem_cid,
local_cid_state: CidState::new(cid_gen.cid_len(), cid_gen.cid_lifetime(), now),
path: PathData::new(
remote,
config.initial_rtt,
config.congestion_controller_factory.build(now),
now,
path_validated,
),
local_ip,
prev_path: None,
side,
state,
zero_rtt_enabled: false,
zero_rtt_crypto: None,
key_phase: false,
peer_params: TransportParameters::default(),
orig_rem_cid: rem_cid,
initial_dst_cid: init_cid,
retry_src_cid: None,
lost_packets: 0,
events: VecDeque::new(),
endpoint_events: VecDeque::new(),
spin_enabled: config.allow_spin && rng.gen_ratio(7, 8),
spin: false,
spaces: [initial_space, PacketSpace::new(now), PacketSpace::new(now)],
highest_space: SpaceId::Initial,
prev_crypto: None,
next_crypto: None,
accepted_0rtt: false,
permit_idle_reset: true,
idle_timeout: config.max_idle_timeout,
timers: TimerTable::default(),
authentication_failures: 0,
path_response: None,
close: false,
pto_count: 0,
app_limited: false,
in_flight: InFlight::new(),
receiving_ecn: false,
total_authed_packets: 0,
streams: StreamsState::new(
side,
config.max_concurrent_uni_streams,
config.max_concurrent_bidi_streams,
config.send_window,
config.receive_window,
config.stream_receive_window,
),
datagrams: DatagramState::default(),
config,
rem_cids: CidQueue::new(rem_cid),
rng,
stats: ConnectionStats::default(),
version,
};
if side.is_client() {
this.write_crypto();
this.init_0rtt();
}
this
}
#[must_use]
pub fn poll_timeout(&mut self) -> Option<Instant> {
self.timers.next_timeout()
}
#[must_use]
pub fn poll(&mut self) -> Option<Event> {
if let Some(event) = self.streams.poll() {
return Some(Event::Stream(event));
}
if let Some(x) = self.events.pop_front() {
return Some(x);
}
None
}
#[must_use]
pub fn poll_endpoint_events(&mut self) -> Option<EndpointEvent> {
self.endpoint_events.pop_front().map(EndpointEvent)
}
#[must_use]
pub fn streams(&mut self) -> Streams<'_> {
Streams {
state: &mut self.streams,
conn_state: &self.state,
}
}
#[must_use]
pub fn recv_stream(&mut self, id: StreamId) -> RecvStream<'_> {
assert!(id.dir() == Dir::Bi || id.initiator() != self.side);
RecvStream {
id,
state: &mut self.streams,
pending: &mut self.spaces[SpaceId::Data].pending,
}
}
#[must_use]
pub fn send_stream(&mut self, id: StreamId) -> SendStream<'_> {
assert!(id.dir() == Dir::Bi || id.initiator() == self.side);
SendStream {
id,
state: &mut self.streams,
pending: &mut self.spaces[SpaceId::Data].pending,
conn_state: &self.state,
}
}
#[must_use]
pub fn poll_transmit(&mut self, now: Instant) -> Option<Transmit> {
const MAX_DATAGRAMS: usize = 1;
let mut num_datagrams = 0;
if let Some(ref mut prev_path) = self.prev_path {
if prev_path.challenge_pending {
prev_path.challenge_pending = false;
let token = prev_path
.challenge
.expect("previous path challenge pending without token");
let destination = prev_path.remote;
debug_assert_eq!(
self.highest_space,
SpaceId::Data,
"PATH_CHALLENGE queued without 1-RTT keys"
);
let mut buf = Vec::with_capacity(self.path.mtu as usize);
let buf_capacity = self.path.mtu as usize;
let mut builder = PacketBuilder::new(
now,
SpaceId::Data,
&mut buf,
buf_capacity,
0,
false,
self,
self.version,
)?;
trace!("validating previous path with PATH_CHALLENGE {:08x}", token);
buf.write(frame::Type::PATH_CHALLENGE);
buf.write(token);
self.stats.frame_tx.path_challenge += 1;
builder.pad_to(MIN_INITIAL_SIZE);
builder.finish(self, &mut buf);
self.stats.udp_tx.datagrams += 1;
self.stats.udp_tx.transmits += 1;
self.stats.udp_tx.bytes += buf.len() as u64;
return Some(Transmit {
destination,
contents: buf,
ecn: None,
segment_size: None,
src_ip: self.local_ip,
});
}
}
for space in SpaceId::iter() {
self.spaces[space].maybe_queue_probe();
}
let close = match self.state {
State::Drained => {
self.app_limited = true;
return None;
}
State::Draining | State::Closed(_) => {
if !self.close {
self.app_limited = true;
return None;
}
true
}
_ => false,
};
let mut buf = Vec::new();
let mut buf_capacity = 0;
let mut coalesce = true;
let mut builder: Option<PacketBuilder> = None;
let mut sent_frames = None;
let mut pad_datagram = false;
let mut congestion_blocked = false;
let mut space_idx = 0;
let spaces = [SpaceId::Initial, SpaceId::Handshake, SpaceId::Data];
while space_idx < spaces.len() {
let space_id = spaces[space_idx];
if close && space_id != self.highest_space {
space_idx += 1;
continue;
}
if !self.space_can_send(space_id) && !close {
space_idx += 1;
continue;
}
let mut ack_eliciting =
!self.spaces[space_id].pending.is_empty() || self.spaces[space_id].ping_pending;
if space_id == SpaceId::Data {
ack_eliciting |= self.can_send_1rtt();
}
let buf_end = if let Some(builder) = &builder {
buf.len().max(builder.min_size) + builder.tag_len
} else {
buf.len()
};
if !coalesce || buf_capacity - buf_end < MIN_PACKET_SPACE {
if buf_capacity >= self.path.mtu as usize * MAX_DATAGRAMS {
break;
}
if self
.path
.anti_amplification_blocked(self.path.mtu as u64 * (num_datagrams + 1) as u64)
{
trace!("blocked by anti-amplification");
break;
}
if ack_eliciting && self.spaces[space_id].loss_probes == 0 {
let untracked_bytes = if let Some(builder) = &builder {
buf_capacity - builder.partial_encode.start
} else {
0
} as u64;
debug_assert!(untracked_bytes <= self.path.mtu as u64);
let bytes_to_send = u64::from(self.path.mtu) + untracked_bytes;
if self.in_flight.bytes + bytes_to_send >= self.path.congestion.window() {
space_idx += 1;
congestion_blocked = true;
continue;
}
let smoothed_rtt = self.path.rtt.get();
if let Some(delay) = self.path.pacing.delay(
smoothed_rtt,
bytes_to_send,
self.path.mtu,
self.path.congestion.window(),
now,
) {
self.timers.set(Timer::Pacing, delay);
congestion_blocked = true;
break;
}
}
if let Some(mut builder) = builder.take() {
builder.pad_to(self.path.mtu);
builder.finish_and_track(now, self, sent_frames.take(), &mut buf);
debug_assert_eq!(buf.len(), buf_capacity, "Packet must be padded");
}
buf_capacity += self.path.mtu as usize;
if buf.capacity() < buf_capacity {
buf.reserve(buf_capacity - buf.capacity());
}
num_datagrams += 1;
coalesce = true;
pad_datagram = false;
} else {
if let Some(builder) = builder.take() {
builder.finish_and_track(now, self, sent_frames.take(), &mut buf);
}
}
debug_assert!(buf_capacity - buf.len() >= MIN_PACKET_SPACE);
if self.spaces[SpaceId::Initial].crypto.is_some()
&& space_id == SpaceId::Handshake
&& self.side.is_client()
{
self.discard_space(now, SpaceId::Initial);
}
if let Some(ref mut prev) = self.prev_crypto {
prev.update_unacked = false;
}
debug_assert!(
builder.is_none() && sent_frames.is_none(),
"Previous packet must have been finished"
);
let builder = builder.get_or_insert(PacketBuilder::new(
now,
space_id,
&mut buf,
buf_capacity,
(num_datagrams - 1) * (self.path.mtu as usize),
ack_eliciting,
self,
self.version,
)?);
coalesce = coalesce && !builder.short_header;
pad_datagram |=
space_id == SpaceId::Initial && (self.side.is_client() || ack_eliciting);
if close {
trace!("sending CONNECTION_CLOSE");
match self.state {
State::Closed(state::Closed { ref reason }) => {
if space_id == SpaceId::Data {
reason.encode(&mut buf, builder.max_size)
} else {
frame::ConnectionClose {
error_code: TransportErrorCode::APPLICATION_ERROR,
frame_type: None,
reason: Bytes::new(),
}
.encode(&mut buf, builder.max_size)
}
}
State::Draining => frame::ConnectionClose {
error_code: TransportErrorCode::NO_ERROR,
frame_type: None,
reason: Bytes::new(),
}
.encode(&mut buf, builder.max_size),
_ => unreachable!(
"tried to make a close packet when the connection wasn't closed"
),
}
self.close = false;
break;
}
let sent = self.populate_packet(space_id, &mut buf, buf_capacity - builder.tag_len);
pad_datagram |= sent.requires_padding;
self.spaces[space_id].permit_ack_only &= sent.acks.is_empty();
sent_frames = Some(sent);
}
if let Some(mut builder) = builder {
if pad_datagram {
builder.pad_to(MIN_INITIAL_SIZE);
}
builder.finish_and_track(now, self, sent_frames, &mut buf);
}
self.app_limited = buf.is_empty() && !congestion_blocked;
if buf.is_empty() {
return None;
}
trace!("sending {} bytes in {} datagrams", buf.len(), num_datagrams);
self.path.total_sent = self.path.total_sent.saturating_add(buf.len() as u64);
self.stats.udp_tx.datagrams += num_datagrams as u64;
self.stats.udp_tx.bytes += buf.len() as u64;
self.stats.udp_tx.transmits += 1;
Some(Transmit {
destination: self.path.remote,
contents: buf,
ecn: if self.path.sending_ecn {
Some(EcnCodepoint::ECT0)
} else {
None
},
segment_size: match num_datagrams {
1 => None,
_ => Some(self.path.mtu as usize),
},
src_ip: self.local_ip,
})
}
fn space_can_send(&self, space_id: SpaceId) -> bool {
if self.spaces[space_id].crypto.is_some() && self.spaces[space_id].can_send() {
return true;
}
if space_id != SpaceId::Data {
return false;
}
if self.spaces[space_id].crypto.is_some() && self.can_send_1rtt() {
return true;
}
self.zero_rtt_crypto.is_some()
&& self.side.is_client()
&& (self.spaces[space_id].can_send() || self.can_send_1rtt())
}
pub fn handle_event(&mut self, event: ConnectionEvent) {
use self::ConnectionEventInner::*;
match event.0 {
Datagram {
now,
remote,
ecn,
first_decode,
remaining,
} => {
if remote != self.path.remote
&& self.server_config.as_ref().map_or(true, |x| !x.migration)
{
trace!("discarding packet from unrecognized peer {}", remote);
return;
}
let was_anti_amplification_blocked =
self.path.anti_amplification_blocked(self.path.mtu as u64);
self.stats.udp_rx.datagrams += 1;
self.stats.udp_rx.bytes += first_decode.len() as u64;
self.path.total_recvd = self
.path
.total_recvd
.saturating_add(first_decode.len() as u64);
self.handle_decode(now, remote, ecn, first_decode);
if let Some(data) = remaining {
self.stats.udp_rx.bytes += data.len() as u64;
self.handle_coalesced(now, remote, ecn, data);
}
if was_anti_amplification_blocked {
self.set_loss_detection_timer(now);
}
}
NewIdentifiers(ids, now) => {
self.local_cid_state.new_cids(&ids, now);
ids.into_iter().rev().for_each(|frame| {
self.spaces[SpaceId::Data].pending.new_cids.push(frame);
});
if self
.timers
.get(Timer::PushNewCid)
.map_or(true, |x| x <= now)
{
self.reset_cid_retirement();
}
}
}
}
pub fn handle_timeout(&mut self, now: Instant) {
for &timer in &Timer::VALUES {
if !self.timers.is_expired(timer, now) {
continue;
}
self.timers.stop(timer);
trace!(timer = ?timer, "timeout");
match timer {
Timer::Close => {
self.state = State::Drained;
self.endpoint_events.push_back(EndpointEventInner::Drained);
}
Timer::Idle => {
self.kill(ConnectionError::TimedOut);
}
Timer::KeepAlive => {
trace!("sending keep-alive");
self.ping();
}
Timer::LossDetection => {
self.on_loss_detection_timeout(now);
}
Timer::KeyDiscard => {
self.zero_rtt_crypto = None;
self.prev_crypto = None;
}
Timer::PathValidation => {
debug!("path validation failed");
if let Some(prev) = self.prev_path.take() {
self.path = prev;
}
self.path.challenge = None;
self.path.challenge_pending = false;
}
Timer::Pacing => trace!("pacing timer expired"),
Timer::PushNewCid => {
let num_new_cid = self.local_cid_state.on_cid_timeout().into();
if !self.state.is_closed() {
trace!(
"push a new cid to peer RETIRE_PRIOR_TO field {}",
self.local_cid_state.retire_prior_to()
);
self.endpoint_events
.push_back(EndpointEventInner::NeedIdentifiers(now, num_new_cid));
}
}
}
}
}
pub fn close(&mut self, now: Instant, error_code: VarInt, reason: Bytes) {
self.close_inner(
now,
Close::Application(frame::ApplicationClose { error_code, reason }),
)
}
fn close_inner(&mut self, now: Instant, reason: Close) {
let was_closed = self.state.is_closed();
if !was_closed {
self.close_common();
self.set_close_timer(now);
self.close = true;
self.state = State::Closed(state::Closed { reason });
}
}
pub fn datagrams(&mut self) -> Datagrams<'_, S> {
Datagrams { conn: self }
}
pub fn stats(&self) -> ConnectionStats {
let mut stats = self.stats;
stats.path.rtt = self.path.rtt.get();
stats.path.cwnd = self.path.congestion.window();
stats
}
pub fn ping(&mut self) {
self.spaces[self.highest_space].ping_pending = true;
}
#[doc(hidden)]
pub fn initiate_key_update(&mut self) {
self.update_keys(None, false);
}
pub fn crypto_session(&self) -> &S {
&self.crypto
}
pub fn is_handshaking(&self) -> bool {
self.state.is_handshake()
}
pub fn is_closed(&self) -> bool {
self.state.is_closed()
}
pub fn is_drained(&self) -> bool {
self.state.is_drained()
}
pub fn accepted_0rtt(&self) -> bool {
self.accepted_0rtt
}
pub fn has_0rtt(&self) -> bool {
self.zero_rtt_enabled
}
pub fn has_pending_retransmits(&self) -> bool {
!self.spaces[SpaceId::Data].pending.is_empty()
}
pub fn side(&self) -> Side {
self.side
}
pub fn remote_address(&self) -> SocketAddr {
self.path.remote
}
pub fn local_ip(&self) -> Option<IpAddr> {
self.local_ip
}
pub fn rtt(&self) -> Duration {
self.path.rtt.get()
}
fn on_ack_received(
&mut self,
now: Instant,
space: SpaceId,
ack: frame::Ack,
) -> Result<(), TransportError> {
if ack.largest >= self.spaces[space].next_packet_number {
return Err(TransportError::PROTOCOL_VIOLATION("unsent packet acked"));
}
let new_largest = {
let space = &mut self.spaces[space];
if space
.largest_acked_packet
.map_or(true, |pn| ack.largest > pn)
{
space.largest_acked_packet = Some(ack.largest);
if let Some(info) = space.sent_packets.get(&ack.largest) {
space.largest_acked_packet_sent = info.time_sent;
}
true
} else {
false
}
};
let newly_acked = ack
.iter()
.flat_map(|range| {
self.spaces[space]
.sent_packets
.range(range)
.map(|(&n, _)| n)
})
.collect::<Vec<_>>();
if newly_acked.is_empty() {
return Ok(());
}
let mut ack_eliciting_acked = false;
for &packet in &newly_acked {
if let Some(info) = self.spaces[space].sent_packets.remove(&packet) {
self.spaces[space].pending_acks.subtract(&info.acks);
ack_eliciting_acked |= info.ack_eliciting;
self.on_packet_acked(now, space, info);
}
}
if new_largest && ack_eliciting_acked {
let ack_delay = if space != SpaceId::Data {
Duration::from_micros(0)
} else {
cmp::min(
self.max_ack_delay(),
Duration::from_micros(ack.delay << self.peer_params.ack_delay_exponent.0),
)
};
let rtt = instant_saturating_sub(now, self.spaces[space].largest_acked_packet_sent);
self.path.rtt.update(ack_delay, rtt);
}
self.detect_lost_packets(now, space);
if self.peer_completed_address_validation() {
self.pto_count = 0;
}
if self.path.sending_ecn {
if let Some(ecn) = ack.ecn {
if new_largest {
let sent = self.spaces[space].largest_acked_packet_sent;
self.process_ecn(now, space, newly_acked.len() as u64, ecn, sent);
}
} else {
debug!("ECN not acknowledged by peer");
self.path.sending_ecn = false;
}
}
self.set_loss_detection_timer(now);
Ok(())
}
fn process_ecn(
&mut self,
now: Instant,
space: SpaceId,
newly_acked: u64,
ecn: frame::EcnCounts,
largest_sent_time: Instant,
) {
match self.spaces[space].detect_ecn(newly_acked, ecn) {
Err(e) => {
debug!("halting ECN due to verification failure: {}", e);
self.path.sending_ecn = false;
self.spaces[space].ecn_feedback = frame::EcnCounts::ZERO;
}
Ok(false) => {}
Ok(true) => {
self.stats.path.congestion_events += 1;
self.path
.congestion
.on_congestion_event(now, largest_sent_time, false);
}
}
}
fn on_packet_acked(&mut self, now: Instant, space: SpaceId, info: SentPacket) {
self.remove_in_flight(space, &info);
if info.ack_eliciting && self.path.challenge.is_none() {
self.path
.congestion
.on_ack(now, info.time_sent, info.size.into(), self.app_limited);
}
if let Some(retransmits) = info.retransmits.get() {
for (id, _) in retransmits.reset_stream.iter() {
self.streams.reset_acked(*id);
}
}
for frame in info.stream_frames {
self.streams.received_ack_of(frame);
}
}
fn set_key_discard_timer(&mut self, now: Instant) {
let start = if self.zero_rtt_crypto.is_some() {
now
} else {
self.prev_crypto
.as_ref()
.expect("no previous keys")
.end_packet
.as_ref()
.expect("update not acknowledged yet")
.1
};
self.timers.set(Timer::KeyDiscard, start + self.pto() * 3);
}
fn on_loss_detection_timeout(&mut self, now: Instant) {
if let Some((_, pn_space)) = self.loss_time_and_space() {
self.detect_lost_packets(now, pn_space);
self.set_loss_detection_timer(now);
return;
}
let (space, count) = self.pto_time_and_space(now).map_or_else(
|| {
debug_assert!(self.side.is_client() && self.highest_space <= SpaceId::Handshake);
(self.highest_space, 1)
},
|(_, pto_space)| (pto_space, 2),
);
trace!(
in_flight = self.in_flight.bytes,
count = self.pto_count,
?space,
"PTO fired"
);
self.spaces[space].loss_probes = self.spaces[space].loss_probes.saturating_add(count);
self.pto_count = self.pto_count.saturating_add(1);
self.set_loss_detection_timer(now);
}
fn detect_lost_packets(&mut self, now: Instant, pn_space: SpaceId) {
let mut lost_packets = Vec::<u64>::new();
let rtt = self.path.rtt.conservative();
let loss_delay = cmp::max(rtt.mul_f32(self.config.time_threshold), TIMER_GRANULARITY);
let lost_send_time = now - loss_delay;
let largest_acked_packet = self.spaces[pn_space].largest_acked_packet.unwrap();
let packet_threshold = self.config.packet_threshold as u64;
let space = &mut self.spaces[pn_space];
space.loss_time = None;
for (&packet, info) in space.sent_packets.range(0..largest_acked_packet) {
if info.time_sent <= lost_send_time || largest_acked_packet >= packet + packet_threshold
{
lost_packets.push(packet);
} else {
let next_loss_time = info.time_sent + loss_delay;
space.loss_time = Some(
space
.loss_time
.map_or(next_loss_time, |x| cmp::min(x, next_loss_time)),
);
}
}
if let Some(largest_lost) = lost_packets.last().cloned() {
let old_bytes_in_flight = self.in_flight.bytes;
let largest_lost_sent = self.spaces[pn_space].sent_packets[&largest_lost].time_sent;
self.lost_packets += lost_packets.len() as u64;
trace!("packets lost: {:?}", lost_packets);
for packet in &lost_packets {
let info = self.spaces[pn_space].sent_packets.remove(&packet).unwrap();
self.remove_in_flight(pn_space, &info);
for frame in info.stream_frames {
self.streams.retransmit(frame);
}
self.spaces[pn_space].pending |= info.retransmits;
}
let lost_ack_eliciting = old_bytes_in_flight != self.in_flight.bytes;
let congestion_period = self.pto() * self.config.persistent_congestion_threshold;
let in_persistent_congestion = self.spaces[pn_space].largest_acked_packet_sent
< largest_lost_sent - congestion_period;
if lost_ack_eliciting {
self.stats.path.congestion_events += 1;
self.path.congestion.on_congestion_event(
now,
largest_lost_sent,
in_persistent_congestion,
);
}
}
}
fn loss_time_and_space(&self) -> Option<(Instant, SpaceId)> {
SpaceId::iter()
.filter_map(|id| Some((self.spaces[id].loss_time?, id)))
.min_by_key(|&(time, _)| time)
}
fn pto_time_and_space(&self, now: Instant) -> Option<(Instant, SpaceId)> {
let backoff = 2u32.pow(self.pto_count.min(MAX_BACKOFF_EXPONENT));
let mut duration = self.path.rtt.pto_base() * backoff;
if self.in_flight.is_empty() {
debug_assert!(!self.peer_completed_address_validation());
let space = match self.highest_space {
SpaceId::Handshake => SpaceId::Handshake,
_ => SpaceId::Initial,
};
return Some((now + duration, space));
}
let mut result = None;
for space in SpaceId::iter() {
if self.spaces[space].in_flight == 0 {
continue;
}
if space == SpaceId::Data {
if self.is_handshaking() {
return result;
}
duration += self.max_ack_delay() * backoff;
}
let last_ack_eliciting = match self.spaces[space].time_of_last_ack_eliciting_packet {
Some(time) => time,
None => continue,
};
let pto = last_ack_eliciting + duration;
if result.map_or(true, |(earliest_pto, _)| pto < earliest_pto) {
result = Some((pto, space));
}
}
result
}
#[allow(clippy::suspicious_operation_groupings)]
fn peer_completed_address_validation(&self) -> bool {
if self.side.is_server() || self.state.is_closed() {
return true;
}
self.spaces[SpaceId::Handshake]
.largest_acked_packet
.is_some()
|| self.spaces[SpaceId::Data].largest_acked_packet.is_some()
|| (self.spaces[SpaceId::Data].crypto.is_some()
&& self.spaces[SpaceId::Handshake].crypto.is_none())
}
fn set_loss_detection_timer(&mut self, now: Instant) {
if let Some((loss_time, _)) = self.loss_time_and_space() {
self.timers.set(Timer::LossDetection, loss_time);
return;
}
if self.path.anti_amplification_blocked(self.path.mtu.into()) {
self.timers.stop(Timer::LossDetection);
return;
}
if self.in_flight.ack_eliciting == 0 && self.peer_completed_address_validation() {
self.timers.stop(Timer::LossDetection);
return;
}
if let Some((timeout, _)) = self.pto_time_and_space(now) {
self.timers.set(Timer::LossDetection, timeout);
} else {
self.timers.stop(Timer::LossDetection);
}
}
fn pto(&self) -> Duration {
self.path.rtt.pto_base() + self.max_ack_delay()
}
fn on_packet_authenticated(
&mut self,
now: Instant,
space_id: SpaceId,
ecn: Option<EcnCodepoint>,
packet: Option<u64>,
spin: bool,
is_1rtt: bool,
) {
self.total_authed_packets += 1;
self.reset_keep_alive(now);
self.reset_idle_timeout(now);
self.permit_idle_reset = true;
self.receiving_ecn |= ecn.is_some();
if let Some(x) = ecn {
self.spaces[space_id].ecn_counters += x;
}
let packet = match packet {
Some(x) => x,
None => return,
};
trace!("authenticated");
if self.side.is_server() {
if self.spaces[SpaceId::Initial].crypto.is_some() && space_id == SpaceId::Handshake {
self.discard_space(now, SpaceId::Initial);
}
if self.zero_rtt_crypto.is_some() && is_1rtt {
self.set_key_discard_timer(now)
}
}
let space = &mut self.spaces[space_id];
space.pending_acks.insert_one(packet);
if space.pending_acks.len() > MAX_ACK_BLOCKS {
space.pending_acks.pop_min();
}
if packet >= space.rx_packet {
space.rx_packet = packet;
self.spin = self.side.is_client() ^ spin;
}
}
fn reset_idle_timeout(&mut self, now: Instant) {
let timeout = match self.idle_timeout {
None => return,
Some(x) => x,
};
if self.state.is_closed() {
self.timers.stop(Timer::Idle);
return;
}
let dt = cmp::max(timeout, 3 * self.pto());
self.timers.set(Timer::Idle, now + dt);
}
fn reset_keep_alive(&mut self, now: Instant) {
let interval = match self.config.keep_alive_interval {
Some(x) if self.state.is_established() => x,
_ => return,
};
self.timers.set(Timer::KeepAlive, now + interval);
}
fn reset_cid_retirement(&mut self) {
if let Some(t) = self.local_cid_state.next_timeout() {
self.timers.set(Timer::PushNewCid, t);
}
}
pub(crate) fn handle_first_packet(
&mut self,
now: Instant,
remote: SocketAddr,
ecn: Option<EcnCodepoint>,
packet_number: u64,
packet: Packet,
remaining: Option<BytesMut>,
) -> Result<(), ConnectionError> {
let span = trace_span!("first recv");
let _guard = span.enter();
debug_assert!(self.side.is_server());
let len = packet.header_data.len() + packet.payload.len();
self.path.total_recvd = len as u64;
self.on_packet_authenticated(
now,
SpaceId::Initial,
ecn,
Some(packet_number),
false,
false,
);
self.process_decrypted_packet(now, remote, Some(packet_number), packet)?;
if let Some(data) = remaining {
self.handle_coalesced(now, remote, ecn, data);
}
Ok(())
}
fn init_0rtt(&mut self) {
let (header, packet) = match self.crypto.early_crypto() {
Some(x) => x,
None => return,
};
if self.side.is_client() {
match self.crypto.transport_parameters() {
Ok(params) => {
let params = params
.expect("crypto layer didn't supply transport parameters with ticket");
let params = TransportParameters {
initial_src_cid: None,
original_dst_cid: None,
preferred_address: None,
retry_src_cid: None,
stateless_reset_token: None,
ack_delay_exponent: TransportParameters::default().ack_delay_exponent,
max_ack_delay: TransportParameters::default().max_ack_delay,
..params
};
self.set_peer_params(params);
}
Err(e) => {
error!("session ticket has malformed transport parameters: {}", e);
return;
}
}
}
trace!("0-RTT enabled");
self.zero_rtt_enabled = true;
self.zero_rtt_crypto = Some(ZeroRttCrypto { header, packet });
}
fn read_crypto(
&mut self,
space: SpaceId,
crypto: &frame::Crypto,
payload_len: usize,
) -> Result<(), TransportError> {
let expected = if !self.state.is_handshake() {
SpaceId::Data
} else if self.highest_space == SpaceId::Initial {
SpaceId::Initial
} else {
SpaceId::Handshake
};
debug_assert!(space <= expected, "received out-of-order CRYPTO data");
let end = crypto.offset + crypto.data.len() as u64;
if space < expected && end > self.spaces[space].crypto_stream.bytes_read() {
warn!(
"received new {:?} CRYPTO data when expecting {:?}",
space, expected
);
return Err(TransportError::PROTOCOL_VIOLATION(
"new data at unexpected encryption level",
));
}
let space = &mut self.spaces[space];
let max = space.crypto_stream.bytes_read() + self.config.crypto_buffer_size as u64;
if end > max {
return Err(TransportError::CRYPTO_BUFFER_EXCEEDED(""));
}
space
.crypto_stream
.insert(crypto.offset, crypto.data.clone(), payload_len);
while let Some(chunk) = space.crypto_stream.read(usize::MAX, true) {
trace!("consumed {} CRYPTO bytes", chunk.bytes.len());
if self.crypto.read_handshake(&chunk.bytes)? {
self.events.push_back(Event::HandshakeDataReady);
}
}
Ok(())
}
fn write_crypto(&mut self) {
loop {
let space = self.highest_space;
let mut outgoing = Vec::new();
if let Some(crypto) = self.crypto.write_handshake(&mut outgoing) {
match space {
SpaceId::Initial => {
self.upgrade_crypto(SpaceId::Handshake, crypto);
}
SpaceId::Handshake => {
self.upgrade_crypto(SpaceId::Data, crypto);
}
_ => unreachable!("got updated secrets during 1-RTT"),
}
}
if outgoing.is_empty() {
if space == self.highest_space {
break;
} else {
continue;
}
}
let offset = self.spaces[space].crypto_offset;
let outgoing = Bytes::from(outgoing);
if let State::Handshake(ref mut state) = self.state {
if space == SpaceId::Initial && offset == 0 && self.side.is_client() {
state.client_hello = Some(outgoing.clone());
}
}
self.spaces[space].crypto_offset += outgoing.len() as u64;
trace!("wrote {} {:?} CRYPTO bytes", outgoing.len(), space);
self.spaces[space].pending.crypto.push_back(frame::Crypto {
offset,
data: outgoing,
});
}
}
fn upgrade_crypto(&mut self, space: SpaceId, crypto: Keys<S>) {
debug_assert!(
self.spaces[space].crypto.is_none(),
"already reached packet space {:?}",
space
);
trace!("{:?} keys ready", space);
if space == SpaceId::Data {
self.next_crypto = Some(self.crypto.next_1rtt_keys());
}
self.spaces[space].crypto = Some(crypto);
debug_assert!(space as usize > self.highest_space as usize);
self.highest_space = space;
if space == SpaceId::Data && self.side.is_client() {
self.zero_rtt_crypto = None;
}
}
fn discard_space(&mut self, now: Instant, space_id: SpaceId) {
debug_assert!(space_id != SpaceId::Data);
trace!("discarding {:?} keys", space_id);
let space = &mut self.spaces[space_id];
space.crypto = None;
space.time_of_last_ack_eliciting_packet = None;
space.loss_time = None;
let sent_packets = mem::replace(&mut space.sent_packets, BTreeMap::new());
for (_, packet) in sent_packets.into_iter() {
self.remove_in_flight(space_id, &packet);
}
self.set_loss_detection_timer(now)
}
fn handle_coalesced(
&mut self,
now: Instant,
remote: SocketAddr,
ecn: Option<EcnCodepoint>,
data: BytesMut,
) {
self.path.total_recvd = self.path.total_recvd.saturating_add(data.len() as u64);
let mut remaining = Some(data);
while let Some(data) = remaining {
match PartialDecode::new(data, self.local_cid_state.cid_len(), &[self.version]) {
Ok((partial_decode, rest)) => {
remaining = rest;
self.handle_decode(now, remote, ecn, partial_decode);
}
Err(e) => {
trace!("malformed header: {}", e);
return;
}
}
}
}
fn handle_decode(
&mut self,
now: Instant,
remote: SocketAddr,
ecn: Option<EcnCodepoint>,
partial_decode: PartialDecode,
) {
let header_crypto = if partial_decode.is_0rtt() {
if let Some(ref crypto) = self.zero_rtt_crypto {
Some(&crypto.header)
} else {
debug!("dropping unexpected 0-RTT packet");
return;
}
} else if let Some(space) = partial_decode.space() {
if let Some(ref crypto) = self.spaces[space].crypto {
Some(&crypto.header.remote)
} else {
debug!(
"discarding unexpected {:?} packet ({} bytes)",
space,
partial_decode.len(),
);
return;
}
} else {
None
};
match partial_decode.finish(header_crypto) {
Ok(packet) => self.handle_packet(now, remote, ecn, packet),
Err(e) => {
trace!("unable to complete packet decoding: {}", e);
}
}
}
fn handle_packet(
&mut self,
now: Instant,
remote: SocketAddr,
ecn: Option<EcnCodepoint>,
mut packet: Packet,
) {
trace!(
"got {:?} packet ({} bytes) from {} using id {}",
packet.header.space(),
packet.payload.len() + packet.header_data.len(),
remote,
packet.header.dst_cid(),
);
if self.is_handshaking() && remote != self.path.remote {
debug!("discarding packet with unexpected remote during handshake");
return;
}
let was_closed = self.state.is_closed();
let was_drained = self.state.is_drained();
let stateless_reset = self
.peer_params
.stateless_reset_token
.map_or(false, |token| {
packet.payload.len() >= RESET_TOKEN_SIZE
&& packet.payload[packet.payload.len() - RESET_TOKEN_SIZE..] == token[..]
});
let result = match self.decrypt_packet(now, &mut packet) {
Err(Some(e)) => {
warn!("illegal packet: {}", e);
Err(e.into())
}
Err(None) => {
if stateless_reset {
debug!("got stateless reset");
Err(ConnectionError::Reset)
} else {
debug!("failed to authenticate packet");
self.authentication_failures += 1;
let integrity_limit = self.spaces[self.highest_space]
.crypto
.as_ref()
.unwrap()
.packet
.local
.integrity_limit();
if self.authentication_failures > integrity_limit {
Err(TransportError::AEAD_LIMIT_REACHED("integrity limit violated").into())
} else {
return;
}
}
}
Ok(number) => {
let span = match number {
Some(pn) => trace_span!("recv", space = ?packet.header.space(), pn),
None => trace_span!("recv", space = ?packet.header.space()),
};
let _guard = span.enter();
let is_duplicate = |n| self.spaces[packet.header.space()].dedup.insert(n);
if number.map_or(false, is_duplicate) {
if stateless_reset {
Err(ConnectionError::Reset)
} else {
warn!("discarding possible duplicate packet");
return;
}
} else if self.state.is_handshake() && packet.header.is_short() {
trace!("dropping short packet during handshake");
return;
} else {
if !self.state.is_closed() {
let spin = match packet.header {
Header::Short { spin, .. } => spin,
_ => false,
};
self.on_packet_authenticated(
now,
packet.header.space(),
ecn,
number,
spin,
packet.header.is_1rtt(),
);
}
self.process_decrypted_packet(now, remote, number, packet)
}
}
};
if let Err(conn_err) = result {
self.events.push_back(conn_err.clone().into());
self.state = match conn_err {
ConnectionError::ApplicationClosed(reason) => State::closed(reason),
ConnectionError::ConnectionClosed(reason) => State::closed(reason),
ConnectionError::Reset
| ConnectionError::TransportError(TransportError {
code: TransportErrorCode::AEAD_LIMIT_REACHED,
..
}) => State::Drained,
ConnectionError::TimedOut => {
unreachable!("timeouts aren't generated by packet processing");
}
ConnectionError::TransportError(err) => {
debug!("closing connection due to transport error: {}", err);
State::closed(err)
}
ConnectionError::VersionMismatch => State::Draining,
ConnectionError::LocallyClosed => {
unreachable!("LocallyClosed isn't generated by packet processing")
}
};
}
if !was_closed && self.state.is_closed() {
self.close_common();
if !self.state.is_drained() {
self.set_close_timer(now);
}
}
if !was_drained && self.state.is_drained() {
self.endpoint_events.push_back(EndpointEventInner::Drained);
self.timers.stop(Timer::Close);
}
if let State::Closed(_) = self.state {
self.close = remote == self.path.remote;
}
}
fn process_decrypted_packet(
&mut self,
now: Instant,
remote: SocketAddr,
number: Option<u64>,
packet: Packet,
) -> Result<(), ConnectionError> {
let state = match self.state {
State::Established => {
match packet.header.space() {
SpaceId::Data => {
self.process_payload(now, remote, number.unwrap(), packet.payload.freeze())?
}
_ => self.process_early_payload(now, packet)?,
}
return Ok(());
}
State::Closed(_) => {
for frame in frame::Iter::new(packet.payload.freeze()) {
if let Frame::Padding = frame {
continue;
};
self.stats.frame_rx.record(&frame);
if let Frame::Close(_) = frame {
trace!("draining");
self.state = State::Draining;
break;
}
}
return Ok(());
}
State::Draining | State::Drained => return Ok(()),
State::Handshake(ref mut state) => state,
};
match packet.header {
Header::Retry {
src_cid: rem_cid, ..
} => {
if self.side.is_server() {
return Err(TransportError::PROTOCOL_VIOLATION("client sent Retry").into());
}
if self.total_authed_packets > 1
|| packet.payload.len() <= 16
|| !S::is_valid_retry(
&self.rem_cids.active(),
&packet.header_data,
&packet.payload,
)
{
trace!("discarding invalid Retry");
return Ok(());
}
trace!("retrying with CID {}", rem_cid);
let client_hello = state.client_hello.take().unwrap();
self.retry_src_cid = Some(rem_cid);
self.rem_cids.update_cid(rem_cid);
self.rem_handshake_cid = rem_cid;
let space = &mut self.spaces[SpaceId::Initial];
if let Some(info) = space.sent_packets.remove(&0) {
space.pending_acks.subtract(&info.acks);
self.on_packet_acked(now, SpaceId::Initial, info);
};
self.discard_space(now, SpaceId::Initial);
self.spaces[SpaceId::Initial] = PacketSpace {
crypto: Some(S::initial_keys(&rem_cid, self.side)),
next_packet_number: self.spaces[SpaceId::Initial].next_packet_number,
crypto_offset: client_hello.len() as u64,
..PacketSpace::new(now)
};
self.spaces[SpaceId::Initial]
.pending
.crypto
.push_back(frame::Crypto {
offset: 0,
data: client_hello,
});
let zero_rtt = mem::replace(
&mut self.spaces[SpaceId::Data].sent_packets,
BTreeMap::new(),
);
for (_, info) in zero_rtt {
self.remove_in_flight(SpaceId::Data, &info);
self.spaces[SpaceId::Data].pending |= info.retransmits;
}
self.streams.retransmit_all_for_0rtt();
let token_len = packet.payload.len() - 16;
self.state = State::Handshake(state::Handshake {
token: Some(packet.payload.freeze().split_to(token_len)),
rem_cid_set: false,
client_hello: None,
});
Ok(())
}
Header::Long {
ty: LongType::Handshake,
src_cid: rem_cid,
..
} => {
if rem_cid != self.rem_handshake_cid {
debug!(
"discarding packet with mismatched remote CID: {} != {}",
self.rem_handshake_cid, rem_cid
);
return Ok(());
}
self.path.validated = true;
let state = state.clone();
self.process_early_payload(now, packet)?;
if self.state.is_closed() {
return Ok(());
}
if self.crypto.is_handshaking() {
trace!("handshake ongoing");
self.state = State::Handshake(state::Handshake {
token: None,
..state
});
return Ok(());
}
if self.side.is_client() {
let params =
self.crypto
.transport_parameters()?
.ok_or_else(|| TransportError {
code: TransportErrorCode::crypto(0x6d),
frame: None,
reason: "transport parameters missing".into(),
})?;
if self.has_0rtt() {
if !self.crypto.early_data_accepted().unwrap() {
debug_assert!(self.side.is_client());
debug!("0-RTT rejected");
self.accepted_0rtt = false;
self.streams.zero_rtt_rejected();
self.spaces[SpaceId::Data].pending = Retransmits::default();
let sent_packets = mem::replace(
&mut self.spaces[SpaceId::Data].sent_packets,
BTreeMap::new(),
);
for (_, packet) in sent_packets {
self.remove_in_flight(SpaceId::Data, &packet);
}
} else {
self.accepted_0rtt = true;
params.validate_resumption_from(&self.peer_params)?;
}
}
if let Some(token) = params.stateless_reset_token {
self.endpoint_events
.push_back(EndpointEventInner::ResetToken(self.path.remote, token));
}
self.handle_peer_params(params)?;
self.issue_cids(now);
} else {
self.spaces[SpaceId::Data].pending.handshake_done = true;
self.discard_space(now, SpaceId::Handshake);
}
self.events.push_back(Event::Connected);
self.state = State::Established;
trace!("established");
Ok(())
}
Header::Initial {
src_cid: rem_cid, ..
} => {
if !state.rem_cid_set {
trace!("switching remote CID to {}", rem_cid);
let mut state = state.clone();
self.rem_cids.update_cid(rem_cid);
self.rem_handshake_cid = rem_cid;
self.orig_rem_cid = rem_cid;
state.rem_cid_set = true;
self.state = State::Handshake(state);
} else if rem_cid != self.rem_handshake_cid {
debug!(
"discarding packet with mismatched remote CID: {} != {}",
self.rem_handshake_cid, rem_cid
);
return Ok(());
}
let starting_space = self.highest_space;
self.process_early_payload(now, packet)?;
if self.side.is_server()
&& starting_space == SpaceId::Initial
&& self.highest_space != SpaceId::Initial
{
let params =
self.crypto
.transport_parameters()?
.ok_or_else(|| TransportError {
code: TransportErrorCode::crypto(0x6d),
frame: None,
reason: "transport parameters missing".into(),
})?;
self.handle_peer_params(params)?;
self.issue_cids(now);
self.init_0rtt();
}
Ok(())
}
Header::Long {
ty: LongType::ZeroRtt,
..
} => {
self.process_payload(now, remote, number.unwrap(), packet.payload.freeze())?;
Ok(())
}
Header::VersionNegotiate { .. } => {
if self.total_authed_packets > 1 {
return Ok(());
}
let supported = packet
.payload
.chunks(4)
.any(|x| match <[u8; 4]>::try_from(x) {
Ok(version) => self.version == u32::from_be_bytes(version),
Err(_) => false,
});
if supported {
return Ok(());
}
debug!("remote doesn't support our version");
Err(ConnectionError::VersionMismatch)
}
Header::Short { .. } => unreachable!(
"short packets received during handshake are discarded in handle_packet"
),
}
}
fn process_early_payload(
&mut self,
now: Instant,
packet: Packet,
) -> Result<(), TransportError> {
debug_assert_ne!(packet.header.space(), SpaceId::Data);
let payload_len = packet.payload.len();
for frame in frame::Iter::new(packet.payload.freeze()) {
let span = match frame {
Frame::Padding => continue,
_ => Some(trace_span!("frame", ty = %frame.ty())),
};
self.stats.frame_rx.record(&frame);
let _guard = span.as_ref().map(|x| x.enter());
match frame {
Frame::Ack(_) | Frame::Padding | Frame::Close(Close::Connection(_)) => {}
_ => {
self.spaces[packet.header.space()].permit_ack_only = true;
}
}
match frame {
Frame::Padding | Frame::Ping => {}
Frame::Crypto(frame) => {
self.read_crypto(packet.header.space(), &frame, payload_len)?;
}
Frame::Ack(ack) => {
self.on_ack_received(now, packet.header.space(), ack)?;
}
Frame::Close(reason) => {
self.events.push_back(ConnectionError::from(reason).into());
self.state = State::Draining;
return Ok(());
}
Frame::Invalid { ty, reason } => {
let mut err = TransportError::FRAME_ENCODING_ERROR(reason);
err.frame = Some(ty);
return Err(err);
}
_ => {
let mut err =
TransportError::PROTOCOL_VIOLATION("illegal frame type in handshake");
err.frame = Some(frame.ty());
return Err(err);
}
}
}
self.write_crypto();
Ok(())
}
fn process_payload(
&mut self,
now: Instant,
remote: SocketAddr,
number: u64,
payload: Bytes,
) -> Result<(), TransportError> {
let is_0rtt = self.spaces[SpaceId::Data].crypto.is_none();
let mut is_probing_packet = true;
let mut close = None;
let payload_len = payload.len();
for frame in frame::Iter::new(payload) {
let span = match frame {
Frame::Padding => continue,
_ => Some(trace_span!("frame", ty = %frame.ty())),
};
self.stats.frame_rx.record(&frame);
let _guard = span.as_ref().map(|x| x.enter());
if is_0rtt {
match frame {
Frame::Crypto(_) | Frame::Close(Close::Application(_)) => {
return Err(TransportError::PROTOCOL_VIOLATION(
"illegal frame type in 0-RTT",
));
}
_ => {}
}
}
match frame {
Frame::Ack(_) | Frame::Padding | Frame::Close(_) => {}
_ => {
self.spaces[SpaceId::Data].permit_ack_only = true;
}
}
match frame {
Frame::Padding
| Frame::PathChallenge(_)
| Frame::PathResponse(_)
| Frame::NewConnectionId(_) => {}
_ => {
is_probing_packet = false;
}
}
match frame {
Frame::Invalid { ty, reason } => {
let mut err = TransportError::FRAME_ENCODING_ERROR(reason);
err.frame = Some(ty);
return Err(err);
}
Frame::Crypto(frame) => {
self.read_crypto(SpaceId::Data, &frame, payload_len)?;
}
Frame::Stream(frame) => {
if self.streams.received(frame, payload_len)?.should_transmit() {
self.spaces[SpaceId::Data].pending.max_data = true;
}
}
Frame::Ack(ack) => {
self.on_ack_received(now, SpaceId::Data, ack)?;
}
Frame::Padding | Frame::Ping => {}
Frame::Close(reason) => {
close = Some(reason);
}
Frame::PathChallenge(token) => {
if self
.path_response
.as_ref()
.map_or(true, |x| x.packet <= number)
{
self.path_response = Some(PathResponse {
packet: number,
token,
});
}
if remote == self.path.remote {
self.ping();
}
}
Frame::PathResponse(token) => {
if self.path.challenge == Some(token) && remote == self.path.remote {
trace!("new path validated");
self.timers.stop(Timer::PathValidation);
self.path.challenge = None;
self.path.validated = true;
if let Some(ref mut prev_path) = self.prev_path {
prev_path.challenge = None;
prev_path.challenge_pending = false;
}
} else if let Some(ref prev_path) = self.prev_path {
if prev_path.challenge == Some(token) && remote == prev_path.remote {
warn!("spurious migration detected");
self.timers.stop(Timer::PathValidation);
self.path = self.prev_path.take().unwrap();
self.path.challenge = None;
}
}
}
Frame::MaxData(bytes) => {
self.streams.received_max_data(bytes);
}
Frame::MaxStreamData { id, offset } => {
self.streams.received_max_stream_data(id, offset)?;
}
Frame::MaxStreams { dir, count } => {
self.streams.received_max_streams(dir, count)?;
}
Frame::ResetStream(frame) => {
if self.streams.received_reset(frame)?.should_transmit() {
self.spaces[SpaceId::Data].pending.max_data = true;
}
}
Frame::DataBlocked { offset } => {
debug!(offset, "peer claims to be blocked at connection level");
}
Frame::StreamDataBlocked { id, offset } => {
if id.initiator() == self.side && id.dir() == Dir::Uni {
debug!("got STREAM_DATA_BLOCKED on send-only {}", id);
return Err(TransportError::STREAM_STATE_ERROR(
"STREAM_DATA_BLOCKED on send-only stream",
));
}
debug!(
stream = %id,
offset, "peer claims to be blocked at stream level"
);
}
Frame::StreamsBlocked { dir, limit } => {
if limit > MAX_STREAM_COUNT {
return Err(TransportError::FRAME_ENCODING_ERROR(
"unrepresentable stream limit",
));
}
debug!(
"peer claims to be blocked opening more than {} {} streams",
limit, dir
);
}
Frame::StopSending(frame::StopSending { id, error_code }) => {
if id.initiator() != self.side {
if id.dir() == Dir::Uni {
debug!("got STOP_SENDING on recv-only {}", id);
return Err(TransportError::STREAM_STATE_ERROR(
"STOP_SENDING on recv-only stream",
));
}
} else if self.streams.is_local_unopened(id) {
return Err(TransportError::STREAM_STATE_ERROR(
"STOP_SENDING on unopened stream",
));
}
self.streams.received_stop_sending(id, error_code);
}
Frame::RetireConnectionId { sequence } => {
let allow_more_cids = self
.local_cid_state
.on_cid_retirement(sequence, self.peer_params.issue_cids_limit())?;
self.endpoint_events
.push_back(EndpointEventInner::RetireConnectionId(
now,
sequence,
allow_more_cids,
));
}
Frame::NewConnectionId(frame) => {
trace!(
sequence = frame.sequence,
id = %frame.id,
retire_prior_to = frame.retire_prior_to,
);
if self.rem_cids.active().is_empty() {
return Err(TransportError::PROTOCOL_VIOLATION(
"NEW_CONNECTION_ID when CIDs aren't in use",
));
}
if frame.retire_prior_to > frame.sequence {
return Err(TransportError::PROTOCOL_VIOLATION(
"NEW_CONNECTION_ID retiring unissued CIDs",
));
}
let retired = self.rem_cids.retire_prior_to(frame.retire_prior_to);
self.spaces[SpaceId::Data]
.pending
.retire_cids
.extend(retired);
use crate::cid_queue::InsertError;
let new_rem_cid = IssuedCid {
sequence: frame.sequence,
id: frame.id,
reset_token: frame.reset_token,
};
match self.rem_cids.insert(new_rem_cid) {
Ok(()) => {}
Err(InsertError::ExceedsLimit) => {
return Err(TransportError::CONNECTION_ID_LIMIT_ERROR(""));
}
Err(InsertError::Retired) => {
trace!("discarding already-retired");
self.spaces[SpaceId::Data]
.pending
.retire_cids
.push(frame.sequence);
continue;
}
}
if self.side.is_server() && self.peer_params.stateless_reset_token.is_none() {
debug_assert_eq!(self.rem_cids.active_seq(), 0);
self.update_rem_cid().unwrap();
} else if self.rem_cids.is_active_retired() {
self.update_rem_cid().unwrap();
}
}
Frame::NewToken { token } => {
if self.side.is_server() {
return Err(TransportError::PROTOCOL_VIOLATION("client sent NEW_TOKEN"));
}
if token.is_empty() {
return Err(TransportError::FRAME_ENCODING_ERROR("empty token"));
}
trace!("got new token");
}
Frame::Datagram(datagram) => {
if self
.datagrams
.received(datagram, &self.config.datagram_receive_buffer_size)?
{
self.events.push_back(Event::DatagramReceived);
}
}
Frame::HandshakeDone => {
if self.side.is_server() {
return Err(TransportError::PROTOCOL_VIOLATION(
"client sent HANDSHAKE_DONE",
));
}
if self.spaces[SpaceId::Handshake].crypto.is_some() {
self.discard_space(now, SpaceId::Handshake);
}
}
}
}
let pending = &mut self.spaces[SpaceId::Data].pending;
for dir in Dir::iter() {
if self.streams.take_max_streams_dirty(dir) {
match dir {
Dir::Uni => pending.max_uni_stream_id = true,
Dir::Bi => pending.max_bi_stream_id = true,
}
}
}
if let Some(reason) = close {
self.events.push_back(ConnectionError::from(reason).into());
self.state = State::Draining;
self.close = true;
}
if remote != self.path.remote
&& !is_probing_packet
&& number == self.spaces[SpaceId::Data].rx_packet
{
debug_assert!(
self.server_config
.as_ref()
.expect("packets from unknown remote should be dropped by clients")
.migration,
"migration-initiating packets should have been dropped immediately"
);
self.migrate(now, remote);
let _ = self.update_rem_cid();
}
Ok(())
}
fn migrate(&mut self, now: Instant, remote: SocketAddr) {
trace!(%remote, "migration initiated");
let mut new_path = if remote.is_ipv4() && remote.ip() == self.path.remote.ip() {
PathData::from_previous(remote, &self.path, now)
} else {
PathData::new(
remote,
self.config.initial_rtt,
self.config.congestion_controller_factory.build(now),
now,
false,
)
};
new_path.challenge = Some(self.rng.gen());
new_path.challenge_pending = true;
let mut prev = mem::replace(&mut self.path, new_path);
if prev.challenge.is_none() {
prev.challenge = Some(self.rng.gen());
prev.challenge_pending = true;
self.prev_path = Some(prev);
}
self.timers.set(
Timer::PathValidation,
now + 3 * cmp::max(self.pto(), 2 * self.config.initial_rtt),
);
}
fn update_rem_cid(&mut self) -> Result<(), ()> {
let (reset_token, retired) = self.rem_cids.next().ok_or(())?;
let retire_cids = &mut self.spaces[SpaceId::Data].pending.retire_cids;
retire_cids.extend(retired);
self.endpoint_events
.push_back(EndpointEventInner::ResetToken(
self.path.remote,
reset_token,
));
self.peer_params.stateless_reset_token = Some(reset_token);
self.spin = false;
Ok(())
}
fn issue_cids(&mut self, now: Instant) {
if self.local_cid_state.cid_len() == 0 {
return;
}
let n = self.peer_params.issue_cids_limit() - 1;
self.endpoint_events
.push_back(EndpointEventInner::NeedIdentifiers(now, n));
}
fn populate_packet(
&mut self,
space_id: SpaceId,
buf: &mut Vec<u8>,
max_size: usize,
) -> SentFrames {
let mut sent = SentFrames::default();
let space = &mut self.spaces[space_id];
let is_0rtt = space_id == SpaceId::Data && space.crypto.is_none();
if !is_0rtt && mem::replace(&mut space.pending.handshake_done, false) {
buf.write(frame::Type::HANDSHAKE_DONE);
sent.retransmits.get_or_create().handshake_done = true;
self.stats.frame_tx.handshake_done =
self.stats.frame_tx.handshake_done.saturating_add(1);
}
if mem::replace(&mut space.ping_pending, false) {
trace!("PING");
buf.write(frame::Type::PING);
self.stats.frame_tx.ping += 1;
}
if !space.pending_acks.is_empty() {
debug_assert!(space.crypto.is_some(), "tried to send ACK in 0-RTT");
trace!("ACK");
let ecn = if self.receiving_ecn {
Some(&space.ecn_counters)
} else {
None
};
frame::Ack::encode(0, &space.pending_acks, ecn, buf);
sent.acks = space.pending_acks.clone();
self.stats.frame_tx.acks += 1;
}
if buf.len() + 9 < max_size && space_id == SpaceId::Data {
if let Some(token) = self.path.challenge {
self.path.challenge_pending = false;
sent.requires_padding = true;
trace!("PATH_CHALLENGE {:08x}", token);
buf.write(frame::Type::PATH_CHALLENGE);
buf.write(token);
self.stats.frame_tx.path_challenge += 1;
}
}
if buf.len() + 9 < max_size && space_id == SpaceId::Data {
if let Some(response) = self.path_response.take() {
sent.requires_padding = true;
trace!("PATH_RESPONSE {:08x}", response.token);
buf.write(frame::Type::PATH_RESPONSE);
buf.write(response.token);
self.stats.frame_tx.path_response += 1;
}
}
while buf.len() + frame::Crypto::SIZE_BOUND < max_size && !is_0rtt {
let mut frame = match space.pending.crypto.pop_front() {
Some(x) => x,
None => break,
};
let len = cmp::min(
frame.data.len(),
max_size as usize - buf.len() - frame::Crypto::SIZE_BOUND,
);
let data = frame.data.split_to(len);
let truncated = frame::Crypto {
offset: frame.offset,
data,
};
trace!(
"CRYPTO: off {} len {}",
truncated.offset,
truncated.data.len()
);
truncated.encode(buf);
self.stats.frame_tx.crypto += 1;
sent.retransmits.get_or_create().crypto.push_back(truncated);
if !frame.data.is_empty() {
frame.offset += len as u64;
space.pending.crypto.push_front(frame);
}
}
if space_id == SpaceId::Data {
self.streams.write_control_frames(
buf,
&mut space.pending,
&mut sent.retransmits,
&mut self.stats.frame_tx,
max_size,
);
}
while buf.len() + 44 < max_size {
let issued = match space.pending.new_cids.pop() {
Some(x) => x,
None => break,
};
trace!(
sequence = issued.sequence,
id = %issued.id,
"NEW_CONNECTION_ID"
);
frame::NewConnectionId {
sequence: issued.sequence,
retire_prior_to: self.local_cid_state.retire_prior_to(),
id: issued.id,
reset_token: issued.reset_token,
}
.encode(buf);
sent.retransmits.get_or_create().new_cids.push(issued);
self.stats.frame_tx.new_connection_id += 1;
}
while buf.len() + frame::RETIRE_CONNECTION_ID_SIZE_BOUND < max_size {
let seq = match space.pending.retire_cids.pop() {
Some(x) => x,
None => break,
};
trace!(sequence = seq, "RETIRE_CONNECTION_ID");
buf.write(frame::Type::RETIRE_CONNECTION_ID);
buf.write_var(seq);
sent.retransmits.get_or_create().retire_cids.push(seq);
self.stats.frame_tx.retire_connection_id += 1;
}
while buf.len() + Datagram::SIZE_BOUND < max_size && space_id == SpaceId::Data {
match self.datagrams.write(buf, max_size) {
true => self.stats.frame_tx.datagram += 1,
false => break,
}
}
if space_id == SpaceId::Data {
sent.stream_frames = self.streams.write_stream_frames(buf, max_size);
self.stats.frame_tx.stream += sent.stream_frames.len() as u64;
}
sent
}
fn close_common(&mut self) {
trace!("connection closed");
for &timer in &Timer::VALUES {
self.timers.stop(timer);
}
}
fn set_close_timer(&mut self, now: Instant) {
self.timers.set(Timer::Close, now + 3 * self.pto());
}
fn handle_peer_params(&mut self, params: TransportParameters) -> Result<(), TransportError> {
if Some(self.orig_rem_cid) != params.initial_src_cid
|| (self.side.is_client()
&& (Some(self.initial_dst_cid) != params.original_dst_cid
|| self.retry_src_cid != params.retry_src_cid))
{
return Err(TransportError::TRANSPORT_PARAMETER_ERROR(
"CID authentication failure",
));
}
self.set_peer_params(params);
Ok(())
}
fn set_peer_params(&mut self, params: TransportParameters) {
self.streams.set_params(¶ms);
self.idle_timeout = match (self.config.max_idle_timeout, params.max_idle_timeout.0) {
(None, 0) => None,
(None, x) => Some(Duration::from_millis(x)),
(Some(x), 0) => Some(x),
(Some(x), y) => Some(cmp::min(x, Duration::from_millis(y))),
};
if let Some(ref info) = params.preferred_address {
self.rem_cids.insert(IssuedCid {
sequence: 1,
id: info.connection_id,
reset_token: info.stateless_reset_token,
}).expect("preferred address CID is the first received, and hence is guaranteed to be legal");
}
self.peer_params = params;
}
fn decrypt_packet(
&mut self,
now: Instant,
packet: &mut Packet,
) -> Result<Option<u64>, Option<TransportError>> {
if !packet.header.is_protected() {
return Ok(None);
}
let space = packet.header.space();
let rx_packet = self.spaces[space].rx_packet;
let number = packet.header.number().ok_or(None)?.expand(rx_packet + 1);
let key_phase = packet.header.key_phase();
let mut crypto_update = false;
let crypto = if packet.header.is_0rtt() {
&self.zero_rtt_crypto.as_ref().unwrap().packet
} else if key_phase == self.key_phase || space != SpaceId::Data {
&self.spaces[space].crypto.as_mut().unwrap().packet.remote
} else if let Some(prev) = self.prev_crypto.as_ref().and_then(|crypto| {
if crypto.end_packet.map_or(true, |(pn, _)| number < pn) {
Some(crypto)
} else {
None
}
}) {
&prev.crypto.remote
} else {
crypto_update = true;
&self.next_crypto.as_ref().unwrap().remote
};
crypto
.decrypt(number, &packet.header_data, &mut packet.payload)
.map_err(|_| {
trace!("decryption failed with packet number {}", number);
None
})?;
if let Some(ref mut prev) = self.prev_crypto {
if prev.end_packet.is_none() && key_phase == self.key_phase {
prev.end_packet = Some((number, now));
self.set_key_discard_timer(now);
}
}
if !packet.reserved_bits_valid() {
return Err(Some(TransportError::PROTOCOL_VIOLATION(
"reserved bits set",
)));
}
if crypto_update {
if number <= rx_packet
|| self
.prev_crypto
.as_ref()
.map_or(false, |x| x.update_unacked)
{
return Err(Some(TransportError::KEY_UPDATE_ERROR("")));
}
trace!("key update authenticated");
self.update_keys(Some((number, now)), true);
self.set_key_discard_timer(now);
}
Ok(Some(number))
}
fn update_keys(&mut self, end_packet: Option<(u64, Instant)>, remote: bool) {
let new = self.crypto.next_1rtt_keys();
let old = mem::replace(
&mut self.spaces[SpaceId::Data]
.crypto
.as_mut()
.unwrap()
.packet,
mem::replace(self.next_crypto.as_mut().unwrap(), new),
);
self.spaces[SpaceId::Data].sent_with_keys = 0;
self.prev_crypto = Some(PrevCrypto {
crypto: old,
end_packet,
update_unacked: remote,
});
self.key_phase = !self.key_phase;
}
#[cfg(test)]
pub(crate) fn bytes_in_flight(&self) -> u64 {
self.in_flight.bytes
}
#[cfg(test)]
pub(crate) fn congestion_state(&self) -> u64 {
self.path
.congestion
.window()
.saturating_sub(self.in_flight.bytes)
}
#[cfg(test)]
pub(crate) fn is_idle(&self) -> bool {
Timer::VALUES
.iter()
.filter(|&&t| t != Timer::KeepAlive && t != Timer::PushNewCid)
.filter_map(|&t| Some((t, self.timers.get(t)?)))
.min_by_key(|&(_, time)| time)
.map_or(true, |(timer, _)| timer == Timer::Idle)
}
#[cfg(test)]
pub(crate) fn lost_packets(&self) -> u64 {
self.lost_packets
}
#[cfg(test)]
pub(crate) fn using_ecn(&self) -> bool {
self.path.sending_ecn
}
#[cfg(test)]
pub(crate) fn active_local_cid_seq(&self) -> (u64, u64) {
self.local_cid_state.active_seq()
}
#[cfg(test)]
pub(crate) fn rotate_local_cid(&mut self, v: u64, now: Instant) {
let n = self.local_cid_state.assign_retire_seq(v);
self.endpoint_events
.push_back(EndpointEventInner::NeedIdentifiers(now, n));
}
#[cfg(test)]
pub(crate) fn active_rem_cid_seq(&self) -> u64 {
self.rem_cids.active_seq()
}
fn max_ack_delay(&self) -> Duration {
Duration::from_micros(self.peer_params.max_ack_delay.0 * 1000)
}
fn can_send_1rtt(&self) -> bool {
self.streams.can_send()
|| self.path.challenge_pending
|| self
.prev_path
.as_ref()
.map_or(false, |x| x.challenge_pending)
|| self.path_response.is_some()
|| !self.datagrams.outgoing.is_empty()
}
fn remove_in_flight(&mut self, space: SpaceId, packet: &SentPacket) {
self.in_flight.bytes -= u64::from(packet.size);
self.in_flight.ack_eliciting -= u64::from(packet.ack_eliciting);
self.spaces[space].in_flight -= u64::from(packet.size);
}
fn kill(&mut self, reason: ConnectionError) {
self.close_common();
self.events.push_back(reason.into());
self.state = State::Drained;
self.endpoint_events.push_back(EndpointEventInner::Drained);
}
}
impl<S> fmt::Debug for Connection<S>
where
S: crypto::Session,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Connection")
.field("handshake_cid", &self.handshake_cid)
.finish()
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ConnectionError {
#[error("peer doesn't implement any supported version")]
VersionMismatch,
#[error("{0}")]
TransportError(#[from] TransportError),
#[error("aborted by peer: {0}")]
ConnectionClosed(frame::ConnectionClose),
#[error("closed by peer: {0}")]
ApplicationClosed(frame::ApplicationClose),
#[error("reset by peer")]
Reset,
#[error("timed out")]
TimedOut,
#[error("closed")]
LocallyClosed,
}
impl From<Close> for ConnectionError {
fn from(x: Close) -> Self {
match x {
Close::Connection(reason) => ConnectionError::ConnectionClosed(reason),
Close::Application(reason) => ConnectionError::ApplicationClosed(reason),
}
}
}
impl From<ConnectionError> for io::Error {
fn from(x: ConnectionError) -> io::Error {
use self::ConnectionError::*;
let kind = match x {
TimedOut => io::ErrorKind::TimedOut,
Reset => io::ErrorKind::ConnectionReset,
ApplicationClosed(_) | ConnectionClosed(_) => io::ErrorKind::ConnectionAborted,
TransportError(_) | VersionMismatch | LocallyClosed => io::ErrorKind::Other,
};
io::Error::new(kind, x)
}
}
#[derive(Clone)]
pub enum State {
Handshake(state::Handshake),
Established,
Closed(state::Closed),
Draining,
Drained,
}
impl State {
fn closed<R: Into<Close>>(reason: R) -> Self {
State::Closed(state::Closed {
reason: reason.into(),
})
}
fn is_handshake(&self) -> bool {
matches!(*self, State::Handshake(_))
}
fn is_established(&self) -> bool {
matches!(*self, State::Established)
}
fn is_closed(&self) -> bool {
matches!(*self, State::Closed(_) | State::Draining | State::Drained)
}
fn is_drained(&self) -> bool {
matches!(*self, State::Drained)
}
}
mod state {
use super::*;
#[derive(Clone)]
pub struct Handshake {
pub rem_cid_set: bool,
pub token: Option<Bytes>,
pub client_hello: Option<Bytes>,
}
#[derive(Clone)]
pub struct Closed {
pub reason: Close,
}
}
const MAX_ACK_BLOCKS: usize = 64;
struct PrevCrypto<K>
where
K: crypto::PacketKey,
{
crypto: KeyPair<K>,
end_packet: Option<(u64, Instant)>,
update_unacked: bool,
}
struct InFlight {
bytes: u64,
ack_eliciting: u64,
}
impl InFlight {
pub fn new() -> Self {
Self {
bytes: 0,
ack_eliciting: 0,
}
}
fn is_empty(&self) -> bool {
self.bytes == 0
}
fn insert(&mut self, packet: &SentPacket) {
self.bytes += u64::from(packet.size);
self.ack_eliciting += u64::from(packet.ack_eliciting);
}
}
#[derive(Debug)]
pub enum Event {
HandshakeDataReady,
Connected,
ConnectionLost {
reason: ConnectionError,
},
Stream(StreamEvent),
DatagramReceived,
}
impl From<ConnectionError> for Event {
fn from(x: ConnectionError) -> Self {
Event::ConnectionLost { reason: x }
}
}
struct PathResponse {
packet: u64,
token: u64,
}
fn instant_saturating_sub(x: Instant, y: Instant) -> Duration {
if x > y {
x - y
} else {
Duration::new(0, 0)
}
}
const MAX_BACKOFF_EXPONENT: u32 = 16;
const MIN_PACKET_SPACE: usize = 40;
struct ZeroRttCrypto<S: crypto::Session> {
header: S::HeaderKey,
packet: S::PacketKey,
}
#[derive(Default)]
struct SentFrames {
retransmits: ThinRetransmits,
acks: RangeSet,
stream_frames: StreamMetaVec,
requires_padding: bool,
}