Skip to main content

rns_core/
packet.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9    TooShort,
10    ExceedsMtu,
11    MissingTransportId,
12    InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            PacketError::TooShort => write!(f, "Packet too short"),
19            PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20            PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21            PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22        }
23    }
24}
25
26// =============================================================================
27// PacketFlags: packs 5 fields into one byte
28// =============================================================================
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32    pub header_type: u8,
33    pub context_flag: u8,
34    pub transport_type: u8,
35    pub destination_type: u8,
36    pub packet_type: u8,
37}
38
39impl PacketFlags {
40    /// Pack fields into a single flags byte.
41    ///
42    /// Bit layout:
43    /// ```text
44    /// Bit 6:     header_type (1 bit)
45    /// Bit 5:     context_flag (1 bit)
46    /// Bit 4:     transport_type (1 bit)
47    /// Bits 3-2:  destination_type (2 bits)
48    /// Bits 1-0:  packet_type (2 bits)
49    /// ```
50    pub fn pack(&self) -> u8 {
51        (self.header_type << 6)
52            | (self.context_flag << 5)
53            | (self.transport_type << 4)
54            | (self.destination_type << 2)
55            | self.packet_type
56    }
57
58    /// Unpack a flags byte into fields.
59    pub fn unpack(byte: u8) -> Self {
60        PacketFlags {
61            header_type: (byte & 0b01000000) >> 6,
62            context_flag: (byte & 0b00100000) >> 5,
63            transport_type: (byte & 0b00010000) >> 4,
64            destination_type: (byte & 0b00001100) >> 2,
65            packet_type: byte & 0b00000011,
66        }
67    }
68}
69
70// =============================================================================
71// RawPacket: wire-level packet representation
72// =============================================================================
73
74#[derive(Debug, Clone)]
75pub struct RawPacket {
76    pub flags: PacketFlags,
77    pub hops: u8,
78    pub transport_id: Option<[u8; 16]>,
79    pub destination_hash: [u8; 16],
80    pub context: u8,
81    pub data: Vec<u8>,
82    pub raw: Vec<u8>,
83    pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87    /// Pack fields into raw bytes.
88    pub fn pack(
89        flags: PacketFlags,
90        hops: u8,
91        destination_hash: &[u8; 16],
92        transport_id: Option<&[u8; 16]>,
93        context: u8,
94        data: &[u8],
95    ) -> Result<Self, PacketError> {
96        if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
97            return Err(PacketError::MissingTransportId);
98        }
99
100        let mut raw = Vec::new();
101        raw.push(flags.pack());
102        raw.push(hops);
103
104        if flags.header_type == constants::HEADER_2 {
105            raw.extend_from_slice(transport_id.unwrap());
106        }
107
108        raw.extend_from_slice(destination_hash);
109        raw.push(context);
110        raw.extend_from_slice(data);
111
112        if raw.len() > constants::MTU {
113            return Err(PacketError::ExceedsMtu);
114        }
115
116        let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
117
118        Ok(RawPacket {
119            flags,
120            hops,
121            transport_id: transport_id.copied(),
122            destination_hash: *destination_hash,
123            context,
124            data: data.to_vec(),
125            raw,
126            packet_hash,
127        })
128    }
129
130    /// Unpack raw bytes into fields.
131    pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
132        if raw.len() < constants::HEADER_MINSIZE {
133            return Err(PacketError::TooShort);
134        }
135
136        let flags = PacketFlags::unpack(raw[0]);
137        let hops = raw[1];
138
139        let dst_len = constants::TRUNCATED_HASHLENGTH / 8; // 16
140
141        if flags.header_type == constants::HEADER_2 {
142            // HEADER_2: [flags:1][hops:1][transport_id:16][dest_hash:16][context:1][data:*]
143            let min_len = 2 + dst_len * 2 + 1;
144            if raw.len() < min_len {
145                return Err(PacketError::TooShort);
146            }
147
148            let mut transport_id = [0u8; 16];
149            transport_id.copy_from_slice(&raw[2..2 + dst_len]);
150
151            let mut destination_hash = [0u8; 16];
152            destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
153
154            let context = raw[2 + 2 * dst_len];
155            let data = raw[2 + 2 * dst_len + 1..].to_vec();
156
157            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
158
159            Ok(RawPacket {
160                flags,
161                hops,
162                transport_id: Some(transport_id),
163                destination_hash,
164                context,
165                data,
166                raw: raw.to_vec(),
167                packet_hash,
168            })
169        } else if flags.header_type == constants::HEADER_1 {
170            // HEADER_1: [flags:1][hops:1][dest_hash:16][context:1][data:*]
171            let min_len = 2 + dst_len + 1;
172            if raw.len() < min_len {
173                return Err(PacketError::TooShort);
174            }
175
176            let mut destination_hash = [0u8; 16];
177            destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
178
179            let context = raw[2 + dst_len];
180            let data = raw[2 + dst_len + 1..].to_vec();
181
182            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
183
184            Ok(RawPacket {
185                flags,
186                hops,
187                transport_id: None,
188                destination_hash,
189                context,
190                data,
191                raw: raw.to_vec(),
192                packet_hash,
193            })
194        } else {
195            Err(PacketError::InvalidHeaderType)
196        }
197    }
198
199    /// Get the hashable part of the packet.
200    ///
201    /// From Python Packet.py:354-361:
202    /// - Take raw[0] & 0x0F (mask out upper 4 bits of flags)
203    /// - For HEADER_1: append raw[2:]
204    /// - For HEADER_2: skip transport_id: append raw[18:]
205    pub fn get_hashable_part(&self) -> Vec<u8> {
206        Self::compute_hashable_part(self.flags.header_type, &self.raw)
207    }
208
209    fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
210        let mut hashable = Vec::new();
211        hashable.push(raw[0] & 0b00001111);
212        if header_type == constants::HEADER_2 {
213            // Skip transport_id: raw[2..18] is transport_id (16 bytes)
214            hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
215        } else {
216            hashable.extend_from_slice(&raw[2..]);
217        }
218        hashable
219    }
220
221    /// Full SHA-256 hash of the hashable part.
222    pub fn get_hash(&self) -> [u8; 32] {
223        self.packet_hash
224    }
225
226    /// Truncated hash (first 16 bytes) of the hashable part.
227    pub fn get_truncated_hash(&self) -> [u8; 16] {
228        let mut result = [0u8; 16];
229        result.copy_from_slice(&self.packet_hash[..16]);
230        result
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_flags_pack_header1_data_single_broadcast() {
240        let flags = PacketFlags {
241            header_type: constants::HEADER_1,
242            context_flag: constants::FLAG_UNSET,
243            transport_type: constants::TRANSPORT_BROADCAST,
244            destination_type: constants::DESTINATION_SINGLE,
245            packet_type: constants::PACKET_TYPE_DATA,
246        };
247        assert_eq!(flags.pack(), 0x00);
248    }
249
250    #[test]
251    fn test_flags_pack_header2_announce_single_transport() {
252        let flags = PacketFlags {
253            header_type: constants::HEADER_2,
254            context_flag: constants::FLAG_UNSET,
255            transport_type: constants::TRANSPORT_TRANSPORT,
256            destination_type: constants::DESTINATION_SINGLE,
257            packet_type: constants::PACKET_TYPE_ANNOUNCE,
258        };
259        // 0b01010001 = 0x51
260        assert_eq!(flags.pack(), 0x51);
261    }
262
263    #[test]
264    fn test_flags_roundtrip() {
265        for byte in 0..=0x7Fu8 {
266            let flags = PacketFlags::unpack(byte);
267            assert_eq!(flags.pack(), byte);
268        }
269    }
270
271    #[test]
272    fn test_pack_header1() {
273        let dest_hash = [0xAA; 16];
274        let data = b"hello";
275        let flags = PacketFlags {
276            header_type: constants::HEADER_1,
277            context_flag: constants::FLAG_UNSET,
278            transport_type: constants::TRANSPORT_BROADCAST,
279            destination_type: constants::DESTINATION_SINGLE,
280            packet_type: constants::PACKET_TYPE_DATA,
281        };
282
283        let pkt =
284            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
285
286        // Verify layout: [flags:1][hops:1][dest:16][context:1][data:5] = 24 bytes
287        assert_eq!(pkt.raw.len(), 24);
288        assert_eq!(pkt.raw[0], 0x00); // flags
289        assert_eq!(pkt.raw[1], 0x00); // hops
290        assert_eq!(&pkt.raw[2..18], &dest_hash); // dest hash
291        assert_eq!(pkt.raw[18], 0x00); // context
292        assert_eq!(&pkt.raw[19..], b"hello"); // data
293    }
294
295    #[test]
296    fn test_pack_header2() {
297        let dest_hash = [0xAA; 16];
298        let transport_id = [0xBB; 16];
299        let data = b"world";
300        let flags = PacketFlags {
301            header_type: constants::HEADER_2,
302            context_flag: constants::FLAG_UNSET,
303            transport_type: constants::TRANSPORT_TRANSPORT,
304            destination_type: constants::DESTINATION_SINGLE,
305            packet_type: constants::PACKET_TYPE_ANNOUNCE,
306        };
307
308        let pkt = RawPacket::pack(
309            flags,
310            3,
311            &dest_hash,
312            Some(&transport_id),
313            constants::CONTEXT_NONE,
314            data,
315        )
316        .unwrap();
317
318        // Layout: [flags:1][hops:1][transport:16][dest:16][context:1][data:5] = 40 bytes
319        assert_eq!(pkt.raw.len(), 40);
320        assert_eq!(pkt.raw[0], flags.pack());
321        assert_eq!(pkt.raw[1], 3);
322        assert_eq!(&pkt.raw[2..18], &transport_id);
323        assert_eq!(&pkt.raw[18..34], &dest_hash);
324        assert_eq!(pkt.raw[34], 0x00);
325        assert_eq!(&pkt.raw[35..], b"world");
326    }
327
328    #[test]
329    fn test_unpack_roundtrip_header1() {
330        let dest_hash = [0x11; 16];
331        let data = b"test data";
332        let flags = PacketFlags {
333            header_type: constants::HEADER_1,
334            context_flag: constants::FLAG_UNSET,
335            transport_type: constants::TRANSPORT_BROADCAST,
336            destination_type: constants::DESTINATION_SINGLE,
337            packet_type: constants::PACKET_TYPE_DATA,
338        };
339
340        let pkt = RawPacket::pack(
341            flags,
342            5,
343            &dest_hash,
344            None,
345            constants::CONTEXT_RESOURCE,
346            data,
347        )
348        .unwrap();
349        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
350
351        assert_eq!(unpacked.flags, flags);
352        assert_eq!(unpacked.hops, 5);
353        assert!(unpacked.transport_id.is_none());
354        assert_eq!(unpacked.destination_hash, dest_hash);
355        assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
356        assert_eq!(unpacked.data, data);
357        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
358    }
359
360    #[test]
361    fn test_unpack_roundtrip_header2() {
362        let dest_hash = [0x22; 16];
363        let transport_id = [0x33; 16];
364        let data = b"transported";
365        let flags = PacketFlags {
366            header_type: constants::HEADER_2,
367            context_flag: constants::FLAG_SET,
368            transport_type: constants::TRANSPORT_TRANSPORT,
369            destination_type: constants::DESTINATION_SINGLE,
370            packet_type: constants::PACKET_TYPE_ANNOUNCE,
371        };
372
373        let pkt = RawPacket::pack(
374            flags,
375            2,
376            &dest_hash,
377            Some(&transport_id),
378            constants::CONTEXT_NONE,
379            data,
380        )
381        .unwrap();
382        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
383
384        assert_eq!(unpacked.flags, flags);
385        assert_eq!(unpacked.hops, 2);
386        assert_eq!(unpacked.transport_id.unwrap(), transport_id);
387        assert_eq!(unpacked.destination_hash, dest_hash);
388        assert_eq!(unpacked.context, constants::CONTEXT_NONE);
389        assert_eq!(unpacked.data, data);
390        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
391    }
392
393    #[test]
394    fn test_unpack_too_short() {
395        assert!(RawPacket::unpack(&[0x00; 5]).is_err());
396    }
397
398    #[test]
399    fn test_pack_exceeds_mtu() {
400        let flags = PacketFlags {
401            header_type: constants::HEADER_1,
402            context_flag: constants::FLAG_UNSET,
403            transport_type: constants::TRANSPORT_BROADCAST,
404            destination_type: constants::DESTINATION_SINGLE,
405            packet_type: constants::PACKET_TYPE_DATA,
406        };
407        let data = [0u8; 500]; // way too much data
408        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_header2_missing_transport_id() {
414        let flags = PacketFlags {
415            header_type: constants::HEADER_2,
416            context_flag: constants::FLAG_UNSET,
417            transport_type: constants::TRANSPORT_TRANSPORT,
418            destination_type: constants::DESTINATION_SINGLE,
419            packet_type: constants::PACKET_TYPE_ANNOUNCE,
420        };
421        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
422        assert!(result.is_err());
423    }
424
425    #[test]
426    fn test_hashable_part_header1_masks_upper_flags() {
427        let dest_hash = [0xCC; 16];
428        let flags = PacketFlags {
429            header_type: constants::HEADER_1,
430            context_flag: constants::FLAG_SET,
431            transport_type: constants::TRANSPORT_BROADCAST,
432            destination_type: constants::DESTINATION_SINGLE,
433            packet_type: constants::PACKET_TYPE_DATA,
434        };
435
436        let pkt =
437            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
438        let hashable = pkt.get_hashable_part();
439
440        // First byte should have upper 4 bits masked out
441        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
442        // Rest should be raw[2:]
443        assert_eq!(&hashable[1..], &pkt.raw[2..]);
444    }
445
446    #[test]
447    fn test_hashable_part_header2_strips_transport_id() {
448        let dest_hash = [0xDD; 16];
449        let transport_id = [0xEE; 16];
450        let flags = PacketFlags {
451            header_type: constants::HEADER_2,
452            context_flag: constants::FLAG_UNSET,
453            transport_type: constants::TRANSPORT_TRANSPORT,
454            destination_type: constants::DESTINATION_SINGLE,
455            packet_type: constants::PACKET_TYPE_ANNOUNCE,
456        };
457
458        let pkt = RawPacket::pack(
459            flags,
460            0,
461            &dest_hash,
462            Some(&transport_id),
463            constants::CONTEXT_NONE,
464            b"data",
465        )
466        .unwrap();
467        let hashable = pkt.get_hashable_part();
468
469        // First byte: flags masked
470        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
471        // Should skip transport_id: raw[18:] = dest_hash + context + data
472        assert_eq!(&hashable[1..], &pkt.raw[18..]);
473    }
474}