1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9 TooShort,
10 ExceedsMtu,
11 MissingTransportId,
12 InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 match self {
18 PacketError::TooShort => write!(f, "Packet too short"),
19 PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20 PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21 PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22 }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32 pub header_type: u8,
33 pub context_flag: u8,
34 pub transport_type: u8,
35 pub destination_type: u8,
36 pub packet_type: u8,
37}
38
39impl PacketFlags {
40 pub fn pack(&self) -> u8 {
51 (self.header_type << 6)
52 | (self.context_flag << 5)
53 | (self.transport_type << 4)
54 | (self.destination_type << 2)
55 | self.packet_type
56 }
57
58 pub fn unpack(byte: u8) -> Self {
60 PacketFlags {
61 header_type: (byte & 0b01000000) >> 6,
62 context_flag: (byte & 0b00100000) >> 5,
63 transport_type: (byte & 0b00010000) >> 4,
64 destination_type: (byte & 0b00001100) >> 2,
65 packet_type: byte & 0b00000011,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
75pub struct RawPacket {
76 pub flags: PacketFlags,
77 pub hops: u8,
78 pub transport_id: Option<[u8; 16]>,
79 pub destination_hash: [u8; 16],
80 pub context: u8,
81 pub data: Vec<u8>,
82 pub raw: Vec<u8>,
83 pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87 pub fn pack(
89 flags: PacketFlags,
90 hops: u8,
91 destination_hash: &[u8; 16],
92 transport_id: Option<&[u8; 16]>,
93 context: u8,
94 data: &[u8],
95 ) -> Result<Self, PacketError> {
96 if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
97 return Err(PacketError::MissingTransportId);
98 }
99
100 let mut raw = Vec::new();
101 raw.push(flags.pack());
102 raw.push(hops);
103
104 if flags.header_type == constants::HEADER_2 {
105 raw.extend_from_slice(transport_id.unwrap());
106 }
107
108 raw.extend_from_slice(destination_hash);
109 raw.push(context);
110 raw.extend_from_slice(data);
111
112 if raw.len() > constants::MTU {
113 return Err(PacketError::ExceedsMtu);
114 }
115
116 let packet_hash = hash::full_hash(&Self::compute_hashable_part(
117 flags.header_type,
118 &raw,
119 ));
120
121 Ok(RawPacket {
122 flags,
123 hops,
124 transport_id: transport_id.copied(),
125 destination_hash: *destination_hash,
126 context,
127 data: data.to_vec(),
128 raw,
129 packet_hash,
130 })
131 }
132
133 pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
135 if raw.len() < constants::HEADER_MINSIZE {
136 return Err(PacketError::TooShort);
137 }
138
139 let flags = PacketFlags::unpack(raw[0]);
140 let hops = raw[1];
141
142 let dst_len = constants::TRUNCATED_HASHLENGTH / 8; if flags.header_type == constants::HEADER_2 {
145 let min_len = 2 + dst_len * 2 + 1;
147 if raw.len() < min_len {
148 return Err(PacketError::TooShort);
149 }
150
151 let mut transport_id = [0u8; 16];
152 transport_id.copy_from_slice(&raw[2..2 + dst_len]);
153
154 let mut destination_hash = [0u8; 16];
155 destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
156
157 let context = raw[2 + 2 * dst_len];
158 let data = raw[2 + 2 * dst_len + 1..].to_vec();
159
160 let packet_hash = hash::full_hash(&Self::compute_hashable_part(
161 flags.header_type,
162 raw,
163 ));
164
165 Ok(RawPacket {
166 flags,
167 hops,
168 transport_id: Some(transport_id),
169 destination_hash,
170 context,
171 data,
172 raw: raw.to_vec(),
173 packet_hash,
174 })
175 } else if flags.header_type == constants::HEADER_1 {
176 let min_len = 2 + dst_len + 1;
178 if raw.len() < min_len {
179 return Err(PacketError::TooShort);
180 }
181
182 let mut destination_hash = [0u8; 16];
183 destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
184
185 let context = raw[2 + dst_len];
186 let data = raw[2 + dst_len + 1..].to_vec();
187
188 let packet_hash = hash::full_hash(&Self::compute_hashable_part(
189 flags.header_type,
190 raw,
191 ));
192
193 Ok(RawPacket {
194 flags,
195 hops,
196 transport_id: None,
197 destination_hash,
198 context,
199 data,
200 raw: raw.to_vec(),
201 packet_hash,
202 })
203 } else {
204 Err(PacketError::InvalidHeaderType)
205 }
206 }
207
208 pub fn get_hashable_part(&self) -> Vec<u8> {
215 Self::compute_hashable_part(self.flags.header_type, &self.raw)
216 }
217
218 fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
219 let mut hashable = Vec::new();
220 hashable.push(raw[0] & 0b00001111);
221 if header_type == constants::HEADER_2 {
222 hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
224 } else {
225 hashable.extend_from_slice(&raw[2..]);
226 }
227 hashable
228 }
229
230 pub fn get_hash(&self) -> [u8; 32] {
232 self.packet_hash
233 }
234
235 pub fn get_truncated_hash(&self) -> [u8; 16] {
237 let mut result = [0u8; 16];
238 result.copy_from_slice(&self.packet_hash[..16]);
239 result
240 }
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_flags_pack_header1_data_single_broadcast() {
249 let flags = PacketFlags {
250 header_type: constants::HEADER_1,
251 context_flag: constants::FLAG_UNSET,
252 transport_type: constants::TRANSPORT_BROADCAST,
253 destination_type: constants::DESTINATION_SINGLE,
254 packet_type: constants::PACKET_TYPE_DATA,
255 };
256 assert_eq!(flags.pack(), 0x00);
257 }
258
259 #[test]
260 fn test_flags_pack_header2_announce_single_transport() {
261 let flags = PacketFlags {
262 header_type: constants::HEADER_2,
263 context_flag: constants::FLAG_UNSET,
264 transport_type: constants::TRANSPORT_TRANSPORT,
265 destination_type: constants::DESTINATION_SINGLE,
266 packet_type: constants::PACKET_TYPE_ANNOUNCE,
267 };
268 assert_eq!(flags.pack(), 0x51);
270 }
271
272 #[test]
273 fn test_flags_roundtrip() {
274 for byte in 0..=0x7Fu8 {
275 let flags = PacketFlags::unpack(byte);
276 assert_eq!(flags.pack(), byte);
277 }
278 }
279
280 #[test]
281 fn test_pack_header1() {
282 let dest_hash = [0xAA; 16];
283 let data = b"hello";
284 let flags = PacketFlags {
285 header_type: constants::HEADER_1,
286 context_flag: constants::FLAG_UNSET,
287 transport_type: constants::TRANSPORT_BROADCAST,
288 destination_type: constants::DESTINATION_SINGLE,
289 packet_type: constants::PACKET_TYPE_DATA,
290 };
291
292 let pkt = RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
293
294 assert_eq!(pkt.raw.len(), 24);
296 assert_eq!(pkt.raw[0], 0x00); assert_eq!(pkt.raw[1], 0x00); assert_eq!(&pkt.raw[2..18], &dest_hash); assert_eq!(pkt.raw[18], 0x00); assert_eq!(&pkt.raw[19..], b"hello"); }
302
303 #[test]
304 fn test_pack_header2() {
305 let dest_hash = [0xAA; 16];
306 let transport_id = [0xBB; 16];
307 let data = b"world";
308 let flags = PacketFlags {
309 header_type: constants::HEADER_2,
310 context_flag: constants::FLAG_UNSET,
311 transport_type: constants::TRANSPORT_TRANSPORT,
312 destination_type: constants::DESTINATION_SINGLE,
313 packet_type: constants::PACKET_TYPE_ANNOUNCE,
314 };
315
316 let pkt = RawPacket::pack(flags, 3, &dest_hash, Some(&transport_id), constants::CONTEXT_NONE, data).unwrap();
317
318 assert_eq!(pkt.raw.len(), 40);
320 assert_eq!(pkt.raw[0], flags.pack());
321 assert_eq!(pkt.raw[1], 3);
322 assert_eq!(&pkt.raw[2..18], &transport_id);
323 assert_eq!(&pkt.raw[18..34], &dest_hash);
324 assert_eq!(pkt.raw[34], 0x00);
325 assert_eq!(&pkt.raw[35..], b"world");
326 }
327
328 #[test]
329 fn test_unpack_roundtrip_header1() {
330 let dest_hash = [0x11; 16];
331 let data = b"test data";
332 let flags = PacketFlags {
333 header_type: constants::HEADER_1,
334 context_flag: constants::FLAG_UNSET,
335 transport_type: constants::TRANSPORT_BROADCAST,
336 destination_type: constants::DESTINATION_SINGLE,
337 packet_type: constants::PACKET_TYPE_DATA,
338 };
339
340 let pkt = RawPacket::pack(flags, 5, &dest_hash, None, constants::CONTEXT_RESOURCE, data).unwrap();
341 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
342
343 assert_eq!(unpacked.flags, flags);
344 assert_eq!(unpacked.hops, 5);
345 assert!(unpacked.transport_id.is_none());
346 assert_eq!(unpacked.destination_hash, dest_hash);
347 assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
348 assert_eq!(unpacked.data, data);
349 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
350 }
351
352 #[test]
353 fn test_unpack_roundtrip_header2() {
354 let dest_hash = [0x22; 16];
355 let transport_id = [0x33; 16];
356 let data = b"transported";
357 let flags = PacketFlags {
358 header_type: constants::HEADER_2,
359 context_flag: constants::FLAG_SET,
360 transport_type: constants::TRANSPORT_TRANSPORT,
361 destination_type: constants::DESTINATION_SINGLE,
362 packet_type: constants::PACKET_TYPE_ANNOUNCE,
363 };
364
365 let pkt = RawPacket::pack(flags, 2, &dest_hash, Some(&transport_id), constants::CONTEXT_NONE, data).unwrap();
366 let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
367
368 assert_eq!(unpacked.flags, flags);
369 assert_eq!(unpacked.hops, 2);
370 assert_eq!(unpacked.transport_id.unwrap(), transport_id);
371 assert_eq!(unpacked.destination_hash, dest_hash);
372 assert_eq!(unpacked.context, constants::CONTEXT_NONE);
373 assert_eq!(unpacked.data, data);
374 assert_eq!(unpacked.packet_hash, pkt.packet_hash);
375 }
376
377 #[test]
378 fn test_unpack_too_short() {
379 assert!(RawPacket::unpack(&[0x00; 5]).is_err());
380 }
381
382 #[test]
383 fn test_pack_exceeds_mtu() {
384 let flags = PacketFlags {
385 header_type: constants::HEADER_1,
386 context_flag: constants::FLAG_UNSET,
387 transport_type: constants::TRANSPORT_BROADCAST,
388 destination_type: constants::DESTINATION_SINGLE,
389 packet_type: constants::PACKET_TYPE_DATA,
390 };
391 let data = [0u8; 500]; let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
393 assert!(result.is_err());
394 }
395
396 #[test]
397 fn test_header2_missing_transport_id() {
398 let flags = PacketFlags {
399 header_type: constants::HEADER_2,
400 context_flag: constants::FLAG_UNSET,
401 transport_type: constants::TRANSPORT_TRANSPORT,
402 destination_type: constants::DESTINATION_SINGLE,
403 packet_type: constants::PACKET_TYPE_ANNOUNCE,
404 };
405 let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
406 assert!(result.is_err());
407 }
408
409 #[test]
410 fn test_hashable_part_header1_masks_upper_flags() {
411 let dest_hash = [0xCC; 16];
412 let flags = PacketFlags {
413 header_type: constants::HEADER_1,
414 context_flag: constants::FLAG_SET,
415 transport_type: constants::TRANSPORT_BROADCAST,
416 destination_type: constants::DESTINATION_SINGLE,
417 packet_type: constants::PACKET_TYPE_DATA,
418 };
419
420 let pkt = RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
421 let hashable = pkt.get_hashable_part();
422
423 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
425 assert_eq!(&hashable[1..], &pkt.raw[2..]);
427 }
428
429 #[test]
430 fn test_hashable_part_header2_strips_transport_id() {
431 let dest_hash = [0xDD; 16];
432 let transport_id = [0xEE; 16];
433 let flags = PacketFlags {
434 header_type: constants::HEADER_2,
435 context_flag: constants::FLAG_UNSET,
436 transport_type: constants::TRANSPORT_TRANSPORT,
437 destination_type: constants::DESTINATION_SINGLE,
438 packet_type: constants::PACKET_TYPE_ANNOUNCE,
439 };
440
441 let pkt = RawPacket::pack(flags, 0, &dest_hash, Some(&transport_id), constants::CONTEXT_NONE, b"data").unwrap();
442 let hashable = pkt.get_hashable_part();
443
444 assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
446 assert_eq!(&hashable[1..], &pkt.raw[18..]);
448 }
449}