1use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::{OpCode, KeyId, ProtocolError, Result};
11
12pub type SessionId = [u8; 8];
14
15pub type PacketId = u32;
17
18#[derive(Debug, Clone)]
20pub struct PacketHeader {
21 pub opcode: OpCode,
23 pub key_id: KeyId,
25 pub session_id: Option<SessionId>,
27 pub hmac: Option<[u8; 32]>,
29 pub packet_id: Option<PacketId>,
31 pub timestamp: Option<u32>,
33}
34
35impl PacketHeader {
36 pub const MIN_SIZE: usize = 1;
38
39 pub const CONTROL_HEADER_SIZE: usize = 1 + 8; #[inline]
44 pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<(Self, usize)> {
45 if data.is_empty() {
46 return Err(ProtocolError::PacketTooShort {
47 expected: 1,
48 got: 0,
49 });
50 }
51
52 let opcode = OpCode::from_byte(data[0])?;
53 let key_id = KeyId::from_byte(data[0]);
54
55 if opcode.is_data() {
56 return Ok((
58 Self {
59 opcode,
60 key_id,
61 session_id: None,
62 hmac: None,
63 packet_id: None,
64 timestamp: None,
65 },
66 1,
67 ));
68 }
69
70 let mut offset = 1;
74 let mut hmac = None;
75 let mut packet_id = None;
76 let mut timestamp = None;
77
78 if data.len() < offset + 8 {
80 return Err(ProtocolError::PacketTooShort {
81 expected: offset + 8,
82 got: data.len(),
83 });
84 }
85 let mut session_id = [0u8; 8];
86 session_id.copy_from_slice(&data[offset..offset + 8]);
87 offset += 8;
88
89 if has_tls_auth {
91 if data.len() < offset + 32 {
93 return Err(ProtocolError::PacketTooShort {
94 expected: offset + 32,
95 got: data.len(),
96 });
97 }
98 let mut h = [0u8; 32];
99 h.copy_from_slice(&data[offset..offset + 32]);
100 hmac = Some(h);
101 offset += 32;
102
103 if data.len() < offset + 4 {
105 return Err(ProtocolError::PacketTooShort {
106 expected: offset + 4,
107 got: data.len(),
108 });
109 }
110 packet_id = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
111 offset += 4;
112
113 if data.len() < offset + 4 {
115 return Err(ProtocolError::PacketTooShort {
116 expected: offset + 4,
117 got: data.len(),
118 });
119 }
120 timestamp = Some(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
121 offset += 4;
122 }
123
124 Ok((
125 Self {
126 opcode,
127 key_id,
128 session_id: Some(session_id),
129 hmac,
130 packet_id,
131 timestamp,
132 },
133 offset,
134 ))
135 }
136
137 #[inline]
140 pub fn serialize(&self, buf: &mut BytesMut) {
141 buf.put_u8(self.opcode.to_byte(self.key_id));
142
143 if let Some(session_id) = &self.session_id {
145 buf.put_slice(session_id);
146 }
147
148 if let Some(hmac) = &self.hmac {
149 buf.put_slice(hmac);
150 }
151
152 if let Some(packet_id) = self.packet_id {
153 buf.put_u32(packet_id);
154 }
155
156 if let Some(timestamp) = self.timestamp {
157 buf.put_u32(timestamp);
158 }
159 }
160}
161
162#[derive(Debug, Clone)]
164pub enum Packet {
165 Control(ControlPacketData),
167 Data(DataPacketData),
169}
170
171#[derive(Debug, Clone)]
173pub struct ControlPacketData {
174 pub header: PacketHeader,
176 pub remote_session_id: Option<SessionId>,
178 pub acks: Vec<PacketId>,
180 pub message_packet_id: Option<PacketId>,
182 pub payload: Bytes,
184}
185
186#[derive(Debug, Clone)]
188pub struct DataPacketData {
189 pub header: PacketHeader,
191 pub peer_id: Option<u32>,
193 pub payload: Bytes,
195}
196
197impl Packet {
198 #[inline]
200 pub fn parse(data: &[u8], has_tls_auth: bool) -> Result<Self> {
201 let (header, mut offset) = PacketHeader::parse(data, has_tls_auth)?;
202
203 if header.opcode.is_data() {
204 let peer_id = if header.opcode == OpCode::DataV2 {
206 if data.len() < offset + 3 {
207 return Err(ProtocolError::PacketTooShort {
208 expected: offset + 3,
209 got: data.len(),
210 });
211 }
212 let pid = ((data[offset] as u32) << 16)
214 | ((data[offset + 1] as u32) << 8)
215 | (data[offset + 2] as u32);
216 offset += 3;
217 Some(pid)
218 } else {
219 None
220 };
221
222 if offset > data.len() {
224 return Err(ProtocolError::PacketTooShort {
225 expected: offset,
226 got: data.len(),
227 });
228 }
229 return Ok(Packet::Data(DataPacketData {
230 header,
231 peer_id,
232 payload: Bytes::copy_from_slice(&data[offset..]),
233 }));
234 }
235
236 let mut remote_session_id = None;
238 let mut acks = Vec::new();
239
240 if data.len() < offset + 1 {
242 return Err(ProtocolError::PacketTooShort {
243 expected: offset + 1,
244 got: data.len(),
245 });
246 }
247 const MAX_ACK_COUNT: usize = 16; let ack_count = data[offset] as usize;
249 if ack_count > MAX_ACK_COUNT {
250 return Err(ProtocolError::InvalidPacket(
251 format!("ACK count {} exceeds maximum {}", ack_count, MAX_ACK_COUNT).into(),
252 ));
253 }
254 offset += 1;
255
256 if ack_count > 0 {
258 for _ in 0..ack_count {
260 if data.len() < offset + 4 {
261 return Err(ProtocolError::PacketTooShort {
262 expected: offset + 4,
263 got: data.len(),
264 });
265 }
266 acks.push(u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap()));
267 offset += 4;
268 }
269
270 if data.len() < offset + 8 {
272 return Err(ProtocolError::PacketTooShort {
273 expected: offset + 8,
274 got: data.len(),
275 });
276 }
277 let mut rsid = [0u8; 8];
278 rsid.copy_from_slice(&data[offset..offset + 8]);
279 remote_session_id = Some(rsid);
280 offset += 8;
281 }
282
283 let message_packet_id = if header.opcode != OpCode::AckV1 && data.len() >= offset + 4 {
285 let id = u32::from_be_bytes(data[offset..offset + 4].try_into().unwrap());
286 offset += 4;
287 Some(id)
288 } else {
289 None
290 };
291
292 let payload = if offset < data.len() {
294 Bytes::copy_from_slice(&data[offset..])
295 } else {
296 Bytes::new()
297 };
298
299 Ok(Packet::Control(ControlPacketData {
300 header,
301 remote_session_id,
302 acks,
303 message_packet_id,
304 payload,
305 }))
306 }
307
308 #[inline]
310 pub fn serialize(&self) -> BytesMut {
311 let mut buf = BytesMut::with_capacity(1500);
313
314 match self {
315 Packet::Control(ctrl) => {
316 ctrl.header.serialize(&mut buf);
317
318 buf.put_u8(ctrl.acks.len() as u8);
320
321 for ack in &ctrl.acks {
323 buf.put_u32(*ack);
324 }
325
326 if !ctrl.acks.is_empty() {
328 if let Some(rsid) = &ctrl.remote_session_id {
329 buf.put_slice(rsid);
330 }
331 }
332
333 if let Some(mpid) = ctrl.message_packet_id {
335 buf.put_u32(mpid);
336 }
337
338 buf.put_slice(&ctrl.payload);
340 }
341 Packet::Data(data) => {
342 data.header.serialize(&mut buf);
343
344 if let Some(pid) = data.peer_id {
346 buf.put_u8((pid >> 16) as u8);
347 buf.put_u8((pid >> 8) as u8);
348 buf.put_u8(pid as u8);
349 }
350
351 buf.put_slice(&data.payload);
353 }
354 }
355
356 buf
357 }
358
359 pub fn opcode(&self) -> OpCode {
361 match self {
362 Packet::Control(c) => c.header.opcode,
363 Packet::Data(d) => d.header.opcode,
364 }
365 }
366
367 pub fn key_id(&self) -> KeyId {
369 match self {
370 Packet::Control(c) => c.header.key_id,
371 Packet::Data(d) => d.header.key_id,
372 }
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[test]
381 fn test_hard_reset_parse() {
382 let data = [
384 0x38, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, ];
388
389 let packet = Packet::parse(&data, false).unwrap();
390 if let Packet::Control(ctrl) = packet {
391 assert_eq!(ctrl.header.opcode, OpCode::HardResetClientV2);
392 assert_eq!(ctrl.header.key_id, KeyId::new(0));
393 assert_eq!(
394 ctrl.header.session_id,
395 Some([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08])
396 );
397 assert!(ctrl.acks.is_empty());
398 } else {
399 panic!("Expected control packet");
400 }
401 }
402
403 #[test]
404 fn test_data_packet_v2() {
405 let data = [
406 0x48, 0x00, 0x00, 0x01, 0xDE, 0xAD, 0xBE, 0xEF, ];
410
411 let packet = Packet::parse(&data, false).unwrap();
412 if let Packet::Data(d) = packet {
413 assert_eq!(d.header.opcode, OpCode::DataV2);
414 assert_eq!(d.peer_id, Some(1));
415 assert_eq!(&d.payload[..], &[0xDE, 0xAD, 0xBE, 0xEF]);
416 } else {
417 panic!("Expected data packet");
418 }
419 }
420}