1use crate::error::Error;
4use crate::varint::read_varint;
5
6#[derive(Debug, Clone, PartialEq)]
9pub struct InitialHeader<'a> {
10 pub version: u32,
12 pub dcid: &'a [u8],
14 pub scid: &'a [u8],
16 pub token: &'a [u8],
18 pub payload: &'a [u8],
20 pub header_bytes: &'a [u8],
24 pub first_byte: u8,
26}
27
28pub fn parse_initial(packet: &[u8]) -> Result<InitialHeader<'_>, Error> {
38 if packet.len() < 7 {
39 return Err(Error::BufferTooShort {
40 need: 7,
41 have: packet.len(),
42 });
43 }
44
45 let first_byte = packet[0];
46
47 if (first_byte & 0x80) == 0 {
48 return Err(Error::NotLongHeader);
49 }
50
51 let packet_type = (first_byte & 0x30) >> 4;
52 if packet_type != 0 {
53 return Err(Error::NotInitialPacket(packet_type));
54 }
55
56 let mut cursor = 1;
57
58 let version = u32::from_be_bytes([
59 packet[cursor],
60 packet[cursor + 1],
61 packet[cursor + 2],
62 packet[cursor + 3],
63 ]);
64 cursor += 4;
65
66 let (dcid, cursor) = read_cid(packet, cursor)?;
67 let (scid, cursor) = read_cid(packet, cursor)?;
68
69 let (token_len, varint_len) =
70 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
71 need: cursor + 1,
72 have: packet.len(),
73 })?)?;
74 let cursor = cursor + varint_len;
75 let token_len = token_len as usize;
76
77 if cursor + token_len > packet.len() {
78 return Err(Error::BufferTooShort {
79 need: cursor + token_len,
80 have: packet.len(),
81 });
82 }
83 let token = &packet[cursor..cursor + token_len];
84 let cursor = cursor + token_len;
85
86 let (remaining_len, varint_len) =
87 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
88 need: cursor + 1,
89 have: packet.len(),
90 })?)?;
91 let cursor = cursor + varint_len;
92 let remaining_len = remaining_len as usize;
93
94 if cursor + remaining_len > packet.len() {
95 return Err(Error::BufferTooShort {
96 need: cursor + remaining_len,
97 have: packet.len(),
98 });
99 }
100
101 let header_bytes = &packet[..cursor];
102 let payload = &packet[cursor..cursor + remaining_len];
103
104 Ok(InitialHeader {
105 version,
106 dcid,
107 scid,
108 token,
109 payload,
110 header_bytes,
111 first_byte,
112 })
113}
114
115#[must_use]
121pub fn peek_long_header_dcid(packet: &[u8]) -> Option<&[u8]> {
122 if packet.len() < 6 {
123 return None;
124 }
125 let dcid_len = packet[5] as usize;
126 if dcid_len == 0 || dcid_len > 20 {
127 return None;
128 }
129 packet.get(6..6 + dcid_len)
130}
131
132#[must_use]
138pub fn peek_short_header_dcid(packet: &[u8], cid_len: usize) -> Option<&[u8]> {
139 packet.get(1..1 + cid_len)
140}
141
142fn read_cid(packet: &[u8], offset: usize) -> Result<(&[u8], usize), Error> {
143 let &cid_len_byte = packet.get(offset).ok_or(Error::BufferTooShort {
144 need: offset + 1,
145 have: packet.len(),
146 })?;
147 if cid_len_byte > 20 {
148 return Err(Error::InvalidCidLength(cid_len_byte));
149 }
150 let cid_len = cid_len_byte as usize;
151 let start = offset + 1;
152 let end = start + cid_len;
153 if end > packet.len() {
154 return Err(Error::BufferTooShort {
155 need: end,
156 have: packet.len(),
157 });
158 }
159 Ok((&packet[start..end], end))
160}