1use crate::error::Error;
4use crate::varint::read_varint;
5
6#[derive(Debug, Clone, PartialEq)]
9pub struct InitialHeader<'a> {
10 pub version: u32,
12 pub dcid: &'a [u8],
14 pub scid: &'a [u8],
16 pub token: &'a [u8],
18 pub payload: &'a [u8],
20 pub header_bytes: &'a [u8],
24 pub first_byte: u8,
26}
27
28pub fn parse_initial(packet: &[u8]) -> Result<InitialHeader<'_>, Error> {
38 if packet.len() < 7 {
39 return Err(Error::BufferTooShort {
40 need: 7,
41 have: packet.len(),
42 });
43 }
44
45 let first_byte = packet[0];
46
47 if (first_byte & 0x80) == 0 {
48 return Err(Error::NotLongHeader);
49 }
50
51 if (first_byte & 0x40) == 0 {
52 return Err(Error::InvalidFixedBit);
53 }
54
55 let packet_type = (first_byte & 0x30) >> 4;
56 if packet_type != 0 {
57 return Err(Error::NotInitialPacket(packet_type));
58 }
59
60 let mut cursor = 1;
61
62 let version = u32::from_be_bytes([
63 packet[cursor],
64 packet[cursor + 1],
65 packet[cursor + 2],
66 packet[cursor + 3],
67 ]);
68 cursor += 4;
69
70 let (dcid, cursor) = read_cid(packet, cursor)?;
71 let (scid, cursor) = read_cid(packet, cursor)?;
72
73 let (token_len, varint_len) =
74 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
75 need: cursor + 1,
76 have: packet.len(),
77 })?)?;
78 let cursor = cursor + varint_len;
79 let token_len = usize::try_from(token_len).map_err(|_| Error::BufferTooShort {
80 need: usize::MAX,
81 have: packet.len(),
82 })?;
83
84 if cursor + token_len > packet.len() {
85 return Err(Error::BufferTooShort {
86 need: cursor + token_len,
87 have: packet.len(),
88 });
89 }
90 let token = &packet[cursor..cursor + token_len];
91 let cursor = cursor + token_len;
92
93 let (remaining_len, varint_len) =
94 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
95 need: cursor + 1,
96 have: packet.len(),
97 })?)?;
98 let cursor = cursor + varint_len;
99 let remaining_len = usize::try_from(remaining_len).map_err(|_| Error::BufferTooShort {
100 need: usize::MAX,
101 have: packet.len(),
102 })?;
103
104 if cursor + remaining_len > packet.len() {
105 return Err(Error::BufferTooShort {
106 need: cursor + remaining_len,
107 have: packet.len(),
108 });
109 }
110
111 let header_bytes = &packet[..cursor];
112 let payload = &packet[cursor..cursor + remaining_len];
113
114 Ok(InitialHeader {
115 version,
116 dcid,
117 scid,
118 token,
119 payload,
120 header_bytes,
121 first_byte,
122 })
123}
124
125#[must_use]
131pub fn peek_long_header_dcid(packet: &[u8]) -> Option<&[u8]> {
132 if packet.len() < 6 {
133 return None;
134 }
135 let dcid_len = packet[5] as usize;
136 if dcid_len == 0 || dcid_len > 20 {
137 return None;
138 }
139 packet.get(6..6 + dcid_len)
140}
141
142#[must_use]
148pub fn peek_short_header_dcid(packet: &[u8], cid_len: usize) -> Option<&[u8]> {
149 packet.get(1..1 + cid_len)
150}
151
152fn read_cid(packet: &[u8], offset: usize) -> Result<(&[u8], usize), Error> {
153 let &cid_len_byte = packet.get(offset).ok_or(Error::BufferTooShort {
154 need: offset + 1,
155 have: packet.len(),
156 })?;
157 if cid_len_byte > 20 {
158 return Err(Error::InvalidCidLength(cid_len_byte));
159 }
160 let cid_len = cid_len_byte as usize;
161 let start = offset + 1;
162 let end = start + cid_len;
163 if end > packet.len() {
164 return Err(Error::BufferTooShort {
165 need: end,
166 have: packet.len(),
167 });
168 }
169 Ok((&packet[start..end], end))
170}