1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9 TooShort,
10 ExceedsMtu,
11 MissingTransportId,
12 InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 PacketError::TooShort => write!(f, "Packet too short"),
19 PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20 PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21 PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32 pub header_type: u8,
33 pub context_flag: u8,
34 pub transport_type: u8,
35 pub destination_type: u8,
36 pub packet_type: u8,
37}
38
39impl PacketFlags {
40 pub fn pack(&self) -> u8 {
51 (self.header_type << 6)
52 | (self.context_flag << 5)
53 | (self.transport_type << 4)
54 | (self.destination_type << 2)
55 | self.packet_type
56 }
57
58 pub fn unpack(byte: u8) -> Self {
60 PacketFlags {
61 header_type: (byte & 0b01000000) >> 6,
62 context_flag: (byte & 0b00100000) >> 5,
63 transport_type: (byte & 0b00010000) >> 4,
64 destination_type: (byte & 0b00001100) >> 2,
65 packet_type: byte & 0b00000011,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
75pub struct RawPacket {
76 pub flags: PacketFlags,
77 pub hops: u8,
78 pub transport_id: Option<[u8; 16]>,
79 pub destination_hash: [u8; 16],
80 pub context: u8,
81 pub data: Vec<u8>,
82 pub raw: Vec<u8>,
83 pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87 pub fn pack(
89 flags: PacketFlags,
90 hops: u8,
91 destination_hash: &[u8; 16],
92 transport_id: Option<&[u8; 16]>,
93 context: u8,
94 data: &[u8],
95 ) -> Result<Self, PacketError> {
96 if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
97 return Err(PacketError::MissingTransportId);
98 }
99
100 let mut raw = Vec::new();
101 raw.push(flags.pack());
102 raw.push(hops);
103
104 if flags.header_type == constants::HEADER_2 {
105 raw.extend_from_slice(transport_id.unwrap());
106 }
107
108 raw.extend_from_slice(destination_hash);
109 raw.push(context);
110 raw.extend_from_slice(data);
111
112 if raw.len() > constants::MTU {
113 return Err(PacketError::ExceedsMtu);
114 }
115
116 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
117
118 Ok(RawPacket {
119 flags,
120 hops,
121 transport_id: transport_id.copied(),
122 destination_hash: *destination_hash,
123 context,
124 data: data.to_vec(),
125 raw,
126 packet_hash,
127 })
128 }
129
130 pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
132 if raw.len() < constants::HEADER_MINSIZE {
133 return Err(PacketError::TooShort);
134 }
135
136 let flags = PacketFlags::unpack(raw[0]);
137 let hops = raw[1];
138
139 let dst_len = constants::TRUNCATED_HASHLENGTH / 8; if flags.header_type == constants::HEADER_2 {
142 let min_len = 2 + dst_len * 2 + 1;
144 if raw.len() < min_len {
145 return Err(PacketError::TooShort);
146 }
147
148 let mut transport_id = [0u8; 16];
149 transport_id.copy_from_slice(&raw[2..2 + dst_len]);
150
151 let mut destination_hash = [0u8; 16];
152 destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
153
154 let context = raw[2 + 2 * dst_len];
155 let data = raw[2 + 2 * dst_len + 1..].to_vec();
156
157 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
158
159 Ok(RawPacket {
160 flags,
161 hops,
162 transport_id: Some(transport_id),
163 destination_hash,
164 context,
165 data,
166 raw: raw.to_vec(),
167 packet_hash,
168 })
169 } else if flags.header_type == constants::HEADER_1 {
170 let min_len = 2 + dst_len + 1;
172 if raw.len() < min_len {
173 return Err(PacketError::TooShort);
174 }
175
176 let mut destination_hash = [0u8; 16];
177 destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
178
179 let context = raw[2 + dst_len];
180 let data = raw[2 + dst_len + 1..].to_vec();
181
182 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
183
184 Ok(RawPacket {
185 flags,
186 hops,
187 transport_id: None,
188 destination_hash,
189 context,
190 data,
191 raw: raw.to_vec(),
192 packet_hash,
193 })
194 } else {
195 Err(PacketError::InvalidHeaderType)
196 }
197 }
198
199 pub fn get_hashable_part(&self) -> Vec<u8> {
206 Self::compute_hashable_part(self.flags.header_type, &self.raw)
207 }
208
209 fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
210 let mut hashable = Vec::new();
211 hashable.push(raw[0] & 0b00001111);
212 if header_type == constants::HEADER_2 {
213 hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
215 } else {
216 hashable.extend_from_slice(&raw[2..]);
217 }
218 hashable
219 }
220
221 pub fn get_hash(&self) -> [u8; 32] {
223 self.packet_hash
224 }
225
226 pub fn get_truncated_hash(&self) -> [u8; 16] {
228 let mut result = [0u8; 16];
229 result.copy_from_slice(&self.packet_hash[..16]);
230 result
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_flags_pack_header1_data_single_broadcast() {
240 let flags = PacketFlags {
241 header_type: constants::HEADER_1,
242 context_flag: constants::FLAG_UNSET,
243 transport_type: constants::TRANSPORT_BROADCAST,
244 destination_type: constants::DESTINATION_SINGLE,
245 packet_type: constants::PACKET_TYPE_DATA,
246 };
247 assert_eq!(flags.pack(), 0x00);
248 }
249
250 #[test]
251 fn test_flags_pack_header2_announce_single_transport() {
252 let flags = PacketFlags {
253 header_type: constants::HEADER_2,
254 context_flag: constants::FLAG_UNSET,
255 transport_type: constants::TRANSPORT_TRANSPORT,
256 destination_type: constants::DESTINATION_SINGLE,
257 packet_type: constants::PACKET_TYPE_ANNOUNCE,
258 };
259 assert_eq!(flags.pack(), 0x51);
261 }
262
263 #[test]
264 fn test_flags_roundtrip() {
265 for byte in 0..=0x7Fu8 {
266 let flags = PacketFlags::unpack(byte);
267 assert_eq!(flags.pack(), byte);
268 }
269 }
270
271 #[test]
272 fn test_pack_header1() {
273 let dest_hash = [0xAA; 16];
274 let data = b"hello";
275 let flags = PacketFlags {
276 header_type: constants::HEADER_1,
277 context_flag: constants::FLAG_UNSET,
278 transport_type: constants::TRANSPORT_BROADCAST,
279 destination_type: constants::DESTINATION_SINGLE,
280 packet_type: constants::PACKET_TYPE_DATA,
281 };
282
283 let pkt =
284 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
285
286 assert_eq!(pkt.raw.len(), 24);
288 assert_eq!(pkt.raw[0], 0x00); assert_eq!(pkt.raw[1], 0x00); assert_eq!(&pkt.raw[2..18], &dest_hash); assert_eq!(pkt.raw[18], 0x00); assert_eq!(&pkt.raw[19..], b"hello"); }
294
295 #[test]
296 fn test_pack_header2() {
297 let dest_hash = [0xAA; 16];
298 let transport_id = [0xBB; 16];
299 let data = b"world";
300 let flags = PacketFlags {
301 header_type: constants::HEADER_2,
302 context_flag: constants::FLAG_UNSET,
303 transport_type: constants::TRANSPORT_TRANSPORT,
304 destination_type: constants::DESTINATION_SINGLE,
305 packet_type: constants::PACKET_TYPE_ANNOUNCE,
306 };
307
308 let pkt = RawPacket::pack(
309 flags,
310 3,
311 &dest_hash,
312 Some(&transport_id),
313 constants::CONTEXT_NONE,
314 data,
315 )
316 .unwrap();
317
318 assert_eq!(pkt.raw.len(), 40);
320 assert_eq!(pkt.raw[0], flags.pack());
321 assert_eq!(pkt.raw[1], 3);
322 assert_eq!(&pkt.raw[2..18], &transport_id);
323 assert_eq!(&pkt.raw[18..34], &dest_hash);
324 assert_eq!(pkt.raw[34], 0x00);
325 assert_eq!(&pkt.raw[35..], b"world");
326 }
327
328 #[test]
329 fn test_unpack_roundtrip_header1() {
330 let dest_hash = [0x11; 16];
331 let data = b"test data";
332 let flags = PacketFlags {
333 header_type: constants::HEADER_1,
334 context_flag: constants::FLAG_UNSET,
335 transport_type: constants::TRANSPORT_BROADCAST,
336 destination_type: constants::DESTINATION_SINGLE,
337 packet_type: constants::PACKET_TYPE_DATA,
338 };
339
340 let pkt = RawPacket::pack(
341 flags,
342 5,
343 &dest_hash,
344 None,
345 constants::CONTEXT_RESOURCE,
346 data,
347 )
348 .unwrap();
349 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
350
351 assert_eq!(unpacked.flags, flags);
352 assert_eq!(unpacked.hops, 5);
353 assert!(unpacked.transport_id.is_none());
354 assert_eq!(unpacked.destination_hash, dest_hash);
355 assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
356 assert_eq!(unpacked.data, data);
357 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
358 }
359
360 #[test]
361 fn test_unpack_roundtrip_header2() {
362 let dest_hash = [0x22; 16];
363 let transport_id = [0x33; 16];
364 let data = b"transported";
365 let flags = PacketFlags {
366 header_type: constants::HEADER_2,
367 context_flag: constants::FLAG_SET,
368 transport_type: constants::TRANSPORT_TRANSPORT,
369 destination_type: constants::DESTINATION_SINGLE,
370 packet_type: constants::PACKET_TYPE_ANNOUNCE,
371 };
372
373 let pkt = RawPacket::pack(
374 flags,
375 2,
376 &dest_hash,
377 Some(&transport_id),
378 constants::CONTEXT_NONE,
379 data,
380 )
381 .unwrap();
382 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
383
384 assert_eq!(unpacked.flags, flags);
385 assert_eq!(unpacked.hops, 2);
386 assert_eq!(unpacked.transport_id.unwrap(), transport_id);
387 assert_eq!(unpacked.destination_hash, dest_hash);
388 assert_eq!(unpacked.context, constants::CONTEXT_NONE);
389 assert_eq!(unpacked.data, data);
390 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
391 }
392
393 #[test]
394 fn test_unpack_too_short() {
395 assert!(RawPacket::unpack(&[0x00; 5]).is_err());
396 }
397
398 #[test]
399 fn test_pack_exceeds_mtu() {
400 let flags = PacketFlags {
401 header_type: constants::HEADER_1,
402 context_flag: constants::FLAG_UNSET,
403 transport_type: constants::TRANSPORT_BROADCAST,
404 destination_type: constants::DESTINATION_SINGLE,
405 packet_type: constants::PACKET_TYPE_DATA,
406 };
407 let data = [0u8; 500]; let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
409 assert!(result.is_err());
410 }
411
412 #[test]
413 fn test_header2_missing_transport_id() {
414 let flags = PacketFlags {
415 header_type: constants::HEADER_2,
416 context_flag: constants::FLAG_UNSET,
417 transport_type: constants::TRANSPORT_TRANSPORT,
418 destination_type: constants::DESTINATION_SINGLE,
419 packet_type: constants::PACKET_TYPE_ANNOUNCE,
420 };
421 let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
422 assert!(result.is_err());
423 }
424
425 #[test]
426 fn test_hashable_part_header1_masks_upper_flags() {
427 let dest_hash = [0xCC; 16];
428 let flags = PacketFlags {
429 header_type: constants::HEADER_1,
430 context_flag: constants::FLAG_SET,
431 transport_type: constants::TRANSPORT_BROADCAST,
432 destination_type: constants::DESTINATION_SINGLE,
433 packet_type: constants::PACKET_TYPE_DATA,
434 };
435
436 let pkt =
437 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
438 let hashable = pkt.get_hashable_part();
439
440 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
442 assert_eq!(&hashable[1..], &pkt.raw[2..]);
444 }
445
446 #[test]
447 fn test_hashable_part_header2_strips_transport_id() {
448 let dest_hash = [0xDD; 16];
449 let transport_id = [0xEE; 16];
450 let flags = PacketFlags {
451 header_type: constants::HEADER_2,
452 context_flag: constants::FLAG_UNSET,
453 transport_type: constants::TRANSPORT_TRANSPORT,
454 destination_type: constants::DESTINATION_SINGLE,
455 packet_type: constants::PACKET_TYPE_ANNOUNCE,
456 };
457
458 let pkt = RawPacket::pack(
459 flags,
460 0,
461 &dest_hash,
462 Some(&transport_id),
463 constants::CONTEXT_NONE,
464 b"data",
465 )
466 .unwrap();
467 let hashable = pkt.get_hashable_part();
468
469 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
471 assert_eq!(&hashable[1..], &pkt.raw[18..]);
473 }
474}