use crate::{
packet::{encoding::PacketPayloadLenCursor, number::TruncatedPacketNumber},
varint::VarInt,
};
use s2n_codec::{
decoder_invariant, CheckedRange, DecoderError, Encoder, EncoderBuffer, EncoderValue,
};
pub(crate) const PACKET_TYPE_MASK: u8 = 0x30;
const PACKET_TYPE_OFFSET: u8 = 4;
pub(crate) type Version = u32;
pub(crate) type DestinationConnectionIdLen = u8;
pub(crate) const DESTINATION_CONNECTION_ID_MAX_LEN: usize = 20;
pub(crate) fn validate_destination_connection_id_range(
range: &CheckedRange,
) -> Result<(), DecoderError> {
validate_destination_connection_id_len(range.len())
}
pub(crate) fn validate_destination_connection_id_len(len: usize) -> Result<(), DecoderError> {
decoder_invariant!(
len <= DESTINATION_CONNECTION_ID_MAX_LEN,
"destination connection exceeds max length"
);
Ok(())
}
pub(crate) type SourceConnectionIdLen = u8;
pub(crate) const SOURCE_CONNECTION_ID_MAX_LEN: usize = 20;
pub(crate) fn validate_source_connection_id_range(
range: &CheckedRange,
) -> Result<(), DecoderError> {
validate_source_connection_id_len(range.len())
}
pub(crate) fn validate_source_connection_id_len(len: usize) -> Result<(), DecoderError> {
decoder_invariant!(
len <= SOURCE_CONNECTION_ID_MAX_LEN,
"source connection exceeds max length"
);
Ok(())
}
#[repr(u8)]
#[derive(Clone, Copy, Debug)]
pub enum PacketType {
Initial = 0x0,
ZeroRtt = 0x1,
Handshake = 0x2,
Retry = 0x3,
}
impl PacketType {
pub const fn into_bits(self) -> u8 {
((self as u8) << PACKET_TYPE_OFFSET) & PACKET_TYPE_MASK
}
pub fn from_bits(bits: u8) -> Self {
(bits & (PACKET_TYPE_MASK >> PACKET_TYPE_OFFSET)).into()
}
}
impl From<u8> for PacketType {
fn from(bits: u8) -> Self {
match bits {
0x0 => PacketType::Initial,
0x1 => PacketType::ZeroRtt,
0x2 => PacketType::Handshake,
0x3 => PacketType::Retry,
_ => PacketType::Initial,
}
}
}
impl From<PacketType> for u8 {
fn from(v: PacketType) -> Self {
v.into_bits()
}
}
pub const RESERVED_BITS_MASK: u8 = 0x0c;
pub(crate) struct LongPayloadEncoder<Payload> {
pub packet_number: TruncatedPacketNumber,
pub payload: Payload,
}
impl<Payload: EncoderValue> EncoderValue for LongPayloadEncoder<&Payload> {
fn encode<E: Encoder>(&self, encoder: &mut E) {
self.packet_number.encode(encoder);
self.payload.encode(encoder);
}
}
impl<Payload: EncoderValue> EncoderValue for LongPayloadEncoder<&mut Payload> {
fn encode<E: Encoder>(&self, encoder: &mut E) {
self.packet_number.encode(encoder);
self.payload.encode(encoder);
}
fn encode_mut<E: Encoder>(&mut self, encoder: &mut E) {
self.packet_number.encode_mut(encoder);
self.payload.encode_mut(encoder);
}
}
#[doc(hidden)]
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LongPayloadLenCursor {
position: usize,
max_value: VarInt,
}
impl PacketPayloadLenCursor for LongPayloadLenCursor {
fn new() -> Self {
let max_value = VarInt::MAX;
Self {
position: 0,
max_value,
}
}
fn update(&self, buffer: &mut EncoderBuffer, actual_len: usize) {
debug_assert!(
self.position != 0,
"position cursor was not updated. encode_mut should be called instead of encode"
);
let actual_value =
VarInt::try_from(actual_len).expect("packets should not be larger than VarInt::MAX");
let max_value = self.max_value;
let prev_pos = buffer.len();
buffer.set_position(self.position);
max_value.encode_updated(actual_value, buffer);
buffer.set_position(prev_pos);
}
}
impl EncoderValue for LongPayloadLenCursor {
fn encode<E: Encoder>(&self, encoder: &mut E) {
self.max_value.encode(encoder)
}
fn encode_mut<E: Encoder>(&mut self, encoder: &mut E) {
self.position = encoder.len();
self.max_value = VarInt::try_from(encoder.remaining_capacity()).unwrap_or(VarInt::MAX);
self.max_value.encode(encoder)
}
}