1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9 TooShort,
10 ExceedsMtu,
11 MissingTransportId,
12 InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 PacketError::TooShort => write!(f, "Packet too short"),
19 PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20 PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21 PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32 pub header_type: u8,
33 pub context_flag: u8,
34 pub transport_type: u8,
35 pub destination_type: u8,
36 pub packet_type: u8,
37}
38
39impl PacketFlags {
40 pub fn pack(&self) -> u8 {
51 (self.header_type << 6)
52 | (self.context_flag << 5)
53 | (self.transport_type << 4)
54 | (self.destination_type << 2)
55 | self.packet_type
56 }
57
58 pub fn unpack(byte: u8) -> Self {
60 PacketFlags {
61 header_type: (byte & 0b01000000) >> 6,
62 context_flag: (byte & 0b00100000) >> 5,
63 transport_type: (byte & 0b00010000) >> 4,
64 destination_type: (byte & 0b00001100) >> 2,
65 packet_type: byte & 0b00000011,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
75pub struct RawPacket {
76 pub flags: PacketFlags,
77 pub hops: u8,
78 pub transport_id: Option<[u8; 16]>,
79 pub destination_hash: [u8; 16],
80 pub context: u8,
81 pub data: Vec<u8>,
82 pub raw: Vec<u8>,
83 pub packet_hash: [u8; 32],
84 pub rssi: Option<i16>,
85 pub snr: Option<f32>,
86}
87
88impl RawPacket {
89 pub fn pack(
91 flags: PacketFlags,
92 hops: u8,
93 destination_hash: &[u8; 16],
94 transport_id: Option<&[u8; 16]>,
95 context: u8,
96 data: &[u8],
97 ) -> Result<Self, PacketError> {
98 Self::pack_with_max_mtu(
99 flags,
100 hops,
101 destination_hash,
102 transport_id,
103 context,
104 data,
105 constants::MTU,
106 )
107 }
108
109 pub fn pack_raw_with_hash(
111 flags: PacketFlags,
112 hops: u8,
113 destination_hash: &[u8; 16],
114 transport_id: Option<&[u8; 16]>,
115 context: u8,
116 data: &[u8],
117 ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
118 Self::pack_raw_with_hash_with_max_mtu(
119 flags,
120 hops,
121 destination_hash,
122 transport_id,
123 context,
124 data,
125 constants::MTU,
126 )
127 }
128
129 pub fn pack_with_max_mtu(
131 flags: PacketFlags,
132 hops: u8,
133 destination_hash: &[u8; 16],
134 transport_id: Option<&[u8; 16]>,
135 context: u8,
136 data: &[u8],
137 max_mtu: usize,
138 ) -> Result<Self, PacketError> {
139 let (raw, packet_hash) = Self::pack_raw_with_hash_with_max_mtu(
140 flags,
141 hops,
142 destination_hash,
143 transport_id,
144 context,
145 data,
146 max_mtu,
147 )?;
148
149 Ok(RawPacket {
150 flags,
151 hops,
152 transport_id: transport_id.copied(),
153 destination_hash: *destination_hash,
154 context,
155 data: data.to_vec(),
156 raw,
157 packet_hash,
158 rssi: None,
159 snr: None,
160 })
161 }
162
163 pub fn pack_raw_with_hash_with_max_mtu(
165 flags: PacketFlags,
166 hops: u8,
167 destination_hash: &[u8; 16],
168 transport_id: Option<&[u8; 16]>,
169 context: u8,
170 data: &[u8],
171 max_mtu: usize,
172 ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
173 if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
174 return Err(PacketError::MissingTransportId);
175 }
176
177 let mut raw = Vec::new();
178 raw.push(flags.pack());
179 raw.push(hops);
180
181 if let Some(transport_id) = transport_id {
182 if flags.header_type == constants::HEADER_2 {
183 raw.extend_from_slice(transport_id);
184 }
185 }
186
187 raw.extend_from_slice(destination_hash);
188 raw.push(context);
189 raw.extend_from_slice(data);
190
191 if raw.len() > max_mtu {
192 return Err(PacketError::ExceedsMtu);
193 }
194
195 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
196 Ok((raw, packet_hash))
197 }
198
199 pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
201 if raw.len() < constants::HEADER_MINSIZE {
202 return Err(PacketError::TooShort);
203 }
204
205 let flags = PacketFlags::unpack(raw[0]);
206 let hops = raw[1];
207
208 let dst_len = constants::TRUNCATED_HASHLENGTH / 8; if flags.header_type == constants::HEADER_2 {
211 let min_len = 2 + dst_len * 2 + 1;
213 if raw.len() < min_len {
214 return Err(PacketError::TooShort);
215 }
216
217 let mut transport_id = [0u8; 16];
218 transport_id.copy_from_slice(&raw[2..2 + dst_len]);
219
220 let mut destination_hash = [0u8; 16];
221 destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
222
223 let context = raw[2 + 2 * dst_len];
224 let data = raw[2 + 2 * dst_len + 1..].to_vec();
225
226 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
227
228 Ok(RawPacket {
229 flags,
230 hops,
231 transport_id: Some(transport_id),
232 destination_hash,
233 context,
234 data,
235 raw: raw.to_vec(),
236 packet_hash,
237 rssi: None,
238 snr: None,
239 })
240 } else if flags.header_type == constants::HEADER_1 {
241 let min_len = 2 + dst_len + 1;
243 if raw.len() < min_len {
244 return Err(PacketError::TooShort);
245 }
246
247 let mut destination_hash = [0u8; 16];
248 destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
249
250 let context = raw[2 + dst_len];
251 let data = raw[2 + dst_len + 1..].to_vec();
252
253 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
254
255 Ok(RawPacket {
256 flags,
257 hops,
258 transport_id: None,
259 destination_hash,
260 context,
261 data,
262 raw: raw.to_vec(),
263 packet_hash,
264 rssi: None,
265 snr: None,
266 })
267 } else {
268 Err(PacketError::InvalidHeaderType)
269 }
270 }
271
272 pub fn get_hashable_part(&self) -> Vec<u8> {
279 Self::compute_hashable_part(self.flags.header_type, &self.raw)
280 }
281
282 fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
283 let mut hashable = Vec::new();
284 hashable.push(raw[0] & 0b00001111);
285 if header_type == constants::HEADER_2 {
286 hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
288 } else {
289 hashable.extend_from_slice(&raw[2..]);
290 }
291 hashable
292 }
293
294 pub fn get_hash(&self) -> [u8; 32] {
296 self.packet_hash
297 }
298
299 pub fn get_truncated_hash(&self) -> [u8; 16] {
301 let mut result = [0u8; 16];
302 result.copy_from_slice(&self.packet_hash[..16]);
303 result
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_flags_pack_header1_data_single_broadcast() {
313 let flags = PacketFlags {
314 header_type: constants::HEADER_1,
315 context_flag: constants::FLAG_UNSET,
316 transport_type: constants::TRANSPORT_BROADCAST,
317 destination_type: constants::DESTINATION_SINGLE,
318 packet_type: constants::PACKET_TYPE_DATA,
319 };
320 assert_eq!(flags.pack(), 0x00);
321 }
322
323 #[test]
324 fn test_flags_pack_header2_announce_single_transport() {
325 let flags = PacketFlags {
326 header_type: constants::HEADER_2,
327 context_flag: constants::FLAG_UNSET,
328 transport_type: constants::TRANSPORT_TRANSPORT,
329 destination_type: constants::DESTINATION_SINGLE,
330 packet_type: constants::PACKET_TYPE_ANNOUNCE,
331 };
332 assert_eq!(flags.pack(), 0x51);
334 }
335
336 #[test]
337 fn test_flags_roundtrip() {
338 for byte in 0..=0x7Fu8 {
339 let flags = PacketFlags::unpack(byte);
340 assert_eq!(flags.pack(), byte);
341 }
342 }
343
344 #[test]
345 fn test_pack_header1() {
346 let dest_hash = [0xAA; 16];
347 let data = b"hello";
348 let flags = PacketFlags {
349 header_type: constants::HEADER_1,
350 context_flag: constants::FLAG_UNSET,
351 transport_type: constants::TRANSPORT_BROADCAST,
352 destination_type: constants::DESTINATION_SINGLE,
353 packet_type: constants::PACKET_TYPE_DATA,
354 };
355
356 let pkt =
357 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
358
359 assert_eq!(pkt.raw.len(), 24);
361 assert_eq!(pkt.raw[0], 0x00); assert_eq!(pkt.raw[1], 0x00); assert_eq!(&pkt.raw[2..18], &dest_hash); assert_eq!(pkt.raw[18], 0x00); assert_eq!(&pkt.raw[19..], b"hello"); }
367
368 #[test]
369 fn test_pack_header2() {
370 let dest_hash = [0xAA; 16];
371 let transport_id = [0xBB; 16];
372 let data = b"world";
373 let flags = PacketFlags {
374 header_type: constants::HEADER_2,
375 context_flag: constants::FLAG_UNSET,
376 transport_type: constants::TRANSPORT_TRANSPORT,
377 destination_type: constants::DESTINATION_SINGLE,
378 packet_type: constants::PACKET_TYPE_ANNOUNCE,
379 };
380
381 let pkt = RawPacket::pack(
382 flags,
383 3,
384 &dest_hash,
385 Some(&transport_id),
386 constants::CONTEXT_NONE,
387 data,
388 )
389 .unwrap();
390
391 assert_eq!(pkt.raw.len(), 40);
393 assert_eq!(pkt.raw[0], flags.pack());
394 assert_eq!(pkt.raw[1], 3);
395 assert_eq!(&pkt.raw[2..18], &transport_id);
396 assert_eq!(&pkt.raw[18..34], &dest_hash);
397 assert_eq!(pkt.raw[34], 0x00);
398 assert_eq!(&pkt.raw[35..], b"world");
399 }
400
401 #[test]
402 fn test_unpack_roundtrip_header1() {
403 let dest_hash = [0x11; 16];
404 let data = b"test data";
405 let flags = PacketFlags {
406 header_type: constants::HEADER_1,
407 context_flag: constants::FLAG_UNSET,
408 transport_type: constants::TRANSPORT_BROADCAST,
409 destination_type: constants::DESTINATION_SINGLE,
410 packet_type: constants::PACKET_TYPE_DATA,
411 };
412
413 let pkt = RawPacket::pack(
414 flags,
415 5,
416 &dest_hash,
417 None,
418 constants::CONTEXT_RESOURCE,
419 data,
420 )
421 .unwrap();
422 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
423
424 assert_eq!(unpacked.flags, flags);
425 assert_eq!(unpacked.hops, 5);
426 assert!(unpacked.transport_id.is_none());
427 assert_eq!(unpacked.destination_hash, dest_hash);
428 assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
429 assert_eq!(unpacked.data, data);
430 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
431 }
432
433 #[test]
434 fn test_unpack_roundtrip_header2() {
435 let dest_hash = [0x22; 16];
436 let transport_id = [0x33; 16];
437 let data = b"transported";
438 let flags = PacketFlags {
439 header_type: constants::HEADER_2,
440 context_flag: constants::FLAG_SET,
441 transport_type: constants::TRANSPORT_TRANSPORT,
442 destination_type: constants::DESTINATION_SINGLE,
443 packet_type: constants::PACKET_TYPE_ANNOUNCE,
444 };
445
446 let pkt = RawPacket::pack(
447 flags,
448 2,
449 &dest_hash,
450 Some(&transport_id),
451 constants::CONTEXT_NONE,
452 data,
453 )
454 .unwrap();
455 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
456
457 assert_eq!(unpacked.flags, flags);
458 assert_eq!(unpacked.hops, 2);
459 assert_eq!(unpacked.transport_id.unwrap(), transport_id);
460 assert_eq!(unpacked.destination_hash, dest_hash);
461 assert_eq!(unpacked.context, constants::CONTEXT_NONE);
462 assert_eq!(unpacked.data, data);
463 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
464 }
465
466 #[test]
467 fn test_unpack_too_short() {
468 assert!(RawPacket::unpack(&[0x00; 5]).is_err());
469 }
470
471 #[test]
472 fn test_pack_exceeds_mtu() {
473 let flags = PacketFlags {
474 header_type: constants::HEADER_1,
475 context_flag: constants::FLAG_UNSET,
476 transport_type: constants::TRANSPORT_BROADCAST,
477 destination_type: constants::DESTINATION_SINGLE,
478 packet_type: constants::PACKET_TYPE_DATA,
479 };
480 let data = [0u8; 500]; let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
482 assert!(result.is_err());
483 }
484
485 #[test]
486 fn test_header2_missing_transport_id() {
487 let flags = PacketFlags {
488 header_type: constants::HEADER_2,
489 context_flag: constants::FLAG_UNSET,
490 transport_type: constants::TRANSPORT_TRANSPORT,
491 destination_type: constants::DESTINATION_SINGLE,
492 packet_type: constants::PACKET_TYPE_ANNOUNCE,
493 };
494 let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
495 assert!(result.is_err());
496 }
497
498 #[test]
499 fn test_hashable_part_header1_masks_upper_flags() {
500 let dest_hash = [0xCC; 16];
501 let flags = PacketFlags {
502 header_type: constants::HEADER_1,
503 context_flag: constants::FLAG_SET,
504 transport_type: constants::TRANSPORT_BROADCAST,
505 destination_type: constants::DESTINATION_SINGLE,
506 packet_type: constants::PACKET_TYPE_DATA,
507 };
508
509 let pkt =
510 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
511 let hashable = pkt.get_hashable_part();
512
513 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
515 assert_eq!(&hashable[1..], &pkt.raw[2..]);
517 }
518
519 #[test]
520 fn test_hashable_part_header2_strips_transport_id() {
521 let dest_hash = [0xDD; 16];
522 let transport_id = [0xEE; 16];
523 let flags = PacketFlags {
524 header_type: constants::HEADER_2,
525 context_flag: constants::FLAG_UNSET,
526 transport_type: constants::TRANSPORT_TRANSPORT,
527 destination_type: constants::DESTINATION_SINGLE,
528 packet_type: constants::PACKET_TYPE_ANNOUNCE,
529 };
530
531 let pkt = RawPacket::pack(
532 flags,
533 0,
534 &dest_hash,
535 Some(&transport_id),
536 constants::CONTEXT_NONE,
537 b"data",
538 )
539 .unwrap();
540 let hashable = pkt.get_hashable_part();
541
542 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
544 assert_eq!(&hashable[1..], &pkt.raw[18..]);
546 }
547}