Skip to main content

rns_core/
packet.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9    TooShort,
10    ExceedsMtu,
11    MissingTransportId,
12    InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            PacketError::TooShort => write!(f, "Packet too short"),
19            PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20            PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21            PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22        }
23    }
24}
25
26// =============================================================================
27// PacketFlags: packs 5 fields into one byte
28// =============================================================================
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32    pub header_type: u8,
33    pub context_flag: u8,
34    pub transport_type: u8,
35    pub destination_type: u8,
36    pub packet_type: u8,
37}
38
39impl PacketFlags {
40    /// Pack fields into a single flags byte.
41    ///
42    /// Bit layout:
43    /// ```text
44    /// Bit 6:     header_type (1 bit)
45    /// Bit 5:     context_flag (1 bit)
46    /// Bit 4:     transport_type (1 bit)
47    /// Bits 3-2:  destination_type (2 bits)
48    /// Bits 1-0:  packet_type (2 bits)
49    /// ```
50    pub fn pack(&self) -> u8 {
51        (self.header_type << 6)
52            | (self.context_flag << 5)
53            | (self.transport_type << 4)
54            | (self.destination_type << 2)
55            | self.packet_type
56    }
57
58    /// Unpack a flags byte into fields.
59    pub fn unpack(byte: u8) -> Self {
60        PacketFlags {
61            header_type: (byte & 0b01000000) >> 6,
62            context_flag: (byte & 0b00100000) >> 5,
63            transport_type: (byte & 0b00010000) >> 4,
64            destination_type: (byte & 0b00001100) >> 2,
65            packet_type: byte & 0b00000011,
66        }
67    }
68}
69
70// =============================================================================
71// RawPacket: wire-level packet representation
72// =============================================================================
73
74#[derive(Debug, Clone)]
75pub struct RawPacket {
76    pub flags: PacketFlags,
77    pub hops: u8,
78    pub transport_id: Option<[u8; 16]>,
79    pub destination_hash: [u8; 16],
80    pub context: u8,
81    pub data: Vec<u8>,
82    pub raw: Vec<u8>,
83    pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87    /// Pack fields into raw bytes.
88    pub fn pack(
89        flags: PacketFlags,
90        hops: u8,
91        destination_hash: &[u8; 16],
92        transport_id: Option<&[u8; 16]>,
93        context: u8,
94        data: &[u8],
95    ) -> Result<Self, PacketError> {
96        if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
97            return Err(PacketError::MissingTransportId);
98        }
99
100        let mut raw = Vec::new();
101        raw.push(flags.pack());
102        raw.push(hops);
103
104        if flags.header_type == constants::HEADER_2 {
105            raw.extend_from_slice(transport_id.unwrap());
106        }
107
108        raw.extend_from_slice(destination_hash);
109        raw.push(context);
110        raw.extend_from_slice(data);
111
112        if raw.len() > constants::MTU {
113            return Err(PacketError::ExceedsMtu);
114        }
115
116        let packet_hash = hash::full_hash(&Self::compute_hashable_part(
117            flags.header_type,
118            &raw,
119        ));
120
121        Ok(RawPacket {
122            flags,
123            hops,
124            transport_id: transport_id.copied(),
125            destination_hash: *destination_hash,
126            context,
127            data: data.to_vec(),
128            raw,
129            packet_hash,
130        })
131    }
132
133    /// Unpack raw bytes into fields.
134    pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
135        if raw.len() < constants::HEADER_MINSIZE {
136            return Err(PacketError::TooShort);
137        }
138
139        let flags = PacketFlags::unpack(raw[0]);
140        let hops = raw[1];
141
142        let dst_len = constants::TRUNCATED_HASHLENGTH / 8; // 16
143
144        if flags.header_type == constants::HEADER_2 {
145            // HEADER_2: [flags:1][hops:1][transport_id:16][dest_hash:16][context:1][data:*]
146            let min_len = 2 + dst_len * 2 + 1;
147            if raw.len() < min_len {
148                return Err(PacketError::TooShort);
149            }
150
151            let mut transport_id = [0u8; 16];
152            transport_id.copy_from_slice(&raw[2..2 + dst_len]);
153
154            let mut destination_hash = [0u8; 16];
155            destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
156
157            let context = raw[2 + 2 * dst_len];
158            let data = raw[2 + 2 * dst_len + 1..].to_vec();
159
160            let packet_hash = hash::full_hash(&Self::compute_hashable_part(
161                flags.header_type,
162                raw,
163            ));
164
165            Ok(RawPacket {
166                flags,
167                hops,
168                transport_id: Some(transport_id),
169                destination_hash,
170                context,
171                data,
172                raw: raw.to_vec(),
173                packet_hash,
174            })
175        } else if flags.header_type == constants::HEADER_1 {
176            // HEADER_1: [flags:1][hops:1][dest_hash:16][context:1][data:*]
177            let min_len = 2 + dst_len + 1;
178            if raw.len() < min_len {
179                return Err(PacketError::TooShort);
180            }
181
182            let mut destination_hash = [0u8; 16];
183            destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
184
185            let context = raw[2 + dst_len];
186            let data = raw[2 + dst_len + 1..].to_vec();
187
188            let packet_hash = hash::full_hash(&Self::compute_hashable_part(
189                flags.header_type,
190                raw,
191            ));
192
193            Ok(RawPacket {
194                flags,
195                hops,
196                transport_id: None,
197                destination_hash,
198                context,
199                data,
200                raw: raw.to_vec(),
201                packet_hash,
202            })
203        } else {
204            Err(PacketError::InvalidHeaderType)
205        }
206    }
207
208    /// Get the hashable part of the packet.
209    ///
210    /// From Python Packet.py:354-361:
211    /// - Take raw[0] & 0x0F (mask out upper 4 bits of flags)
212    /// - For HEADER_1: append raw[2:]
213    /// - For HEADER_2: skip transport_id: append raw[18:]
214    pub fn get_hashable_part(&self) -> Vec<u8> {
215        Self::compute_hashable_part(self.flags.header_type, &self.raw)
216    }
217
218    fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
219        let mut hashable = Vec::new();
220        hashable.push(raw[0] & 0b00001111);
221        if header_type == constants::HEADER_2 {
222            // Skip transport_id: raw[2..18] is transport_id (16 bytes)
223            hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
224        } else {
225            hashable.extend_from_slice(&raw[2..]);
226        }
227        hashable
228    }
229
230    /// Full SHA-256 hash of the hashable part.
231    pub fn get_hash(&self) -> [u8; 32] {
232        self.packet_hash
233    }
234
235    /// Truncated hash (first 16 bytes) of the hashable part.
236    pub fn get_truncated_hash(&self) -> [u8; 16] {
237        let mut result = [0u8; 16];
238        result.copy_from_slice(&self.packet_hash[..16]);
239        result
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_flags_pack_header1_data_single_broadcast() {
249        let flags = PacketFlags {
250            header_type: constants::HEADER_1,
251            context_flag: constants::FLAG_UNSET,
252            transport_type: constants::TRANSPORT_BROADCAST,
253            destination_type: constants::DESTINATION_SINGLE,
254            packet_type: constants::PACKET_TYPE_DATA,
255        };
256        assert_eq!(flags.pack(), 0x00);
257    }
258
259    #[test]
260    fn test_flags_pack_header2_announce_single_transport() {
261        let flags = PacketFlags {
262            header_type: constants::HEADER_2,
263            context_flag: constants::FLAG_UNSET,
264            transport_type: constants::TRANSPORT_TRANSPORT,
265            destination_type: constants::DESTINATION_SINGLE,
266            packet_type: constants::PACKET_TYPE_ANNOUNCE,
267        };
268        // 0b01010001 = 0x51
269        assert_eq!(flags.pack(), 0x51);
270    }
271
272    #[test]
273    fn test_flags_roundtrip() {
274        for byte in 0..=0x7Fu8 {
275            let flags = PacketFlags::unpack(byte);
276            assert_eq!(flags.pack(), byte);
277        }
278    }
279
280    #[test]
281    fn test_pack_header1() {
282        let dest_hash = [0xAA; 16];
283        let data = b"hello";
284        let flags = PacketFlags {
285            header_type: constants::HEADER_1,
286            context_flag: constants::FLAG_UNSET,
287            transport_type: constants::TRANSPORT_BROADCAST,
288            destination_type: constants::DESTINATION_SINGLE,
289            packet_type: constants::PACKET_TYPE_DATA,
290        };
291
292        let pkt = RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
293
294        // Verify layout: [flags:1][hops:1][dest:16][context:1][data:5] = 24 bytes
295        assert_eq!(pkt.raw.len(), 24);
296        assert_eq!(pkt.raw[0], 0x00); // flags
297        assert_eq!(pkt.raw[1], 0x00); // hops
298        assert_eq!(&pkt.raw[2..18], &dest_hash); // dest hash
299        assert_eq!(pkt.raw[18], 0x00); // context
300        assert_eq!(&pkt.raw[19..], b"hello"); // data
301    }
302
303    #[test]
304    fn test_pack_header2() {
305        let dest_hash = [0xAA; 16];
306        let transport_id = [0xBB; 16];
307        let data = b"world";
308        let flags = PacketFlags {
309            header_type: constants::HEADER_2,
310            context_flag: constants::FLAG_UNSET,
311            transport_type: constants::TRANSPORT_TRANSPORT,
312            destination_type: constants::DESTINATION_SINGLE,
313            packet_type: constants::PACKET_TYPE_ANNOUNCE,
314        };
315
316        let pkt = RawPacket::pack(flags, 3, &dest_hash, Some(&transport_id), constants::CONTEXT_NONE, data).unwrap();
317
318        // Layout: [flags:1][hops:1][transport:16][dest:16][context:1][data:5] = 40 bytes
319        assert_eq!(pkt.raw.len(), 40);
320        assert_eq!(pkt.raw[0], flags.pack());
321        assert_eq!(pkt.raw[1], 3);
322        assert_eq!(&pkt.raw[2..18], &transport_id);
323        assert_eq!(&pkt.raw[18..34], &dest_hash);
324        assert_eq!(pkt.raw[34], 0x00);
325        assert_eq!(&pkt.raw[35..], b"world");
326    }
327
328    #[test]
329    fn test_unpack_roundtrip_header1() {
330        let dest_hash = [0x11; 16];
331        let data = b"test data";
332        let flags = PacketFlags {
333            header_type: constants::HEADER_1,
334            context_flag: constants::FLAG_UNSET,
335            transport_type: constants::TRANSPORT_BROADCAST,
336            destination_type: constants::DESTINATION_SINGLE,
337            packet_type: constants::PACKET_TYPE_DATA,
338        };
339
340        let pkt = RawPacket::pack(flags, 5, &dest_hash, None, constants::CONTEXT_RESOURCE, data).unwrap();
341        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
342
343        assert_eq!(unpacked.flags, flags);
344        assert_eq!(unpacked.hops, 5);
345        assert!(unpacked.transport_id.is_none());
346        assert_eq!(unpacked.destination_hash, dest_hash);
347        assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
348        assert_eq!(unpacked.data, data);
349        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
350    }
351
352    #[test]
353    fn test_unpack_roundtrip_header2() {
354        let dest_hash = [0x22; 16];
355        let transport_id = [0x33; 16];
356        let data = b"transported";
357        let flags = PacketFlags {
358            header_type: constants::HEADER_2,
359            context_flag: constants::FLAG_SET,
360            transport_type: constants::TRANSPORT_TRANSPORT,
361            destination_type: constants::DESTINATION_SINGLE,
362            packet_type: constants::PACKET_TYPE_ANNOUNCE,
363        };
364
365        let pkt = RawPacket::pack(flags, 2, &dest_hash, Some(&transport_id), constants::CONTEXT_NONE, data).unwrap();
366        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
367
368        assert_eq!(unpacked.flags, flags);
369        assert_eq!(unpacked.hops, 2);
370        assert_eq!(unpacked.transport_id.unwrap(), transport_id);
371        assert_eq!(unpacked.destination_hash, dest_hash);
372        assert_eq!(unpacked.context, constants::CONTEXT_NONE);
373        assert_eq!(unpacked.data, data);
374        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
375    }
376
377    #[test]
378    fn test_unpack_too_short() {
379        assert!(RawPacket::unpack(&[0x00; 5]).is_err());
380    }
381
382    #[test]
383    fn test_pack_exceeds_mtu() {
384        let flags = PacketFlags {
385            header_type: constants::HEADER_1,
386            context_flag: constants::FLAG_UNSET,
387            transport_type: constants::TRANSPORT_BROADCAST,
388            destination_type: constants::DESTINATION_SINGLE,
389            packet_type: constants::PACKET_TYPE_DATA,
390        };
391        let data = [0u8; 500]; // way too much data
392        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
393        assert!(result.is_err());
394    }
395
396    #[test]
397    fn test_header2_missing_transport_id() {
398        let flags = PacketFlags {
399            header_type: constants::HEADER_2,
400            context_flag: constants::FLAG_UNSET,
401            transport_type: constants::TRANSPORT_TRANSPORT,
402            destination_type: constants::DESTINATION_SINGLE,
403            packet_type: constants::PACKET_TYPE_ANNOUNCE,
404        };
405        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
406        assert!(result.is_err());
407    }
408
409    #[test]
410    fn test_hashable_part_header1_masks_upper_flags() {
411        let dest_hash = [0xCC; 16];
412        let flags = PacketFlags {
413            header_type: constants::HEADER_1,
414            context_flag: constants::FLAG_SET,
415            transport_type: constants::TRANSPORT_BROADCAST,
416            destination_type: constants::DESTINATION_SINGLE,
417            packet_type: constants::PACKET_TYPE_DATA,
418        };
419
420        let pkt = RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
421        let hashable = pkt.get_hashable_part();
422
423        // First byte should have upper 4 bits masked out
424        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
425        // Rest should be raw[2:]
426        assert_eq!(&hashable[1..], &pkt.raw[2..]);
427    }
428
429    #[test]
430    fn test_hashable_part_header2_strips_transport_id() {
431        let dest_hash = [0xDD; 16];
432        let transport_id = [0xEE; 16];
433        let flags = PacketFlags {
434            header_type: constants::HEADER_2,
435            context_flag: constants::FLAG_UNSET,
436            transport_type: constants::TRANSPORT_TRANSPORT,
437            destination_type: constants::DESTINATION_SINGLE,
438            packet_type: constants::PACKET_TYPE_ANNOUNCE,
439        };
440
441        let pkt = RawPacket::pack(flags, 0, &dest_hash, Some(&transport_id), constants::CONTEXT_NONE, b"data").unwrap();
442        let hashable = pkt.get_hashable_part();
443
444        // First byte: flags masked
445        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
446        // Should skip transport_id: raw[18:] = dest_hash + context + data
447        assert_eq!(&hashable[1..], &pkt.raw[18..]);
448    }
449}