1use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::{OpCode, KeyId, ProtocolError, Result};
11
12pub type SessionId = [u8; 8];
14
15pub type PacketId = u32;
17
18#[derive(Debug, Clone)]
20pub struct PacketHeader {
21 pub opcode: OpCode,
23 pub key_id: KeyId,
25 pub session_id: Option<SessionId>,
27 pub hmac: Option<[u8; 32]>,
29 pub packet_id: Option<PacketId>,
31 pub timestamp: Option<u32>,
33}
34
35impl PacketHeader {
36 pub const MIN_SIZE: usize = 1;
38
39 pub const CONTROL_HEADER_SIZE: usize = 1 + 8; #[inline]
44 pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<(Self, usize)> {
45 if data.is_empty() {
46 return Err(ProtocolError::PacketTooShort {
47 expected: 1,
48 got: 0,
49 });
50 }
51
52 let opcode = OpCode::from_byte(data[0])?;
53 let key_id = KeyId::from_byte(data[0]);
54
55 if opcode.is_data() {
56 return Ok((
58 Self {
59 opcode,
60 key_id,
61 session_id: None,
62 hmac: None,
63 packet_id: None,
64 timestamp: None,
65 },
66 1,
67 ));
68 }
69
70 let mut offset = 1;
72 let mut hmac = None;
73 let mut packet_id = None;
74 let mut timestamp = None;
75
76 if has_tls_auth {
78 if data.len() < offset + 32 {
79 return Err(ProtocolError::PacketTooShort {
80 expected: offset + 32,
81 got: data.len(),
82 });
83 }
84 let mut h = [0u8; 32];
85 h.copy_from_slice(&data[offset..offset + 32]);
86 hmac = Some(h);
87 offset += 32;
88
89 if data.len() < offset + 4 {
91 return Err(ProtocolError::PacketTooShort {
92 expected: offset + 4,
93 got: data.len(),
94 });
95 }
96 packet_id = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
97 offset += 4;
98
99 if data.len() < offset + 4 {
101 return Err(ProtocolError::PacketTooShort {
102 expected: offset + 4,
103 got: data.len(),
104 });
105 }
106 timestamp = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
107 offset += 4;
108 }
109
110 if data.len() < offset + 8 {
112 return Err(ProtocolError::PacketTooShort {
113 expected: offset + 8,
114 got: data.len(),
115 });
116 }
117 let mut session_id = [0u8; 8];
118 session_id.copy_from_slice(&data[offset..offset + 8]);
119 offset += 8;
120
121 Ok((
122 Self {
123 opcode,
124 key_id,
125 session_id: Some(session_id),
126 hmac,
127 packet_id,
128 timestamp,
129 },
130 offset,
131 ))
132 }
133
134 #[inline]
136 pub fn serialize(&self, buf: &mut BytesMut) {
137 buf.put_u8(self.opcode.to_byte(self.key_id));
138
139 if let Some(hmac) = &self.hmac {
140 buf.put_slice(hmac);
141 }
142
143 if let Some(packet_id) = self.packet_id {
144 buf.put_u32(packet_id);
145 }
146
147 if let Some(timestamp) = self.timestamp {
148 buf.put_u32(timestamp);
149 }
150
151 if let Some(session_id) = &self.session_id {
152 buf.put_slice(session_id);
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
159pub enum Packet {
160 Control(ControlPacketData),
162 Data(DataPacketData),
164}
165
166#[derive(Debug, Clone)]
168pub struct ControlPacketData {
169 pub header: PacketHeader,
171 pub remote_session_id: Option<SessionId>,
173 pub acks: Vec<PacketId>,
175 pub message_packet_id: Option<PacketId>,
177 pub payload: Bytes,
179}
180
181#[derive(Debug, Clone)]
183pub struct DataPacketData {
184 pub header: PacketHeader,
186 pub peer_id: Option<u32>,
188 pub payload: Bytes,
190}
191
192impl Packet {
193 #[inline]
195 pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<Self> {
196 let (header, mut offset) = PacketHeader::parse(data, has_tls_auth)?;
197
198 if header.opcode.is_data() {
199 let peer_id = if header.opcode == OpCode::DataV2 {
201 if data.len() < offset + 3 {
202 return Err(ProtocolError::PacketTooShort {
203 expected: offset + 3,
204 got: data.len(),
205 });
206 }
207 let pid = ((data[offset] as u32) << 16)
209 | ((data[offset + 1] as u32) << 8)
210 | (data[offset + 2] as u32);
211 offset += 3;
212 Some(pid)
213 } else {
214 None
215 };
216
217 return Ok(Packet::Data(DataPacketData {
218 header,
219 peer_id,
220 payload: Bytes::copy_from_slice(&data[offset..]),
221 }));
222 }
223
224 let mut remote_session_id = None;
226 let mut acks = Vec::new();
227
228 if data.len() < offset + 1 {
230 return Err(ProtocolError::PacketTooShort {
231 expected: offset + 1,
232 got: data.len(),
233 });
234 }
235 let ack_count = data[offset] as usize;
236 offset += 1;
237
238 if ack_count > 0 {
240 for _ in 0..ack_count {
242 if data.len() < offset + 4 {
243 return Err(ProtocolError::PacketTooShort {
244 expected: offset + 4,
245 got: data.len(),
246 });
247 }
248 acks.push(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
249 offset += 4;
250 }
251
252 if data.len() < offset + 8 {
254 return Err(ProtocolError::PacketTooShort {
255 expected: offset + 8,
256 got: data.len(),
257 });
258 }
259 let mut rsid = [0u8; 8];
260 rsid.copy_from_slice(&data[offset..offset + 8]);
261 remote_session_id = Some(rsid);
262 offset += 8;
263 }
264
265 let message_packet_id = if header.opcode != OpCode::AckV1 && data.len() >= offset + 4 {
267 let id = u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap());
268 offset += 4;
269 Some(id)
270 } else {
271 None
272 };
273
274 let payload = if offset < data.len() {
276 Bytes::copy_from_slice(&data[offset..])
277 } else {
278 Bytes::new()
279 };
280
281 Ok(Packet::Control(ControlPacketData {
282 header,
283 remote_session_id,
284 acks,
285 message_packet_id,
286 payload,
287 }))
288 }
289
290 #[inline]
292 pub fn serialize(&self) -> BytesMut {
293 let mut buf = BytesMut::with_capacity(1500);
295
296 match self {
297 Packet::Control(ctrl) => {
298 ctrl.header.serialize(&mut buf);
299
300 buf.put_u8(ctrl.acks.len() as u8);
302
303 for ack in &ctrl.acks {
305 buf.put_u32(*ack);
306 }
307
308 if !ctrl.acks.is_empty() {
310 if let Some(rsid) = &ctrl.remote_session_id {
311 buf.put_slice(rsid);
312 }
313 }
314
315 if let Some(mpid) = ctrl.message_packet_id {
317 buf.put_u32(mpid);
318 }
319
320 buf.put_slice(&ctrl.payload);
322 }
323 Packet::Data(data) => {
324 data.header.serialize(&mut buf);
325
326 if let Some(pid) = data.peer_id {
328 buf.put_u8((pid >> 16) as u8);
329 buf.put_u8((pid >> 8) as u8);
330 buf.put_u8(pid as u8);
331 }
332
333 buf.put_slice(&data.payload);
335 }
336 }
337
338 buf
339 }
340
341 pub fn opcode(&self) -> OpCode {
343 match self {
344 Packet::Control(c) => c.header.opcode,
345 Packet::Data(d) => d.header.opcode,
346 }
347 }
348
349 pub fn key_id(&self) -> KeyId {
351 match self {
352 Packet::Control(c) => c.header.key_id,
353 Packet::Data(d) => d.header.key_id,
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_hard_reset_parse() {
364 let data = [
366 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, ];
370
371 let packet = Packet::parse(&data, false).unwrap();
372 if let Packet::Control(ctrl) = packet {
373 assert_eq!(ctrl.header.opcode, OpCode::HardResetClientV2);
374 assert_eq!(ctrl.header.key_id, KeyId::new(0));
375 assert_eq!(
376 ctrl.header.session_id,
377 Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08])
378 );
379 assert!(ctrl.acks.is_empty());
380 } else {
381 panic!("Expected control packet");
382 }
383 }
384
385 #[test]
386 fn test_data_packet_v2() {
387 let data = [
388 0x48, 0x00, 0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF, ];
392
393 let packet = Packet::parse(&data, false).unwrap();
394 if let Packet::Data(d) = packet {
395 assert_eq!(d.header.opcode, OpCode::DataV2);
396 assert_eq!(d.peer_id, Some(1));
397 assert_eq!(&d.payload[..], &[0xDE, 0xAD, 0xBE, 0xEF]);
398 } else {
399 panic!("Expected data packet");
400 }
401 }
402}