1use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::{OpCode, KeyId, ProtocolError, Result};
11
12pub type SessionId = [u8; 8];
14
15pub type PacketId = u32;
17
18#[derive(Debug, Clone)]
20pub struct PacketHeader {
21 pub opcode: OpCode,
23 pub key_id: KeyId,
25 pub session_id: Option<SessionId>,
27 pub hmac: Option<[u8; 32]>,
29 pub packet_id: Option<PacketId>,
31 pub timestamp: Option<u32>,
33}
34
35impl PacketHeader {
36 pub const MIN_SIZE: usize = 1;
38
39 pub const CONTROL_HEADER_SIZE: usize = 1 + 8; #[inline]
44 pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<(Self, usize)> {
45 if data.is_empty() {
46 return Err(ProtocolError::PacketTooShort {
47 expected: 1,
48 got: 0,
49 });
50 }
51
52 let opcode = OpCode::from_byte(data[0])?;
53 let key_id = KeyId::from_byte(data[0]);
54
55 if opcode.is_data() {
56 return Ok((
58 Self {
59 opcode,
60 key_id,
61 session_id: None,
62 hmac: None,
63 packet_id: None,
64 timestamp: None,
65 },
66 1,
67 ));
68 }
69
70 let mut offset = 1;
72 let mut hmac = None;
73 let mut packet_id = None;
74 let mut timestamp = None;
75
76 if has_tls_auth {
78 if data.len() < offset + 32 {
79 return Err(ProtocolError::PacketTooShort {
80 expected: offset + 32,
81 got: data.len(),
82 });
83 }
84 let mut h = [0u8; 32];
85 h.copy_from_slice(&data[offset..offset + 32]);
86 hmac = Some(h);
87 offset += 32;
88
89 if data.len() < offset + 4 {
91 return Err(ProtocolError::PacketTooShort {
92 expected: offset + 4,
93 got: data.len(),
94 });
95 }
96 packet_id = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
97 offset += 4;
98
99 if data.len() < offset + 4 {
101 return Err(ProtocolError::PacketTooShort {
102 expected: offset + 4,
103 got: data.len(),
104 });
105 }
106 timestamp = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
107 offset += 4;
108 }
109
110 if data.len() < offset + 8 {
112 return Err(ProtocolError::PacketTooShort {
113 expected: offset + 8,
114 got: data.len(),
115 });
116 }
117 let mut session_id = [0u8; 8];
118 session_id.copy_from_slice(&data[offset..offset + 8]);
119 offset += 8;
120
121 Ok((
122 Self {
123 opcode,
124 key_id,
125 session_id: Some(session_id),
126 hmac,
127 packet_id,
128 timestamp,
129 },
130 offset,
131 ))
132 }
133
134 #[inline]
136 pub fn serialize(&self, buf: &mut BytesMut) {
137 buf.put_u8(self.opcode.to_byte(self.key_id));
138
139 if let Some(hmac) = &self.hmac {
140 buf.put_slice(hmac);
141 }
142
143 if let Some(packet_id) = self.packet_id {
144 buf.put_u32(packet_id);
145 }
146
147 if let Some(timestamp) = self.timestamp {
148 buf.put_u32(timestamp);
149 }
150
151 if let Some(session_id) = &self.session_id {
152 buf.put_slice(session_id);
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
159pub enum Packet {
160 Control(ControlPacketData),
162 Data(DataPacketData),
164}
165
166#[derive(Debug, Clone)]
168pub struct ControlPacketData {
169 pub header: PacketHeader,
171 pub remote_session_id: Option<SessionId>,
173 pub acks: Vec<PacketId>,
175 pub message_packet_id: Option<PacketId>,
177 pub payload: Bytes,
179}
180
181#[derive(Debug, Clone)]
183pub struct DataPacketData {
184 pub header: PacketHeader,
186 pub peer_id: Option<u32>,
188 pub payload: Bytes,
190}
191
192impl Packet {
193 #[inline]
195 pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<Self> {
196 let (header, mut offset) = PacketHeader::parse(data, has_tls_auth)?;
197
198 if header.opcode.is_data() {
199 let peer_id = if header.opcode == OpCode::DataV2 {
201 if data.len() < offset + 3 {
202 return Err(ProtocolError::PacketTooShort {
203 expected: offset + 3,
204 got: data.len(),
205 });
206 }
207 let pid = ((data[offset] as u32) << 16)
209 | ((data[offset + 1] as u32) << 8)
210 | (data[offset + 2] as u32);
211 offset += 3;
212 Some(pid)
213 } else {
214 None
215 };
216
217 if offset > data.len() {
219 return Err(ProtocolError::PacketTooShort {
220 expected: offset,
221 got: data.len(),
222 });
223 }
224 return Ok(Packet::Data(DataPacketData {
225 header,
226 peer_id,
227 payload: Bytes::copy_from_slice(&data[offset..]),
228 }));
229 }
230
231 let mut remote_session_id = None;
233 let mut acks = Vec::new();
234
235 if data.len() < offset + 1 {
237 return Err(ProtocolError::PacketTooShort {
238 expected: offset + 1,
239 got: data.len(),
240 });
241 }
242 const MAX_ACK_COUNT: usize = 16; let ack_count = data[offset] as usize;
244 if ack_count > MAX_ACK_COUNT {
245 return Err(ProtocolError::InvalidPacket(
246 format!("ACK count {} exceeds maximum {}", ack_count, MAX_ACK_COUNT).into(),
247 ));
248 }
249 offset += 1;
250
251 if ack_count > 0 {
253 for _ in 0..ack_count {
255 if data.len() < offset + 4 {
256 return Err(ProtocolError::PacketTooShort {
257 expected: offset + 4,
258 got: data.len(),
259 });
260 }
261 acks.push(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
262 offset += 4;
263 }
264
265 if data.len() < offset + 8 {
267 return Err(ProtocolError::PacketTooShort {
268 expected: offset + 8,
269 got: data.len(),
270 });
271 }
272 let mut rsid = [0u8; 8];
273 rsid.copy_from_slice(&data[offset..offset + 8]);
274 remote_session_id = Some(rsid);
275 offset += 8;
276 }
277
278 let message_packet_id = if header.opcode != OpCode::AckV1 && data.len() >= offset + 4 {
280 let id = u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap());
281 offset += 4;
282 Some(id)
283 } else {
284 None
285 };
286
287 let payload = if offset < data.len() {
289 Bytes::copy_from_slice(&data[offset..])
290 } else {
291 Bytes::new()
292 };
293
294 Ok(Packet::Control(ControlPacketData {
295 header,
296 remote_session_id,
297 acks,
298 message_packet_id,
299 payload,
300 }))
301 }
302
303 #[inline]
305 pub fn serialize(&self) -> BytesMut {
306 let mut buf = BytesMut::with_capacity(1500);
308
309 match self {
310 Packet::Control(ctrl) => {
311 ctrl.header.serialize(&mut buf);
312
313 buf.put_u8(ctrl.acks.len() as u8);
315
316 for ack in &ctrl.acks {
318 buf.put_u32(*ack);
319 }
320
321 if !ctrl.acks.is_empty() {
323 if let Some(rsid) = &ctrl.remote_session_id {
324 buf.put_slice(rsid);
325 }
326 }
327
328 if let Some(mpid) = ctrl.message_packet_id {
330 buf.put_u32(mpid);
331 }
332
333 buf.put_slice(&ctrl.payload);
335 }
336 Packet::Data(data) => {
337 data.header.serialize(&mut buf);
338
339 if let Some(pid) = data.peer_id {
341 buf.put_u8((pid >> 16) as u8);
342 buf.put_u8((pid >> 8) as u8);
343 buf.put_u8(pid as u8);
344 }
345
346 buf.put_slice(&data.payload);
348 }
349 }
350
351 buf
352 }
353
354 pub fn opcode(&self) -> OpCode {
356 match self {
357 Packet::Control(c) => c.header.opcode,
358 Packet::Data(d) => d.header.opcode,
359 }
360 }
361
362 pub fn key_id(&self) -> KeyId {
364 match self {
365 Packet::Control(c) => c.header.key_id,
366 Packet::Data(d) => d.header.key_id,
367 }
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_hard_reset_parse() {
377 let data = [
379 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, ];
383
384 let packet = Packet::parse(&data, false).unwrap();
385 if let Packet::Control(ctrl) = packet {
386 assert_eq!(ctrl.header.opcode, OpCode::HardResetClientV2);
387 assert_eq!(ctrl.header.key_id, KeyId::new(0));
388 assert_eq!(
389 ctrl.header.session_id,
390 Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08])
391 );
392 assert!(ctrl.acks.is_empty());
393 } else {
394 panic!("Expected control packet");
395 }
396 }
397
398 #[test]
399 fn test_data_packet_v2() {
400 let data = [
401 0x48, 0x00, 0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF, ];
405
406 let packet = Packet::parse(&data, false).unwrap();
407 if let Packet::Data(d) = packet {
408 assert_eq!(d.header.opcode, OpCode::DataV2);
409 assert_eq!(d.peer_id, Some(1));
410 assert_eq!(&d.payload[..], &[0xDE, 0xAD, 0xBE, 0xEF]);
411 } else {
412 panic!("Expected data packet");
413 }
414 }
415}