1use crate::error::Error;
4use crate::varint::read_varint;
5
6#[derive(Debug, Clone, PartialEq)]
9pub struct InitialHeader<'a> {
10 pub version: u32,
12 pub dcid: &'a [u8],
14 pub scid: &'a [u8],
16 pub token: &'a [u8],
18 pub payload: &'a [u8],
20 pub header_bytes: &'a [u8],
24 pub first_byte: u8,
26}
27
28pub fn parse_initial(packet: &[u8]) -> Result<InitialHeader<'_>, Error> {
38 if packet.len() < 7 {
39 return Err(Error::BufferTooShort {
40 need: 7,
41 have: packet.len(),
42 });
43 }
44
45 let first_byte = packet[0];
46
47 if (first_byte & 0x80) == 0 {
48 return Err(Error::NotLongHeader);
49 }
50
51 let packet_type = (first_byte & 0x30) >> 4;
52 if packet_type != 0 {
53 return Err(Error::NotInitialPacket(packet_type));
54 }
55
56 let mut cursor = 1;
57
58 let version = u32::from_be_bytes([
59 packet[cursor],
60 packet[cursor + 1],
61 packet[cursor + 2],
62 packet[cursor + 3],
63 ]);
64 cursor += 4;
65
66 let (dcid, cursor) = read_cid(packet, cursor)?;
67 let (scid, cursor) = read_cid(packet, cursor)?;
68
69 let (token_len, varint_len) =
70 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
71 need: cursor + 1,
72 have: packet.len(),
73 })?)?;
74 let cursor = cursor + varint_len;
75 let token_len = usize::try_from(token_len).map_err(|_| Error::BufferTooShort {
76 need: usize::MAX,
77 have: packet.len(),
78 })?;
79
80 if cursor + token_len > packet.len() {
81 return Err(Error::BufferTooShort {
82 need: cursor + token_len,
83 have: packet.len(),
84 });
85 }
86 let token = &packet[cursor..cursor + token_len];
87 let cursor = cursor + token_len;
88
89 let (remaining_len, varint_len) =
90 read_varint(packet.get(cursor..).ok_or(Error::BufferTooShort {
91 need: cursor + 1,
92 have: packet.len(),
93 })?)?;
94 let cursor = cursor + varint_len;
95 let remaining_len = usize::try_from(remaining_len).map_err(|_| Error::BufferTooShort {
96 need: usize::MAX,
97 have: packet.len(),
98 })?;
99
100 if cursor + remaining_len > packet.len() {
101 return Err(Error::BufferTooShort {
102 need: cursor + remaining_len,
103 have: packet.len(),
104 });
105 }
106
107 let header_bytes = &packet[..cursor];
108 let payload = &packet[cursor..cursor + remaining_len];
109
110 Ok(InitialHeader {
111 version,
112 dcid,
113 scid,
114 token,
115 payload,
116 header_bytes,
117 first_byte,
118 })
119}
120
121#[must_use]
127pub fn peek_long_header_dcid(packet: &[u8]) -> Option<&[u8]> {
128 if packet.len() < 6 {
129 return None;
130 }
131 let dcid_len = packet[5] as usize;
132 if dcid_len == 0 || dcid_len > 20 {
133 return None;
134 }
135 packet.get(6..6 + dcid_len)
136}
137
138#[must_use]
144pub fn peek_short_header_dcid(packet: &[u8], cid_len: usize) -> Option<&[u8]> {
145 packet.get(1..1 + cid_len)
146}
147
148fn read_cid(packet: &[u8], offset: usize) -> Result<(&[u8], usize), Error> {
149 let &cid_len_byte = packet.get(offset).ok_or(Error::BufferTooShort {
150 need: offset + 1,
151 have: packet.len(),
152 })?;
153 if cid_len_byte > 20 {
154 return Err(Error::InvalidCidLength(cid_len_byte));
155 }
156 let cid_len = cid_len_byte as usize;
157 let start = offset + 1;
158 let end = start + cid_len;
159 if end > packet.len() {
160 return Err(Error::BufferTooShort {
161 need: end,
162 have: packet.len(),
163 });
164 }
165 Ok((&packet[start..end], end))
166}