use std::ops::Deref;
use bytes::{buf::UninitSlice, BufMut, BytesMut};
use deref_derive::{Deref, DerefMut};
use encrypt::{encode_long_first_byte, encode_short_first_byte, encrypt_packet, protect_header};
use enum_dispatch::enum_dispatch;
use getset::CopyGetters;
use header::io::WriteHeader;
use crate::{
cid::ConnectionId,
frame::{
io::{WriteDataFrame, WriteFrame},
BeFrame, ContainSpec, Spec,
},
util::{DescribeData, WriteData},
varint::{EncodeBytes, VarInt, WriteVarInt},
};
pub mod error;
pub mod signal;
#[doc(hidden)]
pub use signal::{KeyPhaseBit, SpinBit};
pub mod r#type;
#[doc(hidden)]
pub use r#type::{
GetPacketNumberLength, LongSpecificBits, ShortSpecificBits, Type, LONG_RESERVED_MASK,
SHORT_RESERVED_MASK,
};
pub mod header;
#[doc(hidden)]
pub use header::{
long, EncodeHeader, GetDcid, GetType, HandshakeHeader, Header, InitialHeader,
LongHeaderBuilder, OneRttHeader, RetryHeader, VersionNegotiationHeader, ZeroRttHeader,
};
pub mod io;
pub mod number;
#[doc(hidden)]
pub use number::{take_pn_len, PacketNumber, WritePacketNumber};
pub mod decrypt;
pub mod encrypt;
pub mod keys;
#[derive(Debug, Clone)]
#[enum_dispatch(GetDcid, GetType)]
pub enum DataHeader {
Long(long::DataHeader),
Short(OneRttHeader),
}
#[derive(Debug, Clone, Deref, DerefMut)]
pub struct DataPacket {
#[deref]
pub header: DataHeader,
pub bytes: BytesMut,
pub offset: usize,
}
impl GetType for DataPacket {
fn get_type(&self) -> Type {
self.header.get_type()
}
}
#[derive(Debug, Clone)]
pub enum Packet {
VN(VersionNegotiationHeader),
Retry(RetryHeader),
Data(DataPacket),
}
#[derive(Debug)]
pub struct PacketReader {
raw: BytesMut,
dcid_len: usize,
}
impl PacketReader {
pub fn new(raw: BytesMut, dcid_len: usize) -> Self {
Self { raw, dcid_len }
}
}
impl Iterator for PacketReader {
type Item = Result<Packet, error::Error>;
fn next(&mut self) -> Option<Self::Item> {
if self.raw.is_empty() {
return None;
}
match io::be_packet(&mut self.raw, self.dcid_len) {
Ok(packet) => Some(Ok(packet)),
Err(e) => {
self.raw.clear(); Some(Err(e))
}
}
}
}
pub trait MarshalFrame<F> {
fn dump_frame(&mut self, frame: F) -> Option<F>;
}
pub trait MarshalDataFrame<F, D> {
fn dump_frame_with_data(&mut self, frame: F, data: D) -> Option<F>;
}
pub struct PacketWriter<'b> {
buffer: &'b mut [u8],
hdr_len: usize,
len_encoding: usize,
pn: (u64, PacketNumber),
cursor: usize,
end: usize,
tag_len: usize,
ack_eliciting: bool,
in_flight: bool,
_probe_new_path: bool,
}
impl<'b> PacketWriter<'b> {
pub fn new<H>(
header: &H,
buffer: &'b mut [u8],
pn: (u64, PacketNumber),
tag_len: usize,
) -> Option<Self>
where
H: EncodeHeader,
for<'a> &'a mut [u8]: WriteHeader<H>,
{
let hdr_len = header.size();
let len_encoding = header.length_encoding();
if buffer.len() < hdr_len + len_encoding + 20 {
return None;
}
let (mut hdr_buf, mut payload_buf) = buffer.split_at_mut(hdr_len + len_encoding);
let encoded_pn = pn.1;
hdr_buf.put_header(header);
payload_buf.put_packet_number(encoded_pn);
let end = buffer.len() - tag_len;
Some(Self {
buffer,
hdr_len,
len_encoding,
pn,
cursor: hdr_len + len_encoding + encoded_pn.size(),
end,
tag_len,
ack_eliciting: false,
in_flight: false,
_probe_new_path: false,
})
}
pub fn pad(&mut self, cnt: usize) {
self.put_bytes(0, cnt);
}
#[inline]
pub fn is_ack_eliciting(&self) -> bool {
self.ack_eliciting
}
#[inline]
pub fn in_flight(&self) -> bool {
self.in_flight
}
pub fn is_empty(&self) -> bool {
self.cursor == self.hdr_len + self.len_encoding + self.pn.1.size()
}
pub fn encrypt_long_packet(
mut self,
hpk: &dyn rustls::quic::HeaderProtectionKey,
pk: &dyn rustls::quic::PacketKey,
) -> AssembledPacket<'b> {
let mut payload_len = self.cursor - self.hdr_len - self.len_encoding;
debug_assert!(payload_len > 0);
if payload_len + self.tag_len < 20 {
let padding_len = 20 - payload_len - self.tag_len;
self.pad(padding_len);
payload_len += padding_len;
}
let mut len_buf = &mut self.buffer[self.hdr_len..self.hdr_len + self.len_encoding];
let (actual_pn, encoded_pn) = self.pn;
let pkt_size = self.cursor + self.tag_len;
len_buf.encode_varint(&VarInt::try_from(payload_len).unwrap(), EncodeBytes::Two);
encode_long_first_byte(&mut self.buffer[0], encoded_pn.size());
encrypt_packet(
pk,
actual_pn,
&mut self.buffer[..pkt_size],
self.hdr_len + self.len_encoding + encoded_pn.size(),
);
protect_header(
hpk,
&mut self.buffer[..pkt_size],
self.hdr_len,
encoded_pn.size(),
);
AssembledPacket {
buffer: self.buffer,
pn: actual_pn,
size: pkt_size,
is_ack_eliciting: self.ack_eliciting,
in_flight: self.in_flight,
}
}
pub fn encrypt_short_packet(
mut self,
key_phase: KeyPhaseBit,
hpk: &dyn rustls::quic::HeaderProtectionKey,
pk: &dyn rustls::quic::PacketKey,
) -> AssembledPacket<'b> {
let payload_len = self.cursor - self.hdr_len - self.len_encoding;
debug_assert!(payload_len > 0);
if payload_len + self.tag_len < 20 {
let padding_len = 20 - payload_len - self.tag_len;
self.pad(padding_len);
}
let pkt_size = self.cursor + self.tag_len;
let (actual_pn, encoded_pn) = self.pn;
encode_short_first_byte(&mut self.buffer[0], encoded_pn.size(), key_phase);
encrypt_packet(
pk,
actual_pn,
&mut self.buffer[..pkt_size],
self.hdr_len + self.len_encoding + encoded_pn.size(),
);
protect_header(
hpk,
&mut self.buffer[..pkt_size],
self.hdr_len,
encoded_pn.size(),
);
AssembledPacket {
buffer: self.buffer,
pn: actual_pn,
size: pkt_size,
is_ack_eliciting: self.ack_eliciting,
in_flight: self.in_flight,
}
}
}
#[derive(Debug, CopyGetters)]
pub struct AssembledPacket<'b> {
buffer: &'b mut [u8],
#[getset(get_copy = "pub")]
pn: u64,
#[getset(get_copy = "pub")]
size: usize,
#[getset(get_copy = "pub")]
is_ack_eliciting: bool,
#[getset(get_copy = "pub")]
in_flight: bool,
}
impl Deref for AssembledPacket<'_> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.buffer[..self.size]
}
}
impl<F> MarshalFrame<F> for PacketWriter<'_>
where
F: BeFrame,
Self: WriteFrame<F>,
{
fn dump_frame(&mut self, frame: F) -> Option<F> {
let specs = frame.frame_type().specs();
self.ack_eliciting |= !specs.contain(Spec::NonAckEliciting);
self.in_flight |= !specs.contain(Spec::CongestionControlFree);
self.put_frame(&frame);
Some(frame)
}
}
impl<F, D> MarshalDataFrame<F, D> for PacketWriter<'_>
where
F: BeFrame,
D: DescribeData,
Self: WriteData<D> + WriteDataFrame<F, D>,
{
fn dump_frame_with_data(&mut self, frame: F, data: D) -> Option<F> {
self.ack_eliciting = true;
self.in_flight = true;
self.put_data_frame(&frame, &data);
Some(frame)
}
}
unsafe impl BufMut for PacketWriter<'_> {
fn remaining_mut(&self) -> usize {
self.end - self.cursor
}
unsafe fn advance_mut(&mut self, cnt: usize) {
if self.remaining_mut() < cnt {
panic!(
"advance out of bounds: the len is {} but advancing by {}",
cnt,
self.remaining_mut()
);
}
self.cursor += cnt;
}
fn chunk_mut(&mut self) -> &mut UninitSlice {
UninitSlice::new(&mut self.buffer[self.cursor..self.end])
}
}