1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9 TooShort,
10 ExceedsMtu,
11 MissingTransportId,
12 InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 PacketError::TooShort => write!(f, "Packet too short"),
19 PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20 PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21 PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32 pub header_type: u8,
33 pub context_flag: u8,
34 pub transport_type: u8,
35 pub destination_type: u8,
36 pub packet_type: u8,
37}
38
39impl PacketFlags {
40 pub fn pack(&self) -> u8 {
51 (self.header_type << 6)
52 | (self.context_flag << 5)
53 | (self.transport_type << 4)
54 | (self.destination_type << 2)
55 | self.packet_type
56 }
57
58 pub fn unpack(byte: u8) -> Self {
60 PacketFlags {
61 header_type: (byte & 0b01000000) >> 6,
62 context_flag: (byte & 0b00100000) >> 5,
63 transport_type: (byte & 0b00010000) >> 4,
64 destination_type: (byte & 0b00001100) >> 2,
65 packet_type: byte & 0b00000011,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
75pub struct RawPacket {
76 pub flags: PacketFlags,
77 pub hops: u8,
78 pub transport_id: Option<[u8; 16]>,
79 pub destination_hash: [u8; 16],
80 pub context: u8,
81 pub data: Vec<u8>,
82 pub raw: Vec<u8>,
83 pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87 pub fn pack(
89 flags: PacketFlags,
90 hops: u8,
91 destination_hash: &[u8; 16],
92 transport_id: Option<&[u8; 16]>,
93 context: u8,
94 data: &[u8],
95 ) -> Result<Self, PacketError> {
96 Self::pack_with_max_mtu(
97 flags,
98 hops,
99 destination_hash,
100 transport_id,
101 context,
102 data,
103 constants::MTU,
104 )
105 }
106
107 pub fn pack_raw_with_hash(
109 flags: PacketFlags,
110 hops: u8,
111 destination_hash: &[u8; 16],
112 transport_id: Option<&[u8; 16]>,
113 context: u8,
114 data: &[u8],
115 ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
116 Self::pack_raw_with_hash_with_max_mtu(
117 flags,
118 hops,
119 destination_hash,
120 transport_id,
121 context,
122 data,
123 constants::MTU,
124 )
125 }
126
127 pub fn pack_with_max_mtu(
129 flags: PacketFlags,
130 hops: u8,
131 destination_hash: &[u8; 16],
132 transport_id: Option<&[u8; 16]>,
133 context: u8,
134 data: &[u8],
135 max_mtu: usize,
136 ) -> Result<Self, PacketError> {
137 let (raw, packet_hash) = Self::pack_raw_with_hash_with_max_mtu(
138 flags,
139 hops,
140 destination_hash,
141 transport_id,
142 context,
143 data,
144 max_mtu,
145 )?;
146
147 Ok(RawPacket {
148 flags,
149 hops,
150 transport_id: transport_id.copied(),
151 destination_hash: *destination_hash,
152 context,
153 data: data.to_vec(),
154 raw,
155 packet_hash,
156 })
157 }
158
159 pub fn pack_raw_with_hash_with_max_mtu(
161 flags: PacketFlags,
162 hops: u8,
163 destination_hash: &[u8; 16],
164 transport_id: Option<&[u8; 16]>,
165 context: u8,
166 data: &[u8],
167 max_mtu: usize,
168 ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
169 if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
170 return Err(PacketError::MissingTransportId);
171 }
172
173 let mut raw = Vec::new();
174 raw.push(flags.pack());
175 raw.push(hops);
176
177 if let Some(transport_id) = transport_id {
178 if flags.header_type == constants::HEADER_2 {
179 raw.extend_from_slice(transport_id);
180 }
181 }
182
183 raw.extend_from_slice(destination_hash);
184 raw.push(context);
185 raw.extend_from_slice(data);
186
187 if raw.len() > max_mtu {
188 return Err(PacketError::ExceedsMtu);
189 }
190
191 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
192 Ok((raw, packet_hash))
193 }
194
195 pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
197 if raw.len() < constants::HEADER_MINSIZE {
198 return Err(PacketError::TooShort);
199 }
200
201 let flags = PacketFlags::unpack(raw[0]);
202 let hops = raw[1];
203
204 let dst_len = constants::TRUNCATED_HASHLENGTH / 8; if flags.header_type == constants::HEADER_2 {
207 let min_len = 2 + dst_len * 2 + 1;
209 if raw.len() < min_len {
210 return Err(PacketError::TooShort);
211 }
212
213 let mut transport_id = [0u8; 16];
214 transport_id.copy_from_slice(&raw[2..2 + dst_len]);
215
216 let mut destination_hash = [0u8; 16];
217 destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
218
219 let context = raw[2 + 2 * dst_len];
220 let data = raw[2 + 2 * dst_len + 1..].to_vec();
221
222 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
223
224 Ok(RawPacket {
225 flags,
226 hops,
227 transport_id: Some(transport_id),
228 destination_hash,
229 context,
230 data,
231 raw: raw.to_vec(),
232 packet_hash,
233 })
234 } else if flags.header_type == constants::HEADER_1 {
235 let min_len = 2 + dst_len + 1;
237 if raw.len() < min_len {
238 return Err(PacketError::TooShort);
239 }
240
241 let mut destination_hash = [0u8; 16];
242 destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
243
244 let context = raw[2 + dst_len];
245 let data = raw[2 + dst_len + 1..].to_vec();
246
247 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
248
249 Ok(RawPacket {
250 flags,
251 hops,
252 transport_id: None,
253 destination_hash,
254 context,
255 data,
256 raw: raw.to_vec(),
257 packet_hash,
258 })
259 } else {
260 Err(PacketError::InvalidHeaderType)
261 }
262 }
263
264 pub fn get_hashable_part(&self) -> Vec<u8> {
271 Self::compute_hashable_part(self.flags.header_type, &self.raw)
272 }
273
274 fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
275 let mut hashable = Vec::new();
276 hashable.push(raw[0] & 0b00001111);
277 if header_type == constants::HEADER_2 {
278 hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
280 } else {
281 hashable.extend_from_slice(&raw[2..]);
282 }
283 hashable
284 }
285
286 pub fn get_hash(&self) -> [u8; 32] {
288 self.packet_hash
289 }
290
291 pub fn get_truncated_hash(&self) -> [u8; 16] {
293 let mut result = [0u8; 16];
294 result.copy_from_slice(&self.packet_hash[..16]);
295 result
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_flags_pack_header1_data_single_broadcast() {
305 let flags = PacketFlags {
306 header_type: constants::HEADER_1,
307 context_flag: constants::FLAG_UNSET,
308 transport_type: constants::TRANSPORT_BROADCAST,
309 destination_type: constants::DESTINATION_SINGLE,
310 packet_type: constants::PACKET_TYPE_DATA,
311 };
312 assert_eq!(flags.pack(), 0x00);
313 }
314
315 #[test]
316 fn test_flags_pack_header2_announce_single_transport() {
317 let flags = PacketFlags {
318 header_type: constants::HEADER_2,
319 context_flag: constants::FLAG_UNSET,
320 transport_type: constants::TRANSPORT_TRANSPORT,
321 destination_type: constants::DESTINATION_SINGLE,
322 packet_type: constants::PACKET_TYPE_ANNOUNCE,
323 };
324 assert_eq!(flags.pack(), 0x51);
326 }
327
328 #[test]
329 fn test_flags_roundtrip() {
330 for byte in 0..=0x7Fu8 {
331 let flags = PacketFlags::unpack(byte);
332 assert_eq!(flags.pack(), byte);
333 }
334 }
335
336 #[test]
337 fn test_pack_header1() {
338 let dest_hash = [0xAA; 16];
339 let data = b"hello";
340 let flags = PacketFlags {
341 header_type: constants::HEADER_1,
342 context_flag: constants::FLAG_UNSET,
343 transport_type: constants::TRANSPORT_BROADCAST,
344 destination_type: constants::DESTINATION_SINGLE,
345 packet_type: constants::PACKET_TYPE_DATA,
346 };
347
348 let pkt =
349 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
350
351 assert_eq!(pkt.raw.len(), 24);
353 assert_eq!(pkt.raw[0], 0x00); assert_eq!(pkt.raw[1], 0x00); assert_eq!(&pkt.raw[2..18], &dest_hash); assert_eq!(pkt.raw[18], 0x00); assert_eq!(&pkt.raw[19..], b"hello"); }
359
360 #[test]
361 fn test_pack_header2() {
362 let dest_hash = [0xAA; 16];
363 let transport_id = [0xBB; 16];
364 let data = b"world";
365 let flags = PacketFlags {
366 header_type: constants::HEADER_2,
367 context_flag: constants::FLAG_UNSET,
368 transport_type: constants::TRANSPORT_TRANSPORT,
369 destination_type: constants::DESTINATION_SINGLE,
370 packet_type: constants::PACKET_TYPE_ANNOUNCE,
371 };
372
373 let pkt = RawPacket::pack(
374 flags,
375 3,
376 &dest_hash,
377 Some(&transport_id),
378 constants::CONTEXT_NONE,
379 data,
380 )
381 .unwrap();
382
383 assert_eq!(pkt.raw.len(), 40);
385 assert_eq!(pkt.raw[0], flags.pack());
386 assert_eq!(pkt.raw[1], 3);
387 assert_eq!(&pkt.raw[2..18], &transport_id);
388 assert_eq!(&pkt.raw[18..34], &dest_hash);
389 assert_eq!(pkt.raw[34], 0x00);
390 assert_eq!(&pkt.raw[35..], b"world");
391 }
392
393 #[test]
394 fn test_unpack_roundtrip_header1() {
395 let dest_hash = [0x11; 16];
396 let data = b"test data";
397 let flags = PacketFlags {
398 header_type: constants::HEADER_1,
399 context_flag: constants::FLAG_UNSET,
400 transport_type: constants::TRANSPORT_BROADCAST,
401 destination_type: constants::DESTINATION_SINGLE,
402 packet_type: constants::PACKET_TYPE_DATA,
403 };
404
405 let pkt = RawPacket::pack(
406 flags,
407 5,
408 &dest_hash,
409 None,
410 constants::CONTEXT_RESOURCE,
411 data,
412 )
413 .unwrap();
414 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
415
416 assert_eq!(unpacked.flags, flags);
417 assert_eq!(unpacked.hops, 5);
418 assert!(unpacked.transport_id.is_none());
419 assert_eq!(unpacked.destination_hash, dest_hash);
420 assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
421 assert_eq!(unpacked.data, data);
422 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
423 }
424
425 #[test]
426 fn test_unpack_roundtrip_header2() {
427 let dest_hash = [0x22; 16];
428 let transport_id = [0x33; 16];
429 let data = b"transported";
430 let flags = PacketFlags {
431 header_type: constants::HEADER_2,
432 context_flag: constants::FLAG_SET,
433 transport_type: constants::TRANSPORT_TRANSPORT,
434 destination_type: constants::DESTINATION_SINGLE,
435 packet_type: constants::PACKET_TYPE_ANNOUNCE,
436 };
437
438 let pkt = RawPacket::pack(
439 flags,
440 2,
441 &dest_hash,
442 Some(&transport_id),
443 constants::CONTEXT_NONE,
444 data,
445 )
446 .unwrap();
447 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
448
449 assert_eq!(unpacked.flags, flags);
450 assert_eq!(unpacked.hops, 2);
451 assert_eq!(unpacked.transport_id.unwrap(), transport_id);
452 assert_eq!(unpacked.destination_hash, dest_hash);
453 assert_eq!(unpacked.context, constants::CONTEXT_NONE);
454 assert_eq!(unpacked.data, data);
455 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
456 }
457
458 #[test]
459 fn test_unpack_too_short() {
460 assert!(RawPacket::unpack(&[0x00; 5]).is_err());
461 }
462
463 #[test]
464 fn test_pack_exceeds_mtu() {
465 let flags = PacketFlags {
466 header_type: constants::HEADER_1,
467 context_flag: constants::FLAG_UNSET,
468 transport_type: constants::TRANSPORT_BROADCAST,
469 destination_type: constants::DESTINATION_SINGLE,
470 packet_type: constants::PACKET_TYPE_DATA,
471 };
472 let data = [0u8; 500]; let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
474 assert!(result.is_err());
475 }
476
477 #[test]
478 fn test_header2_missing_transport_id() {
479 let flags = PacketFlags {
480 header_type: constants::HEADER_2,
481 context_flag: constants::FLAG_UNSET,
482 transport_type: constants::TRANSPORT_TRANSPORT,
483 destination_type: constants::DESTINATION_SINGLE,
484 packet_type: constants::PACKET_TYPE_ANNOUNCE,
485 };
486 let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
487 assert!(result.is_err());
488 }
489
490 #[test]
491 fn test_hashable_part_header1_masks_upper_flags() {
492 let dest_hash = [0xCC; 16];
493 let flags = PacketFlags {
494 header_type: constants::HEADER_1,
495 context_flag: constants::FLAG_SET,
496 transport_type: constants::TRANSPORT_BROADCAST,
497 destination_type: constants::DESTINATION_SINGLE,
498 packet_type: constants::PACKET_TYPE_DATA,
499 };
500
501 let pkt =
502 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
503 let hashable = pkt.get_hashable_part();
504
505 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
507 assert_eq!(&hashable[1..], &pkt.raw[2..]);
509 }
510
511 #[test]
512 fn test_hashable_part_header2_strips_transport_id() {
513 let dest_hash = [0xDD; 16];
514 let transport_id = [0xEE; 16];
515 let flags = PacketFlags {
516 header_type: constants::HEADER_2,
517 context_flag: constants::FLAG_UNSET,
518 transport_type: constants::TRANSPORT_TRANSPORT,
519 destination_type: constants::DESTINATION_SINGLE,
520 packet_type: constants::PACKET_TYPE_ANNOUNCE,
521 };
522
523 let pkt = RawPacket::pack(
524 flags,
525 0,
526 &dest_hash,
527 Some(&transport_id),
528 constants::CONTEXT_NONE,
529 b"data",
530 )
531 .unwrap();
532 let hashable = pkt.get_hashable_part();
533
534 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
536 assert_eq!(&hashable[1..], &pkt.raw[18..]);
538 }
539}