use anyhow::{Result, anyhow};
use std::collections::BTreeMap;
#[derive(Debug, Default, Clone)]
pub struct QuicInitialData {
pub version: String,
pub dcid: String,
pub scid: String,
pub token: Option<String>,
pub sni_hint: Option<String>,
pub crypto_frames: BTreeMap<usize, Vec<u8>>,
}
pub fn parse_initial_packet(payload: &[u8]) -> Result<QuicInitialData> {
if payload.len() < 10 {
return Err(anyhow!("Buffer too short for QUIC Long Header"));
}
let first_byte = payload[0];
if (first_byte & 0x80) == 0 {
return Err(anyhow!("Not a Long Header packet"));
}
if (first_byte & 0x40) == 0 {
return Err(anyhow!("Fixed bit not set"));
}
let packet_type = (first_byte & 0x30) >> 4;
if packet_type != 0 {
return Err(anyhow!("Not an Initial Packet (Type: {packet_type})"));
}
let mut cursor = 1;
if cursor + 4 > payload.len() {
return Err(anyhow!("Truncated Version"));
}
let version_bytes = &payload[cursor..cursor + 4];
let version_val = u32::from_be_bytes([
version_bytes[0],
version_bytes[1],
version_bytes[2],
version_bytes[3],
]);
let version = format!("0x{version_val:08x}");
cursor += 4;
if cursor + 1 > payload.len() {
return Err(anyhow!("Truncated DCID Length"));
}
let dcid_len = payload[cursor] as usize;
cursor += 1;
if dcid_len > 20 {
return Err(anyhow!("DCID length {dcid_len} exceeds 20"));
}
if cursor + dcid_len > payload.len() {
return Err(anyhow!("Truncated DCID"));
}
let dcid_bytes = &payload[cursor..cursor + dcid_len];
let dcid = hex::encode(dcid_bytes);
cursor += dcid_len;
if cursor + 1 > payload.len() {
return Err(anyhow!("Truncated SCID Length"));
}
let scid_len = payload[cursor] as usize;
cursor += 1;
if scid_len > 20 {
return Err(anyhow!("SCID length {scid_len} exceeds 20"));
}
if cursor + scid_len > payload.len() {
return Err(anyhow!("Truncated SCID"));
}
let scid = hex::encode(&payload[cursor..cursor + scid_len]);
cursor += scid_len;
let (token_len, varint_len) = read_varint(&payload[cursor..])?;
cursor += varint_len;
let mut token = None;
if token_len > 0 {
if cursor + token_len > payload.len() {
return Err(anyhow!("Truncated Token"));
}
token = Some(hex::encode(&payload[cursor..cursor + token_len]));
cursor += token_len;
}
let (remaining_len, varint_len) = read_varint(&payload[cursor..])?;
cursor += varint_len;
if cursor + remaining_len > payload.len() {
return Err(anyhow!("Truncated packet payload"));
}
let header_start = 0;
let protected_payload_start = cursor;
let (sni_hint, crypto_frames) = super::crypto::extract_decrypted_content(
payload,
header_start,
protected_payload_start,
remaining_len,
dcid_bytes,
version_val,
)
.unwrap_or((None, BTreeMap::new()));
Ok(QuicInitialData {
version,
dcid,
scid,
token,
sni_hint,
crypto_frames,
})
}
#[must_use]
pub fn peek_long_header_dcid(packet: &[u8]) -> Option<Vec<u8>> {
if packet.len() < 6 {
return None;
}
let dcid_len = packet[5] as usize;
if dcid_len == 0 || dcid_len > 20 {
return None;
}
if packet.len() < 6 + dcid_len {
return None;
}
Some(packet[6..6 + dcid_len].to_vec())
}
#[must_use]
pub fn peek_short_header_dcid(packet: &[u8], len: usize) -> Option<Vec<u8>> {
if packet.len() < 1 + len {
return None;
}
Some(packet[1..1 + len].to_vec())
}
pub fn read_varint(buf: &[u8]) -> Result<(usize, usize)> {
if buf.is_empty() {
return Err(anyhow!("Buffer empty"));
}
let first = buf[0];
let prefix = first >> 6;
let len = 1 << prefix;
if buf.len() < len {
return Err(anyhow!("Buffer too short for VarInt"));
}
let mut val = (first & 0x3f) as u64;
for b in buf.iter().take(len).skip(1) {
val = (val << 8) | (*b as u64);
}
Ok((val as usize, len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_varint() {
assert_eq!(read_varint(&[0x25]).unwrap(), (37, 1));
assert_eq!(read_varint(&[0x40, 0x40]).unwrap(), (64, 2));
assert_eq!(read_varint(&[0x7b, 0xbd]).unwrap(), (15293, 2));
assert_eq!(
read_varint(&[0x9d, 0x7f, 0x3e, 0x7d]).unwrap(),
(494878333, 4)
);
}
#[test]
fn test_peek_long_header_dcid() {
let packet = vec![
0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
];
let dcid = peek_long_header_dcid(&packet).unwrap();
assert_eq!(dcid, vec![1, 2, 3, 4, 5, 6, 7, 8]);
assert!(peek_long_header_dcid(&[0xc0, 0x00]).is_none());
assert!(peek_long_header_dcid(&[0xc0, 0x00, 0x00, 0x00, 0x01, 0x00]).is_none());
}
#[test]
fn test_peek_short_header_dcid() {
let packet = vec![0x40, 0xaa, 0xbb, 0xcc, 0xdd];
let dcid = peek_short_header_dcid(&packet, 4).unwrap();
assert_eq!(dcid, vec![0xaa, 0xbb, 0xcc, 0xdd]);
assert!(peek_short_header_dcid(&packet, 10).is_none());
}
#[test]
fn test_parse_initial_packet_basic_header() {
let packet = vec![
0xc0, 0x00, 0x00, 0x00, 0x01, 0x04, 0x11, 0x22, 0x33, 0x44, 0x00, 0x00, 0x00, ];
let res = parse_initial_packet(&packet).unwrap();
assert_eq!(res.version, "0x00000001");
assert_eq!(res.dcid, "11223344");
assert_eq!(res.scid, "");
}
}