use std::{net::SocketAddr, sync::Arc, time::Duration, time::Instant};
use anyhow::Result;
use bytes::{Buf, Bytes};
use h3::server::RequestStream;
use quinn::Connection as QuinnConnection;
use quinn::ConnectionError;
use scru128::Id as Scru128Id;
use tokio::time::timeout;
#[derive(Clone)]
pub struct QuicSession {
id: String,
remote_addr: SocketAddr,
}
impl QuicSession {
pub fn new(remote_addr: SocketAddr) -> Self {
let id = Scru128Id::from_u128(rand::random()).to_string();
Self { id, remote_addr }
}
pub fn id(&self) -> &str {
&self.id
}
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
pub struct WebTransportStream {
inner: RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
max_frame_size: Option<usize>,
read_timeout: Option<Duration>,
max_datagram_size: Option<usize>,
datagram_per_sec: Option<u64>,
datagram_tokens: u64,
last_refill: Instant,
record_drop: bool,
conn: Option<QuinnConnection>,
}
impl WebTransportStream {
pub(crate) fn new(
inner: RequestStream<h3_quinn::BidiStream<Bytes>, Bytes>,
max_frame_size: Option<usize>,
read_timeout: Option<Duration>,
max_datagram_size: Option<usize>,
datagram_per_sec: Option<u64>,
record_drop: bool,
conn: Option<QuinnConnection>,
) -> Self {
Self {
inner,
max_frame_size,
read_timeout,
max_datagram_size,
datagram_per_sec,
datagram_tokens: datagram_per_sec.unwrap_or(0),
last_refill: Instant::now(),
record_drop,
conn,
}
}
fn refill(&mut self) {
if let Some(rate) = self.datagram_per_sec {
let now = Instant::now();
let elapsed = now.saturating_duration_since(self.last_refill);
let refill = rate.saturating_mul(elapsed.as_secs());
self.datagram_tokens = (self.datagram_tokens + refill).min(rate);
self.last_refill = now;
}
}
pub async fn recv_data(&mut self) -> Result<Option<Bytes>> {
let fut = self.inner.recv_data();
let maybe = match self.read_timeout {
Some(t) => timeout(t, fut).await??,
None => fut.await?,
};
self.refill();
match maybe {
Some(mut buf) => {
let data = buf.copy_to_bytes(buf.remaining());
if let Some(max) = self.max_frame_size
&& data.len() > max
{
anyhow::bail!("WebTransport frame exceeds limit");
}
Ok(Some(data))
}
None => Ok(None),
}
}
pub fn try_send_datagram(&mut self, data: Bytes) -> Result<()> {
self.refill();
if let Some(max) = self.max_datagram_size
&& data.len() > max
{
#[cfg(feature = "metrics")]
if self.record_drop {
crate::server::metrics::record_webtransport_datagram_dropped();
}
anyhow::bail!("Datagram frame exceeds limit");
}
if self.datagram_per_sec.is_some() {
if self.datagram_tokens == 0 {
#[cfg(feature = "metrics")]
if self.record_drop {
crate::server::metrics::record_webtransport_rate_limited();
}
anyhow::bail!("Datagram rate limited");
}
self.datagram_tokens -= 1;
}
match &self.conn {
Some(conn) => {
if let Err(err) = conn.send_datagram(data) {
#[cfg(feature = "metrics")]
if self.record_drop {
crate::server::metrics::record_webtransport_datagram_dropped();
}
anyhow::bail!("Datagram send failed: {err}");
}
Ok(())
}
None => anyhow::bail!("Datagram not supported by connection"),
}
}
pub async fn recv_datagram(&mut self) -> Result<Option<Bytes>> {
let Some(conn) = self.conn.clone() else {
anyhow::bail!("Datagram not supported by connection");
};
self.refill();
let raw = match conn.read_datagram().await {
Ok(bytes) => bytes,
Err(ConnectionError::ApplicationClosed { .. })
| Err(ConnectionError::LocallyClosed) => return Ok(None),
Err(err) => anyhow::bail!("Datagram recv failed: {err}"),
};
if let Some(max) = self.max_datagram_size
&& raw.len() > max
{
#[cfg(feature = "metrics")]
if self.record_drop {
crate::server::metrics::record_webtransport_datagram_dropped();
}
if self.record_drop {
return Ok(None);
} else {
anyhow::bail!("Datagram frame exceeds limit");
}
}
if self.datagram_per_sec.is_some() {
if self.datagram_tokens == 0 {
#[cfg(feature = "metrics")]
if self.record_drop {
crate::server::metrics::record_webtransport_rate_limited();
}
if self.record_drop {
return Ok(None);
} else {
anyhow::bail!("Datagram rate limited");
}
}
self.datagram_tokens -= 1;
}
Ok(Some(raw))
}
pub async fn send_data(&mut self, data: Bytes) -> Result<()> {
Ok(self.inner.send_data(data).await?)
}
pub async fn finish(&mut self) -> Result<()> {
Ok(self.inner.finish().await?)
}
}
#[async_trait::async_trait]
pub trait WebTransportHandler: Send + Sync {
async fn handle(
&self,
session: Arc<QuicSession>,
stream: &mut WebTransportStream,
) -> Result<()>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quic_session_basics() {
let addr1: SocketAddr = "127.0.0.1:1111".parse().unwrap();
let addr2: SocketAddr = "127.0.0.1:2222".parse().unwrap();
let s1 = QuicSession::new(addr1);
let s2 = QuicSession::new(addr2);
assert!(!s1.id().is_empty());
assert_ne!(s1.id(), s2.id());
assert_eq!(s1.remote_addr(), addr1);
assert_eq!(s2.remote_addr(), addr2);
}
#[test]
fn test_quic_session_clone() {
let addr: SocketAddr = "192.168.1.1:8080".parse().unwrap();
let s1 = QuicSession::new(addr);
let s2 = s1.clone();
assert_eq!(s1.id(), s2.id());
assert_eq!(s1.remote_addr(), s2.remote_addr());
}
#[test]
fn test_quic_session_id_format() {
let addr: SocketAddr = "10.0.0.1:443".parse().unwrap();
let session = QuicSession::new(addr);
let id = session.id();
assert!(!id.is_empty());
assert!(id.len() > 20); }
#[test]
fn test_quic_session_id_uniqueness() {
let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let sessions: Vec<_> = (0..100).map(|_| QuicSession::new(addr)).collect();
let ids: Vec<_> = sessions.iter().map(|s| s.id()).collect();
let mut unique_ids = std::collections::HashSet::new();
for id in ids {
assert!(unique_ids.insert(id), "发现重复的 ID: {}", id);
}
}
#[test]
fn test_quic_session_ipv6_support() {
let ipv6: SocketAddr = "[::1]:443".parse().unwrap();
let session = QuicSession::new(ipv6);
assert_eq!(session.remote_addr(), ipv6);
}
#[tokio::test]
async fn test_quic_session_in_async_context() {
let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
let session = QuicSession::new(addr);
let id = session.id().to_string();
assert!(!id.is_empty());
}
#[test]
fn test_quic_session_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<QuicSession>();
assert_sync::<QuicSession>();
}
#[test]
fn test_webtransport_stream_struct_size() {
let size = std::mem::size_of::<WebTransportStream>();
assert!(size > 0);
}
#[test]
fn test_webtransport_stream_refill_logic() {
let rate: u64 = 10;
let elapsed_secs: u64 = 2;
let refill = rate.saturating_mul(elapsed_secs);
assert_eq!(refill, 20);
let current_tokens: u64 = 5;
let new_tokens = (current_tokens + refill).min(rate);
assert_eq!(new_tokens, 10); }
#[test]
fn test_duration_saturating() {
use std::time::Instant;
let now = Instant::now();
let past = now - std::time::Duration::from_secs(1);
let elapsed = now.saturating_duration_since(past);
assert!(elapsed.as_secs() >= 1);
}
#[test]
fn test_webtransport_handler_trait_exists() {
fn assert_send_sync<T: Send + Sync + ?Sized>() {}
assert_send_sync::<dyn WebTransportHandler>();
}
#[test]
fn test_webtransport_stream_field_types() {
assert!(std::mem::size_of::<Option<usize>>() > 0);
assert!(std::mem::size_of::<Option<Duration>>() > 0);
assert!(std::mem::size_of::<Option<u64>>() > 0);
}
#[test]
fn test_bytes_copy_to_bytes() {
let data = b"hello world".to_vec();
let mut buf = Bytes::from(data);
let remaining = buf.remaining();
let copied = buf.copy_to_bytes(remaining);
assert_eq!(copied.len(), 11);
}
#[test]
fn test_webtransport_stream_refill_no_rate() {
let rate: Option<u64> = None;
assert!(rate.is_none());
}
#[test]
fn test_webtransport_stream_refill_with_rate() {
let rate: u64 = 100;
let elapsed_secs: u64 = 1;
let expected_refill = rate.saturating_mul(elapsed_secs);
assert_eq!(expected_refill, 100);
let current_tokens: u64 = 50;
let new_tokens = (current_tokens + expected_refill).min(rate);
assert_eq!(new_tokens, 100); }
#[test]
fn test_webtransport_stream_refill_zero_elapsed() {
let rate: u64 = 10;
let elapsed_secs: u64 = 0;
let refill = rate.saturating_mul(elapsed_secs);
assert_eq!(refill, 0);
let current_tokens: u64 = 5;
let new_tokens = (current_tokens + refill).min(rate);
assert_eq!(new_tokens, 5);
}
#[test]
fn test_webtransport_stream_refill_large_elapsed() {
let rate: u64 = 10;
let elapsed_secs: u64 = 1000;
let refill = rate.saturating_mul(elapsed_secs);
assert_eq!(refill, 10000);
let current_tokens: u64 = 0;
let new_tokens = (current_tokens + refill).min(rate);
assert_eq!(new_tokens, 10); }
#[test]
fn test_webtransport_stream_token_consumption() {
let initial_tokens: u64 = 10;
let consume: u64 = 1;
let remaining = initial_tokens.saturating_sub(consume);
assert_eq!(remaining, 9);
let zero_tokens: u64 = 0;
let after_consume = zero_tokens.saturating_sub(consume);
assert_eq!(after_consume, 0);
}
#[test]
fn test_webtransport_stream_size_validation() {
let data_size = 100;
let max_size: usize = 50;
assert!(data_size > max_size);
let valid_size = 30;
assert!(valid_size <= max_size);
}
#[test]
fn test_webtransport_stream_optional_size() {
let max_size: Option<usize> = None;
assert!(max_size.is_none());
let data_size = 1000;
if let Some(max) = max_size {
assert!(data_size <= max);
}
}
#[test]
fn test_webtransport_stream_rate_limit_check() {
let tokens: u64 = 0;
let has_rate_limit = true;
assert!(tokens == 0 && has_rate_limit);
let tokens_with_balance: u64 = 5;
assert!(tokens_with_balance > 0);
}
#[test]
fn test_webtransport_stream_no_rate_limit() {
let rate_per_sec: Option<u64> = None;
assert!(rate_per_sec.is_none());
let unlimited_operations = true;
assert!(unlimited_operations || rate_per_sec.is_some());
}
#[test]
fn test_webtransport_max_frame_size_validation() {
let frame_size = 1024;
let max_frame_size = 512usize;
assert!(frame_size > max_frame_size);
let valid_frame_size = 256;
assert!(valid_frame_size <= max_frame_size);
}
#[test]
fn test_webtransport_datagram_size_validation() {
let datagram_size = 2048;
let max_datagram_size = 1024usize;
assert!(datagram_size > max_datagram_size);
let valid_size = 512;
assert!(valid_size <= max_datagram_size);
}
#[test]
fn test_webtransport_connection_availability() {
let conn_available = true;
assert!(conn_available);
let conn_unavailable: Option<bool> = None;
assert!(conn_unavailable.is_none());
}
#[test]
fn test_webtransport_timeout_configuration() {
let timeout = Duration::from_secs(30);
assert_eq!(timeout.as_secs(), 30);
let some_timeout: Option<Duration> = Some(timeout);
assert!(some_timeout.is_some());
if let Some(t) = some_timeout {
assert_eq!(t.as_secs(), 30);
}
let no_timeout: Option<Duration> = None;
assert!(no_timeout.is_none());
}
#[test]
fn test_webtransport_record_drop_flag() {
let record_drop_true = true;
let record_drop_false = false;
assert!(record_drop_true);
assert!(!record_drop_false);
}
#[test]
fn test_duration_arithmetic() {
use std::time::Duration;
let d1 = Duration::from_secs(10);
let d2 = Duration::from_secs(5);
let diff = d1.saturating_sub(d2);
assert_eq!(diff.as_secs(), 5);
let d3 = Duration::from_secs(3);
let d4 = Duration::from_secs(5);
let saturated = d3.saturating_sub(d4);
assert_eq!(saturated.as_secs(), 0);
}
#[test]
fn test_webtransport_stream_new_default_params() {
fn assert_new_exists<T>() {}
assert_new_exists::<fn()>();
}
#[test]
fn test_webtransport_stream_optional_conn() {
let conn_opt: Option<QuinnConnection> = None;
assert!(conn_opt.is_none());
fn assert_option_conn<T: Sized>() {}
assert_option_conn::<Option<QuinnConnection>>();
}
#[test]
fn test_instant_arithmetic() {
use std::time::{Duration, Instant};
let now = Instant::now();
let future = now + Duration::from_secs(10);
let elapsed = future.saturating_duration_since(now);
assert!(elapsed.as_secs() >= 9); }
#[test]
fn test_saturating_operations() {
let val: u64 = 100;
let add = val.saturating_add(200);
assert_eq!(add, 300);
let sub = val.saturating_sub(50);
assert_eq!(sub, 50);
let underflow = val.saturating_sub(200);
assert_eq!(underflow, 0);
let mul = val.saturating_mul(3);
assert_eq!(mul, 300);
let overflow = u64::MAX.saturating_mul(2);
assert_eq!(overflow, u64::MAX);
}
#[test]
fn test_bytes_operations() {
let data = b"test data".to_vec();
let bytes = Bytes::from(data);
assert_eq!(bytes.len(), 9);
assert!(!bytes.is_empty());
}
#[test]
fn test_error_conversions() {
use anyhow::{Result, anyhow};
fn check_error() -> Result<()> {
Err(anyhow!("test error"))
}
assert!(check_error().is_err());
}
#[tokio::test]
async fn test_timeout_future() {
use tokio::time::{Duration, timeout};
async fn long_operation() -> &'static str {
tokio::time::sleep(Duration::from_millis(100)).await;
"done"
}
let result = timeout(Duration::from_millis(200), long_operation()).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "done");
}
#[test]
fn test_socket_addr_validation() {
let valid_ipv4: SocketAddr = "127.0.0.1:8080".parse().unwrap();
assert_eq!(valid_ipv4.port(), 8080);
let valid_ipv6: SocketAddr = "[::1]:443".parse().unwrap();
assert_eq!(valid_ipv6.port(), 443);
}
#[test]
fn test_data_size_comparisons() {
let data_size = 1024usize;
let max_size = 512usize;
assert!(data_size > max_size);
let valid_size = 256usize;
assert!(valid_size <= max_size);
assert_eq!(1024usize.saturating_sub(1000), 24);
assert_eq!(100usize.saturating_sub(200), 0);
}
#[test]
fn test_token_refill_scenarios() {
let rate: u64 = 10;
let elapsed1 = 0u64;
let refill1 = rate.saturating_mul(elapsed1);
assert_eq!(refill1, 0);
let elapsed2 = 1u64;
let refill2 = rate.saturating_mul(elapsed2);
assert_eq!(refill2, 10);
let elapsed3 = 5u64;
let refill3 = rate.saturating_mul(elapsed3);
assert_eq!(refill3, 50);
let current = 8u64;
let new_tokens = (current + refill3).min(rate);
assert_eq!(new_tokens, 10);
}
#[test]
fn test_rate_limit_boundary() {
let tokens: u64 = 1;
let remaining = tokens.saturating_sub(1);
assert_eq!(remaining, 0);
let over_consume = remaining.saturating_sub(1);
assert_eq!(over_consume, 0);
}
#[test]
fn test_max_frame_size_validation_logic() {
let frame_sizes = vec![128, 256, 512, 1024, 2048];
let max_size = 1024usize;
for size in frame_sizes {
let exceeds = size > max_size;
if size == 2048 {
assert!(exceeds);
} else {
assert!(!exceeds || size == 1024);
}
}
}
#[test]
fn test_datagram_size_check() {
let sizes = vec![100, 500, 1000, 1500, 2000];
let max_datagram = 1350usize;
for size in sizes {
let valid = size <= max_datagram;
if size <= 1350 {
assert!(valid);
} else {
assert!(!valid);
}
}
}
#[test]
fn test_optional_configurations() {
let max_frame: Option<usize> = Some(1024);
let read_timeout: Option<Duration> = Some(Duration::from_secs(30));
let max_datagram: Option<usize> = Some(1350);
let datagram_rate: Option<u64> = Some(1000);
assert!(max_frame.is_some());
assert!(read_timeout.is_some());
assert!(max_datagram.is_some());
assert!(datagram_rate.is_some());
let none_frame: Option<usize> = None;
let none_timeout: Option<Duration> = None;
assert!(none_frame.is_none());
assert!(none_timeout.is_none());
}
#[test]
fn test_u64_boundaries() {
assert_eq!(u64::MAX, 18446744073709551615);
assert_eq!(u64::MIN, 0);
let val: u64 = 1000;
assert_eq!(val.saturating_add(u64::MAX), u64::MAX);
assert_eq!(val.saturating_mul(0), 0);
}
#[test]
fn test_duration_conversions() {
let secs = 60u64;
let duration = Duration::from_secs(secs);
assert_eq!(duration.as_secs(), 60);
let millis = 5000u64;
let duration2 = Duration::from_millis(millis);
assert_eq!(duration2.as_secs(), 5);
assert_eq!(duration2.subsec_millis(), 0);
}
#[test]
fn test_boolean_flag_combinations() {
let record_drop = true;
let has_connection = false;
let has_rate_limit = true;
assert!(record_drop);
assert!(!has_connection);
assert!(has_rate_limit);
assert!(record_drop && has_rate_limit);
assert!(!has_connection || record_drop);
}
#[test]
fn test_string_id_generation() {
use scru128::Id as Scru128Id;
let id1 = Scru128Id::from_u128(rand::random()).to_string();
let id2 = Scru128Id::from_u128(rand::random()).to_string();
assert!(!id1.is_empty());
assert!(!id2.is_empty());
assert_ne!(id1, id2);
}
}