#![allow(dead_code)]
use alloc::vec::Vec;
use crate::cipher::{Aes128, Gcm};
use crate::quic::varint;
use crate::tls::Error;
pub(crate) const QUIC_V1: u32 = 0x0000_0001;
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum LongType {
Initial = 0x00,
ZeroRtt = 0x01,
Handshake = 0x02,
Retry = 0x03,
}
impl LongType {
fn from_bits(bits: u8) -> Option<Self> {
Some(match bits & 0x03 {
0x00 => Self::Initial,
0x01 => Self::ZeroRtt,
0x02 => Self::Handshake,
0x03 => Self::Retry,
_ => return None,
})
}
fn first_byte_template(self) -> u8 {
0x80 | 0x40 | ((self as u8) << 4)
}
}
#[derive(Debug)]
pub(crate) struct LongHeader<'a> {
pub typ: LongType,
pub version: u32,
pub dcid: &'a [u8],
pub scid: &'a [u8],
pub token: &'a [u8],
pub length: u64,
pub pn_offset: usize,
pub payload_off: usize,
}
impl<'a> LongHeader<'a> {
pub(crate) fn parse(buf: &'a [u8]) -> Result<Self, Error> {
if buf.is_empty() {
return Err(Error::Decode);
}
let b0 = buf[0];
if b0 & 0x80 == 0 {
return Err(Error::Decode); }
if buf.len() < 6 {
return Err(Error::Decode);
}
let version = u32::from_be_bytes([buf[1], buf[2], buf[3], buf[4]]);
if version == 0 {
let mut p = 5usize;
let dcid_len = *buf.get(p).ok_or(Error::Decode)? as usize;
p += 1;
if dcid_len > 20 || buf.len() < p + dcid_len + 1 {
return Err(Error::Decode);
}
let dcid = &buf[p..p + dcid_len];
p += dcid_len;
let scid_len = *buf.get(p).ok_or(Error::Decode)? as usize;
p += 1;
if scid_len > 20 || buf.len() < p + scid_len {
return Err(Error::Decode);
}
let scid = &buf[p..p + scid_len];
p += scid_len;
return Ok(LongHeader {
typ: LongType::Initial, version: 0,
dcid,
scid,
token: &[],
length: 0,
pn_offset: p,
payload_off: p,
});
}
if b0 & 0x40 == 0 {
return Err(Error::Decode);
}
let typ = LongType::from_bits((b0 >> 4) & 0x03).ok_or(Error::Decode)?;
let mut p = 5usize;
let dcid_len = *buf.get(p).ok_or(Error::Decode)? as usize;
p += 1;
if dcid_len > 20 {
return Err(Error::Decode);
}
if buf.len() < p + dcid_len + 1 {
return Err(Error::Decode);
}
let dcid = &buf[p..p + dcid_len];
p += dcid_len;
let scid_len = *buf.get(p).ok_or(Error::Decode)? as usize;
p += 1;
if scid_len > 20 {
return Err(Error::Decode);
}
if buf.len() < p + scid_len {
return Err(Error::Decode);
}
let scid = &buf[p..p + scid_len];
p += scid_len;
match typ {
LongType::Retry => {
if buf.len() < p + 16 {
return Err(Error::Decode);
}
Ok(LongHeader {
typ,
version,
dcid,
scid,
token: &buf[p..buf.len() - 16],
length: 0,
pn_offset: p,
payload_off: p,
})
}
LongType::Initial => {
let (tlen, n) = varint::decode(&buf[p..])?;
p += n;
let tlen = tlen as usize;
if buf.len() < p + tlen {
return Err(Error::Decode);
}
let token = &buf[p..p + tlen];
p += tlen;
let (length, n) = varint::decode(&buf[p..])?;
p += n;
Ok(LongHeader {
typ,
version,
dcid,
scid,
token,
length,
pn_offset: p,
payload_off: p,
})
}
LongType::ZeroRtt | LongType::Handshake => {
let (length, n) = varint::decode(&buf[p..])?;
p += n;
Ok(LongHeader {
typ,
version,
dcid,
scid,
token: &[],
length,
pn_offset: p,
payload_off: p,
})
}
}
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn build_long_header(
typ: LongType,
version: u32,
dcid: &[u8],
scid: &[u8],
token: &[u8],
pn_value: u64,
pn_len: u8,
payload_len_for_field: u64,
) -> (Vec<u8>, usize) {
debug_assert!((1..=4).contains(&pn_len));
debug_assert!(dcid.len() <= 20);
debug_assert!(scid.len() <= 20);
debug_assert!(token.is_empty() || matches!(typ, LongType::Initial));
debug_assert!(!matches!(typ, LongType::Retry), "use build_retry()");
let mut out = Vec::with_capacity(32 + token.len());
out.push(typ.first_byte_template() | (pn_len - 1));
out.extend_from_slice(&version.to_be_bytes());
out.push(dcid.len() as u8);
out.extend_from_slice(dcid);
out.push(scid.len() as u8);
out.extend_from_slice(scid);
if matches!(typ, LongType::Initial) {
varint::encode(token.len() as u64, &mut out);
out.extend_from_slice(token);
}
varint::encode(payload_len_for_field, &mut out);
let pn_offset = out.len();
let pn_bytes = pn_value.to_be_bytes();
out.extend_from_slice(&pn_bytes[8 - pn_len as usize..]);
(out, pn_offset)
}
#[derive(Debug)]
pub(crate) struct ShortHeader<'a> {
pub dcid: &'a [u8],
pub key_phase: bool,
pub spin: bool,
pub pn_offset: usize,
}
impl<'a> ShortHeader<'a> {
pub(crate) fn parse(buf: &'a [u8], dcid_len: usize) -> Result<Self, Error> {
if buf.is_empty() {
return Err(Error::Decode);
}
let b0 = buf[0];
if b0 & 0x80 != 0 || b0 & 0x40 == 0 {
return Err(Error::Decode);
}
if dcid_len > 20 {
return Err(Error::Decode);
}
if buf.len() < 1 + dcid_len {
return Err(Error::Decode);
}
let dcid = &buf[1..1 + dcid_len];
Ok(ShortHeader {
dcid,
spin: b0 & 0x20 != 0,
key_phase: b0 & 0x04 != 0,
pn_offset: 1 + dcid_len,
})
}
}
pub(crate) fn build_short_header(
dcid: &[u8],
spin: bool,
key_phase: bool,
pn_value: u64,
pn_len: u8,
) -> (Vec<u8>, usize) {
debug_assert!((1..=4).contains(&pn_len));
debug_assert!(dcid.len() <= 20);
let mut b0 = 0x40u8 | (pn_len - 1);
if spin {
b0 |= 0x20;
}
if key_phase {
b0 |= 0x04;
}
let mut out = Vec::with_capacity(1 + dcid.len() + 4);
out.push(b0);
out.extend_from_slice(dcid);
let pn_offset = out.len();
let pn_bytes = pn_value.to_be_bytes();
out.extend_from_slice(&pn_bytes[8 - pn_len as usize..]);
(out, pn_offset)
}
pub(crate) fn apply_header_protection(
packet: &mut [u8],
pn_offset: usize,
pn_len: u8,
mask: &[u8; 5],
long_header: bool,
) {
debug_assert!((1..=4).contains(&pn_len));
debug_assert!(packet.len() >= pn_offset + pn_len as usize);
let first_byte_mask = if long_header { 0x0f } else { 0x1f };
packet[0] ^= mask[0] & first_byte_mask;
for i in 0..pn_len as usize {
packet[pn_offset + i] ^= mask[1 + i];
}
}
pub(crate) fn remove_header_protection(
packet: &mut [u8],
pn_offset: usize,
mask: &[u8; 5],
long_header: bool,
) -> Result<u8, Error> {
let first_byte_mask = if long_header { 0x0f } else { 0x1f };
packet[0] ^= mask[0] & first_byte_mask;
let pn_len = (packet[0] & 0x03) + 1;
if (pn_len as usize) > 4 || packet.len() < pn_offset + pn_len as usize {
return Err(Error::Decode);
}
for i in 0..pn_len as usize {
packet[pn_offset + i] ^= mask[1 + i];
}
Ok(pn_len)
}
const RETRY_INTEGRITY_KEY_V1: [u8; 16] = [
0xbe, 0x0c, 0x69, 0x0b, 0x9f, 0x66, 0x57, 0x5a, 0x1d, 0x76, 0x6b, 0x54, 0xe3, 0x68, 0xc8, 0x4e,
];
const RETRY_INTEGRITY_NONCE_V1: [u8; 12] = [
0x46, 0x15, 0x99, 0xd3, 0x5d, 0x63, 0x2b, 0xf2, 0x23, 0x98, 0x25, 0xbb,
];
pub(crate) fn retry_integrity_tag(odcid: &[u8], retry_unauth: &[u8]) -> [u8; 16] {
let mut aad = Vec::with_capacity(1 + odcid.len() + retry_unauth.len());
aad.push(odcid.len() as u8);
aad.extend_from_slice(odcid);
aad.extend_from_slice(retry_unauth);
let aes = Aes128::new(&RETRY_INTEGRITY_KEY_V1);
let g: Gcm<Aes128> = Gcm::new(aes);
let mut empty: [u8; 0] = [];
g.encrypt(&RETRY_INTEGRITY_NONCE_V1, &aad, &mut empty)
}
pub(crate) fn build_retry_unauth(
version: u32,
dcid: &[u8],
scid: &[u8],
retry_token: &[u8],
) -> Vec<u8> {
debug_assert!(dcid.len() <= 20);
debug_assert!(scid.len() <= 20);
let mut out = Vec::with_capacity(7 + dcid.len() + scid.len() + retry_token.len());
out.push(0x80 | 0x40 | 0x30); out.extend_from_slice(&version.to_be_bytes());
out.push(dcid.len() as u8);
out.extend_from_slice(dcid);
out.push(scid.len() as u8);
out.extend_from_slice(scid);
out.extend_from_slice(retry_token);
out
}
pub(crate) fn build_retry(
version: u32,
dcid: &[u8],
scid: &[u8],
retry_token: &[u8],
odcid: &[u8],
) -> Vec<u8> {
let mut pkt = build_retry_unauth(version, dcid, scid, retry_token);
let tag = retry_integrity_tag(odcid, &pkt);
pkt.extend_from_slice(&tag);
pkt
}
pub(crate) fn build_version_negotiation(dcid: &[u8], scid: &[u8], versions: &[u32]) -> Vec<u8> {
debug_assert!(dcid.len() <= 20);
debug_assert!(scid.len() <= 20);
let mut out = Vec::with_capacity(7 + dcid.len() + scid.len() + 4 * versions.len());
out.push(0x80 | 0x40); out.extend_from_slice(&[0u8; 4]); out.push(dcid.len() as u8);
out.extend_from_slice(dcid);
out.push(scid.len() as u8);
out.extend_from_slice(scid);
for &v in versions {
out.extend_from_slice(&v.to_be_bytes());
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::quic::crypto::{AeadAlg, derive_dir_keys, derive_initial_secrets};
fn hex(s: &str) -> Vec<u8> {
(0..s.len())
.step_by(2)
.map(|i| u8::from_str_radix(&s[i..i + 2], 16).expect("hex"))
.collect()
}
#[test]
fn long_header_initial_roundtrip() {
let dcid = [0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08];
let scid: [u8; 0] = [];
let token: [u8; 0] = [];
let (hdr, pn_off) = build_long_header(
LongType::Initial,
QUIC_V1,
&dcid,
&scid,
&token,
2, 4, 1182, );
assert_eq!(
hdr.as_slice(),
hex("c300000001088394c8f03e5157080000449e00000002").as_slice(),
);
assert_eq!(pn_off, 18);
let parsed = LongHeader::parse(&hdr).expect("parse");
assert_eq!(parsed.typ, LongType::Initial);
assert_eq!(parsed.version, QUIC_V1);
assert_eq!(parsed.dcid, &dcid);
assert_eq!(parsed.scid, &scid);
assert_eq!(parsed.token, &token);
assert_eq!(parsed.length, 1182);
assert_eq!(parsed.pn_offset, 18);
}
#[test]
fn long_header_server_initial_layout() {
let dcid: [u8; 0] = [];
let scid = hex("f067a5502a4262b5");
let (hdr, pn_off) =
build_long_header(LongType::Initial, QUIC_V1, &dcid, &scid, &[], 1, 2, 117);
assert_eq!(
hdr.as_slice(),
hex("c1000000010008f067a5502a4262b50040750001").as_slice(),
);
assert_eq!(pn_off, 18);
let parsed = LongHeader::parse(&hdr).expect("parse");
assert_eq!(parsed.scid, scid.as_slice());
assert_eq!(parsed.length, 117);
}
#[test]
fn varint_pn_roundtrip_at_boundaries() {
for pn_len in 1u8..=4 {
let (hdr, pn_off) = build_long_header(
LongType::Handshake,
QUIC_V1,
&[1, 2, 3, 4],
&[5, 6, 7, 8],
&[],
0x0a,
pn_len,
50,
);
let parsed = LongHeader::parse(&hdr).expect("parse");
assert_eq!(hdr[0] & 0x03, pn_len - 1);
assert_eq!(parsed.pn_offset, pn_off);
assert_eq!(parsed.dcid, &[1, 2, 3, 4][..]);
assert_eq!(parsed.scid, &[5, 6, 7, 8][..]);
}
}
#[test]
fn short_header_roundtrip() {
let dcid = hex("f067a5502a4262b5");
let (hdr, pn_off) = build_short_header(&dcid, false, false, 7, 2);
assert_eq!(hdr[0], 0x41);
assert_eq!(pn_off, 9);
let parsed = ShortHeader::parse(&hdr, dcid.len()).expect("parse");
assert_eq!(parsed.dcid, dcid.as_slice());
assert!(!parsed.key_phase);
assert!(!parsed.spin);
assert_eq!(parsed.pn_offset, 9);
let (hdr2, _) = build_short_header(&dcid, true, true, 0x0102_0304, 4);
assert_eq!(hdr2[0], 0x67);
let p2 = ShortHeader::parse(&hdr2, dcid.len()).expect("parse");
assert!(p2.key_phase);
assert!(p2.spin);
}
#[test]
fn header_protection_long_apply_remove_roundtrip() {
let (mut hdr, pn_off) = build_long_header(
LongType::Initial,
QUIC_V1,
&hex("8394c8f03e515708"),
&[],
&[],
2,
4,
1182,
);
let orig = hdr.clone();
hdr.resize(orig.len() + 16, 0);
let mask = [0x43, 0x7b, 0x9a, 0xec, 0x36];
apply_header_protection(&mut hdr, pn_off, 4, &mask, true);
let pn_len = remove_header_protection(&mut hdr, pn_off, &mask, true).expect("ok");
assert_eq!(pn_len, 4);
assert_eq!(&hdr[..orig.len()], &orig[..]);
}
#[test]
fn header_protection_short_apply_remove_roundtrip() {
let dcid = hex("f067a5502a4262b5");
let (mut hdr, pn_off) = build_short_header(&dcid, false, false, 0x010203, 3);
let orig = hdr.clone();
hdr.resize(orig.len() + 16, 0);
let mask = [0x5a, 0x11, 0x22, 0x33, 0x44];
apply_header_protection(&mut hdr, pn_off, 3, &mask, false);
let pn_len = remove_header_protection(&mut hdr, pn_off, &mask, false).expect("ok");
assert_eq!(pn_len, 3);
assert_eq!(&hdr[..orig.len()], &orig[..]);
}
#[test]
fn rfc9001_a2_apply_header_protection() {
let mut wire = hex("c300000001088394c8f03e5157080000449e00000002");
let mask = [0x43, 0x7b, 0x9a, 0xec, 0x36];
apply_header_protection(&mut wire, 18, 4, &mask, true);
assert_eq!(
wire.as_slice(),
hex("c000000001088394c8f03e5157080000449e7b9aec34").as_slice(),
);
}
#[test]
fn rfc9001_a3_apply_header_protection() {
let mut wire = hex("c1000000010008f067a5502a4262b50040750001");
let mask = [0x2e, 0xc0, 0xd8, 0x35, 0x6a];
apply_header_protection(&mut wire, 18, 2, &mask, true);
assert_eq!(
wire.as_slice(),
hex("cf000000010008f067a5502a4262b5004075c0d9").as_slice(),
);
}
#[test]
fn rfc9001_a5_apply_header_protection() {
let mut wire = hex("4200bff4");
let mask = [0xae, 0xfe, 0xfe, 0x7d, 0x03];
apply_header_protection(&mut wire, 1, 3, &mask, false);
assert_eq!(wire.as_slice(), hex("4cfe4189").as_slice());
}
#[test]
fn rfc9001_a4_retry_integrity_tag() {
let odcid = hex("8394c8f03e515708");
let mut unauth = Vec::new();
unauth.push(0xff); unauth.extend_from_slice(&QUIC_V1.to_be_bytes());
unauth.push(0x00); unauth.push(0x08); unauth.extend_from_slice(&hex("f067a5502a4262b5"));
unauth.extend_from_slice(b"token");
let tag = retry_integrity_tag(&odcid, &unauth);
assert_eq!(
tag.as_slice(),
hex("04a265ba2eff4d829058fb3f0f2496ba").as_slice()
);
let mut full = unauth.clone();
full.extend_from_slice(&tag);
assert_eq!(
full.as_slice(),
hex("ff000000010008f067a5502a4262b5746f6b656e04a265ba2eff4d829058fb3f0f2496ba")
.as_slice(),
);
}
#[test]
fn build_retry_tag_matches_helper() {
let odcid = hex("8394c8f03e515708");
let scid = hex("f067a5502a4262b5");
let pkt = build_retry(QUIC_V1, &[], &scid, b"token", &odcid);
let unauth_len = pkt.len() - 16;
let tag_field: [u8; 16] = pkt[unauth_len..].try_into().expect("16");
let computed = retry_integrity_tag(&odcid, &pkt[..unauth_len]);
assert_eq!(tag_field, computed);
}
#[test]
fn build_vn_layout() {
let dcid = hex("0102030405");
let scid = hex("aabbcc");
let vns = [QUIC_V1, 0x0a0a_0a0a]; let pkt = build_version_negotiation(&dcid, &scid, &vns);
assert_eq!(pkt[0], 0xc0);
assert_eq!(&pkt[1..5], &[0, 0, 0, 0]);
assert_eq!(pkt[5], 5);
assert_eq!(&pkt[6..11], dcid.as_slice());
assert_eq!(pkt[11], 3);
assert_eq!(&pkt[12..15], scid.as_slice());
assert_eq!(&pkt[15..19], &QUIC_V1.to_be_bytes());
assert_eq!(&pkt[19..23], &0x0a0a_0a0au32.to_be_bytes());
assert_eq!(pkt.len(), 23);
let parsed = LongHeader::parse(&pkt).expect("parse");
assert_eq!(parsed.version, 0);
assert_eq!(parsed.dcid, dcid.as_slice());
assert_eq!(parsed.scid, scid.as_slice());
}
#[test]
fn long_header_rejects_oversized_cid() {
let mut buf = Vec::new();
buf.push(0xc0); buf.extend_from_slice(&QUIC_V1.to_be_bytes());
buf.push(21);
buf.resize(buf.len() + 21, 0);
buf.push(0);
assert!(LongHeader::parse(&buf).is_err());
}
#[test]
fn long_header_rejects_missing_fixed_bit() {
let mut buf = Vec::new();
buf.push(0x80);
buf.extend_from_slice(&QUIC_V1.to_be_bytes());
buf.push(0); buf.push(0); assert!(LongHeader::parse(&buf).is_err());
}
#[test]
fn rfc9001_a2_end_to_end_protected_header_prefix() {
let dcid = hex("8394c8f03e515708");
let (hdr, pn_off) =
build_long_header(LongType::Initial, QUIC_V1, &dcid, &[], &[], 2, 4, 1182);
let crypto_frame = hex(
"060040f1010000ed0303ebf8fa56f12939b9584a3896472ec40bb863cfd3e868\
04fe3a47f06a2b69484c000004130113\
02010000c000000010000e00000b6578\
616d706c652e636f6dff01000100000a\
00080006001d00170018001000070005\
04616c706e000500050100000000\
003300260024001d00209370b2c9caa47fba\
baf4559fedba753de171fa71f50f1ce1\
5d43e994ec74d748002b00030203040\
00d0010000e040305030603020308040\
8050806002d00020101001c00024001\
003900320408ffffffffffffffff050480\
00ffff07048000ffff080110010480\
0075300901100f088394c8f03e515708\
06048000ffff",
);
let mut payload = crypto_frame;
payload.resize(1162, 0);
let (cs, _) = derive_initial_secrets(&dcid);
let dk = derive_dir_keys(AeadAlg::Aes128Gcm, &cs);
let tag = crate::quic::crypto::aead_seal(&dk, 2, &hdr, &mut payload);
let mut wire = hdr.clone();
wire.extend_from_slice(&payload);
wire.extend_from_slice(&tag);
let sample: [u8; 16] = wire[pn_off + 4..pn_off + 4 + 16].try_into().expect("16");
let mask = dk.hp.mask(&sample).expect("mask");
apply_header_protection(&mut wire, pn_off, 4, &mask, true);
let expected_header = hex("c000000001088394c8f03e5157080000449e7b9aec34");
assert_eq!(&wire[..expected_header.len()], expected_header.as_slice());
}
}