use std::fmt;
pub const QUIC_VARINT_MAX: u64 = (1u64 << 62) - 1;
pub const TP_MAX_IDLE_TIMEOUT: u64 = 0x01;
pub const TP_MAX_UDP_PAYLOAD_SIZE: u64 = 0x03;
pub const TP_INITIAL_MAX_DATA: u64 = 0x04;
pub const TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL: u64 = 0x05;
pub const TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE: u64 = 0x06;
pub const TP_INITIAL_MAX_STREAM_DATA_UNI: u64 = 0x07;
pub const TP_INITIAL_MAX_STREAMS_BIDI: u64 = 0x08;
pub const TP_INITIAL_MAX_STREAMS_UNI: u64 = 0x09;
pub const TP_ACK_DELAY_EXPONENT: u64 = 0x0a;
pub const TP_MAX_ACK_DELAY: u64 = 0x0b;
pub const TP_DISABLE_ACTIVE_MIGRATION: u64 = 0x0c;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum QuicCoreError {
UnexpectedEof,
VarIntOutOfRange(u64),
InvalidHeader(&'static str),
InvalidConnectionIdLength(usize),
PacketNumberTooLarge {
packet_number: u32,
width: u8,
},
DuplicateTransportParameter(u64),
InvalidTransportParameter(u64),
}
impl fmt::Display for QuicCoreError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnexpectedEof => write!(f, "unexpected EOF"),
Self::VarIntOutOfRange(v) => write!(f, "varint out of range: {v}"),
Self::InvalidHeader(msg) => write!(f, "invalid header: {msg}"),
Self::InvalidConnectionIdLength(len) => {
write!(f, "invalid connection id length: {len}")
}
Self::PacketNumberTooLarge {
packet_number,
width,
} => write!(
f,
"packet number {packet_number} does not fit in {width} bytes"
),
Self::DuplicateTransportParameter(id) => {
write!(f, "duplicate transport parameter: 0x{id:x}")
}
Self::InvalidTransportParameter(id) => {
write!(f, "invalid transport parameter: 0x{id:x}")
}
}
}
}
impl std::error::Error for QuicCoreError {}
#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct ConnectionId {
bytes: [u8; 20],
len: u8,
}
impl ConnectionId {
pub const MAX_LEN: usize = 20;
pub fn new(bytes: &[u8]) -> Result<Self, QuicCoreError> {
if bytes.len() > Self::MAX_LEN {
return Err(QuicCoreError::InvalidConnectionIdLength(bytes.len()));
}
let mut out = [0u8; Self::MAX_LEN];
out[..bytes.len()].copy_from_slice(bytes);
Ok(Self {
bytes: out,
len: bytes.len() as u8,
})
}
#[must_use]
pub fn len(&self) -> usize {
self.len as usize
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len()]
}
}
impl fmt::Debug for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ConnectionId(")?;
for b in self.as_bytes() {
write!(f, "{b:02x}")?;
}
write!(f, ")")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LongPacketType {
Initial,
ZeroRtt,
Handshake,
Retry,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LongHeader {
pub packet_type: LongPacketType,
pub version: u32,
pub dst_cid: ConnectionId,
pub src_cid: ConnectionId,
pub token: Vec<u8>,
pub payload_length: u64,
pub packet_number: u32,
pub packet_number_len: u8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RetryHeader {
pub version: u32,
pub dst_cid: ConnectionId,
pub src_cid: ConnectionId,
pub token: Vec<u8>,
pub integrity_tag: [u8; 16],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ShortHeader {
pub spin: bool,
pub key_phase: bool,
pub dst_cid: ConnectionId,
pub packet_number: u32,
pub packet_number_len: u8,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PacketHeader {
Long(LongHeader),
Retry(RetryHeader),
Short(ShortHeader),
}
impl PacketHeader {
pub fn encode(&self, out: &mut Vec<u8>) -> Result<(), QuicCoreError> {
match self {
Self::Long(h) => encode_long_header(h, out),
Self::Retry(h) => {
encode_retry_header(h, out);
Ok(())
}
Self::Short(h) => encode_short_header(h, out),
}
}
pub fn decode(input: &[u8], short_dcid_len: usize) -> Result<(Self, usize), QuicCoreError> {
if input.is_empty() {
return Err(QuicCoreError::UnexpectedEof);
}
if input[0] & 0x80 != 0 {
decode_long_header(input)
} else {
decode_short_header(input, short_dcid_len).map(|(h, n)| (Self::Short(h), n))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnknownTransportParameter {
pub id: u64,
pub value: Vec<u8>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct TransportParameters {
pub max_idle_timeout: Option<u64>,
pub max_udp_payload_size: Option<u64>,
pub initial_max_data: Option<u64>,
pub initial_max_stream_data_bidi_local: Option<u64>,
pub initial_max_stream_data_bidi_remote: Option<u64>,
pub initial_max_stream_data_uni: Option<u64>,
pub initial_max_streams_bidi: Option<u64>,
pub initial_max_streams_uni: Option<u64>,
pub ack_delay_exponent: Option<u64>,
pub max_ack_delay: Option<u64>,
pub disable_active_migration: bool,
pub unknown: Vec<UnknownTransportParameter>,
}
impl TransportParameters {
pub fn encode(&self, out: &mut Vec<u8>) -> Result<(), QuicCoreError> {
encode_known_u64(out, TP_MAX_IDLE_TIMEOUT, self.max_idle_timeout)?;
encode_known_u64(out, TP_MAX_UDP_PAYLOAD_SIZE, self.max_udp_payload_size)?;
encode_known_u64(out, TP_INITIAL_MAX_DATA, self.initial_max_data)?;
encode_known_u64(
out,
TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL,
self.initial_max_stream_data_bidi_local,
)?;
encode_known_u64(
out,
TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE,
self.initial_max_stream_data_bidi_remote,
)?;
encode_known_u64(
out,
TP_INITIAL_MAX_STREAM_DATA_UNI,
self.initial_max_stream_data_uni,
)?;
encode_known_u64(
out,
TP_INITIAL_MAX_STREAMS_BIDI,
self.initial_max_streams_bidi,
)?;
encode_known_u64(
out,
TP_INITIAL_MAX_STREAMS_UNI,
self.initial_max_streams_uni,
)?;
encode_known_u64(out, TP_ACK_DELAY_EXPONENT, self.ack_delay_exponent)?;
encode_known_u64(out, TP_MAX_ACK_DELAY, self.max_ack_delay)?;
if self.disable_active_migration {
encode_parameter(out, TP_DISABLE_ACTIVE_MIGRATION, &[])?;
}
for p in &self.unknown {
encode_parameter(out, p.id, &p.value)?;
}
Ok(())
}
pub fn decode(input: &[u8]) -> Result<Self, QuicCoreError> {
let mut tp = Self::default();
let mut seen_ids: Vec<u64> = Vec::new();
let mut pos = 0usize;
while pos < input.len() {
let (id, id_len) = decode_varint(&input[pos..])?;
pos += id_len;
let (len, len_len) = decode_varint(&input[pos..])?;
pos += len_len;
let len = len as usize;
if input.len().saturating_sub(pos) < len {
return Err(QuicCoreError::UnexpectedEof);
}
let value = &input[pos..pos + len];
pos += len;
if seen_ids.contains(&id) {
return Err(QuicCoreError::DuplicateTransportParameter(id));
}
seen_ids.push(id);
match id {
TP_MAX_IDLE_TIMEOUT => set_unique_u64(&mut tp.max_idle_timeout, id, value)?,
TP_MAX_UDP_PAYLOAD_SIZE => {
set_unique_u64(&mut tp.max_udp_payload_size, id, value)?;
if tp.max_udp_payload_size.is_some_and(|v| v < 1200) {
return Err(QuicCoreError::InvalidTransportParameter(id));
}
}
TP_INITIAL_MAX_DATA => set_unique_u64(&mut tp.initial_max_data, id, value)?,
TP_INITIAL_MAX_STREAM_DATA_BIDI_LOCAL => {
set_unique_u64(&mut tp.initial_max_stream_data_bidi_local, id, value)?;
}
TP_INITIAL_MAX_STREAM_DATA_BIDI_REMOTE => {
set_unique_u64(&mut tp.initial_max_stream_data_bidi_remote, id, value)?;
}
TP_INITIAL_MAX_STREAM_DATA_UNI => {
set_unique_u64(&mut tp.initial_max_stream_data_uni, id, value)?;
}
TP_INITIAL_MAX_STREAMS_BIDI => {
set_unique_u64(&mut tp.initial_max_streams_bidi, id, value)?;
}
TP_INITIAL_MAX_STREAMS_UNI => {
set_unique_u64(&mut tp.initial_max_streams_uni, id, value)?;
}
TP_ACK_DELAY_EXPONENT => {
set_unique_u64(&mut tp.ack_delay_exponent, id, value)?;
if tp.ack_delay_exponent.is_some_and(|v| v > 20) {
return Err(QuicCoreError::InvalidTransportParameter(id));
}
}
TP_MAX_ACK_DELAY => set_unique_u64(&mut tp.max_ack_delay, id, value)?,
TP_DISABLE_ACTIVE_MIGRATION => {
if tp.disable_active_migration {
return Err(QuicCoreError::DuplicateTransportParameter(id));
}
if !value.is_empty() {
return Err(QuicCoreError::InvalidTransportParameter(id));
}
tp.disable_active_migration = true;
}
_ => tp.unknown.push(UnknownTransportParameter {
id,
value: value.to_vec(),
}),
}
}
Ok(tp)
}
}
pub fn encode_varint(value: u64, out: &mut Vec<u8>) -> Result<(), QuicCoreError> {
if value > QUIC_VARINT_MAX {
return Err(QuicCoreError::VarIntOutOfRange(value));
}
if value < (1 << 6) {
out.push(value as u8);
return Ok(());
}
if value < (1 << 14) {
let x = value as u16;
out.push(((x >> 8) as u8 & 0x3f) | 0x40);
out.push(x as u8);
return Ok(());
}
if value < (1 << 30) {
let x = value as u32;
out.push(((x >> 24) as u8 & 0x3f) | 0x80);
out.push((x >> 16) as u8);
out.push((x >> 8) as u8);
out.push(x as u8);
return Ok(());
}
let x = value;
out.push(((x >> 56) as u8 & 0x3f) | 0xc0);
out.push((x >> 48) as u8);
out.push((x >> 40) as u8);
out.push((x >> 32) as u8);
out.push((x >> 24) as u8);
out.push((x >> 16) as u8);
out.push((x >> 8) as u8);
out.push(x as u8);
Ok(())
}
pub fn decode_varint(input: &[u8]) -> Result<(u64, usize), QuicCoreError> {
if input.is_empty() {
return Err(QuicCoreError::UnexpectedEof);
}
let first = input[0];
let len = 1usize << (first >> 6);
if input.len() < len {
return Err(QuicCoreError::UnexpectedEof);
}
let mut value = u64::from(first & 0x3f);
for b in &input[1..len] {
value = (value << 8) | u64::from(*b);
}
Ok((value, len))
}
fn encode_long_header(header: &LongHeader, out: &mut Vec<u8>) -> Result<(), QuicCoreError> {
let pn_len = validate_pn_len(header.packet_number_len)?;
ensure_pn_fits(header.packet_number, pn_len)?;
if !matches!(header.packet_type, LongPacketType::Initial) && !header.token.is_empty() {
return Err(QuicCoreError::InvalidHeader(
"token only valid for Initial packets",
));
}
if header.payload_length < u64::from(pn_len) {
return Err(QuicCoreError::InvalidHeader(
"payload length smaller than packet number length",
));
}
let type_bits = match header.packet_type {
LongPacketType::Initial => 0u8,
LongPacketType::ZeroRtt => 1u8,
LongPacketType::Handshake => 2u8,
LongPacketType::Retry => {
return Err(QuicCoreError::InvalidHeader(
"retry packets must use PacketHeader::Retry",
));
}
};
let first = 0b1100_0000u8 | (type_bits << 4) | (pn_len - 1);
out.push(first);
out.extend_from_slice(&header.version.to_be_bytes());
out.push(header.dst_cid.len() as u8);
out.extend_from_slice(header.dst_cid.as_bytes());
out.push(header.src_cid.len() as u8);
out.extend_from_slice(header.src_cid.as_bytes());
if matches!(header.packet_type, LongPacketType::Initial) {
encode_varint(header.token.len() as u64, out)?;
out.extend_from_slice(&header.token);
}
encode_varint(header.payload_length, out)?;
write_packet_number(header.packet_number, pn_len, out);
Ok(())
}
fn encode_retry_header(header: &RetryHeader, out: &mut Vec<u8>) {
out.push(0b1111_0000u8);
out.extend_from_slice(&header.version.to_be_bytes());
out.push(header.dst_cid.len() as u8);
out.extend_from_slice(header.dst_cid.as_bytes());
out.push(header.src_cid.len() as u8);
out.extend_from_slice(header.src_cid.as_bytes());
out.extend_from_slice(&header.token);
out.extend_from_slice(&header.integrity_tag);
}
fn encode_short_header(header: &ShortHeader, out: &mut Vec<u8>) -> Result<(), QuicCoreError> {
let pn_len = validate_pn_len(header.packet_number_len)?;
ensure_pn_fits(header.packet_number, pn_len)?;
let mut first = 0b0100_0000u8 | (pn_len - 1);
if header.spin {
first |= 0b0010_0000;
}
if header.key_phase {
first |= 0b0000_0100;
}
out.push(first);
out.extend_from_slice(header.dst_cid.as_bytes());
write_packet_number(header.packet_number, pn_len, out);
Ok(())
}
fn decode_long_header(input: &[u8]) -> Result<(PacketHeader, usize), QuicCoreError> {
if input.len() < 6 {
return Err(QuicCoreError::UnexpectedEof);
}
let first = input[0];
if first & 0x40 == 0 {
return Err(QuicCoreError::InvalidHeader("long header fixed bit unset"));
}
let packet_type = match (first >> 4) & 0x03 {
0 => LongPacketType::Initial,
1 => LongPacketType::ZeroRtt,
2 => LongPacketType::Handshake,
3 => LongPacketType::Retry,
_ => unreachable!("2-bit pattern"),
};
if matches!(packet_type, LongPacketType::Retry) {
if first & 0x0f != 0 {
return Err(QuicCoreError::InvalidHeader(
"retry header reserved bits set",
));
}
} else if first & 0x0c != 0 {
return Err(QuicCoreError::InvalidHeader(
"long header reserved bits set",
));
}
let pn_len = (first & 0x03) + 1;
let mut pos = 1usize;
let version = u32::from_be_bytes([input[pos], input[pos + 1], input[pos + 2], input[pos + 3]]);
pos += 4;
let dcid_len = input[pos] as usize;
pos += 1;
let dst_cid = read_cid(input, &mut pos, dcid_len)?;
if pos >= input.len() {
return Err(QuicCoreError::UnexpectedEof);
}
let scid_len = input[pos] as usize;
pos += 1;
let src_cid = read_cid(input, &mut pos, scid_len)?;
if matches!(packet_type, LongPacketType::Retry) {
if input.len().saturating_sub(pos) < 16 {
return Err(QuicCoreError::UnexpectedEof);
}
let token_end = input.len() - 16;
let token = input[pos..token_end].to_vec();
let integrity_tag = input[token_end..]
.try_into()
.map_err(|_| QuicCoreError::UnexpectedEof)?;
return Ok((
PacketHeader::Retry(RetryHeader {
version,
dst_cid,
src_cid,
token,
integrity_tag,
}),
input.len(),
));
}
let token = if matches!(packet_type, LongPacketType::Initial) {
let (token_len, consumed) = decode_varint(&input[pos..])?;
pos += consumed;
let token_len = token_len as usize;
if input.len().saturating_sub(pos) < token_len {
return Err(QuicCoreError::UnexpectedEof);
}
let token = input[pos..pos + token_len].to_vec();
pos += token_len;
token
} else {
Vec::new()
};
let (payload_length, consumed) = decode_varint(&input[pos..])?;
pos += consumed;
if payload_length < u64::from(pn_len) {
return Err(QuicCoreError::InvalidHeader(
"payload length smaller than packet number length",
));
}
let packet_number = read_packet_number(input, &mut pos, pn_len)?;
Ok((
PacketHeader::Long(LongHeader {
packet_type,
version,
dst_cid,
src_cid,
token,
payload_length,
packet_number,
packet_number_len: pn_len,
}),
pos,
))
}
fn decode_short_header(
input: &[u8],
short_dcid_len: usize,
) -> Result<(ShortHeader, usize), QuicCoreError> {
if input.is_empty() {
return Err(QuicCoreError::UnexpectedEof);
}
if input[0] & 0x40 == 0 {
return Err(QuicCoreError::InvalidHeader("short header fixed bit unset"));
}
let first = input[0];
if first & 0x18 != 0 {
return Err(QuicCoreError::InvalidHeader(
"short header reserved bits set",
));
}
let pn_len = (first & 0x03) + 1;
let spin = first & 0b0010_0000 != 0;
let key_phase = first & 0b0000_0100 != 0;
let mut pos = 1usize;
let dst_cid = read_cid(input, &mut pos, short_dcid_len)?;
let packet_number = read_packet_number(input, &mut pos, pn_len)?;
Ok((
ShortHeader {
spin,
key_phase,
dst_cid,
packet_number,
packet_number_len: pn_len,
},
pos,
))
}
fn encode_parameter(out: &mut Vec<u8>, id: u64, value: &[u8]) -> Result<(), QuicCoreError> {
encode_varint(id, out)?;
encode_varint(value.len() as u64, out)?;
out.extend_from_slice(value);
Ok(())
}
fn encode_known_u64(out: &mut Vec<u8>, id: u64, value: Option<u64>) -> Result<(), QuicCoreError> {
if let Some(value) = value {
let mut body = Vec::with_capacity(8);
encode_varint(value, &mut body)?;
encode_parameter(out, id, &body)?;
}
Ok(())
}
fn set_unique_u64(slot: &mut Option<u64>, id: u64, value: &[u8]) -> Result<(), QuicCoreError> {
if slot.is_some() {
return Err(QuicCoreError::DuplicateTransportParameter(id));
}
let (decoded, consumed) = decode_varint(value)?;
if consumed != value.len() {
return Err(QuicCoreError::InvalidTransportParameter(id));
}
*slot = Some(decoded);
Ok(())
}
fn read_cid(input: &[u8], pos: &mut usize, cid_len: usize) -> Result<ConnectionId, QuicCoreError> {
if cid_len > ConnectionId::MAX_LEN {
return Err(QuicCoreError::InvalidConnectionIdLength(cid_len));
}
if input.len().saturating_sub(*pos) < cid_len {
return Err(QuicCoreError::UnexpectedEof);
}
let cid = ConnectionId::new(&input[*pos..*pos + cid_len])?;
*pos += cid_len;
Ok(cid)
}
fn write_packet_number(packet_number: u32, width: u8, out: &mut Vec<u8>) {
let bytes = packet_number.to_be_bytes();
let take = width as usize;
out.extend_from_slice(&bytes[4 - take..]);
}
fn read_packet_number(input: &[u8], pos: &mut usize, width: u8) -> Result<u32, QuicCoreError> {
let width = validate_pn_len(width)?;
let width = width as usize;
if input.len().saturating_sub(*pos) < width {
return Err(QuicCoreError::UnexpectedEof);
}
let mut out = [0u8; 4];
out[4 - width..].copy_from_slice(&input[*pos..*pos + width]);
*pos += width;
Ok(u32::from_be_bytes(out))
}
fn validate_pn_len(packet_number_len: u8) -> Result<u8, QuicCoreError> {
if (1..=4).contains(&packet_number_len) {
Ok(packet_number_len)
} else {
Err(QuicCoreError::InvalidHeader(
"packet number length must be 1..=4",
))
}
}
fn ensure_pn_fits(packet_number: u32, packet_number_len: u8) -> Result<(), QuicCoreError> {
let max = match packet_number_len {
1 => 0xff,
2 => 0xffff,
3 => 0x00ff_ffff,
4 => u32::MAX,
_ => return Err(QuicCoreError::InvalidHeader("invalid packet number length")),
};
if packet_number <= max {
Ok(())
} else {
Err(QuicCoreError::PacketNumberTooLarge {
packet_number,
width: packet_number_len,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn reference_encode_varint_rfc9000(value: u64) -> Result<Vec<u8>, QuicCoreError> {
if value > QUIC_VARINT_MAX {
return Err(QuicCoreError::VarIntOutOfRange(value));
}
let encoded = if value <= 63 {
vec![value as u8]
} else if value <= 16_383 {
((value as u16) | 0x4000).to_be_bytes().to_vec()
} else if value <= ((1 << 30) - 1) {
((value as u32) | 0x8000_0000).to_be_bytes().to_vec()
} else {
(value | 0xc000_0000_0000_0000).to_be_bytes().to_vec()
};
Ok(encoded)
}
fn reference_decode_varint_rfc9000(input: &[u8]) -> Result<(u64, usize), QuicCoreError> {
let Some(&first) = input.first() else {
return Err(QuicCoreError::UnexpectedEof);
};
let prefix = first >> 6;
let len = 1usize << usize::from(prefix);
if input.len() < len {
return Err(QuicCoreError::UnexpectedEof);
}
let value = match len {
1 => u64::from(first & 0x3f),
2 => u64::from(u16::from_be_bytes([first & 0x3f, input[1]])),
4 => u64::from(u32::from_be_bytes([
first & 0x3f,
input[1],
input[2],
input[3],
])),
8 => u64::from_be_bytes([
first & 0x3f,
input[1],
input[2],
input[3],
input[4],
input[5],
input[6],
input[7],
]),
_ => unreachable!("QUIC varints are only 1, 2, 4, or 8 bytes"),
};
Ok((value, len))
}
#[test]
fn varint_roundtrip_boundaries() {
let values = [
0u64,
63,
64,
16_383,
16_384,
(1 << 30) - 1,
1 << 30,
QUIC_VARINT_MAX,
];
for value in values {
let mut encoded = Vec::new();
encode_varint(value, &mut encoded).expect("encode");
let (decoded, consumed) = decode_varint(&encoded).expect("decode");
assert_eq!(decoded, value);
assert_eq!(consumed, encoded.len());
}
}
#[test]
fn varint_rejects_out_of_range() {
let mut out = Vec::new();
let err = encode_varint(QUIC_VARINT_MAX + 1, &mut out).expect_err("should fail");
assert_eq!(err, QuicCoreError::VarIntOutOfRange(QUIC_VARINT_MAX + 1));
}
#[test]
fn varint_detects_truncation() {
let encoded = [0b01_000000u8];
let err = decode_varint(&encoded).expect_err("should fail");
assert_eq!(err, QuicCoreError::UnexpectedEof);
}
#[test]
fn rfc9000_varint_examples_match_reference_codec() {
let cases = [
(37u64, vec![0x25]),
(15_293, vec![0x7b, 0xbd]),
(494_878_333, vec![0x9d, 0x7f, 0x3e, 0x7d]),
(
151_288_809_941_952_652,
vec![0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c],
),
];
for (value, expected_wire) in cases {
let reference_wire = reference_encode_varint_rfc9000(value).expect("reference encode");
assert_eq!(
reference_wire, expected_wire,
"reference encoder must match RFC 9000 example bytes for {value}"
);
let mut ours = Vec::new();
encode_varint(value, &mut ours).expect("encode");
assert_eq!(
ours, reference_wire,
"implementation diverged from RFC 9000 example encoding for {value}"
);
let ours_decoded = decode_varint(&expected_wire).expect("decode");
let reference_decoded =
reference_decode_varint_rfc9000(&expected_wire).expect("reference decode");
assert_eq!(
ours_decoded, reference_decoded,
"implementation diverged from reference decoder for RFC 9000 bytes {expected_wire:02x?}"
);
assert_eq!(ours_decoded, (value, expected_wire.len()));
}
}
#[test]
fn connection_id_bounds() {
assert!(ConnectionId::new(&[0u8; 20]).is_ok());
let err = ConnectionId::new(&[0u8; 21]).expect_err("should fail");
assert_eq!(err, QuicCoreError::InvalidConnectionIdLength(21));
}
#[test]
fn long_initial_header_roundtrip() {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Initial,
version: 1,
dst_cid: ConnectionId::new(&[1, 2, 3, 4]).expect("cid"),
src_cid: ConnectionId::new(&[9, 8, 7]).expect("cid"),
token: vec![0xaa, 0xbb],
payload_length: 1234,
packet_number: 0x1234,
packet_number_len: 2,
});
let mut buf = Vec::new();
header.encode(&mut buf).expect("encode");
let (decoded, consumed) = PacketHeader::decode(&buf, 0).expect("decode");
assert_eq!(decoded, header);
assert_eq!(consumed, buf.len());
}
#[test]
fn long_header_rejects_reserved_bits() {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Initial,
version: 1,
dst_cid: ConnectionId::new(&[1, 2, 3, 4]).expect("cid"),
src_cid: ConnectionId::new(&[9, 8, 7]).expect("cid"),
token: vec![],
payload_length: 2,
packet_number: 1,
packet_number_len: 2,
});
let mut buf = Vec::new();
header.encode(&mut buf).expect("encode");
buf[0] |= 0x0c;
let err = PacketHeader::decode(&buf, 0).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidHeader("long header reserved bits set")
);
}
#[test]
fn long_header_rejects_non_initial_token() {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Handshake,
version: 1,
dst_cid: ConnectionId::new(&[1, 2, 3, 4]).expect("cid"),
src_cid: ConnectionId::new(&[9, 8, 7]).expect("cid"),
token: vec![1],
payload_length: 2,
packet_number: 1,
packet_number_len: 2,
});
let mut buf = Vec::new();
let err = header.encode(&mut buf).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidHeader("token only valid for Initial packets")
);
}
#[test]
fn short_header_roundtrip() {
let header = PacketHeader::Short(ShortHeader {
spin: true,
key_phase: true,
dst_cid: ConnectionId::new(&[0xde, 0xad, 0xbe, 0xef]).expect("cid"),
packet_number: 0x00ab_cdef,
packet_number_len: 3,
});
let mut buf = Vec::new();
header.encode(&mut buf).expect("encode");
let (decoded, consumed) = PacketHeader::decode(&buf, 4).expect("decode");
assert_eq!(decoded, header);
assert_eq!(consumed, buf.len());
}
#[test]
fn short_header_rejects_reserved_bits() {
let header = PacketHeader::Short(ShortHeader {
spin: false,
key_phase: false,
dst_cid: ConnectionId::new(&[0xde, 0xad, 0xbe, 0xef]).expect("cid"),
packet_number: 1,
packet_number_len: 1,
});
let mut buf = Vec::new();
header.encode(&mut buf).expect("encode");
buf[0] |= 0x18;
let err = PacketHeader::decode(&buf, 4).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidHeader("short header reserved bits set")
);
}
#[test]
fn retry_header_roundtrip() {
let header = PacketHeader::Retry(RetryHeader {
version: 0x0000_0001,
dst_cid: ConnectionId::new(&[0xaa, 0xbb, 0xcc]).expect("cid"),
src_cid: ConnectionId::new(&[0x10, 0x20]).expect("cid"),
token: vec![0xde, 0xad, 0xbe, 0xef],
integrity_tag: [
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54,
0x32, 0x10,
],
});
let mut buf = Vec::new();
header.encode(&mut buf).expect("encode");
let (decoded, consumed) = PacketHeader::decode(&buf, 0).expect("decode");
assert_eq!(decoded, header);
assert_eq!(consumed, buf.len());
}
#[test]
fn retry_header_rejects_reserved_bits() {
let raw = [
0b1111_0001,
0,
0,
0,
1,
1,
0xaa,
1,
0xbb,
0x01,
0x23,
0x45,
0x67,
0x89,
0xab,
0xcd,
0xef,
0xfe,
0xdc,
0xba,
0x98,
0x76,
0x54,
0x32,
0x10,
];
let err = PacketHeader::decode(&raw, 0).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidHeader("retry header reserved bits set")
);
}
#[test]
fn transport_params_roundtrip_with_unknown() {
let params = TransportParameters {
max_idle_timeout: Some(10_000),
initial_max_data: Some(1_000_000),
disable_active_migration: true,
unknown: vec![UnknownTransportParameter {
id: 0xface,
value: vec![1, 2, 3, 4],
}],
..TransportParameters::default()
};
let mut encoded = Vec::new();
params.encode(&mut encoded).expect("encode");
let decoded = TransportParameters::decode(&encoded).expect("decode");
assert_eq!(decoded, params);
}
#[test]
fn transport_params_reject_duplicate_known() {
let mut encoded = Vec::new();
encode_parameter(&mut encoded, TP_MAX_ACK_DELAY, &[0x19]).expect("encode");
encode_parameter(&mut encoded, TP_MAX_ACK_DELAY, &[0x1a]).expect("encode");
let err = TransportParameters::decode(&encoded).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::DuplicateTransportParameter(TP_MAX_ACK_DELAY)
);
}
#[test]
fn transport_params_reject_nonempty_disable_active_migration() {
let mut encoded = Vec::new();
encode_parameter(&mut encoded, TP_DISABLE_ACTIVE_MIGRATION, &[0x01]).expect("encode");
let err = TransportParameters::decode(&encoded).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidTransportParameter(TP_DISABLE_ACTIVE_MIGRATION)
);
}
#[test]
fn transport_params_reject_duplicate_unknown() {
let mut encoded = Vec::new();
encode_parameter(&mut encoded, 0x1337, &[0x01]).expect("encode");
encode_parameter(&mut encoded, 0x1337, &[0x02]).expect("encode");
let err = TransportParameters::decode(&encoded).expect_err("should fail");
assert_eq!(err, QuicCoreError::DuplicateTransportParameter(0x1337));
}
#[test]
fn transport_params_reject_small_udp_payload() {
let mut encoded = Vec::new();
let mut body = Vec::new();
encode_varint(1199, &mut body).expect("varint");
encode_parameter(&mut encoded, TP_MAX_UDP_PAYLOAD_SIZE, &body).expect("encode");
let err = TransportParameters::decode(&encoded).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidTransportParameter(TP_MAX_UDP_PAYLOAD_SIZE)
);
}
#[test]
fn transport_params_reject_large_ack_delay_exponent() {
let mut encoded = Vec::new();
let mut body = Vec::new();
encode_varint(21, &mut body).expect("varint");
encode_parameter(&mut encoded, TP_ACK_DELAY_EXPONENT, &body).expect("encode");
let err = TransportParameters::decode(&encoded).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::InvalidTransportParameter(TP_ACK_DELAY_EXPONENT)
);
}
#[test]
fn varint_decode_empty_input() {
let err = decode_varint(&[]).expect_err("empty should fail");
assert_eq!(err, QuicCoreError::UnexpectedEof);
}
#[test]
fn varint_decode_truncated_4byte() {
let err = decode_varint(&[0x80, 0x01]).expect_err("truncated 4-byte should fail");
assert_eq!(err, QuicCoreError::UnexpectedEof);
}
#[test]
fn varint_decode_truncated_8byte() {
let err = decode_varint(&[0xc0, 0x00, 0x00]).expect_err("truncated 8-byte should fail");
assert_eq!(err, QuicCoreError::UnexpectedEof);
}
#[test]
fn varint_encoding_sizes() {
let mut buf = Vec::new();
encode_varint(0, &mut buf).unwrap();
assert_eq!(buf.len(), 1);
buf.clear();
encode_varint(63, &mut buf).unwrap();
assert_eq!(buf.len(), 1);
buf.clear();
encode_varint(64, &mut buf).unwrap();
assert_eq!(buf.len(), 2);
buf.clear();
encode_varint(16383, &mut buf).unwrap();
assert_eq!(buf.len(), 2);
buf.clear();
encode_varint(16384, &mut buf).unwrap();
assert_eq!(buf.len(), 4);
buf.clear();
encode_varint((1 << 30) - 1, &mut buf).unwrap();
assert_eq!(buf.len(), 4);
buf.clear();
encode_varint(1 << 30, &mut buf).unwrap();
assert_eq!(buf.len(), 8);
buf.clear();
encode_varint(QUIC_VARINT_MAX, &mut buf).unwrap();
assert_eq!(buf.len(), 8);
}
#[test]
fn transport_params_empty_roundtrip() {
let params = TransportParameters::default();
let mut encoded = Vec::new();
params.encode(&mut encoded).unwrap();
assert!(encoded.is_empty());
let decoded = TransportParameters::decode(&encoded).unwrap();
assert_eq!(decoded, params);
}
#[test]
fn transport_params_single_param_roundtrip() {
let params = TransportParameters {
max_idle_timeout: Some(30_000),
..TransportParameters::default()
};
let mut encoded = Vec::new();
params.encode(&mut encoded).unwrap();
let decoded = TransportParameters::decode(&encoded).unwrap();
assert_eq!(decoded, params);
}
#[test]
fn transport_params_all_known_fields_roundtrip() {
let params = TransportParameters {
max_idle_timeout: Some(30_000),
max_udp_payload_size: Some(1400),
initial_max_data: Some(1_000_000),
initial_max_stream_data_bidi_local: Some(256_000),
initial_max_stream_data_bidi_remote: Some(256_000),
initial_max_stream_data_uni: Some(128_000),
initial_max_streams_bidi: Some(100),
initial_max_streams_uni: Some(50),
ack_delay_exponent: Some(3),
max_ack_delay: Some(25),
disable_active_migration: true,
unknown: vec![],
};
let mut encoded = Vec::new();
params.encode(&mut encoded).unwrap();
let decoded = TransportParameters::decode(&encoded).unwrap();
assert_eq!(decoded, params);
}
#[test]
fn transport_params_unknown_preserved() {
let params = TransportParameters {
unknown: vec![
UnknownTransportParameter {
id: 0xff00,
value: vec![0x01, 0x02, 0x03],
},
UnknownTransportParameter {
id: 0xff01,
value: vec![],
},
],
..TransportParameters::default()
};
let mut encoded = Vec::new();
params.encode(&mut encoded).unwrap();
let decoded = TransportParameters::decode(&encoded).unwrap();
assert_eq!(decoded.unknown.len(), 2);
assert_eq!(decoded.unknown[0].id, 0xff00);
assert_eq!(decoded.unknown[0].value, vec![0x01, 0x02, 0x03]);
assert_eq!(decoded.unknown[1].id, 0xff01);
assert!(decoded.unknown[1].value.is_empty());
}
#[test]
fn quic_core_error_display_all_variants() {
let cases: Vec<(QuicCoreError, &str)> = vec![
(QuicCoreError::UnexpectedEof, "unexpected EOF"),
(
QuicCoreError::VarIntOutOfRange(99),
"varint out of range: 99",
),
(
QuicCoreError::InvalidHeader("test msg"),
"invalid header: test msg",
),
(
QuicCoreError::InvalidConnectionIdLength(25),
"invalid connection id length: 25",
),
(
QuicCoreError::PacketNumberTooLarge {
packet_number: 1000,
width: 1,
},
"packet number 1000 does not fit in 1 bytes",
),
(
QuicCoreError::DuplicateTransportParameter(0x01),
"duplicate transport parameter: 0x1",
),
(
QuicCoreError::InvalidTransportParameter(0x03),
"invalid transport parameter: 0x3",
),
];
for (err, expected) in cases {
assert_eq!(format!("{err}"), expected);
}
}
#[test]
fn quic_core_error_is_std_error() {
let err: Box<dyn std::error::Error> = Box::new(QuicCoreError::UnexpectedEof);
assert!(err.source().is_none());
}
#[test]
fn connection_id_empty_and_max() {
let empty = ConnectionId::new(&[]).unwrap();
assert!(empty.is_empty());
assert_eq!(empty.len(), 0);
assert_eq!(empty.as_bytes(), &[] as &[u8]);
let max = ConnectionId::new(&[0xab; 20]).unwrap();
assert!(!max.is_empty());
assert_eq!(max.len(), 20);
assert_eq!(max.as_bytes().len(), 20);
let debug = format!("{empty:?}");
assert!(debug.contains("ConnectionId("));
}
#[test]
fn packet_header_decode_empty_input() {
let err = PacketHeader::decode(&[], 0).expect_err("empty should fail");
assert_eq!(err, QuicCoreError::UnexpectedEof);
}
#[test]
fn long_header_handshake_roundtrip() {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Handshake,
version: 0x0000_0001,
dst_cid: ConnectionId::new(&[0x01, 0x02]).unwrap(),
src_cid: ConnectionId::new(&[0x03]).unwrap(),
token: vec![],
payload_length: 100,
packet_number: 42,
packet_number_len: 1,
});
let mut buf = Vec::new();
header.encode(&mut buf).unwrap();
let (decoded, consumed) = PacketHeader::decode(&buf, 0).unwrap();
assert_eq!(decoded, header);
assert_eq!(consumed, buf.len());
}
#[test]
fn long_header_zerortt_roundtrip() {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::ZeroRtt,
version: 0xff00_001d,
dst_cid: ConnectionId::new(&[0xaa, 0xbb, 0xcc]).unwrap(),
src_cid: ConnectionId::new(&[]).unwrap(),
token: vec![],
payload_length: 50,
packet_number: 7,
packet_number_len: 1,
});
let mut buf = Vec::new();
header.encode(&mut buf).unwrap();
let (decoded, consumed) = PacketHeader::decode(&buf, 0).unwrap();
assert_eq!(decoded, header);
assert_eq!(consumed, buf.len());
}
#[test]
fn packet_number_too_large_for_width() {
let header = PacketHeader::Short(ShortHeader {
spin: false,
key_phase: false,
dst_cid: ConnectionId::new(&[0x01]).unwrap(),
packet_number: 256, packet_number_len: 1,
});
let mut buf = Vec::new();
let err = header.encode(&mut buf).expect_err("should fail");
assert_eq!(
err,
QuicCoreError::PacketNumberTooLarge {
packet_number: 256,
width: 1,
}
);
}
#[test]
fn packet_number_length_invalid() {
let header = PacketHeader::Short(ShortHeader {
spin: false,
key_phase: false,
dst_cid: ConnectionId::new(&[0x01]).unwrap(),
packet_number: 1,
packet_number_len: 0, });
let mut buf = Vec::new();
let err = header.encode(&mut buf).expect_err("should fail");
assert!(matches!(err, QuicCoreError::InvalidHeader(_)));
}
#[test]
fn long_header_payload_length_too_small() {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Initial,
version: 1,
dst_cid: ConnectionId::new(&[]).unwrap(),
src_cid: ConnectionId::new(&[]).unwrap(),
token: vec![],
payload_length: 0, packet_number: 1,
packet_number_len: 1,
});
let mut buf = Vec::new();
let err = header.encode(&mut buf).expect_err("should fail");
assert!(matches!(err, QuicCoreError::InvalidHeader(_)));
}
#[test]
fn transport_params_truncated_value() {
let mut encoded = Vec::new();
encode_varint(TP_MAX_IDLE_TIMEOUT, &mut encoded).unwrap();
encode_varint(10, &mut encoded).unwrap(); encoded.extend_from_slice(&[0x01, 0x02, 0x03]); let err = TransportParameters::decode(&encoded).expect_err("should fail");
assert_eq!(err, QuicCoreError::UnexpectedEof);
}
#[test]
fn long_packet_type_debug_clone_eq() {
let types = [
LongPacketType::Initial,
LongPacketType::ZeroRtt,
LongPacketType::Handshake,
LongPacketType::Retry,
];
for t in &types {
let clone = *t;
assert_eq!(clone, *t);
assert!(!format!("{t:?}").is_empty());
}
}
#[test]
fn unknown_transport_parameter_debug_clone_eq() {
let p = UnknownTransportParameter {
id: 42,
value: vec![1, 2, 3],
};
let p2 = p.clone();
assert_eq!(p, p2);
assert!(format!("{p:?}").contains("UnknownTransportParameter"));
}
#[test]
fn quic_core_error_debug_clone_eq_display() {
let e1 = QuicCoreError::UnexpectedEof;
let e2 = QuicCoreError::VarIntOutOfRange(999);
assert!(format!("{e1:?}").contains("UnexpectedEof"));
assert!(format!("{e1}").contains("unexpected EOF"));
assert!(format!("{e2}").contains("varint out of range"));
assert_eq!(e1.clone(), e1);
assert_ne!(e1, e2);
let err: &dyn std::error::Error = &e1;
assert!(err.source().is_none());
}
#[test]
fn connection_id_debug_clone_copy_eq_hash_default() {
use std::collections::HashSet;
let def = ConnectionId::default();
assert!(def.is_empty());
assert_eq!(def.len(), 0);
let dbg = format!("{def:?}");
assert!(dbg.contains("ConnectionId"), "{dbg}");
let cid = ConnectionId::new(&[0xab, 0xcd]).unwrap();
let copied = cid;
let cloned = cid;
assert_eq!(copied, cloned);
assert_ne!(cid, def);
let mut set = HashSet::new();
set.insert(cid);
set.insert(def);
set.insert(cid);
assert_eq!(set.len(), 2);
}
#[test]
fn transport_parameters_debug_clone_default_eq() {
let def = TransportParameters::default();
let dbg = format!("{def:?}");
assert!(dbg.contains("TransportParameters"), "{dbg}");
assert_eq!(def.max_idle_timeout, None);
assert!(!def.disable_active_migration);
let tp = TransportParameters {
max_idle_timeout: Some(5000),
..TransportParameters::default()
};
let cloned = tp.clone();
assert_eq!(cloned, tp);
assert_ne!(cloned, def);
}
#[test]
fn rfc9000_packet_number_encoding_length() {
let test_cases = [
(0, 1, 1), (1, 1, 1),
(255, 1, 1), (256, 2, 2), (65535, 2, 2), (65536, 3, 3), (16777215, 3, 3), (16777216, 4, 4), (0xFFFFFFFF, 4, 4), ];
for (packet_number, min_width, _max_width) in test_cases {
assert!(
ensure_pn_fits(packet_number, min_width).is_ok(),
"Packet number {packet_number} should fit in {min_width} bytes"
);
let mut buf = Vec::new();
write_packet_number(packet_number, min_width, &mut buf);
assert_eq!(
buf.len(),
min_width as usize,
"Packet number {packet_number} should encode to {min_width} bytes"
);
let mut pos = 0;
let decoded = read_packet_number(&buf, &mut pos, min_width).unwrap();
assert_eq!(
decoded, packet_number,
"Packet number {packet_number} failed round-trip"
);
assert_eq!(pos, buf.len(), "Should consume all encoded bytes");
if min_width > 1 {
assert!(
ensure_pn_fits(packet_number, min_width - 1).is_err(),
"Packet number {packet_number} should NOT fit in {} bytes",
min_width - 1
);
}
}
}
#[test]
fn rfc9000_packet_number_truncation_algorithm() {
let test_cases = [
(0, 1, 1), (0, 255, 1), (0, 256, 2), (100, 101, 1), (100, 356, 2), (1000, 1001, 1), (1000, 2024, 2), (50000, 50001, 1), (50000, 51024, 2), ];
for (largest_acked, full_pn, expected_min_width) in test_cases {
let num_unacked = (full_pn - largest_acked) + 1;
let space_needed = 2 * num_unacked + 1;
let calculated_width = if space_needed <= 256 {
1
} else if space_needed <= 65536 {
2
} else if space_needed <= 16777216 {
3
} else {
4
};
assert_eq!(
calculated_width, expected_min_width,
"Truncation algorithm: largest_acked={largest_acked}, full_pn={full_pn}, \
unacked_range={num_unacked}, space_needed={space_needed}"
);
assert!(
ensure_pn_fits(full_pn, calculated_width).is_ok(),
"Calculated width {calculated_width} should accommodate packet number {full_pn}"
);
}
}
#[test]
fn rfc9000_packet_number_edge_cases() {
let mut buf = Vec::new();
write_packet_number(0, 1, &mut buf);
assert_eq!(buf, vec![0x00]);
buf.clear();
write_packet_number(255, 1, &mut buf);
assert_eq!(buf, vec![0xFF]);
buf.clear();
write_packet_number(256, 2, &mut buf);
assert_eq!(buf, vec![0x01, 0x00]);
buf.clear();
write_packet_number(65535, 2, &mut buf);
assert_eq!(buf, vec![0xFF, 0xFF]);
buf.clear();
write_packet_number(65536, 3, &mut buf);
assert_eq!(buf, vec![0x01, 0x00, 0x00]);
buf.clear();
write_packet_number(16777215, 3, &mut buf);
assert_eq!(buf, vec![0xFF, 0xFF, 0xFF]);
buf.clear();
write_packet_number(16777216, 4, &mut buf);
assert_eq!(buf, vec![0x01, 0x00, 0x00, 0x00]);
buf.clear();
write_packet_number(0xFFFFFFFF, 4, &mut buf);
assert_eq!(buf, vec![0xFF, 0xFF, 0xFF, 0xFF]);
}
#[test]
fn rfc9000_packet_number_width_validation() {
for valid_width in [1, 2, 3, 4] {
assert!(
validate_pn_len(valid_width).is_ok(),
"Width {valid_width} should be valid"
);
}
for invalid_width in [0, 5, 6, 255] {
assert!(
validate_pn_len(invalid_width).is_err(),
"Width {invalid_width} should be invalid"
);
}
let err = validate_pn_len(0).unwrap_err();
assert!(matches!(err, QuicCoreError::InvalidHeader(_)));
let err = validate_pn_len(5).unwrap_err();
assert!(matches!(err, QuicCoreError::InvalidHeader(_)));
}
#[test]
fn rfc9000_packet_number_overflow() {
let overflow_cases = [
(256, 1), (65536, 1), (65536, 2), (16777216, 1), (16777216, 2), (16777216, 3), ];
for (packet_number, width) in overflow_cases {
let err = ensure_pn_fits(packet_number, width).unwrap_err();
assert!(
matches!(err, QuicCoreError::PacketNumberTooLarge { .. }),
"Should get PacketNumberTooLarge for pn={packet_number}, width={width}"
);
if let QuicCoreError::PacketNumberTooLarge {
packet_number: pn,
width: w,
} = err
{
assert_eq!(pn, packet_number);
assert_eq!(w, width);
}
}
}
#[test]
fn rfc9000_packet_number_truncated_decode() {
let truncated_cases = [
(2, vec![0x12]), (3, vec![0x12, 0x34]), (4, vec![0x12, 0x34, 0x56]), ];
for (declared_width, truncated_data) in truncated_cases {
let mut pos = 0;
let err = read_packet_number(&truncated_data, &mut pos, declared_width).unwrap_err();
assert_eq!(
err,
QuicCoreError::UnexpectedEof,
"Should get UnexpectedEof for width={declared_width}, data={truncated_data:?}"
);
}
}
#[test]
fn rfc9000_packet_number_in_headers() {
let long_header_cases = [
(1, 1),
(255, 1), (256, 2),
(65535, 2), (65536, 3),
(16777215, 3), (16777216, 4),
(0x12345678, 4), ];
for (packet_number, width) in long_header_cases {
let header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Initial,
version: 1,
dst_cid: ConnectionId::new(&[1, 2, 3, 4]).unwrap(),
src_cid: ConnectionId::new(&[5, 6, 7]).unwrap(),
token: vec![],
payload_length: 100,
packet_number,
packet_number_len: width,
});
let mut buf = Vec::new();
header.encode(&mut buf).unwrap();
let (decoded, consumed) = PacketHeader::decode(&buf, 0).unwrap();
if let PacketHeader::Long(decoded_header) = decoded {
assert_eq!(
decoded_header.packet_number, packet_number,
"Long header packet number round-trip failed"
);
assert_eq!(
decoded_header.packet_number_len, width,
"Long header packet number width mismatch"
);
} else {
panic!("Expected long header");
}
assert_eq!(consumed, buf.len());
}
let short_header_cases = [
(42, 1),
(200, 1), (300, 2),
(50000, 2), (70000, 3),
(1000000, 3), (20000000, 4),
(0x87654321, 4), ];
for (packet_number, width) in short_header_cases {
let header = PacketHeader::Short(ShortHeader {
spin: false,
key_phase: false,
dst_cid: ConnectionId::new(&[0xAA, 0xBB]).unwrap(),
packet_number,
packet_number_len: width,
});
let mut buf = Vec::new();
header.encode(&mut buf).unwrap();
let (decoded, consumed) = PacketHeader::decode(&buf, 2).unwrap();
if let PacketHeader::Short(decoded_header) = decoded {
assert_eq!(
decoded_header.packet_number, packet_number,
"Short header packet number round-trip failed"
);
assert_eq!(
decoded_header.packet_number_len, width,
"Short header packet number width mismatch"
);
} else {
panic!("Expected short header");
}
assert_eq!(consumed, buf.len());
}
}
#[test]
fn rfc9000_packet_number_wire_format() {
let wire_format_cases = [
(0x1234, 2, vec![0x12, 0x34]),
(0x123456, 3, vec![0x12, 0x34, 0x56]),
(0x12345678, 4, vec![0x12, 0x34, 0x56, 0x78]),
];
for (packet_number, width, expected_bytes) in wire_format_cases {
let mut buf = Vec::new();
write_packet_number(packet_number, width, &mut buf);
assert_eq!(
buf, expected_bytes,
"Packet number {packet_number:#x} width {width} wire format mismatch"
);
let mut pos = 0;
let decoded = read_packet_number(&buf, &mut pos, width).unwrap();
assert_eq!(decoded, packet_number);
}
}
#[test]
fn rfc9000_packet_number_space_isolation() {
let packet_number = 1234;
let width = 2;
let initial_header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Initial,
version: 1,
dst_cid: ConnectionId::new(&[1, 2, 3]).unwrap(),
src_cid: ConnectionId::new(&[4, 5, 6]).unwrap(),
token: vec![0xAA, 0xBB],
payload_length: 100,
packet_number,
packet_number_len: width,
});
let handshake_header = PacketHeader::Long(LongHeader {
packet_type: LongPacketType::Handshake,
version: 1,
dst_cid: ConnectionId::new(&[1, 2, 3]).unwrap(),
src_cid: ConnectionId::new(&[4, 5, 6]).unwrap(),
token: vec![], payload_length: 100,
packet_number,
packet_number_len: width,
});
let app_data_header = PacketHeader::Short(ShortHeader {
spin: true,
key_phase: false,
dst_cid: ConnectionId::new(&[1, 2, 3]).unwrap(),
packet_number,
packet_number_len: width,
});
for header in [initial_header, handshake_header, app_data_header] {
let mut buf = Vec::new();
header.encode(&mut buf).unwrap();
let dcid_len = match &header {
PacketHeader::Short(_) => 3, _ => 0, };
let (decoded, _) = PacketHeader::decode(&buf, dcid_len).unwrap();
let decoded_pn = match decoded {
PacketHeader::Long(h) => h.packet_number,
PacketHeader::Short(h) => h.packet_number,
PacketHeader::Retry(_) => panic!("Unexpected retry header"),
};
assert_eq!(
decoded_pn, packet_number,
"Packet number should be preserved across packet type boundaries"
);
}
}
}