use enum_dispatch::enum_dispatch;
use crate::cid::ConnectionId;
pub mod long;
pub mod short;
#[doc(hidden)]
pub use long::{
io::{LongHeaderBuilder, WriteSpecific},
DataHeader, HandshakeHeader, InitialHeader, LongHeader, RetryHeader, VersionNegotiationHeader,
ZeroRttHeader,
};
#[doc(hidden)]
pub use short::OneRttHeader;
use super::r#type::{
long::{v1, Type as LongType, Version},
short::OneRtt,
Type,
};
#[enum_dispatch]
pub trait GetType {
fn get_type(&self) -> Type;
}
#[enum_dispatch]
pub trait EncodeHeader {
fn size(&self) -> usize {
0
}
fn length_encoding(&self) -> usize {
0
}
}
#[enum_dispatch]
pub trait GetDcid {
fn get_dcid(&self) -> &ConnectionId;
}
#[enum_dispatch]
pub trait GetScid {
fn get_scid(&self) -> &ConnectionId;
}
#[derive(Debug, Clone)]
#[enum_dispatch(GetDcid)]
pub enum Header {
VN(long::VersionNegotiationHeader),
Retry(long::RetryHeader),
Initial(long::InitialHeader),
ZeroRtt(long::ZeroRttHeader),
Handshake(long::HandshakeHeader),
OneRtt(short::OneRttHeader),
}
pub mod io {
use super::{
long::{io::LongHeaderBuilder, Handshake, Initial, Retry, VersionNegotiation, ZeroRtt},
Header, LongHeader, OneRttHeader,
};
use crate::{
cid::be_connection_id,
packet::{
header::short::io::be_one_rtt_header,
r#type::{short::OneRtt, Type},
},
};
pub fn be_header(
packet_type: Type,
dcid_len: usize,
input: &[u8],
) -> nom::IResult<&[u8], Header> {
match packet_type {
Type::Long(long_ty) => {
let (remain, dcid) = be_connection_id(input)?;
let (remain, scid) = be_connection_id(remain)?;
let builder = LongHeaderBuilder { dcid, scid };
builder.parse(long_ty, remain)
}
Type::Short(OneRtt(spin)) => {
let (remain, one_rtt) = be_one_rtt_header(spin, dcid_len, input)?;
Ok((remain, Header::OneRtt(one_rtt)))
}
}
}
pub trait WriteHeader<H>: bytes::BufMut {
fn put_header(&mut self, header: &H);
}
impl<T> WriteHeader<Header> for T
where
T: bytes::BufMut
+ WriteHeader<LongHeader<VersionNegotiation>>
+ WriteHeader<LongHeader<Retry>>
+ WriteHeader<LongHeader<Initial>>
+ WriteHeader<LongHeader<ZeroRtt>>
+ WriteHeader<LongHeader<Handshake>>
+ WriteHeader<OneRttHeader>,
{
fn put_header(&mut self, header: &Header) {
match header {
Header::VN(vn) => self.put_header(vn),
Header::Retry(retry) => self.put_header(retry),
Header::Initial(initial) => self.put_header(initial),
Header::ZeroRtt(zero_rtt) => self.put_header(zero_rtt),
Header::Handshake(handshake) => self.put_header(handshake),
Header::OneRtt(one_rtt) => self.put_header(one_rtt),
}
}
}
}
#[cfg(test)]
mod tests {
use super::{
io::be_header,
long::{Handshake, Initial, Retry, VersionNegotiation, ZeroRtt},
Header, LongHeaderBuilder,
};
use crate::{
cid::ConnectionId,
packet::{
header::io::WriteHeader,
r#type::{long, long::Ver1, short::OneRtt, Type},
OneRttHeader, SpinBit,
},
};
#[test]
fn test_read_header() {
let buf = vec![0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02];
let (remain, vn_long_header) =
be_header(Type::Long(long::Type::VersionNegotiation), 0, &buf).unwrap();
assert_eq!(remain.len(), 0);
match vn_long_header {
Header::VN(vn) => {
assert_eq!(vn.dcid, ConnectionId::default());
assert_eq!(vn.scid, ConnectionId::default());
assert_eq!(vn.specific.versions, vec![0x01, 0x02]);
}
_ => panic!("unexpected header type"),
}
let buf = vec![
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
];
let (remain, retry_long_header) =
be_header(Type::Long(long::Type::V1(Ver1::RETRY)), 0, &buf).unwrap();
assert_eq!(remain.len(), 0);
match retry_long_header {
Header::Retry(retry) => {
assert_eq!(retry.dcid, ConnectionId::default());
assert_eq!(retry.scid, ConnectionId::default());
assert_eq!(retry.token, [0x00, 0x00, 0x00]);
assert_eq!(
retry.integrity,
[
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f
]
);
}
_ => panic!("unexpected header type"),
}
let buf = vec![
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
0x0f,
];
match be_header(Type::Long(long::Type::V1(Ver1::RETRY)), 0, &buf) {
Err(e) => assert_eq!(e, nom::Err::Incomplete(nom::Needed::new(16))),
_ => panic!("unexpected result"),
}
let buf = vec![0x00, 0x00, 0x03, 0x01, 0x02, 0x03];
let (remain, initial_long_header) =
be_header(Type::Long(long::Type::V1(Ver1::INITIAL)), 0, &buf).unwrap();
assert_eq!(remain.len(), 0);
match initial_long_header {
Header::Initial(initial) => {
assert_eq!(initial.dcid, ConnectionId::default());
assert_eq!(initial.scid, ConnectionId::default());
assert_eq!(initial.token, [0x01, 0x02, 0x03,]);
}
_ => panic!("unexpected header type"),
}
let buf = vec![0x00, 0x00];
let (remain, zero_rtt_long_header) =
be_header(Type::Long(long::Type::V1(Ver1::ZERO_RTT)), 0, &buf).unwrap();
assert_eq!(remain.len(), 0);
match zero_rtt_long_header {
Header::ZeroRtt(zero_rtt) => {
assert_eq!(zero_rtt.dcid, ConnectionId::default());
assert_eq!(zero_rtt.scid, ConnectionId::default());
}
_ => panic!("unexpected header type"),
}
let buf = vec![0x00, 0x00];
let (remain, handshake_long_header) =
be_header(Type::Long(long::Type::V1(Ver1::HANDSHAKE)), 0, &buf).unwrap();
assert_eq!(remain.len(), 0);
match handshake_long_header {
Header::Handshake(handshake) => {
assert_eq!(handshake.dcid, ConnectionId::default());
assert_eq!(handshake.scid, ConnectionId::default());
}
_ => panic!("unexpected header type"),
}
let buf = vec![];
let (remain, one_rtt_header) =
be_header(Type::Short(OneRtt(SpinBit::One)), 0, &buf).unwrap();
assert_eq!(remain.len(), 0);
match one_rtt_header {
Header::OneRtt(one_rtt) => {
assert_eq!(
one_rtt,
OneRttHeader::new(SpinBit::One, ConnectionId::default())
);
}
_ => panic!("unexpected header type"),
}
}
#[test]
fn test_write_header() {
let mut buf = vec![];
let vn_long_header = Header::VN(
LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()).wrap(
VersionNegotiation {
versions: vec![0x01, 0x02],
},
),
);
buf.put_header(&vn_long_header);
assert_eq!(
buf,
[
0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
0x02
]
);
let mut buf = vec![];
let retry_long_header = Header::Retry(
LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()).wrap(
Retry {
token: vec![0x00, 0x00, 0x00],
integrity: [
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
0x0c, 0x0d, 0x0e, 0x0f,
],
},
),
);
buf.put_header(&retry_long_header);
assert_eq!(
buf,
[
0xf0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03,
0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
]
);
let mut buf = vec![];
let initial_header = Header::Initial(
LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default()).wrap(
Initial {
token: vec![0x01, 0x02, 0x03],
},
),
);
buf.put_header(&initial_header);
assert_eq!(
buf,
[0xc0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x03, 0x01, 0x02, 0x03]
);
let mut buf = vec![];
let zero_rtt_header = Header::ZeroRtt(
LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default())
.wrap(ZeroRtt),
);
buf.put_header(&zero_rtt_header);
assert_eq!(buf, [0xd0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00]);
let mut buf = vec![];
let handshake_header = Header::Handshake(
LongHeaderBuilder::with_cid(ConnectionId::default(), ConnectionId::default())
.wrap(Handshake),
);
buf.put_header(&handshake_header);
assert_eq!(buf, [0xe0, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00]);
let mut buf = vec![];
let one_rtt_header =
Header::OneRtt(OneRttHeader::new(SpinBit::One, ConnectionId::default()));
buf.put_header(&one_rtt_header);
assert_eq!(buf, [0x60]);
let mut buf = vec![];
let one_rtt_header =
Header::OneRtt(OneRttHeader::new(SpinBit::Zero, ConnectionId::default()));
buf.put_header(&one_rtt_header);
assert_eq!(buf, [0x40]);
}
}