use crate::bytes::{Bytes, BytesMut};
use crate::net::atp::handshake::state_machine::HandshakeError;
use crate::net::atp::protocol::varint::VarInt;
use crate::types::outcome::Outcome;
use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u64)]
pub enum TransportParamId {
OriginalDestinationConnectionId = 0x00,
MaxIdleTimeout = 0x01,
StatelessResetToken = 0x02,
MaxUdpPayloadSize = 0x03,
InitialMaxData = 0x04,
InitialMaxStreamDataBidiLocal = 0x05,
InitialMaxStreamDataBidiRemote = 0x06,
InitialMaxStreamDataUni = 0x07,
InitialMaxStreamsBidi = 0x08,
InitialMaxStreamsUni = 0x09,
AckDelayExponent = 0x0a,
MaxAckDelay = 0x0b,
DisableActiveMigration = 0x0c,
PreferredAddress = 0x0d,
ActiveConnectionIdLimit = 0x0e,
InitialSourceConnectionId = 0x0f,
RetrySourceConnectionId = 0x10,
MaxDatagramFrameSize = 0x20,
}
impl TransportParamId {
pub fn to_varint(self) -> VarInt {
VarInt::from_u64_unchecked(self as u64)
}
pub fn from_varint(varint: VarInt) -> Option<Self> {
match varint.value() {
0x00 => Some(Self::OriginalDestinationConnectionId),
0x01 => Some(Self::MaxIdleTimeout),
0x02 => Some(Self::StatelessResetToken),
0x03 => Some(Self::MaxUdpPayloadSize),
0x04 => Some(Self::InitialMaxData),
0x05 => Some(Self::InitialMaxStreamDataBidiLocal),
0x06 => Some(Self::InitialMaxStreamDataBidiRemote),
0x07 => Some(Self::InitialMaxStreamDataUni),
0x08 => Some(Self::InitialMaxStreamsBidi),
0x09 => Some(Self::InitialMaxStreamsUni),
0x0a => Some(Self::AckDelayExponent),
0x0b => Some(Self::MaxAckDelay),
0x0c => Some(Self::DisableActiveMigration),
0x0d => Some(Self::PreferredAddress),
0x0e => Some(Self::ActiveConnectionIdLimit),
0x0f => Some(Self::InitialSourceConnectionId),
0x10 => Some(Self::RetrySourceConnectionId),
0x20 => Some(Self::MaxDatagramFrameSize),
_ => None,
}
}
pub fn requires_value(self) -> bool {
!matches!(self, Self::DisableActiveMigration)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TransportParamValue {
Integer(u64),
Bytes(Bytes),
Empty,
}
impl TransportParamValue {
pub fn encode(&self) -> Bytes {
match self {
Self::Integer(value) => {
let mut buf = BytesMut::new();
let varint = VarInt::from_u64_unchecked(*value);
match varint.encode(&mut buf) {
Outcome::Ok(()) => {}
Outcome::Err(_) | Outcome::Cancelled(_) | Outcome::Panicked(_) => {
unreachable!("validated varint encoding must succeed")
}
}
buf.freeze()
}
Self::Bytes(bytes) => bytes.clone(),
Self::Empty => Bytes::new(),
}
}
pub fn as_integer(&self) -> Option<u64> {
match self {
Self::Integer(value) => Some(*value),
Self::Bytes(bytes) if !bytes.is_empty() => {
let mut buf = BytesMut::from(&bytes[..]);
match VarInt::decode(&mut buf) {
Outcome::Ok(Some(value)) => Some(value.value()),
_ => None,
}
}
_ => None,
}
}
pub fn as_bytes(&self) -> Option<&Bytes> {
match self {
Self::Bytes(bytes) => Some(bytes),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct TransportParameters {
params: HashMap<u64, TransportParamValue>,
}
impl TransportParameters {
pub fn new() -> Self {
Self {
params: HashMap::new(),
}
}
pub fn client_defaults() -> Self {
let mut params = Self::new();
params.set_integer(TransportParamId::MaxIdleTimeout, 30_000); params.set_integer(TransportParamId::MaxUdpPayloadSize, 65527); params.set_integer(TransportParamId::InitialMaxData, 1024 * 1024); params.set_integer(TransportParamId::InitialMaxStreamDataBidiLocal, 256 * 1024); params.set_integer(TransportParamId::InitialMaxStreamDataBidiRemote, 256 * 1024); params.set_integer(TransportParamId::InitialMaxStreamDataUni, 256 * 1024); params.set_integer(TransportParamId::InitialMaxStreamsBidi, 100);
params.set_integer(TransportParamId::InitialMaxStreamsUni, 100);
params.set_integer(TransportParamId::AckDelayExponent, 3);
params.set_integer(TransportParamId::MaxAckDelay, 25); params.set_integer(TransportParamId::ActiveConnectionIdLimit, 8);
params
}
pub fn server_defaults() -> Self {
let mut params = Self::client_defaults();
let mut stateless_reset_token = [0_u8; 16];
getrandom::fill(&mut stateless_reset_token)
.expect("OS entropy is required for QUIC stateless reset tokens");
params.set_bytes(
TransportParamId::StatelessResetToken,
Bytes::copy_from_slice(&stateless_reset_token),
);
params
}
pub fn set_integer(&mut self, id: TransportParamId, value: u64) {
self.params
.insert(id as u64, TransportParamValue::Integer(value));
}
pub fn set_bytes(&mut self, id: TransportParamId, value: Bytes) {
self.params
.insert(id as u64, TransportParamValue::Bytes(value));
}
pub fn set_flag(&mut self, id: TransportParamId) {
self.params.insert(id as u64, TransportParamValue::Empty);
}
pub fn get(&self, id: TransportParamId) -> Option<&TransportParamValue> {
self.params.get(&(id as u64))
}
pub fn get_integer(&self, id: TransportParamId) -> Option<u64> {
self.get(id)?.as_integer()
}
pub fn get_bytes(&self, id: TransportParamId) -> Option<&Bytes> {
self.get(id)?.as_bytes()
}
pub fn has_flag(&self, id: TransportParamId) -> bool {
matches!(self.get(id), Some(TransportParamValue::Empty))
}
pub fn encode(&self) -> Outcome<Bytes, HandshakeError> {
let mut buf = BytesMut::new();
for (¶m_id, param_value) in &self.params {
let id_varint = match VarInt::new(param_id) {
Outcome::Ok(varint) => varint,
Outcome::Err(_) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "parameter ID too large".to_string(),
});
}
Outcome::Cancelled(reason) => return Outcome::Cancelled(reason),
Outcome::Panicked(payload) => return Outcome::Panicked(payload),
};
match id_varint.encode(&mut buf) {
Outcome::Ok(()) => {}
Outcome::Err(_) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "failed to encode parameter ID".to_string(),
});
}
Outcome::Cancelled(reason) => return Outcome::Cancelled(reason),
Outcome::Panicked(payload) => return Outcome::Panicked(payload),
}
let value_bytes = param_value.encode();
let length_varint = match VarInt::new(value_bytes.len() as u64) {
Outcome::Ok(varint) => varint,
Outcome::Err(_) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "parameter value too large".to_string(),
});
}
Outcome::Cancelled(reason) => return Outcome::Cancelled(reason),
Outcome::Panicked(payload) => return Outcome::Panicked(payload),
};
match length_varint.encode(&mut buf) {
Outcome::Ok(()) => {}
Outcome::Err(_) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "failed to encode parameter length".to_string(),
});
}
Outcome::Cancelled(reason) => return Outcome::Cancelled(reason),
Outcome::Panicked(payload) => return Outcome::Panicked(payload),
}
buf.put_slice(&value_bytes);
}
Outcome::ok(buf.freeze())
}
pub fn decode(data: &[u8]) -> Outcome<Self, HandshakeError> {
let mut params = Self::new();
let mut buf = BytesMut::from(data);
while !buf.is_empty() {
let id_varint = match VarInt::decode(&mut buf) {
Outcome::Ok(Some(varint)) => varint,
Outcome::Ok(None) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: 0,
reason: "truncated parameter ID".to_string(),
});
}
Outcome::Err(_) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: 0,
reason: "failed to decode parameter ID".to_string(),
});
}
Outcome::Cancelled(reason) => return Outcome::Cancelled(reason),
Outcome::Panicked(payload) => return Outcome::Panicked(payload),
};
let param_id = id_varint.value();
if params.params.contains_key(¶m_id) {
return Outcome::err(HandshakeError::DuplicateTransportParam { param_id });
}
let length_varint = match VarInt::decode(&mut buf) {
Outcome::Ok(Some(varint)) => varint,
Outcome::Ok(None) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "truncated parameter length".to_string(),
});
}
Outcome::Err(_) => {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "failed to decode parameter length".to_string(),
});
}
Outcome::Cancelled(reason) => return Outcome::Cancelled(reason),
Outcome::Panicked(payload) => return Outcome::Panicked(payload),
};
let length = length_varint.value() as usize;
if buf.len() < length {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id,
reason: "truncated parameter value".to_string(),
});
}
let value_bytes = if length == 0 {
TransportParamValue::Empty
} else {
let bytes = buf.split_to(length).freeze();
if let Some(param_type) = TransportParamId::from_varint(id_varint) {
if param_type.requires_value()
&& param_type != TransportParamId::StatelessResetToken
&& param_type != TransportParamId::OriginalDestinationConnectionId
&& param_type != TransportParamId::InitialSourceConnectionId
&& param_type != TransportParamId::RetrySourceConnectionId
&& param_type != TransportParamId::PreferredAddress
{
let mut value_buf = BytesMut::from(&bytes[..]);
if let Outcome::Ok(Some(int_varint)) = VarInt::decode(&mut value_buf) {
TransportParamValue::Integer(int_varint.value())
} else {
TransportParamValue::Bytes(bytes)
}
} else {
TransportParamValue::Bytes(bytes)
}
} else {
TransportParamValue::Bytes(bytes)
}
};
params.params.insert(param_id, value_bytes);
}
Outcome::ok(params)
}
pub fn validate(&self) -> Outcome<(), HandshakeError> {
if let Some(exp) = self.get_integer(TransportParamId::AckDelayExponent) {
if exp > 20 {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: TransportParamId::AckDelayExponent as u64,
reason: "ACK delay exponent too large".to_string(),
});
}
}
if let Some(delay) = self.get_integer(TransportParamId::MaxAckDelay) {
if delay >= (1u64 << 14) {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: TransportParamId::MaxAckDelay as u64,
reason: "maximum ACK delay too large".to_string(),
});
}
}
if let Some(size) = self.get_integer(TransportParamId::MaxUdpPayloadSize) {
if size < 1200 {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: TransportParamId::MaxUdpPayloadSize as u64,
reason: "maximum UDP payload size too small".to_string(),
});
}
}
if let Some(limit) = self.get_integer(TransportParamId::ActiveConnectionIdLimit) {
if limit < 2 {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: TransportParamId::ActiveConnectionIdLimit as u64,
reason: "active connection ID limit too small".to_string(),
});
}
}
if let Some(token) = self.get_bytes(TransportParamId::StatelessResetToken) {
if token.len() != 16 {
return Outcome::err(HandshakeError::InvalidTransportParam {
param_id: TransportParamId::StatelessResetToken as u64,
reason: "stateless reset token must be 16 bytes".to_string(),
});
}
}
Outcome::ok(())
}
pub fn max_idle_timeout(&self) -> Option<Duration> {
self.get_integer(TransportParamId::MaxIdleTimeout)
.map(Duration::from_millis)
}
}
impl Default for TransportParameters {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bytes::BufMut;
#[test]
fn test_transport_params_roundtrip() {
let mut params = TransportParameters::new();
params.set_integer(TransportParamId::MaxIdleTimeout, 30000);
params.set_integer(TransportParamId::InitialMaxData, 1048576);
params.set_flag(TransportParamId::DisableActiveMigration);
params.set_bytes(
TransportParamId::StatelessResetToken,
Bytes::from_static(b"0123456789abcdef"),
);
let encoded = params.encode().unwrap();
let decoded = TransportParameters::decode(&encoded).unwrap();
assert_eq!(
decoded.get_integer(TransportParamId::MaxIdleTimeout),
Some(30000)
);
assert_eq!(
decoded.get_integer(TransportParamId::InitialMaxData),
Some(1048576)
);
assert!(decoded.has_flag(TransportParamId::DisableActiveMigration));
assert_eq!(
decoded.get_bytes(TransportParamId::StatelessResetToken),
Some(&Bytes::from_static(b"0123456789abcdef"))
);
}
#[test]
fn test_client_defaults() {
let params = TransportParameters::client_defaults();
assert!(
params
.get_integer(TransportParamId::MaxIdleTimeout)
.is_some()
);
assert!(
params
.get_integer(TransportParamId::InitialMaxData)
.is_some()
);
assert!(!params.has_flag(TransportParamId::DisableActiveMigration));
}
#[test]
fn test_validation() {
let mut params = TransportParameters::new();
params.set_integer(TransportParamId::AckDelayExponent, 25);
assert!(params.validate().is_err());
params.set_integer(TransportParamId::AckDelayExponent, 3); assert!(params.validate().is_ok());
}
#[test]
fn test_duplicate_parameter() {
let mut buf = BytesMut::new();
VarInt::new(0x01).unwrap().encode(&mut buf).unwrap(); VarInt::new(2).unwrap().encode(&mut buf).unwrap(); buf.put_u16(30000);
VarInt::new(0x01).unwrap().encode(&mut buf).unwrap(); VarInt::new(2).unwrap().encode(&mut buf).unwrap(); buf.put_u16(60000);
let result = TransportParameters::decode(&buf);
assert!(matches!(
result,
Outcome::Err(HandshakeError::DuplicateTransportParam { .. })
));
}
#[test]
fn test_max_idle_timeout_duration() {
let mut params = TransportParameters::new();
params.set_integer(TransportParamId::MaxIdleTimeout, 5000);
assert_eq!(params.max_idle_timeout(), Some(Duration::from_millis(5000)));
}
#[test]
fn test_max_datagram_frame_size_roundtrip() {
let mut params = TransportParameters::new();
params.set_integer(TransportParamId::MaxDatagramFrameSize, 1200);
let encoded = params.encode().unwrap();
let decoded = TransportParameters::decode(&encoded).unwrap();
assert_eq!(
decoded.get_integer(TransportParamId::MaxDatagramFrameSize),
Some(1200)
);
}
}