1use crate::error::Error;
4use crate::varint::read_varint;
5
6#[derive(Debug, Clone, PartialEq)]
9pub struct InitialHeader<'a> {
10 pub version: u32,
12 pub dcid: &'a [u8],
14 pub scid: &'a [u8],
16 pub token: &'a [u8],
18 pub payload: &'a [u8],
20 pub header_bytes: &'a [u8],
24 pub first_byte: u8,
26}
27
28pub fn parse_initial(packet: &[u8]) -> Result<InitialHeader<'_>, Error> {
38 if packet.len() < 7 {
39 return Err(Error::BufferTooShort {
40 need: 7,
41 have: packet.len(),
42 });
43 }
44
45 let first_byte = packet[0];
46
47 if (first_byte & 0x80) == 0 {
48 return Err(Error::NotLongHeader);
49 }
50
51 if (first_byte & 0x40) == 0 {
52 return Err(Error::InvalidFixedBit);
53 }
54
55 let packet_type = (first_byte & 0x30) >> 4;
56
57 let mut cursor = 1;
58
59 let version = u32::from_be_bytes([
60 packet[cursor],
61 packet[cursor + 1],
62 packet[cursor + 2],
63 packet[cursor + 3],
64 ]);
65 cursor += 4;
66
67 let expected_initial_type = match version {
69 0x6b33_43cf => 1,
70 _ => 0,
71 };
72 if packet_type != expected_initial_type {
73 return Err(Error::NotInitialPacket(packet_type));
74 }
75
76 let (dcid, cursor) = read_cid(packet, cursor)?;
77 let (scid, cursor) = read_cid(packet, cursor)?;
78
79 let (token_len, varint_len) =
80 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
81 need: cursor + 1,
82 have: packet.len(),
83 })?)?;
84 let cursor = cursor + varint_len;
85 let token_len = usize::try_from(token_len).map_err(|_| Error::BufferTooShort {
86 need: usize::MAX,
87 have: packet.len(),
88 })?;
89
90 let token_end = cursor.checked_add(token_len).ok_or(Error::BufferTooShort {
91 need: usize::MAX,
92 have: packet.len(),
93 })?;
94 if token_end > packet.len() {
95 return Err(Error::BufferTooShort {
96 need: token_end,
97 have: packet.len(),
98 });
99 }
100 let token = &packet[cursor..token_end];
101 let cursor = token_end;
102
103 let (remaining_len, varint_len) =
104 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
105 need: cursor + 1,
106 have: packet.len(),
107 })?)?;
108 let cursor = cursor + varint_len;
109 let remaining_len = usize::try_from(remaining_len).map_err(|_| Error::BufferTooShort {
110 need: usize::MAX,
111 have: packet.len(),
112 })?;
113
114 let payload_end = cursor
115 .checked_add(remaining_len)
116 .ok_or(Error::BufferTooShort {
117 need: usize::MAX,
118 have: packet.len(),
119 })?;
120 if payload_end > packet.len() {
121 return Err(Error::BufferTooShort {
122 need: payload_end,
123 have: packet.len(),
124 });
125 }
126
127 let header_bytes = &packet[..cursor];
128 let payload = &packet[cursor..payload_end];
129
130 Ok(InitialHeader {
131 version,
132 dcid,
133 scid,
134 token,
135 payload,
136 header_bytes,
137 first_byte,
138 })
139}
140
141#[must_use]
147pub fn peek_long_header_dcid(packet: &[u8]) -> Option<&[u8]> {
148 if packet.len() < 6 {
149 return None;
150 }
151 if (packet[0] & 0x80) == 0 {
152 return None;
153 }
154 let dcid_len = packet[5] as usize;
155 if dcid_len == 0 || dcid_len > 20 {
156 return None;
157 }
158 packet.get(6..6 + dcid_len)
159}
160
161#[must_use]
167pub fn peek_short_header_dcid(packet: &[u8], cid_len: usize) -> Option<&[u8]> {
168 packet.get(1..1 + cid_len)
169}
170
171fn read_cid(packet: &[u8], offset: usize) -> Result<(&[u8], usize), Error> {
172 let &cid_len_byte = packet.get(offset).ok_or(Error::BufferTooShort {
173 need: offset + 1,
174 have: packet.len(),
175 })?;
176 if cid_len_byte > 20 {
177 return Err(Error::InvalidCidLength(cid_len_byte));
178 }
179 let cid_len = cid_len_byte as usize;
180 let start = offset + 1;
181 let end = start + cid_len;
182 if end > packet.len() {
183 return Err(Error::BufferTooShort {
184 need: end,
185 have: packet.len(),
186 });
187 }
188 Ok((&packet[start..end], end))
189}