Skip to main content

rns_core/
packet.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9    TooShort,
10    ExceedsMtu,
11    MissingTransportId,
12    InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            PacketError::TooShort => write!(f, "Packet too short"),
19            PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20            PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21            PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22        }
23    }
24}
25
26// =============================================================================
27// PacketFlags: packs 5 fields into one byte
28// =============================================================================
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32    pub header_type: u8,
33    pub context_flag: u8,
34    pub transport_type: u8,
35    pub destination_type: u8,
36    pub packet_type: u8,
37}
38
39impl PacketFlags {
40    /// Pack fields into a single flags byte.
41    ///
42    /// Bit layout:
43    /// ```text
44    /// Bit 6:     header_type (1 bit)
45    /// Bit 5:     context_flag (1 bit)
46    /// Bit 4:     transport_type (1 bit)
47    /// Bits 3-2:  destination_type (2 bits)
48    /// Bits 1-0:  packet_type (2 bits)
49    /// ```
50    pub fn pack(&self) -> u8 {
51        (self.header_type << 6)
52            | (self.context_flag << 5)
53            | (self.transport_type << 4)
54            | (self.destination_type << 2)
55            | self.packet_type
56    }
57
58    /// Unpack a flags byte into fields.
59    pub fn unpack(byte: u8) -> Self {
60        PacketFlags {
61            header_type: (byte & 0b01000000) >> 6,
62            context_flag: (byte & 0b00100000) >> 5,
63            transport_type: (byte & 0b00010000) >> 4,
64            destination_type: (byte & 0b00001100) >> 2,
65            packet_type: byte & 0b00000011,
66        }
67    }
68}
69
70// =============================================================================
71// RawPacket: wire-level packet representation
72// =============================================================================
73
74#[derive(Debug, Clone)]
75pub struct RawPacket {
76    pub flags: PacketFlags,
77    pub hops: u8,
78    pub transport_id: Option<[u8; 16]>,
79    pub destination_hash: [u8; 16],
80    pub context: u8,
81    pub data: Vec<u8>,
82    pub raw: Vec<u8>,
83    pub packet_hash: [u8; 32],
84    pub rssi: Option<i16>,
85    pub snr: Option<f32>,
86}
87
88impl RawPacket {
89    /// Pack fields into raw bytes.
90    pub fn pack(
91        flags: PacketFlags,
92        hops: u8,
93        destination_hash: &[u8; 16],
94        transport_id: Option<&[u8; 16]>,
95        context: u8,
96        data: &[u8],
97    ) -> Result<Self, PacketError> {
98        Self::pack_with_max_mtu(
99            flags,
100            hops,
101            destination_hash,
102            transport_id,
103            context,
104            data,
105            constants::MTU,
106        )
107    }
108
109    /// Pack fields into raw bytes and packet hash without constructing a full RawPacket.
110    pub fn pack_raw_with_hash(
111        flags: PacketFlags,
112        hops: u8,
113        destination_hash: &[u8; 16],
114        transport_id: Option<&[u8; 16]>,
115        context: u8,
116        data: &[u8],
117    ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
118        Self::pack_raw_with_hash_with_max_mtu(
119            flags,
120            hops,
121            destination_hash,
122            transport_id,
123            context,
124            data,
125            constants::MTU,
126        )
127    }
128
129    /// Pack fields into raw bytes with a caller-provided MTU limit.
130    pub fn pack_with_max_mtu(
131        flags: PacketFlags,
132        hops: u8,
133        destination_hash: &[u8; 16],
134        transport_id: Option<&[u8; 16]>,
135        context: u8,
136        data: &[u8],
137        max_mtu: usize,
138    ) -> Result<Self, PacketError> {
139        let (raw, packet_hash) = Self::pack_raw_with_hash_with_max_mtu(
140            flags,
141            hops,
142            destination_hash,
143            transport_id,
144            context,
145            data,
146            max_mtu,
147        )?;
148
149        Ok(RawPacket {
150            flags,
151            hops,
152            transport_id: transport_id.copied(),
153            destination_hash: *destination_hash,
154            context,
155            data: data.to_vec(),
156            raw,
157            packet_hash,
158            rssi: None,
159            snr: None,
160        })
161    }
162
163    /// Pack fields into raw bytes and packet hash with a caller-provided MTU limit.
164    pub fn pack_raw_with_hash_with_max_mtu(
165        flags: PacketFlags,
166        hops: u8,
167        destination_hash: &[u8; 16],
168        transport_id: Option<&[u8; 16]>,
169        context: u8,
170        data: &[u8],
171        max_mtu: usize,
172    ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
173        if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
174            return Err(PacketError::MissingTransportId);
175        }
176
177        let mut raw = Vec::new();
178        raw.push(flags.pack());
179        raw.push(hops);
180
181        if let Some(transport_id) = transport_id {
182            if flags.header_type == constants::HEADER_2 {
183                raw.extend_from_slice(transport_id);
184            }
185        }
186
187        raw.extend_from_slice(destination_hash);
188        raw.push(context);
189        raw.extend_from_slice(data);
190
191        if raw.len() > max_mtu {
192            return Err(PacketError::ExceedsMtu);
193        }
194
195        let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
196        Ok((raw, packet_hash))
197    }
198
199    /// Unpack raw bytes into fields.
200    pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
201        if raw.len() < constants::HEADER_MINSIZE {
202            return Err(PacketError::TooShort);
203        }
204
205        let flags = PacketFlags::unpack(raw[0]);
206        let hops = raw[1];
207
208        let dst_len = constants::TRUNCATED_HASHLENGTH / 8; // 16
209
210        if flags.header_type == constants::HEADER_2 {
211            // HEADER_2: [flags:1][hops:1][transport_id:16][dest_hash:16][context:1][data:*]
212            let min_len = 2 + dst_len * 2 + 1;
213            if raw.len() < min_len {
214                return Err(PacketError::TooShort);
215            }
216
217            let mut transport_id = [0u8; 16];
218            transport_id.copy_from_slice(&raw[2..2 + dst_len]);
219
220            let mut destination_hash = [0u8; 16];
221            destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
222
223            let context = raw[2 + 2 * dst_len];
224            let data = raw[2 + 2 * dst_len + 1..].to_vec();
225
226            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
227
228            Ok(RawPacket {
229                flags,
230                hops,
231                transport_id: Some(transport_id),
232                destination_hash,
233                context,
234                data,
235                raw: raw.to_vec(),
236                packet_hash,
237                rssi: None,
238                snr: None,
239            })
240        } else if flags.header_type == constants::HEADER_1 {
241            // HEADER_1: [flags:1][hops:1][dest_hash:16][context:1][data:*]
242            let min_len = 2 + dst_len + 1;
243            if raw.len() < min_len {
244                return Err(PacketError::TooShort);
245            }
246
247            let mut destination_hash = [0u8; 16];
248            destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
249
250            let context = raw[2 + dst_len];
251            let data = raw[2 + dst_len + 1..].to_vec();
252
253            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
254
255            Ok(RawPacket {
256                flags,
257                hops,
258                transport_id: None,
259                destination_hash,
260                context,
261                data,
262                raw: raw.to_vec(),
263                packet_hash,
264                rssi: None,
265                snr: None,
266            })
267        } else {
268            Err(PacketError::InvalidHeaderType)
269        }
270    }
271
272    /// Get the hashable part of the packet.
273    ///
274    /// From Python Packet.py:354-361:
275    /// - Take raw[0] & 0x0F (mask out upper 4 bits of flags)
276    /// - For HEADER_1: append raw[2:]
277    /// - For HEADER_2: skip transport_id: append raw[18:]
278    pub fn get_hashable_part(&self) -> Vec<u8> {
279        Self::compute_hashable_part(self.flags.header_type, &self.raw)
280    }
281
282    fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
283        let mut hashable = Vec::new();
284        hashable.push(raw[0] & 0b00001111);
285        if header_type == constants::HEADER_2 {
286            // Skip transport_id: raw[2..18] is transport_id (16 bytes)
287            hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
288        } else {
289            hashable.extend_from_slice(&raw[2..]);
290        }
291        hashable
292    }
293
294    /// Full SHA-256 hash of the hashable part.
295    pub fn get_hash(&self) -> [u8; 32] {
296        self.packet_hash
297    }
298
299    /// Truncated hash (first 16 bytes) of the hashable part.
300    pub fn get_truncated_hash(&self) -> [u8; 16] {
301        let mut result = [0u8; 16];
302        result.copy_from_slice(&self.packet_hash[..16]);
303        result
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_flags_pack_header1_data_single_broadcast() {
313        let flags = PacketFlags {
314            header_type: constants::HEADER_1,
315            context_flag: constants::FLAG_UNSET,
316            transport_type: constants::TRANSPORT_BROADCAST,
317            destination_type: constants::DESTINATION_SINGLE,
318            packet_type: constants::PACKET_TYPE_DATA,
319        };
320        assert_eq!(flags.pack(), 0x00);
321    }
322
323    #[test]
324    fn test_flags_pack_header2_announce_single_transport() {
325        let flags = PacketFlags {
326            header_type: constants::HEADER_2,
327            context_flag: constants::FLAG_UNSET,
328            transport_type: constants::TRANSPORT_TRANSPORT,
329            destination_type: constants::DESTINATION_SINGLE,
330            packet_type: constants::PACKET_TYPE_ANNOUNCE,
331        };
332        // 0b01010001 = 0x51
333        assert_eq!(flags.pack(), 0x51);
334    }
335
336    #[test]
337    fn test_flags_roundtrip() {
338        for byte in 0..=0x7Fu8 {
339            let flags = PacketFlags::unpack(byte);
340            assert_eq!(flags.pack(), byte);
341        }
342    }
343
344    #[test]
345    fn test_pack_header1() {
346        let dest_hash = [0xAA; 16];
347        let data = b"hello";
348        let flags = PacketFlags {
349            header_type: constants::HEADER_1,
350            context_flag: constants::FLAG_UNSET,
351            transport_type: constants::TRANSPORT_BROADCAST,
352            destination_type: constants::DESTINATION_SINGLE,
353            packet_type: constants::PACKET_TYPE_DATA,
354        };
355
356        let pkt =
357            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
358
359        // Verify layout: [flags:1][hops:1][dest:16][context:1][data:5] = 24 bytes
360        assert_eq!(pkt.raw.len(), 24);
361        assert_eq!(pkt.raw[0], 0x00); // flags
362        assert_eq!(pkt.raw[1], 0x00); // hops
363        assert_eq!(&pkt.raw[2..18], &dest_hash); // dest hash
364        assert_eq!(pkt.raw[18], 0x00); // context
365        assert_eq!(&pkt.raw[19..], b"hello"); // data
366    }
367
368    #[test]
369    fn test_pack_header2() {
370        let dest_hash = [0xAA; 16];
371        let transport_id = [0xBB; 16];
372        let data = b"world";
373        let flags = PacketFlags {
374            header_type: constants::HEADER_2,
375            context_flag: constants::FLAG_UNSET,
376            transport_type: constants::TRANSPORT_TRANSPORT,
377            destination_type: constants::DESTINATION_SINGLE,
378            packet_type: constants::PACKET_TYPE_ANNOUNCE,
379        };
380
381        let pkt = RawPacket::pack(
382            flags,
383            3,
384            &dest_hash,
385            Some(&transport_id),
386            constants::CONTEXT_NONE,
387            data,
388        )
389        .unwrap();
390
391        // Layout: [flags:1][hops:1][transport:16][dest:16][context:1][data:5] = 40 bytes
392        assert_eq!(pkt.raw.len(), 40);
393        assert_eq!(pkt.raw[0], flags.pack());
394        assert_eq!(pkt.raw[1], 3);
395        assert_eq!(&pkt.raw[2..18], &transport_id);
396        assert_eq!(&pkt.raw[18..34], &dest_hash);
397        assert_eq!(pkt.raw[34], 0x00);
398        assert_eq!(&pkt.raw[35..], b"world");
399    }
400
401    #[test]
402    fn test_unpack_roundtrip_header1() {
403        let dest_hash = [0x11; 16];
404        let data = b"test data";
405        let flags = PacketFlags {
406            header_type: constants::HEADER_1,
407            context_flag: constants::FLAG_UNSET,
408            transport_type: constants::TRANSPORT_BROADCAST,
409            destination_type: constants::DESTINATION_SINGLE,
410            packet_type: constants::PACKET_TYPE_DATA,
411        };
412
413        let pkt = RawPacket::pack(
414            flags,
415            5,
416            &dest_hash,
417            None,
418            constants::CONTEXT_RESOURCE,
419            data,
420        )
421        .unwrap();
422        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
423
424        assert_eq!(unpacked.flags, flags);
425        assert_eq!(unpacked.hops, 5);
426        assert!(unpacked.transport_id.is_none());
427        assert_eq!(unpacked.destination_hash, dest_hash);
428        assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
429        assert_eq!(unpacked.data, data);
430        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
431    }
432
433    #[test]
434    fn test_unpack_roundtrip_header2() {
435        let dest_hash = [0x22; 16];
436        let transport_id = [0x33; 16];
437        let data = b"transported";
438        let flags = PacketFlags {
439            header_type: constants::HEADER_2,
440            context_flag: constants::FLAG_SET,
441            transport_type: constants::TRANSPORT_TRANSPORT,
442            destination_type: constants::DESTINATION_SINGLE,
443            packet_type: constants::PACKET_TYPE_ANNOUNCE,
444        };
445
446        let pkt = RawPacket::pack(
447            flags,
448            2,
449            &dest_hash,
450            Some(&transport_id),
451            constants::CONTEXT_NONE,
452            data,
453        )
454        .unwrap();
455        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
456
457        assert_eq!(unpacked.flags, flags);
458        assert_eq!(unpacked.hops, 2);
459        assert_eq!(unpacked.transport_id.unwrap(), transport_id);
460        assert_eq!(unpacked.destination_hash, dest_hash);
461        assert_eq!(unpacked.context, constants::CONTEXT_NONE);
462        assert_eq!(unpacked.data, data);
463        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
464    }
465
466    #[test]
467    fn test_unpack_too_short() {
468        assert!(RawPacket::unpack(&[0x00; 5]).is_err());
469    }
470
471    #[test]
472    fn test_pack_exceeds_mtu() {
473        let flags = PacketFlags {
474            header_type: constants::HEADER_1,
475            context_flag: constants::FLAG_UNSET,
476            transport_type: constants::TRANSPORT_BROADCAST,
477            destination_type: constants::DESTINATION_SINGLE,
478            packet_type: constants::PACKET_TYPE_DATA,
479        };
480        let data = [0u8; 500]; // way too much data
481        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_header2_missing_transport_id() {
487        let flags = PacketFlags {
488            header_type: constants::HEADER_2,
489            context_flag: constants::FLAG_UNSET,
490            transport_type: constants::TRANSPORT_TRANSPORT,
491            destination_type: constants::DESTINATION_SINGLE,
492            packet_type: constants::PACKET_TYPE_ANNOUNCE,
493        };
494        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
495        assert!(result.is_err());
496    }
497
498    #[test]
499    fn test_hashable_part_header1_masks_upper_flags() {
500        let dest_hash = [0xCC; 16];
501        let flags = PacketFlags {
502            header_type: constants::HEADER_1,
503            context_flag: constants::FLAG_SET,
504            transport_type: constants::TRANSPORT_BROADCAST,
505            destination_type: constants::DESTINATION_SINGLE,
506            packet_type: constants::PACKET_TYPE_DATA,
507        };
508
509        let pkt =
510            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
511        let hashable = pkt.get_hashable_part();
512
513        // First byte should have upper 4 bits masked out
514        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
515        // Rest should be raw[2:]
516        assert_eq!(&hashable[1..], &pkt.raw[2..]);
517    }
518
519    #[test]
520    fn test_hashable_part_header2_strips_transport_id() {
521        let dest_hash = [0xDD; 16];
522        let transport_id = [0xEE; 16];
523        let flags = PacketFlags {
524            header_type: constants::HEADER_2,
525            context_flag: constants::FLAG_UNSET,
526            transport_type: constants::TRANSPORT_TRANSPORT,
527            destination_type: constants::DESTINATION_SINGLE,
528            packet_type: constants::PACKET_TYPE_ANNOUNCE,
529        };
530
531        let pkt = RawPacket::pack(
532            flags,
533            0,
534            &dest_hash,
535            Some(&transport_id),
536            constants::CONTEXT_NONE,
537            b"data",
538        )
539        .unwrap();
540        let hashable = pkt.get_hashable_part();
541
542        // First byte: flags masked
543        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
544        // Should skip transport_id: raw[18:] = dest_hash + context + data
545        assert_eq!(&hashable[1..], &pkt.raw[18..]);
546    }
547}