use std::collections::HashMap;
use std::net::SocketAddr;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuicTransportError {
ConnectionClosed,
StreamReset(u64),
FlowControlViolation,
ProtocolError(String),
StreamNotFound(u64),
StreamLimitReached,
}
impl std::fmt::Display for QuicTransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ConnectionClosed => write!(f, "QUIC connection closed"),
Self::StreamReset(id) => write!(f, "QUIC stream {id} was reset by peer"),
Self::FlowControlViolation => write!(f, "QUIC flow-control limit violated"),
Self::ProtocolError(msg) => write!(f, "QUIC protocol error: {msg}"),
Self::StreamNotFound(id) => write!(f, "QUIC stream {id} not found"),
Self::StreamLimitReached => {
write!(f, "QUIC max_streams limit reached")
}
}
}
}
impl std::error::Error for QuicTransportError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StreamDirection {
Bidirectional,
UniSend,
UniReceive,
}
impl StreamDirection {
#[must_use]
pub fn can_send(&self) -> bool {
matches!(self, Self::Bidirectional | Self::UniSend)
}
#[must_use]
pub fn can_receive(&self) -> bool {
matches!(self, Self::Bidirectional | Self::UniReceive)
}
}
#[derive(Debug, Clone)]
pub struct QuicStream {
pub stream_id: u64,
pub direction: StreamDirection,
pub priority: u8,
pub bytes_sent: u64,
pub bytes_received: u64,
pub fin_sent: bool,
}
impl QuicStream {
fn new(stream_id: u64, direction: StreamDirection, priority: u8) -> Self {
Self {
stream_id,
direction,
priority,
bytes_sent: 0,
bytes_received: u64::MIN,
fin_sent: false,
}
}
}
#[derive(Debug, Clone)]
pub struct QuicConfig {
pub max_streams: u32,
pub idle_timeout_ms: u64,
pub keep_alive_interval_ms: u64,
pub initial_max_data: u64,
pub initial_max_stream_data: u64,
pub max_udp_payload_size: u16,
pub enable_datagrams: bool,
}
impl Default for QuicConfig {
fn default() -> Self {
Self {
max_streams: 100,
idle_timeout_ms: 30_000,
keep_alive_interval_ms: 10_000,
initial_max_data: 10 * 1024 * 1024, initial_max_stream_data: 1 * 1024 * 1024, max_udp_payload_size: 1200,
enable_datagrams: false,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct QuicStats {
pub rtt_us: u64,
pub rtt_var_us: u64,
pub packet_loss_rate: f64,
pub cwnd_bytes: u64,
pub packets_sent: u64,
pub packets_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub loss_events: u32,
}
impl QuicStats {
#[must_use]
pub fn rtt(&self) -> std::time::Duration {
std::time::Duration::from_micros(self.rtt_us)
}
#[must_use]
pub fn min_rtt_us(&self) -> u64 {
self.rtt_us.saturating_sub(self.rtt_var_us)
}
}
pub struct QuicConnection {
config: QuicConfig,
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
pub connection_id: u64,
streams: HashMap<u64, QuicStream>,
next_stream_id: u64,
closed: bool,
stats: QuicStats,
flow_credit: u64,
}
impl QuicConnection {
#[must_use]
pub fn new(config: QuicConfig, local_addr: SocketAddr, remote_addr: SocketAddr) -> Self {
let flow_credit = config.initial_max_data;
let connection_id = {
let b = remote_addr.ip().to_string();
b.bytes().fold(0u64, |acc, byte| {
acc.wrapping_mul(31).wrapping_add(byte as u64)
})
};
Self {
config,
local_addr,
remote_addr,
connection_id,
streams: HashMap::new(),
next_stream_id: 0,
closed: false,
stats: QuicStats::default(),
flow_credit,
}
}
pub fn open_stream(
&mut self,
direction: StreamDirection,
priority: u8,
) -> Result<u64, QuicTransportError> {
if self.closed {
return Err(QuicTransportError::ConnectionClosed);
}
if self.streams.len() >= self.config.max_streams as usize {
return Err(QuicTransportError::StreamLimitReached);
}
let stream_id = self.next_stream_id;
self.next_stream_id += 1;
let stream = QuicStream::new(stream_id, direction, priority);
self.streams.insert(stream_id, stream);
Ok(stream_id)
}
pub fn close_stream(&mut self, stream_id: u64) -> Result<(), QuicTransportError> {
if self.closed {
return Err(QuicTransportError::ConnectionClosed);
}
if self.streams.remove(&stream_id).is_none() {
return Err(QuicTransportError::StreamNotFound(stream_id));
}
Ok(())
}
#[must_use]
pub fn active_streams(&self) -> Vec<&QuicStream> {
self.streams.values().collect()
}
pub fn stream_mut(&mut self, stream_id: u64) -> Option<&mut QuicStream> {
self.streams.get_mut(&stream_id)
}
pub fn send_bytes(&mut self, stream_id: u64, bytes: u64) -> Result<(), QuicTransportError> {
if self.closed {
return Err(QuicTransportError::ConnectionClosed);
}
if !self.streams.contains_key(&stream_id) {
return Err(QuicTransportError::StreamNotFound(stream_id));
}
if bytes > self.flow_credit {
return Err(QuicTransportError::FlowControlViolation);
}
self.flow_credit -= bytes;
let stream = self
.streams
.get_mut(&stream_id)
.ok_or(QuicTransportError::StreamNotFound(stream_id))?;
stream.bytes_sent += bytes;
self.stats.bytes_sent += bytes;
self.stats.packets_sent += 1;
Ok(())
}
pub fn receive_bytes(&mut self, stream_id: u64, bytes: u64) -> Result<(), QuicTransportError> {
if self.closed {
return Err(QuicTransportError::ConnectionClosed);
}
let stream = self
.streams
.get_mut(&stream_id)
.ok_or(QuicTransportError::StreamNotFound(stream_id))?;
stream.bytes_received = stream.bytes_received.saturating_add(bytes);
self.stats.bytes_received += bytes;
self.stats.packets_received += 1;
Ok(())
}
pub fn update_rtt(&mut self, sample_us: u64) {
if self.stats.rtt_us == 0 {
self.stats.rtt_us = sample_us;
self.stats.rtt_var_us = sample_us / 2;
} else {
let diff = if sample_us > self.stats.rtt_us {
sample_us - self.stats.rtt_us
} else {
self.stats.rtt_us - sample_us
};
self.stats.rtt_var_us = (3 * self.stats.rtt_var_us + diff) / 4;
self.stats.rtt_us = (7 * self.stats.rtt_us + sample_us) / 8;
}
}
pub fn signal_loss_event(&mut self) {
self.stats.loss_events += 1;
self.stats.packet_loss_rate =
self.stats.loss_events as f64 / (self.stats.packets_sent.max(1) as f64);
let min_cwnd = 2 * self.config.max_udp_payload_size as u64;
self.stats.cwnd_bytes = (self.stats.cwnd_bytes / 2).max(min_cwnd);
}
pub fn increase_cwnd(&mut self, bytes: u64) {
self.stats.cwnd_bytes += bytes;
}
pub fn close(&mut self) {
self.closed = true;
self.streams.clear();
}
#[must_use]
pub fn is_open(&self) -> bool {
!self.closed
}
#[must_use]
pub fn connection_stats(&self) -> QuicStats {
self.stats.clone()
}
#[must_use]
pub fn flow_credit(&self) -> u64 {
self.flow_credit
}
pub fn extend_flow_credit(&mut self, additional: u64) {
self.flow_credit = self.flow_credit.saturating_add(additional);
}
}
pub struct HlsOverQuic {
conn: QuicConnection,
segment_streams: HashMap<String, u64>,
part_streams: HashMap<String, u64>,
}
impl HlsOverQuic {
#[must_use]
pub fn new(conn: QuicConnection) -> Self {
Self {
conn,
segment_streams: HashMap::new(),
part_streams: HashMap::new(),
}
}
pub fn request_segment(&mut self, uri: &str) -> Result<u64, QuicTransportError> {
if let Some(&id) = self.segment_streams.get(uri) {
return Ok(id);
}
let id = self.conn.open_stream(StreamDirection::UniReceive, 0)?;
self.segment_streams.insert(uri.to_owned(), id);
Ok(id)
}
pub fn request_part(&mut self, uri: &str) -> Result<u64, QuicTransportError> {
if let Some(&id) = self.part_streams.get(uri) {
return Ok(id);
}
let id = self.conn.open_stream(StreamDirection::UniReceive, 64)?;
self.part_streams.insert(uri.to_owned(), id);
Ok(id)
}
#[must_use]
pub fn connection(&self) -> &QuicConnection {
&self.conn
}
pub fn connection_mut(&mut self) -> &mut QuicConnection {
&mut self.conn
}
#[must_use]
pub fn segment_stream_ids(&self) -> &HashMap<String, u64> {
&self.segment_streams
}
#[must_use]
pub fn part_stream_ids(&self) -> &HashMap<String, u64> {
&self.part_streams
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_conn() -> QuicConnection {
let local: SocketAddr = "127.0.0.1:0".parse().expect("valid addr");
let remote: SocketAddr = "127.0.0.1:4433".parse().expect("valid addr");
QuicConnection::new(QuicConfig::default(), local, remote)
}
#[test]
fn test_default_config_max_streams() {
let cfg = QuicConfig::default();
assert_eq!(cfg.max_streams, 100);
}
#[test]
fn test_new_connection_is_open() {
let conn = make_conn();
assert!(conn.is_open());
}
#[test]
fn test_open_stream_unique_ids() {
let mut conn = make_conn();
let id0 = conn
.open_stream(StreamDirection::Bidirectional, 0)
.expect("ok");
let id1 = conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
assert_ne!(id0, id1);
}
#[test]
fn test_active_streams_count() {
let mut conn = make_conn();
conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
conn.open_stream(StreamDirection::UniReceive, 0)
.expect("ok");
assert_eq!(conn.active_streams().len(), 2);
}
#[test]
fn test_close_stream_removes_it() {
let mut conn = make_conn();
let id = conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
conn.close_stream(id).expect("close ok");
assert_eq!(conn.active_streams().len(), 0);
}
#[test]
fn test_close_unknown_stream_error() {
let mut conn = make_conn();
let err = conn.close_stream(999).expect_err("should fail");
assert_eq!(err, QuicTransportError::StreamNotFound(999));
}
#[test]
fn test_send_bytes_deducts_flow_credit() {
let mut conn = make_conn();
let id = conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
let initial = conn.flow_credit();
conn.send_bytes(id, 1024).expect("ok");
assert_eq!(conn.flow_credit(), initial - 1024);
}
#[test]
fn test_send_bytes_flow_control_violation() {
let mut conn = make_conn();
let id = conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
let err = conn.send_bytes(id, u64::MAX).expect_err("must fail");
assert_eq!(err, QuicTransportError::FlowControlViolation);
}
#[test]
fn test_update_rtt_first_sample() {
let mut conn = make_conn();
conn.update_rtt(20_000);
assert_eq!(conn.connection_stats().rtt_us, 20_000);
assert_eq!(conn.connection_stats().rtt_var_us, 10_000);
}
#[test]
fn test_update_rtt_convergence() {
let mut conn = make_conn();
conn.update_rtt(20_000);
for _ in 0..16 {
conn.update_rtt(30_000);
}
let stats = conn.connection_stats();
assert!(stats.rtt_us > 25_000, "rtt={}", stats.rtt_us);
}
#[test]
fn test_signal_loss_halves_cwnd() {
let mut conn = make_conn();
conn.increase_cwnd(100_000);
let before = conn.connection_stats().cwnd_bytes;
conn.signal_loss_event();
let after = conn.connection_stats().cwnd_bytes;
assert!(after <= before / 2 + 2400); }
#[test]
fn test_close_connection() {
let mut conn = make_conn();
conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
conn.close();
assert!(!conn.is_open());
assert_eq!(conn.active_streams().len(), 0);
}
#[test]
fn test_open_stream_on_closed_connection() {
let mut conn = make_conn();
conn.close();
let err = conn
.open_stream(StreamDirection::UniSend, 0)
.expect_err("fail");
assert_eq!(err, QuicTransportError::ConnectionClosed);
}
#[test]
fn test_stream_direction_semantics() {
assert!(StreamDirection::Bidirectional.can_send());
assert!(StreamDirection::Bidirectional.can_receive());
assert!(StreamDirection::UniSend.can_send());
assert!(!StreamDirection::UniSend.can_receive());
assert!(!StreamDirection::UniReceive.can_send());
assert!(StreamDirection::UniReceive.can_receive());
}
#[test]
fn test_stats_rtt_duration() {
let stats = QuicStats {
rtt_us: 15_000,
..Default::default()
};
assert_eq!(stats.rtt(), std::time::Duration::from_micros(15_000));
}
#[test]
fn test_extend_flow_credit() {
let mut conn = make_conn();
let initial = conn.flow_credit();
conn.extend_flow_credit(1_000_000);
assert_eq!(conn.flow_credit(), initial + 1_000_000);
}
#[test]
fn test_max_streams_limit() {
let cfg = QuicConfig {
max_streams: 2,
..QuicConfig::default()
};
let local: SocketAddr = "127.0.0.1:0".parse().expect("valid");
let remote: SocketAddr = "127.0.0.1:4433".parse().expect("valid");
let mut conn = QuicConnection::new(cfg, local, remote);
conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
conn.open_stream(StreamDirection::UniSend, 0).expect("ok");
let err = conn
.open_stream(StreamDirection::UniSend, 0)
.expect_err("must fail");
assert_eq!(err, QuicTransportError::StreamLimitReached);
}
#[test]
fn test_hls_over_quic_request_segment() {
let conn = make_conn();
let mut hoq = HlsOverQuic::new(conn);
let id = hoq.request_segment("seg0.ts").expect("ok");
assert_eq!(hoq.segment_stream_ids().len(), 1);
let id2 = hoq.request_segment("seg0.ts").expect("ok");
assert_eq!(id, id2);
}
#[test]
fn test_hls_over_quic_request_part_priority() {
let conn = make_conn();
let mut hoq = HlsOverQuic::new(conn);
let seg_id = hoq.request_segment("seg0.ts").expect("ok");
let part_id = hoq.request_part("part0.mp4").expect("ok");
assert_ne!(seg_id, part_id);
let streams = hoq.connection().active_streams();
let seg_stream = streams
.iter()
.find(|s| s.stream_id == seg_id)
.expect("found");
let part_stream = streams
.iter()
.find(|s| s.stream_id == part_id)
.expect("found");
assert!(seg_stream.priority < part_stream.priority); }
#[test]
fn test_stats_min_rtt_saturates() {
let stats = QuicStats {
rtt_us: 5_000,
rtt_var_us: 10_000,
..Default::default()
};
assert_eq!(stats.min_rtt_us(), 0);
}
}