Skip to main content

rns_core/
packet.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9    TooShort,
10    ExceedsMtu,
11    MissingTransportId,
12    InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            PacketError::TooShort => write!(f, "Packet too short"),
19            PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20            PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21            PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22        }
23    }
24}
25
26// =============================================================================
27// PacketFlags: packs 5 fields into one byte
28// =============================================================================
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32    pub header_type: u8,
33    pub context_flag: u8,
34    pub transport_type: u8,
35    pub destination_type: u8,
36    pub packet_type: u8,
37}
38
39impl PacketFlags {
40    /// Pack fields into a single flags byte.
41    ///
42    /// Bit layout:
43    /// ```text
44    /// Bit 6:     header_type (1 bit)
45    /// Bit 5:     context_flag (1 bit)
46    /// Bit 4:     transport_type (1 bit)
47    /// Bits 3-2:  destination_type (2 bits)
48    /// Bits 1-0:  packet_type (2 bits)
49    /// ```
50    pub fn pack(&self) -> u8 {
51        (self.header_type << 6)
52            | (self.context_flag << 5)
53            | (self.transport_type << 4)
54            | (self.destination_type << 2)
55            | self.packet_type
56    }
57
58    /// Unpack a flags byte into fields.
59    pub fn unpack(byte: u8) -> Self {
60        PacketFlags {
61            header_type: (byte & 0b01000000) >> 6,
62            context_flag: (byte & 0b00100000) >> 5,
63            transport_type: (byte & 0b00010000) >> 4,
64            destination_type: (byte & 0b00001100) >> 2,
65            packet_type: byte & 0b00000011,
66        }
67    }
68}
69
70// =============================================================================
71// RawPacket: wire-level packet representation
72// =============================================================================
73
74#[derive(Debug, Clone)]
75pub struct RawPacket {
76    pub flags: PacketFlags,
77    pub hops: u8,
78    pub transport_id: Option<[u8; 16]>,
79    pub destination_hash: [u8; 16],
80    pub context: u8,
81    pub data: Vec<u8>,
82    pub raw: Vec<u8>,
83    pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87    /// Pack fields into raw bytes.
88    pub fn pack(
89        flags: PacketFlags,
90        hops: u8,
91        destination_hash: &[u8; 16],
92        transport_id: Option<&[u8; 16]>,
93        context: u8,
94        data: &[u8],
95    ) -> Result<Self, PacketError> {
96        Self::pack_with_max_mtu(
97            flags,
98            hops,
99            destination_hash,
100            transport_id,
101            context,
102            data,
103            constants::MTU,
104        )
105    }
106
107    /// Pack fields into raw bytes with a caller-provided MTU limit.
108    pub fn pack_with_max_mtu(
109        flags: PacketFlags,
110        hops: u8,
111        destination_hash: &[u8; 16],
112        transport_id: Option<&[u8; 16]>,
113        context: u8,
114        data: &[u8],
115        max_mtu: usize,
116    ) -> Result<Self, PacketError> {
117        if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
118            return Err(PacketError::MissingTransportId);
119        }
120
121        let mut raw = Vec::new();
122        raw.push(flags.pack());
123        raw.push(hops);
124
125        if flags.header_type == constants::HEADER_2 {
126            raw.extend_from_slice(transport_id.unwrap());
127        }
128
129        raw.extend_from_slice(destination_hash);
130        raw.push(context);
131        raw.extend_from_slice(data);
132
133        if raw.len() > max_mtu {
134            return Err(PacketError::ExceedsMtu);
135        }
136
137        let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
138
139        Ok(RawPacket {
140            flags,
141            hops,
142            transport_id: transport_id.copied(),
143            destination_hash: *destination_hash,
144            context,
145            data: data.to_vec(),
146            raw,
147            packet_hash,
148        })
149    }
150
151    /// Unpack raw bytes into fields.
152    pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
153        if raw.len() < constants::HEADER_MINSIZE {
154            return Err(PacketError::TooShort);
155        }
156
157        let flags = PacketFlags::unpack(raw[0]);
158        let hops = raw[1];
159
160        let dst_len = constants::TRUNCATED_HASHLENGTH / 8; // 16
161
162        if flags.header_type == constants::HEADER_2 {
163            // HEADER_2: [flags:1][hops:1][transport_id:16][dest_hash:16][context:1][data:*]
164            let min_len = 2 + dst_len * 2 + 1;
165            if raw.len() < min_len {
166                return Err(PacketError::TooShort);
167            }
168
169            let mut transport_id = [0u8; 16];
170            transport_id.copy_from_slice(&raw[2..2 + dst_len]);
171
172            let mut destination_hash = [0u8; 16];
173            destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
174
175            let context = raw[2 + 2 * dst_len];
176            let data = raw[2 + 2 * dst_len + 1..].to_vec();
177
178            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
179
180            Ok(RawPacket {
181                flags,
182                hops,
183                transport_id: Some(transport_id),
184                destination_hash,
185                context,
186                data,
187                raw: raw.to_vec(),
188                packet_hash,
189            })
190        } else if flags.header_type == constants::HEADER_1 {
191            // HEADER_1: [flags:1][hops:1][dest_hash:16][context:1][data:*]
192            let min_len = 2 + dst_len + 1;
193            if raw.len() < min_len {
194                return Err(PacketError::TooShort);
195            }
196
197            let mut destination_hash = [0u8; 16];
198            destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
199
200            let context = raw[2 + dst_len];
201            let data = raw[2 + dst_len + 1..].to_vec();
202
203            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
204
205            Ok(RawPacket {
206                flags,
207                hops,
208                transport_id: None,
209                destination_hash,
210                context,
211                data,
212                raw: raw.to_vec(),
213                packet_hash,
214            })
215        } else {
216            Err(PacketError::InvalidHeaderType)
217        }
218    }
219
220    /// Get the hashable part of the packet.
221    ///
222    /// From Python Packet.py:354-361:
223    /// - Take raw[0] & 0x0F (mask out upper 4 bits of flags)
224    /// - For HEADER_1: append raw[2:]
225    /// - For HEADER_2: skip transport_id: append raw[18:]
226    pub fn get_hashable_part(&self) -> Vec<u8> {
227        Self::compute_hashable_part(self.flags.header_type, &self.raw)
228    }
229
230    fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
231        let mut hashable = Vec::new();
232        hashable.push(raw[0] & 0b00001111);
233        if header_type == constants::HEADER_2 {
234            // Skip transport_id: raw[2..18] is transport_id (16 bytes)
235            hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
236        } else {
237            hashable.extend_from_slice(&raw[2..]);
238        }
239        hashable
240    }
241
242    /// Full SHA-256 hash of the hashable part.
243    pub fn get_hash(&self) -> [u8; 32] {
244        self.packet_hash
245    }
246
247    /// Truncated hash (first 16 bytes) of the hashable part.
248    pub fn get_truncated_hash(&self) -> [u8; 16] {
249        let mut result = [0u8; 16];
250        result.copy_from_slice(&self.packet_hash[..16]);
251        result
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    #[test]
260    fn test_flags_pack_header1_data_single_broadcast() {
261        let flags = PacketFlags {
262            header_type: constants::HEADER_1,
263            context_flag: constants::FLAG_UNSET,
264            transport_type: constants::TRANSPORT_BROADCAST,
265            destination_type: constants::DESTINATION_SINGLE,
266            packet_type: constants::PACKET_TYPE_DATA,
267        };
268        assert_eq!(flags.pack(), 0x00);
269    }
270
271    #[test]
272    fn test_flags_pack_header2_announce_single_transport() {
273        let flags = PacketFlags {
274            header_type: constants::HEADER_2,
275            context_flag: constants::FLAG_UNSET,
276            transport_type: constants::TRANSPORT_TRANSPORT,
277            destination_type: constants::DESTINATION_SINGLE,
278            packet_type: constants::PACKET_TYPE_ANNOUNCE,
279        };
280        // 0b01010001 = 0x51
281        assert_eq!(flags.pack(), 0x51);
282    }
283
284    #[test]
285    fn test_flags_roundtrip() {
286        for byte in 0..=0x7Fu8 {
287            let flags = PacketFlags::unpack(byte);
288            assert_eq!(flags.pack(), byte);
289        }
290    }
291
292    #[test]
293    fn test_pack_header1() {
294        let dest_hash = [0xAA; 16];
295        let data = b"hello";
296        let flags = PacketFlags {
297            header_type: constants::HEADER_1,
298            context_flag: constants::FLAG_UNSET,
299            transport_type: constants::TRANSPORT_BROADCAST,
300            destination_type: constants::DESTINATION_SINGLE,
301            packet_type: constants::PACKET_TYPE_DATA,
302        };
303
304        let pkt =
305            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
306
307        // Verify layout: [flags:1][hops:1][dest:16][context:1][data:5] = 24 bytes
308        assert_eq!(pkt.raw.len(), 24);
309        assert_eq!(pkt.raw[0], 0x00); // flags
310        assert_eq!(pkt.raw[1], 0x00); // hops
311        assert_eq!(&pkt.raw[2..18], &dest_hash); // dest hash
312        assert_eq!(pkt.raw[18], 0x00); // context
313        assert_eq!(&pkt.raw[19..], b"hello"); // data
314    }
315
316    #[test]
317    fn test_pack_header2() {
318        let dest_hash = [0xAA; 16];
319        let transport_id = [0xBB; 16];
320        let data = b"world";
321        let flags = PacketFlags {
322            header_type: constants::HEADER_2,
323            context_flag: constants::FLAG_UNSET,
324            transport_type: constants::TRANSPORT_TRANSPORT,
325            destination_type: constants::DESTINATION_SINGLE,
326            packet_type: constants::PACKET_TYPE_ANNOUNCE,
327        };
328
329        let pkt = RawPacket::pack(
330            flags,
331            3,
332            &dest_hash,
333            Some(&transport_id),
334            constants::CONTEXT_NONE,
335            data,
336        )
337        .unwrap();
338
339        // Layout: [flags:1][hops:1][transport:16][dest:16][context:1][data:5] = 40 bytes
340        assert_eq!(pkt.raw.len(), 40);
341        assert_eq!(pkt.raw[0], flags.pack());
342        assert_eq!(pkt.raw[1], 3);
343        assert_eq!(&pkt.raw[2..18], &transport_id);
344        assert_eq!(&pkt.raw[18..34], &dest_hash);
345        assert_eq!(pkt.raw[34], 0x00);
346        assert_eq!(&pkt.raw[35..], b"world");
347    }
348
349    #[test]
350    fn test_unpack_roundtrip_header1() {
351        let dest_hash = [0x11; 16];
352        let data = b"test data";
353        let flags = PacketFlags {
354            header_type: constants::HEADER_1,
355            context_flag: constants::FLAG_UNSET,
356            transport_type: constants::TRANSPORT_BROADCAST,
357            destination_type: constants::DESTINATION_SINGLE,
358            packet_type: constants::PACKET_TYPE_DATA,
359        };
360
361        let pkt = RawPacket::pack(
362            flags,
363            5,
364            &dest_hash,
365            None,
366            constants::CONTEXT_RESOURCE,
367            data,
368        )
369        .unwrap();
370        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
371
372        assert_eq!(unpacked.flags, flags);
373        assert_eq!(unpacked.hops, 5);
374        assert!(unpacked.transport_id.is_none());
375        assert_eq!(unpacked.destination_hash, dest_hash);
376        assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
377        assert_eq!(unpacked.data, data);
378        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
379    }
380
381    #[test]
382    fn test_unpack_roundtrip_header2() {
383        let dest_hash = [0x22; 16];
384        let transport_id = [0x33; 16];
385        let data = b"transported";
386        let flags = PacketFlags {
387            header_type: constants::HEADER_2,
388            context_flag: constants::FLAG_SET,
389            transport_type: constants::TRANSPORT_TRANSPORT,
390            destination_type: constants::DESTINATION_SINGLE,
391            packet_type: constants::PACKET_TYPE_ANNOUNCE,
392        };
393
394        let pkt = RawPacket::pack(
395            flags,
396            2,
397            &dest_hash,
398            Some(&transport_id),
399            constants::CONTEXT_NONE,
400            data,
401        )
402        .unwrap();
403        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
404
405        assert_eq!(unpacked.flags, flags);
406        assert_eq!(unpacked.hops, 2);
407        assert_eq!(unpacked.transport_id.unwrap(), transport_id);
408        assert_eq!(unpacked.destination_hash, dest_hash);
409        assert_eq!(unpacked.context, constants::CONTEXT_NONE);
410        assert_eq!(unpacked.data, data);
411        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
412    }
413
414    #[test]
415    fn test_unpack_too_short() {
416        assert!(RawPacket::unpack(&[0x00; 5]).is_err());
417    }
418
419    #[test]
420    fn test_pack_exceeds_mtu() {
421        let flags = PacketFlags {
422            header_type: constants::HEADER_1,
423            context_flag: constants::FLAG_UNSET,
424            transport_type: constants::TRANSPORT_BROADCAST,
425            destination_type: constants::DESTINATION_SINGLE,
426            packet_type: constants::PACKET_TYPE_DATA,
427        };
428        let data = [0u8; 500]; // way too much data
429        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
430        assert!(result.is_err());
431    }
432
433    #[test]
434    fn test_header2_missing_transport_id() {
435        let flags = PacketFlags {
436            header_type: constants::HEADER_2,
437            context_flag: constants::FLAG_UNSET,
438            transport_type: constants::TRANSPORT_TRANSPORT,
439            destination_type: constants::DESTINATION_SINGLE,
440            packet_type: constants::PACKET_TYPE_ANNOUNCE,
441        };
442        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
443        assert!(result.is_err());
444    }
445
446    #[test]
447    fn test_hashable_part_header1_masks_upper_flags() {
448        let dest_hash = [0xCC; 16];
449        let flags = PacketFlags {
450            header_type: constants::HEADER_1,
451            context_flag: constants::FLAG_SET,
452            transport_type: constants::TRANSPORT_BROADCAST,
453            destination_type: constants::DESTINATION_SINGLE,
454            packet_type: constants::PACKET_TYPE_DATA,
455        };
456
457        let pkt =
458            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
459        let hashable = pkt.get_hashable_part();
460
461        // First byte should have upper 4 bits masked out
462        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
463        // Rest should be raw[2:]
464        assert_eq!(&hashable[1..], &pkt.raw[2..]);
465    }
466
467    #[test]
468    fn test_hashable_part_header2_strips_transport_id() {
469        let dest_hash = [0xDD; 16];
470        let transport_id = [0xEE; 16];
471        let flags = PacketFlags {
472            header_type: constants::HEADER_2,
473            context_flag: constants::FLAG_UNSET,
474            transport_type: constants::TRANSPORT_TRANSPORT,
475            destination_type: constants::DESTINATION_SINGLE,
476            packet_type: constants::PACKET_TYPE_ANNOUNCE,
477        };
478
479        let pkt = RawPacket::pack(
480            flags,
481            0,
482            &dest_hash,
483            Some(&transport_id),
484            constants::CONTEXT_NONE,
485            b"data",
486        )
487        .unwrap();
488        let hashable = pkt.get_hashable_part();
489
490        // First byte: flags masked
491        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
492        // Should skip transport_id: raw[18:] = dest_hash + context + data
493        assert_eq!(&hashable[1..], &pkt.raw[18..]);
494    }
495}