#![allow(dead_code)]
use crate::cx::Cx;
use crate::net::quic_core::{ConnectionId, LongPacketType, PacketHeader, QuicCoreError};
use crate::net::quic_native::{
NativeQuicConnection, NativeQuicConnectionConfig, OutgoingPacket, ReceivedPacket,
};
use crate::net::quic_native::{NativeQuicConnectionError, PacketNumberSpace};
use crate::time::Sleep;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct ConnectionRouter {
connections: HashMap<ConnectionId, ConnectionHandle>,
next_connection_id: u64,
config_template: NativeQuicConnectionConfig,
clock_origin: Instant,
}
#[derive(Debug)]
pub struct ConnectionHandle {
connection: NativeQuicConnection,
peer_addr: SocketAddr,
last_activity: Instant,
established_at: Option<Instant>,
next_timer_deadline: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct ConnectionTimerEvent {
pub connection_id: ConnectionId,
pub timer_type: TimerType,
pub deadline: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TimerType {
ProbeTimeout,
AckDelay,
IdleTimeout,
DrainTimeout,
KeepAlive,
}
#[derive(Debug)]
pub enum RoutingResult {
Routed {
connection_id: ConnectionId,
outgoing_packets: Vec<OutgoingPacket>,
},
NewConnection {
connection_id: ConnectionId,
peer_addr: SocketAddr,
outgoing_packets: Vec<OutgoingPacket>,
},
Drop {
reason: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectionRouterError {
Cancelled,
ConnectionNotFound(ConnectionId),
InvalidConnectionState {
connection_id: ConnectionId,
reason: String,
},
ConnectionCreationFailed(String),
TimerSchedulingFailed(String),
PacketProcessingFailed {
connection_id: ConnectionId,
reason: String,
},
}
impl std::fmt::Display for ConnectionRouterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cancelled => write!(f, "operation cancelled"),
Self::ConnectionNotFound(cid) => write!(f, "connection not found: {cid:?}"),
Self::InvalidConnectionState {
connection_id,
reason,
} => {
write!(
f,
"invalid connection state for {connection_id:?}: {reason}"
)
}
Self::ConnectionCreationFailed(msg) => write!(f, "connection creation failed: {msg}"),
Self::TimerSchedulingFailed(msg) => write!(f, "timer scheduling failed: {msg}"),
Self::PacketProcessingFailed {
connection_id,
reason,
} => {
write!(
f,
"packet processing failed for {connection_id:?}: {reason}"
)
}
}
}
}
impl std::error::Error for ConnectionRouterError {}
impl ConnectionRouter {
pub fn new(config_template: NativeQuicConnectionConfig) -> Self {
Self {
connections: HashMap::new(),
next_connection_id: 1,
config_template,
clock_origin: Instant::now(),
}
}
pub async fn route_packet(
&mut self,
cx: &Cx,
packet: ReceivedPacket,
) -> Result<RoutingResult, ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
let routing_info = match self.decode_routing_info(&packet) {
Ok(info) => info,
Err(err) => {
return Ok(RoutingResult::Drop {
reason: format!("invalid QUIC header: {err}"),
});
}
};
let connection_id = routing_info.destination_cid;
let now_micros = self.instant_micros(packet.receive_time);
if let Some(handle) = self.connections.get_mut(&connection_id) {
handle.last_activity = Instant::now();
handle
.connection
.on_datagram_received(cx, packet.data.len() as u64)
.map_err(|err| ConnectionRouterError::PacketProcessingFailed {
connection_id,
reason: err.to_string(),
})?;
let payload = packet.data.get(routing_info.header_len..).ok_or_else(|| {
ConnectionRouterError::PacketProcessingFailed {
connection_id,
reason: "header length exceeded datagram length".to_string(),
}
})?;
handle
.connection
.process_packet_payload(
cx,
routing_info.space,
routing_info.packet_number,
payload,
now_micros,
)
.map_err(|err| ConnectionRouterError::PacketProcessingFailed {
connection_id,
reason: err.to_string(),
})?;
let outgoing_packets = drain_connection_frames(
cx,
connection_id,
handle,
routing_info.space,
packet.src_addr,
packet.receive_time,
)?;
Self::refresh_connection_timer(
cx,
connection_id,
handle,
self.clock_origin,
now_micros,
packet.receive_time,
)?;
cx.trace(&format!(
"Routed packet from {} to connection {connection_id:?}",
packet.src_addr
));
Ok(RoutingResult::Routed {
connection_id,
outgoing_packets,
})
} else if routing_info.kind == PacketRoutingKind::Initial {
let new_connection_id = self.allocate_connection_id();
cx.trace(&format!(
"New connection attempt from {} assigned ID {new_connection_id:?}",
packet.src_addr
));
Ok(RoutingResult::NewConnection {
connection_id: new_connection_id,
peer_addr: packet.src_addr,
outgoing_packets: Vec::new(),
})
} else {
Ok(RoutingResult::Drop {
reason: format!(
"unknown connection ID {connection_id:?} for {:?} packet",
routing_info.kind
),
})
}
}
pub async fn create_connection(
&mut self,
cx: &Cx,
connection_id: ConnectionId,
peer_addr: SocketAddr,
is_server: bool,
) -> Result<(), ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
let mut config = self.config_template;
config.role = if is_server {
crate::net::quic_native::StreamRole::Server
} else {
crate::net::quic_native::StreamRole::Client
};
let connection = NativeQuicConnection::new(config);
let handle = ConnectionHandle {
connection,
peer_addr,
last_activity: Instant::now(),
established_at: None,
next_timer_deadline: None,
};
self.connections.insert(connection_id, handle);
cx.trace(&format!(
"Created new connection {connection_id:?} for peer {peer_addr}"
));
Ok(())
}
pub fn remove_connection(
&mut self,
cx: &Cx,
connection_id: ConnectionId,
) -> Result<(), ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
if self.connections.remove(&connection_id).is_some() {
cx.trace(&format!("Removed connection {connection_id:?}"));
Ok(())
} else {
Err(ConnectionRouterError::ConnectionNotFound(connection_id))
}
}
pub fn close_all(
&mut self,
cx: &Cx,
now: Instant,
app_error_code: u64,
) -> Result<usize, ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
let now_micros = self.instant_micros(now);
for (connection_id, handle) in &mut self.connections {
handle
.connection
.begin_close(cx, now_micros, app_error_code)
.or_else(|_| handle.connection.close_immediately(cx, app_error_code))
.map_err(|err| ConnectionRouterError::PacketProcessingFailed {
connection_id: *connection_id,
reason: err.to_string(),
})?;
}
let closed = self.connections.len();
self.connections.clear();
Ok(closed)
}
fn refresh_connection_timer(
cx: &Cx,
connection_id: ConnectionId,
handle: &mut ConnectionHandle,
origin: Instant,
now_micros: u64,
now_instant: Instant,
) -> Result<(), ConnectionRouterError> {
handle.next_timer_deadline = handle
.connection
.pto_deadline_micros(cx, now_micros)
.map_err(|err| ConnectionRouterError::PacketProcessingFailed {
connection_id,
reason: err.to_string(),
})?
.and_then(|deadline| {
let delta = deadline.saturating_sub(now_micros);
origin
.checked_add(Duration::from_micros(deadline))
.or_else(|| now_instant.checked_add(Duration::from_micros(delta)))
});
Ok(())
}
pub fn next_timer_deadline(&self) -> Option<Instant> {
self.connections
.values()
.filter_map(|handle| handle.next_timer_deadline)
.min()
}
pub async fn process_timer_events(
&mut self,
cx: &Cx,
current_time: Instant,
) -> Result<Vec<OutgoingPacket>, ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
let mut outgoing_packets = Vec::new();
let origin = self.clock_origin;
for (connection_id, handle) in &mut self.connections {
if let Some(deadline) = handle.next_timer_deadline {
if current_time >= deadline {
cx.trace(&format!(
"Timer fired for connection {connection_id:?} at {current_time:?}"
));
handle.next_timer_deadline = None;
handle.connection.on_probe_timeout(cx).map_err(|err| {
ConnectionRouterError::PacketProcessingFailed {
connection_id: *connection_id,
reason: err.to_string(),
}
})?;
let peer_addr = handle.peer_addr;
outgoing_packets.extend(drain_connection_frames(
cx,
*connection_id,
handle,
PacketNumberSpace::ApplicationData,
peer_addr,
current_time,
)?);
Self::refresh_connection_timer(
cx,
*connection_id,
handle,
origin,
instant_micros_from(origin, current_time),
current_time,
)?;
}
}
}
Ok(outgoing_packets)
}
pub fn connection_stats(&self) -> ConnectionRouterStats {
let active_connections = self.connections.len();
let established_connections = self
.connections
.values()
.filter(|h| h.established_at.is_some())
.count();
ConnectionRouterStats {
active_connections,
established_connections,
pending_connections: active_connections - established_connections,
}
}
fn decode_routing_info(
&self,
packet: &ReceivedPacket,
) -> Result<PacketRoutingInfo, QuicCoreError> {
if packet.data.first().is_some_and(|first| first & 0x80 != 0) {
let (header, header_len) = PacketHeader::decode(&packet.data, 0)?;
return PacketRoutingInfo::from_header(header, header_len);
}
for cid_len in self.known_connection_id_lengths() {
if let Ok((header, header_len)) = PacketHeader::decode(&packet.data, cid_len) {
let info = PacketRoutingInfo::from_header(header, header_len)?;
if self.connections.contains_key(&info.destination_cid) {
return Ok(info);
}
}
}
let (header, header_len) = PacketHeader::decode(&packet.data, 0)?;
PacketRoutingInfo::from_header(header, header_len)
}
fn known_connection_id_lengths(&self) -> Vec<usize> {
let mut lengths = self
.connections
.keys()
.map(ConnectionId::len)
.collect::<Vec<_>>();
lengths.sort_unstable_by(|a, b| b.cmp(a));
lengths.dedup();
if !lengths.contains(&0) {
lengths.push(0);
}
lengths
}
fn instant_micros(&self, instant: Instant) -> u64 {
instant_micros_from(self.clock_origin, instant)
}
pub(crate) fn allocate_connection_id(&mut self) -> ConnectionId {
let id = self.next_connection_id;
self.next_connection_id += 1;
let id_bytes = id.to_be_bytes();
ConnectionId::new(&id_bytes).expect("Connection ID from counter should always be valid")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PacketRoutingKind {
Initial,
Handshake,
ZeroRtt,
OneRtt,
Retry,
}
#[derive(Debug, Clone)]
struct PacketRoutingInfo {
destination_cid: ConnectionId,
kind: PacketRoutingKind,
space: PacketNumberSpace,
packet_number: u64,
header_len: usize,
}
impl PacketRoutingInfo {
fn from_header(header: PacketHeader, header_len: usize) -> Result<Self, QuicCoreError> {
match header {
PacketHeader::Long(header) => {
let (kind, space) = match header.packet_type {
LongPacketType::Initial => {
(PacketRoutingKind::Initial, PacketNumberSpace::Initial)
}
LongPacketType::ZeroRtt => (
PacketRoutingKind::ZeroRtt,
PacketNumberSpace::ApplicationData,
),
LongPacketType::Handshake => {
(PacketRoutingKind::Handshake, PacketNumberSpace::Handshake)
}
LongPacketType::Retry => (PacketRoutingKind::Retry, PacketNumberSpace::Initial),
};
Ok(Self {
destination_cid: header.dst_cid,
kind,
space,
packet_number: header.packet_number,
header_len,
})
}
PacketHeader::Retry(header) => Ok(Self {
destination_cid: header.dst_cid,
kind: PacketRoutingKind::Retry,
space: PacketNumberSpace::Initial,
packet_number: 0,
header_len,
}),
PacketHeader::Short(header) => Ok(Self {
destination_cid: header.dst_cid,
kind: PacketRoutingKind::OneRtt,
space: PacketNumberSpace::ApplicationData,
packet_number: header.packet_number,
header_len,
}),
}
}
}
fn instant_micros_from(origin: Instant, instant: Instant) -> u64 {
instant
.checked_duration_since(origin)
.unwrap_or(Duration::ZERO)
.as_micros()
.min(u128::from(u64::MAX)) as u64
}
fn drain_connection_frames(
cx: &Cx,
connection_id: ConnectionId,
handle: &mut ConnectionHandle,
space: PacketNumberSpace,
dst_addr: SocketAddr,
now: Instant,
) -> Result<Vec<OutgoingPacket>, ConnectionRouterError> {
let frames = handle
.connection
.generate_frames(cx, space, 1_200)
.map_err(|err| ConnectionRouterError::PacketProcessingFailed {
connection_id,
reason: err.to_string(),
})?;
if frames.is_empty() {
return Ok(Vec::new());
}
let mut data = crate::bytes::BytesMut::new();
NativeQuicConnection::encode_frames(&frames, &mut data).map_err(
|err: NativeQuicConnectionError| ConnectionRouterError::PacketProcessingFailed {
connection_id,
reason: err.to_string(),
},
)?;
Ok(vec![OutgoingPacket {
dst_addr,
data: data.to_vec(),
send_time: Some(now),
}])
}
#[derive(Debug, Clone)]
pub struct ConnectionRouterStats {
pub active_connections: usize,
pub established_connections: usize,
pub pending_connections: usize,
}
#[derive(Debug)]
pub struct QuicTimerScheduler {
current_sleep: Option<Sleep>,
current_deadline: Option<Instant>,
}
impl QuicTimerScheduler {
pub fn new() -> Self {
Self {
current_sleep: None,
current_deadline: None,
}
}
pub async fn schedule_timer(
&mut self,
cx: &Cx,
deadline: Instant,
) -> Result<(), ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
let now = Instant::now();
if deadline <= now {
return Ok(());
}
let should_reschedule = match self.current_deadline {
Some(current) => deadline < current,
None => true,
};
if should_reschedule {
let duration = deadline.saturating_duration_since(now);
let duration_from_now = deadline.saturating_duration_since(Instant::now());
let time_deadline = crate::Time::from_nanos(duration_from_now.as_nanos() as u64);
self.current_sleep = Some(Sleep::new(time_deadline));
self.current_deadline = Some(deadline);
cx.trace(&format!(
"Scheduled QUIC timer for {deadline:?} (in {duration:?})"
));
}
Ok(())
}
pub async fn wait_for_timer(
&mut self,
cx: &Cx,
) -> Result<Option<Instant>, ConnectionRouterError> {
if cx.checkpoint().is_err() {
return Err(ConnectionRouterError::Cancelled);
}
if let Some(sleep) = self.current_sleep.take() {
let deadline = self.current_deadline.take();
sleep.await;
cx.trace(&format!("QUIC timer fired for {deadline:?}"));
Ok(deadline)
} else {
Ok(None)
}
}
pub fn has_pending_timer(&self) -> bool {
self.current_sleep.is_some()
}
pub fn current_deadline(&self) -> Option<Instant> {
self.current_deadline
}
pub fn cancel(&mut self) {
self.current_sleep = None;
self.current_deadline = None;
}
}
impl Default for QuicTimerScheduler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::BytesMut;
use crate::net::atp::protocol::quic_frames::QuicFrame;
use crate::net::quic_core::{LongHeader, LongPacketType, PacketHeader};
use crate::test_utils::run_test_with_cx;
#[test]
fn test_connection_router_creation() {
let config = NativeQuicConnectionConfig::default();
let router = ConnectionRouter::new(config);
assert_eq!(router.connections.len(), 0);
assert_eq!(router.next_connection_id, 1);
}
#[test]
fn test_connection_id_allocation() {
run_test_with_cx(|_cx| async move {
let config = NativeQuicConnectionConfig::default();
let mut router = ConnectionRouter::new(config);
let id1 = router.allocate_connection_id();
let id2 = router.allocate_connection_id();
assert_ne!(id1, id2);
assert!(router.next_connection_id > 2);
});
}
#[test]
fn test_connection_creation() {
run_test_with_cx(|cx| async move {
let config = NativeQuicConnectionConfig::default();
let mut router = ConnectionRouter::new(config);
let connection_id = router.allocate_connection_id();
let peer_addr = "127.0.0.1:12345".parse().unwrap();
router
.create_connection(&cx, connection_id, peer_addr, false)
.await
.expect("connection creation should succeed");
assert_eq!(router.connections.len(), 1);
assert!(router.connections.contains_key(&connection_id));
});
}
#[test]
fn test_long_header_initial_routes_as_new_connection() {
run_test_with_cx(|cx| async move {
let config = NativeQuicConnectionConfig::default();
let mut router = ConnectionRouter::new(config);
let dst_cid = ConnectionId::new(&[0xaa, 0xbb, 0xcc]).expect("cid");
let src_addr: SocketAddr = "127.0.0.1:4433".parse().unwrap();
let packet = ReceivedPacket {
src_addr,
data: encode_long_packet(dst_cid, LongPacketType::Initial, 0, QuicFrame::Ping),
receive_time: Instant::now(),
transmit_time: None,
};
match router.route_packet(&cx, packet).await.expect("route") {
RoutingResult::NewConnection { peer_addr, .. } => assert_eq!(peer_addr, src_addr),
other => panic!("expected new connection, got {other:?}"),
}
});
}
#[test]
fn test_existing_connection_processes_ping_and_emits_ack_frame() {
run_test_with_cx(|cx| async move {
let config = NativeQuicConnectionConfig::default();
let mut router = ConnectionRouter::new(config);
let connection_id = router.allocate_connection_id();
let peer_addr: SocketAddr = "127.0.0.1:4434".parse().unwrap();
router
.create_connection(&cx, connection_id, peer_addr, false)
.await
.expect("connection creation should succeed");
let packet = ReceivedPacket {
src_addr: peer_addr,
data: encode_long_packet(
connection_id,
LongPacketType::Initial,
42,
QuicFrame::Ping,
),
receive_time: Instant::now(),
transmit_time: None,
};
match router.route_packet(&cx, packet).await.expect("route") {
RoutingResult::Routed {
outgoing_packets, ..
} => {
assert_eq!(outgoing_packets.len(), 1);
assert_eq!(outgoing_packets[0].dst_addr, peer_addr);
assert!(!outgoing_packets[0].data.is_empty());
}
other => panic!("expected routed packet, got {other:?}"),
}
});
}
#[test]
fn test_timer_scheduler_basic() {
run_test_with_cx(|cx| async move {
let mut scheduler = QuicTimerScheduler::new();
assert!(!scheduler.has_pending_timer());
assert_eq!(scheduler.current_deadline(), None);
let deadline = Instant::now() + std::time::Duration::from_millis(10);
scheduler
.schedule_timer(&cx, deadline)
.await
.expect("timer scheduling should succeed");
assert!(scheduler.has_pending_timer());
assert_eq!(scheduler.current_deadline(), Some(deadline));
});
}
fn encode_long_packet(
dst_cid: ConnectionId,
packet_type: LongPacketType,
packet_number: u64,
frame: QuicFrame,
) -> Vec<u8> {
let mut payload = BytesMut::new();
frame.encode(&mut payload).expect("frame encode");
let header = PacketHeader::Long(LongHeader {
packet_type,
version: 1,
dst_cid,
src_cid: ConnectionId::new(&[0x01, 0x02, 0x03, 0x04]).expect("src cid"),
token: Vec::new(),
payload_length: payload.len() as u64 + 1,
packet_number,
packet_number_len: 1,
});
let mut out = Vec::new();
header.encode(&mut out).expect("header encode");
out.extend_from_slice(&payload);
out
}
}