Skip to main content

rns_core/
packet.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4use crate::constants;
5use crate::hash;
6
7#[derive(Debug)]
8pub enum PacketError {
9    TooShort,
10    ExceedsMtu,
11    MissingTransportId,
12    InvalidHeaderType,
13}
14
15impl fmt::Display for PacketError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            PacketError::TooShort => write!(f, "Packet too short"),
19            PacketError::ExceedsMtu => write!(f, "Packet exceeds MTU"),
20            PacketError::MissingTransportId => write!(f, "HEADER_2 requires transport_id"),
21            PacketError::InvalidHeaderType => write!(f, "Invalid header type"),
22        }
23    }
24}
25
26// =============================================================================
27// PacketFlags: packs 5 fields into one byte
28// =============================================================================
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub struct PacketFlags {
32    pub header_type: u8,
33    pub context_flag: u8,
34    pub transport_type: u8,
35    pub destination_type: u8,
36    pub packet_type: u8,
37}
38
39impl PacketFlags {
40    /// Pack fields into a single flags byte.
41    ///
42    /// Bit layout:
43    /// ```text
44    /// Bit 6:     header_type (1 bit)
45    /// Bit 5:     context_flag (1 bit)
46    /// Bit 4:     transport_type (1 bit)
47    /// Bits 3-2:  destination_type (2 bits)
48    /// Bits 1-0:  packet_type (2 bits)
49    /// ```
50    pub fn pack(&self) -> u8 {
51        (self.header_type << 6)
52            | (self.context_flag << 5)
53            | (self.transport_type << 4)
54            | (self.destination_type << 2)
55            | self.packet_type
56    }
57
58    /// Unpack a flags byte into fields.
59    pub fn unpack(byte: u8) -> Self {
60        PacketFlags {
61            header_type: (byte & 0b01000000) >> 6,
62            context_flag: (byte & 0b00100000) >> 5,
63            transport_type: (byte & 0b00010000) >> 4,
64            destination_type: (byte & 0b00001100) >> 2,
65            packet_type: byte & 0b00000011,
66        }
67    }
68}
69
70// =============================================================================
71// RawPacket: wire-level packet representation
72// =============================================================================
73
74#[derive(Debug, Clone)]
75pub struct RawPacket {
76    pub flags: PacketFlags,
77    pub hops: u8,
78    pub transport_id: Option<[u8; 16]>,
79    pub destination_hash: [u8; 16],
80    pub context: u8,
81    pub data: Vec<u8>,
82    pub raw: Vec<u8>,
83    pub packet_hash: [u8; 32],
84}
85
86impl RawPacket {
87    /// Pack fields into raw bytes.
88    pub fn pack(
89        flags: PacketFlags,
90        hops: u8,
91        destination_hash: &[u8; 16],
92        transport_id: Option<&[u8; 16]>,
93        context: u8,
94        data: &[u8],
95    ) -> Result<Self, PacketError> {
96        Self::pack_with_max_mtu(
97            flags,
98            hops,
99            destination_hash,
100            transport_id,
101            context,
102            data,
103            constants::MTU,
104        )
105    }
106
107    /// Pack fields into raw bytes and packet hash without constructing a full RawPacket.
108    pub fn pack_raw_with_hash(
109        flags: PacketFlags,
110        hops: u8,
111        destination_hash: &[u8; 16],
112        transport_id: Option<&[u8; 16]>,
113        context: u8,
114        data: &[u8],
115    ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
116        Self::pack_raw_with_hash_with_max_mtu(
117            flags,
118            hops,
119            destination_hash,
120            transport_id,
121            context,
122            data,
123            constants::MTU,
124        )
125    }
126
127    /// Pack fields into raw bytes with a caller-provided MTU limit.
128    pub fn pack_with_max_mtu(
129        flags: PacketFlags,
130        hops: u8,
131        destination_hash: &[u8; 16],
132        transport_id: Option<&[u8; 16]>,
133        context: u8,
134        data: &[u8],
135        max_mtu: usize,
136    ) -> Result<Self, PacketError> {
137        let (raw, packet_hash) = Self::pack_raw_with_hash_with_max_mtu(
138            flags,
139            hops,
140            destination_hash,
141            transport_id,
142            context,
143            data,
144            max_mtu,
145        )?;
146
147        Ok(RawPacket {
148            flags,
149            hops,
150            transport_id: transport_id.copied(),
151            destination_hash: *destination_hash,
152            context,
153            data: data.to_vec(),
154            raw,
155            packet_hash,
156        })
157    }
158
159    /// Pack fields into raw bytes and packet hash with a caller-provided MTU limit.
160    pub fn pack_raw_with_hash_with_max_mtu(
161        flags: PacketFlags,
162        hops: u8,
163        destination_hash: &[u8; 16],
164        transport_id: Option<&[u8; 16]>,
165        context: u8,
166        data: &[u8],
167        max_mtu: usize,
168    ) -> Result<(Vec<u8>, [u8; 32]), PacketError> {
169        if flags.header_type == constants::HEADER_2 && transport_id.is_none() {
170            return Err(PacketError::MissingTransportId);
171        }
172
173        let mut raw = Vec::new();
174        raw.push(flags.pack());
175        raw.push(hops);
176
177        if let Some(transport_id) = transport_id {
178            if flags.header_type == constants::HEADER_2 {
179                raw.extend_from_slice(transport_id);
180            }
181        }
182
183        raw.extend_from_slice(destination_hash);
184        raw.push(context);
185        raw.extend_from_slice(data);
186
187        if raw.len() > max_mtu {
188            return Err(PacketError::ExceedsMtu);
189        }
190
191        let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, &raw));
192        Ok((raw, packet_hash))
193    }
194
195    /// Unpack raw bytes into fields.
196    pub fn unpack(raw: &[u8]) -> Result<Self, PacketError> {
197        if raw.len() < constants::HEADER_MINSIZE {
198            return Err(PacketError::TooShort);
199        }
200
201        let flags = PacketFlags::unpack(raw[0]);
202        let hops = raw[1];
203
204        let dst_len = constants::TRUNCATED_HASHLENGTH / 8; // 16
205
206        if flags.header_type == constants::HEADER_2 {
207            // HEADER_2: [flags:1][hops:1][transport_id:16][dest_hash:16][context:1][data:*]
208            let min_len = 2 + dst_len * 2 + 1;
209            if raw.len() < min_len {
210                return Err(PacketError::TooShort);
211            }
212
213            let mut transport_id = [0u8; 16];
214            transport_id.copy_from_slice(&raw[2..2 + dst_len]);
215
216            let mut destination_hash = [0u8; 16];
217            destination_hash.copy_from_slice(&raw[2 + dst_len..2 + 2 * dst_len]);
218
219            let context = raw[2 + 2 * dst_len];
220            let data = raw[2 + 2 * dst_len + 1..].to_vec();
221
222            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
223
224            Ok(RawPacket {
225                flags,
226                hops,
227                transport_id: Some(transport_id),
228                destination_hash,
229                context,
230                data,
231                raw: raw.to_vec(),
232                packet_hash,
233            })
234        } else if flags.header_type == constants::HEADER_1 {
235            // HEADER_1: [flags:1][hops:1][dest_hash:16][context:1][data:*]
236            let min_len = 2 + dst_len + 1;
237            if raw.len() < min_len {
238                return Err(PacketError::TooShort);
239            }
240
241            let mut destination_hash = [0u8; 16];
242            destination_hash.copy_from_slice(&raw[2..2 + dst_len]);
243
244            let context = raw[2 + dst_len];
245            let data = raw[2 + dst_len + 1..].to_vec();
246
247            let packet_hash = hash::full_hash(&Self::compute_hashable_part(flags.header_type, raw));
248
249            Ok(RawPacket {
250                flags,
251                hops,
252                transport_id: None,
253                destination_hash,
254                context,
255                data,
256                raw: raw.to_vec(),
257                packet_hash,
258            })
259        } else {
260            Err(PacketError::InvalidHeaderType)
261        }
262    }
263
264    /// Get the hashable part of the packet.
265    ///
266    /// From Python Packet.py:354-361:
267    /// - Take raw[0] & 0x0F (mask out upper 4 bits of flags)
268    /// - For HEADER_1: append raw[2:]
269    /// - For HEADER_2: skip transport_id: append raw[18:]
270    pub fn get_hashable_part(&self) -> Vec<u8> {
271        Self::compute_hashable_part(self.flags.header_type, &self.raw)
272    }
273
274    fn compute_hashable_part(header_type: u8, raw: &[u8]) -> Vec<u8> {
275        let mut hashable = Vec::new();
276        hashable.push(raw[0] & 0b00001111);
277        if header_type == constants::HEADER_2 {
278            // Skip transport_id: raw[2..18] is transport_id (16 bytes)
279            hashable.extend_from_slice(&raw[(constants::TRUNCATED_HASHLENGTH / 8 + 2)..]);
280        } else {
281            hashable.extend_from_slice(&raw[2..]);
282        }
283        hashable
284    }
285
286    /// Full SHA-256 hash of the hashable part.
287    pub fn get_hash(&self) -> [u8; 32] {
288        self.packet_hash
289    }
290
291    /// Truncated hash (first 16 bytes) of the hashable part.
292    pub fn get_truncated_hash(&self) -> [u8; 16] {
293        let mut result = [0u8; 16];
294        result.copy_from_slice(&self.packet_hash[..16]);
295        result
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_flags_pack_header1_data_single_broadcast() {
305        let flags = PacketFlags {
306            header_type: constants::HEADER_1,
307            context_flag: constants::FLAG_UNSET,
308            transport_type: constants::TRANSPORT_BROADCAST,
309            destination_type: constants::DESTINATION_SINGLE,
310            packet_type: constants::PACKET_TYPE_DATA,
311        };
312        assert_eq!(flags.pack(), 0x00);
313    }
314
315    #[test]
316    fn test_flags_pack_header2_announce_single_transport() {
317        let flags = PacketFlags {
318            header_type: constants::HEADER_2,
319            context_flag: constants::FLAG_UNSET,
320            transport_type: constants::TRANSPORT_TRANSPORT,
321            destination_type: constants::DESTINATION_SINGLE,
322            packet_type: constants::PACKET_TYPE_ANNOUNCE,
323        };
324        // 0b01010001 = 0x51
325        assert_eq!(flags.pack(), 0x51);
326    }
327
328    #[test]
329    fn test_flags_roundtrip() {
330        for byte in 0..=0x7Fu8 {
331            let flags = PacketFlags::unpack(byte);
332            assert_eq!(flags.pack(), byte);
333        }
334    }
335
336    #[test]
337    fn test_pack_header1() {
338        let dest_hash = [0xAA; 16];
339        let data = b"hello";
340        let flags = PacketFlags {
341            header_type: constants::HEADER_1,
342            context_flag: constants::FLAG_UNSET,
343            transport_type: constants::TRANSPORT_BROADCAST,
344            destination_type: constants::DESTINATION_SINGLE,
345            packet_type: constants::PACKET_TYPE_DATA,
346        };
347
348        let pkt =
349            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, data).unwrap();
350
351        // Verify layout: [flags:1][hops:1][dest:16][context:1][data:5] = 24 bytes
352        assert_eq!(pkt.raw.len(), 24);
353        assert_eq!(pkt.raw[0], 0x00); // flags
354        assert_eq!(pkt.raw[1], 0x00); // hops
355        assert_eq!(&pkt.raw[2..18], &dest_hash); // dest hash
356        assert_eq!(pkt.raw[18], 0x00); // context
357        assert_eq!(&pkt.raw[19..], b"hello"); // data
358    }
359
360    #[test]
361    fn test_pack_header2() {
362        let dest_hash = [0xAA; 16];
363        let transport_id = [0xBB; 16];
364        let data = b"world";
365        let flags = PacketFlags {
366            header_type: constants::HEADER_2,
367            context_flag: constants::FLAG_UNSET,
368            transport_type: constants::TRANSPORT_TRANSPORT,
369            destination_type: constants::DESTINATION_SINGLE,
370            packet_type: constants::PACKET_TYPE_ANNOUNCE,
371        };
372
373        let pkt = RawPacket::pack(
374            flags,
375            3,
376            &dest_hash,
377            Some(&transport_id),
378            constants::CONTEXT_NONE,
379            data,
380        )
381        .unwrap();
382
383        // Layout: [flags:1][hops:1][transport:16][dest:16][context:1][data:5] = 40 bytes
384        assert_eq!(pkt.raw.len(), 40);
385        assert_eq!(pkt.raw[0], flags.pack());
386        assert_eq!(pkt.raw[1], 3);
387        assert_eq!(&pkt.raw[2..18], &transport_id);
388        assert_eq!(&pkt.raw[18..34], &dest_hash);
389        assert_eq!(pkt.raw[34], 0x00);
390        assert_eq!(&pkt.raw[35..], b"world");
391    }
392
393    #[test]
394    fn test_unpack_roundtrip_header1() {
395        let dest_hash = [0x11; 16];
396        let data = b"test data";
397        let flags = PacketFlags {
398            header_type: constants::HEADER_1,
399            context_flag: constants::FLAG_UNSET,
400            transport_type: constants::TRANSPORT_BROADCAST,
401            destination_type: constants::DESTINATION_SINGLE,
402            packet_type: constants::PACKET_TYPE_DATA,
403        };
404
405        let pkt = RawPacket::pack(
406            flags,
407            5,
408            &dest_hash,
409            None,
410            constants::CONTEXT_RESOURCE,
411            data,
412        )
413        .unwrap();
414        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
415
416        assert_eq!(unpacked.flags, flags);
417        assert_eq!(unpacked.hops, 5);
418        assert!(unpacked.transport_id.is_none());
419        assert_eq!(unpacked.destination_hash, dest_hash);
420        assert_eq!(unpacked.context, constants::CONTEXT_RESOURCE);
421        assert_eq!(unpacked.data, data);
422        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
423    }
424
425    #[test]
426    fn test_unpack_roundtrip_header2() {
427        let dest_hash = [0x22; 16];
428        let transport_id = [0x33; 16];
429        let data = b"transported";
430        let flags = PacketFlags {
431            header_type: constants::HEADER_2,
432            context_flag: constants::FLAG_SET,
433            transport_type: constants::TRANSPORT_TRANSPORT,
434            destination_type: constants::DESTINATION_SINGLE,
435            packet_type: constants::PACKET_TYPE_ANNOUNCE,
436        };
437
438        let pkt = RawPacket::pack(
439            flags,
440            2,
441            &dest_hash,
442            Some(&transport_id),
443            constants::CONTEXT_NONE,
444            data,
445        )
446        .unwrap();
447        let unpacked = RawPacket::unpack(&pkt.raw).unwrap();
448
449        assert_eq!(unpacked.flags, flags);
450        assert_eq!(unpacked.hops, 2);
451        assert_eq!(unpacked.transport_id.unwrap(), transport_id);
452        assert_eq!(unpacked.destination_hash, dest_hash);
453        assert_eq!(unpacked.context, constants::CONTEXT_NONE);
454        assert_eq!(unpacked.data, data);
455        assert_eq!(unpacked.packet_hash, pkt.packet_hash);
456    }
457
458    #[test]
459    fn test_unpack_too_short() {
460        assert!(RawPacket::unpack(&[0x00; 5]).is_err());
461    }
462
463    #[test]
464    fn test_pack_exceeds_mtu() {
465        let flags = PacketFlags {
466            header_type: constants::HEADER_1,
467            context_flag: constants::FLAG_UNSET,
468            transport_type: constants::TRANSPORT_BROADCAST,
469            destination_type: constants::DESTINATION_SINGLE,
470            packet_type: constants::PACKET_TYPE_DATA,
471        };
472        let data = [0u8; 500]; // way too much data
473        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, &data);
474        assert!(result.is_err());
475    }
476
477    #[test]
478    fn test_header2_missing_transport_id() {
479        let flags = PacketFlags {
480            header_type: constants::HEADER_2,
481            context_flag: constants::FLAG_UNSET,
482            transport_type: constants::TRANSPORT_TRANSPORT,
483            destination_type: constants::DESTINATION_SINGLE,
484            packet_type: constants::PACKET_TYPE_ANNOUNCE,
485        };
486        let result = RawPacket::pack(flags, 0, &[0; 16], None, 0, b"data");
487        assert!(result.is_err());
488    }
489
490    #[test]
491    fn test_hashable_part_header1_masks_upper_flags() {
492        let dest_hash = [0xCC; 16];
493        let flags = PacketFlags {
494            header_type: constants::HEADER_1,
495            context_flag: constants::FLAG_SET,
496            transport_type: constants::TRANSPORT_BROADCAST,
497            destination_type: constants::DESTINATION_SINGLE,
498            packet_type: constants::PACKET_TYPE_DATA,
499        };
500
501        let pkt =
502            RawPacket::pack(flags, 0, &dest_hash, None, constants::CONTEXT_NONE, b"test").unwrap();
503        let hashable = pkt.get_hashable_part();
504
505        // First byte should have upper 4 bits masked out
506        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
507        // Rest should be raw[2:]
508        assert_eq!(&hashable[1..], &pkt.raw[2..]);
509    }
510
511    #[test]
512    fn test_hashable_part_header2_strips_transport_id() {
513        let dest_hash = [0xDD; 16];
514        let transport_id = [0xEE; 16];
515        let flags = PacketFlags {
516            header_type: constants::HEADER_2,
517            context_flag: constants::FLAG_UNSET,
518            transport_type: constants::TRANSPORT_TRANSPORT,
519            destination_type: constants::DESTINATION_SINGLE,
520            packet_type: constants::PACKET_TYPE_ANNOUNCE,
521        };
522
523        let pkt = RawPacket::pack(
524            flags,
525            0,
526            &dest_hash,
527            Some(&transport_id),
528            constants::CONTEXT_NONE,
529            b"data",
530        )
531        .unwrap();
532        let hashable = pkt.get_hashable_part();
533
534        // First byte: flags masked
535        assert_eq!(hashable[0], pkt.raw[0] & 0x0F);
536        // Should skip transport_id: raw[18:] = dest_hash + context + data
537        assert_eq!(&hashable[1..], &pkt.raw[18..]);
538    }
539}