1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9 TooShort,
10 ExceedsMtu,
11 MissingTransportId,
12 InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 PacketError::TooShort => write!(f, "Packet too short"),
19 PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20 PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21 PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32 pub header_type: u8,
33 pub context_flag: u8,
34 pub transport_type: u8,
35 pub destination_type: u8,
36 pub packet_type: u8,
37}
38
39impl PacketFlags {
40 pub fn pack(&self) -> u8 {
51 (self.header_type << 6)
52 | (self.context_flag << 5)
53 | (self.transport_type << 4)
54 | (self.destination_type << 2)
55 | self.packet_type
56 }
57
58 pub fn unpack(byte: u8) -> Self {
60 PacketFlags {
61 header_type: (byte & 0b01000000) >> 6,
62 context_flag: (byte & 0b00100000) >> 5,
63 transport_type: (byte & 0b00010000) >> 4,
64 destination_type: (byte & 0b00001100) >> 2,
65 packet_type: byte & 0b00000011,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
75pub struct RawPacket {
76 pub flags: PacketFlags,
77 pub hops: u8,
78 pub transport_id: Option<[u8; 16]>,
79 pub destination_hash: [u8; 16],
80 pub context: u8,
81 pub data: Vec<u8>,
82 pub raw: Vec<u8>,
83 pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87 pub fn pack(
89 flags: PacketFlags,
90 hops: u8,
91 destination_hash: &[u8; 16],
92 transport_id: Option<&[u8; 16]>,
93 context: u8,
94 data: &[u8],
95 ) -> Result<Self, PacketError> {
96 Self::pack_with_max_mtu(
97 flags,
98 hops,
99 destination_hash,
100 transport_id,
101 context,
102 data,
103 constants::MTU,
104 )
105 }
106
107 pub fn pack_with_max_mtu(
109 flags: PacketFlags,
110 hops: u8,
111 destination_hash: &[u8; 16],
112 transport_id: Option<&[u8; 16]>,
113 context: u8,
114 data: &[u8],
115 max_mtu: usize,
116 ) -> Result<Self, PacketError> {
117 if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
118 return Err(PacketError::MissingTransportId);
119 }
120
121 let mut raw = Vec::new();
122 raw.push(flags.pack());
123 raw.push(hops);
124
125 if flags.header_type == constants::HEADER_2 {
126 raw.extend_from_slice(transport_id.unwrap());
127 }
128
129 raw.extend_from_slice(destination_hash);
130 raw.push(context);
131 raw.extend_from_slice(data);
132
133 if raw.len() > max_mtu {
134 return Err(PacketError::ExceedsMtu);
135 }
136
137 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
138
139 Ok(RawPacket {
140 flags,
141 hops,
142 transport_id: transport_id.copied(),
143 destination_hash: *destination_hash,
144 context,
145 data: data.to_vec(),
146 raw,
147 packet_hash,
148 })
149 }
150
151 pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
153 if raw.len() < constants::HEADER_MINSIZE {
154 return Err(PacketError::TooShort);
155 }
156
157 let flags = PacketFlags::unpack(raw[0]);
158 let hops = raw[1];
159
160 let dst_len = constants::TRUNCATED_HASHLENGTH / 8; if flags.header_type == constants::HEADER_2 {
163 let min_len = 2 + dst_len * 2 + 1;
165 if raw.len() < min_len {
166 return Err(PacketError::TooShort);
167 }
168
169 let mut transport_id = [0u8; 16];
170 transport_id.copy_from_slice(&raw[2..2 + dst_len]);
171
172 let mut destination_hash = [0u8; 16];
173 destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
174
175 let context = raw[2 + 2 * dst_len];
176 let data = raw[2 + 2 * dst_len + 1..].to_vec();
177
178 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
179
180 Ok(RawPacket {
181 flags,
182 hops,
183 transport_id: Some(transport_id),
184 destination_hash,
185 context,
186 data,
187 raw: raw.to_vec(),
188 packet_hash,
189 })
190 } else if flags.header_type == constants::HEADER_1 {
191 let min_len = 2 + dst_len + 1;
193 if raw.len() < min_len {
194 return Err(PacketError::TooShort);
195 }
196
197 let mut destination_hash = [0u8; 16];
198 destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
199
200 let context = raw[2 + dst_len];
201 let data = raw[2 + dst_len + 1..].to_vec();
202
203 let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
204
205 Ok(RawPacket {
206 flags,
207 hops,
208 transport_id: None,
209 destination_hash,
210 context,
211 data,
212 raw: raw.to_vec(),
213 packet_hash,
214 })
215 } else {
216 Err(PacketError::InvalidHeaderType)
217 }
218 }
219
220 pub fn get_hashable_part(&self) -> Vec<u8> {
227 Self::compute_hashable_part(self.flags.header_type, &self.raw)
228 }
229
230 fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
231 let mut hashable = Vec::new();
232 hashable.push(raw[0] & 0b00001111);
233 if header_type == constants::HEADER_2 {
234 hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
236 } else {
237 hashable.extend_from_slice(&raw[2..]);
238 }
239 hashable
240 }
241
242 pub fn get_hash(&self) -> [u8; 32] {
244 self.packet_hash
245 }
246
247 pub fn get_truncated_hash(&self) -> [u8; 16] {
249 let mut result = [0u8; 16];
250 result.copy_from_slice(&self.packet_hash[..16]);
251 result
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 #[test]
260 fn test_flags_pack_header1_data_single_broadcast() {
261 let flags = PacketFlags {
262 header_type: constants::HEADER_1,
263 context_flag: constants::FLAG_UNSET,
264 transport_type: constants::TRANSPORT_BROADCAST,
265 destination_type: constants::DESTINATION_SINGLE,
266 packet_type: constants::PACKET_TYPE_DATA,
267 };
268 assert_eq!(flags.pack(), 0x00);
269 }
270
271 #[test]
272 fn test_flags_pack_header2_announce_single_transport() {
273 let flags = PacketFlags {
274 header_type: constants::HEADER_2,
275 context_flag: constants::FLAG_UNSET,
276 transport_type: constants::TRANSPORT_TRANSPORT,
277 destination_type: constants::DESTINATION_SINGLE,
278 packet_type: constants::PACKET_TYPE_ANNOUNCE,
279 };
280 assert_eq!(flags.pack(), 0x51);
282 }
283
284 #[test]
285 fn test_flags_roundtrip() {
286 for byte in 0..=0x7Fu8 {
287 let flags = PacketFlags::unpack(byte);
288 assert_eq!(flags.pack(), byte);
289 }
290 }
291
292 #[test]
293 fn test_pack_header1() {
294 let dest_hash = [0xAA; 16];
295 let data = b"hello";
296 let flags = PacketFlags {
297 header_type: constants::HEADER_1,
298 context_flag: constants::FLAG_UNSET,
299 transport_type: constants::TRANSPORT_BROADCAST,
300 destination_type: constants::DESTINATION_SINGLE,
301 packet_type: constants::PACKET_TYPE_DATA,
302 };
303
304 let pkt =
305 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
306
307 assert_eq!(pkt.raw.len(), 24);
309 assert_eq!(pkt.raw[0], 0x00); assert_eq!(pkt.raw[1], 0x00); assert_eq!(&pkt.raw[2..18], &dest_hash); assert_eq!(pkt.raw[18], 0x00); assert_eq!(&pkt.raw[19..], b"hello"); }
315
316 #[test]
317 fn test_pack_header2() {
318 let dest_hash = [0xAA; 16];
319 let transport_id = [0xBB; 16];
320 let data = b"world";
321 let flags = PacketFlags {
322 header_type: constants::HEADER_2,
323 context_flag: constants::FLAG_UNSET,
324 transport_type: constants::TRANSPORT_TRANSPORT,
325 destination_type: constants::DESTINATION_SINGLE,
326 packet_type: constants::PACKET_TYPE_ANNOUNCE,
327 };
328
329 let pkt = RawPacket::pack(
330 flags,
331 3,
332 &dest_hash,
333 Some(&transport_id),
334 constants::CONTEXT_NONE,
335 data,
336 )
337 .unwrap();
338
339 assert_eq!(pkt.raw.len(), 40);
341 assert_eq!(pkt.raw[0], flags.pack());
342 assert_eq!(pkt.raw[1], 3);
343 assert_eq!(&pkt.raw[2..18], &transport_id);
344 assert_eq!(&pkt.raw[18..34], &dest_hash);
345 assert_eq!(pkt.raw[34], 0x00);
346 assert_eq!(&pkt.raw[35..], b"world");
347 }
348
349 #[test]
350 fn test_unpack_roundtrip_header1() {
351 let dest_hash = [0x11; 16];
352 let data = b"test data";
353 let flags = PacketFlags {
354 header_type: constants::HEADER_1,
355 context_flag: constants::FLAG_UNSET,
356 transport_type: constants::TRANSPORT_BROADCAST,
357 destination_type: constants::DESTINATION_SINGLE,
358 packet_type: constants::PACKET_TYPE_DATA,
359 };
360
361 let pkt = RawPacket::pack(
362 flags,
363 5,
364 &dest_hash,
365 None,
366 constants::CONTEXT_RESOURCE,
367 data,
368 )
369 .unwrap();
370 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
371
372 assert_eq!(unpacked.flags, flags);
373 assert_eq!(unpacked.hops, 5);
374 assert!(unpacked.transport_id.is_none());
375 assert_eq!(unpacked.destination_hash, dest_hash);
376 assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
377 assert_eq!(unpacked.data, data);
378 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
379 }
380
381 #[test]
382 fn test_unpack_roundtrip_header2() {
383 let dest_hash = [0x22; 16];
384 let transport_id = [0x33; 16];
385 let data = b"transported";
386 let flags = PacketFlags {
387 header_type: constants::HEADER_2,
388 context_flag: constants::FLAG_SET,
389 transport_type: constants::TRANSPORT_TRANSPORT,
390 destination_type: constants::DESTINATION_SINGLE,
391 packet_type: constants::PACKET_TYPE_ANNOUNCE,
392 };
393
394 let pkt = RawPacket::pack(
395 flags,
396 2,
397 &dest_hash,
398 Some(&transport_id),
399 constants::CONTEXT_NONE,
400 data,
401 )
402 .unwrap();
403 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
404
405 assert_eq!(unpacked.flags, flags);
406 assert_eq!(unpacked.hops, 2);
407 assert_eq!(unpacked.transport_id.unwrap(), transport_id);
408 assert_eq!(unpacked.destination_hash, dest_hash);
409 assert_eq!(unpacked.context, constants::CONTEXT_NONE);
410 assert_eq!(unpacked.data, data);
411 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
412 }
413
414 #[test]
415 fn test_unpack_too_short() {
416 assert!(RawPacket::unpack(&[0x00; 5]).is_err());
417 }
418
419 #[test]
420 fn test_pack_exceeds_mtu() {
421 let flags = PacketFlags {
422 header_type: constants::HEADER_1,
423 context_flag: constants::FLAG_UNSET,
424 transport_type: constants::TRANSPORT_BROADCAST,
425 destination_type: constants::DESTINATION_SINGLE,
426 packet_type: constants::PACKET_TYPE_DATA,
427 };
428 let data = [0u8; 500]; let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
430 assert!(result.is_err());
431 }
432
433 #[test]
434 fn test_header2_missing_transport_id() {
435 let flags = PacketFlags {
436 header_type: constants::HEADER_2,
437 context_flag: constants::FLAG_UNSET,
438 transport_type: constants::TRANSPORT_TRANSPORT,
439 destination_type: constants::DESTINATION_SINGLE,
440 packet_type: constants::PACKET_TYPE_ANNOUNCE,
441 };
442 let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
443 assert!(result.is_err());
444 }
445
446 #[test]
447 fn test_hashable_part_header1_masks_upper_flags() {
448 let dest_hash = [0xCC; 16];
449 let flags = PacketFlags {
450 header_type: constants::HEADER_1,
451 context_flag: constants::FLAG_SET,
452 transport_type: constants::TRANSPORT_BROADCAST,
453 destination_type: constants::DESTINATION_SINGLE,
454 packet_type: constants::PACKET_TYPE_DATA,
455 };
456
457 let pkt =
458 RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
459 let hashable = pkt.get_hashable_part();
460
461 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
463 assert_eq!(&hashable[1..], &pkt.raw[2..]);
465 }
466
467 #[test]
468 fn test_hashable_part_header2_strips_transport_id() {
469 let dest_hash = [0xDD; 16];
470 let transport_id = [0xEE; 16];
471 let flags = PacketFlags {
472 header_type: constants::HEADER_2,
473 context_flag: constants::FLAG_UNSET,
474 transport_type: constants::TRANSPORT_TRANSPORT,
475 destination_type: constants::DESTINATION_SINGLE,
476 packet_type: constants::PACKET_TYPE_ANNOUNCE,
477 };
478
479 let pkt = RawPacket::pack(
480 flags,
481 0,
482 &dest_hash,
483 Some(&transport_id),
484 constants::CONTEXT_NONE,
485 b"data",
486 )
487 .unwrap();
488 let hashable = pkt.get_hashable_part();
489
490 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
492 assert_eq!(&hashable[1..], &pkt.raw[18..]);
494 }
495}