Skip to main content

nnrp_core/
header.rs

1use crate::{HeaderFlags, MessageType, NnrpError, CURRENT_WIRE_FORMAT};
2
3pub const COMMON_HEADER_LEN: usize = 40;
4pub const CURRENT_VERSION_MAJOR: u8 = 1;
5pub const ALPN: &str = "nnrp/1";
6const MAGIC: [u8; 4] = *b"NNRP";
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct CommonHeader {
10    pub version_major: u8,
11    pub wire_format: u8,
12    pub message_type: MessageType,
13    pub header_len: u8,
14    pub flags: HeaderFlags,
15    pub meta_len: u32,
16    pub body_len: u32,
17    pub session_id: u32,
18    pub frame_id: u32,
19    pub view_id: u16,
20    pub route_id: u16,
21    pub trace_id: u64,
22}
23
24impl CommonHeader {
25    pub fn new(message_type: MessageType, meta_len: u32, body_len: u32) -> Self {
26        Self {
27            version_major: CURRENT_VERSION_MAJOR,
28            wire_format: CURRENT_WIRE_FORMAT,
29            message_type,
30            header_len: COMMON_HEADER_LEN as u8,
31            flags: HeaderFlags::NONE,
32            meta_len,
33            body_len,
34            session_id: 0,
35            frame_id: 0,
36            view_id: 0,
37            route_id: 0,
38            trace_id: 0,
39        }
40    }
41
42    pub fn packet_len(&self) -> Result<usize, NnrpError> {
43        let payload_len = self
44            .meta_len
45            .checked_add(self.body_len)
46            .ok_or(NnrpError::MessageLengthOverflow)? as usize;
47        COMMON_HEADER_LEN
48            .checked_add(payload_len)
49            .ok_or(NnrpError::MessageLengthOverflow)
50    }
51
52    pub fn write(&self, destination: &mut [u8]) -> Result<(), NnrpError> {
53        if destination.len() < COMMON_HEADER_LEN {
54            return Err(NnrpError::DestinationTooShort {
55                expected: COMMON_HEADER_LEN,
56                actual: destination.len(),
57            });
58        }
59
60        if self.header_len != COMMON_HEADER_LEN as u8 {
61            return Err(NnrpError::InvalidHeaderLength(self.header_len));
62        }
63
64        self.flags.validate_known()?;
65
66        destination[0..4].copy_from_slice(&MAGIC);
67        destination[4] = self.version_major;
68        destination[5] = self.wire_format;
69        destination[6] = self.message_type as u8;
70        destination[7] = self.header_len;
71        destination[8..12].copy_from_slice(&self.flags.0.to_le_bytes());
72        destination[12..16].copy_from_slice(&self.meta_len.to_le_bytes());
73        destination[16..20].copy_from_slice(&self.body_len.to_le_bytes());
74        destination[20..24].copy_from_slice(&self.session_id.to_le_bytes());
75        destination[24..28].copy_from_slice(&self.frame_id.to_le_bytes());
76        destination[28..30].copy_from_slice(&self.view_id.to_le_bytes());
77        destination[30..32].copy_from_slice(&self.route_id.to_le_bytes());
78        destination[32..40].copy_from_slice(&self.trace_id.to_le_bytes());
79        Ok(())
80    }
81
82    pub fn to_bytes(&self) -> Result<[u8; COMMON_HEADER_LEN], NnrpError> {
83        let mut bytes = [0u8; COMMON_HEADER_LEN];
84        self.write(&mut bytes)?;
85        Ok(bytes)
86    }
87
88    pub fn parse(source: &[u8]) -> Result<Self, NnrpError> {
89        if source.len() < COMMON_HEADER_LEN {
90            return Err(NnrpError::SourceTooShort {
91                expected: COMMON_HEADER_LEN,
92                actual: source.len(),
93            });
94        }
95
96        if source[0..4] != MAGIC {
97            return Err(NnrpError::InvalidMagic);
98        }
99
100        let version_major = source[4];
101        if version_major != CURRENT_VERSION_MAJOR {
102            return Err(NnrpError::UnsupportedVersionMajor(version_major));
103        }
104
105        let wire_format = source[5];
106        if wire_format != CURRENT_WIRE_FORMAT {
107            return Err(NnrpError::UnsupportedWireFormat(wire_format));
108        }
109
110        let header_len = source[7];
111        if header_len != COMMON_HEADER_LEN as u8 {
112            return Err(NnrpError::InvalidHeaderLength(header_len));
113        }
114
115        let flags = HeaderFlags(u32::from_le_bytes(
116            source[8..12].try_into().expect("slice length"),
117        ));
118        flags.validate_known()?;
119
120        Ok(Self {
121            version_major,
122            wire_format,
123            message_type: MessageType::try_from_u8(source[6])?,
124            header_len,
125            flags,
126            meta_len: u32::from_le_bytes(source[12..16].try_into().expect("slice length")),
127            body_len: u32::from_le_bytes(source[16..20].try_into().expect("slice length")),
128            session_id: u32::from_le_bytes(source[20..24].try_into().expect("slice length")),
129            frame_id: u32::from_le_bytes(source[24..28].try_into().expect("slice length")),
130            view_id: u16::from_le_bytes(source[28..30].try_into().expect("slice length")),
131            route_id: u16::from_le_bytes(source[30..32].try_into().expect("slice length")),
132            trace_id: u64::from_le_bytes(source[32..40].try_into().expect("slice length")),
133        })
134    }
135
136    pub fn parse_packet(source: &[u8]) -> Result<(Self, &[u8], &[u8]), NnrpError> {
137        let header = Self::parse(source)?;
138        let declared = header.packet_len()?;
139        if declared != source.len() {
140            return Err(NnrpError::PacketLengthMismatch {
141                declared,
142                actual: source.len(),
143            });
144        }
145
146        let meta_start = COMMON_HEADER_LEN;
147        let meta_end = meta_start + header.meta_len as usize;
148        Ok((header, &source[meta_start..meta_end], &source[meta_end..]))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::{CommonHeader, COMMON_HEADER_LEN};
155    use crate::{HeaderFlags, MessageType, NnrpError};
156
157    #[test]
158    fn common_header_round_trips_flow_update_vector_header() {
159        let packet = hex_to_bytes("4e4e5250010017280000000020000000000000002a000000000000000000090088776655443322110104020000000200000000000000000000000000780000000700000003000000");
160
161        let header = CommonHeader::parse(&packet).expect("header should parse");
162
163        assert_eq!(header.version_major, 1);
164        assert_eq!(header.message_type, MessageType::FlowUpdate);
165        assert_eq!(header.header_len, COMMON_HEADER_LEN as u8);
166        assert_eq!(header.meta_len, 32);
167        assert_eq!(header.body_len, 0);
168        assert_eq!(header.session_id, 42);
169        assert_eq!(header.view_id, 0);
170        assert_eq!(header.route_id, 9);
171        assert_eq!(header.trace_id, 0x1122_3344_5566_7788);
172        assert_eq!(
173            header.to_bytes().unwrap().as_slice(),
174            &packet[..COMMON_HEADER_LEN]
175        );
176    }
177
178    #[test]
179    fn common_header_rejects_length_mismatch() {
180        let mut packet = CommonHeader::new(MessageType::Ping, 4, 0)
181            .to_bytes()
182            .expect("header writes")
183            .to_vec();
184        packet.extend_from_slice(&[1, 2]);
185
186        assert_eq!(
187            CommonHeader::parse_packet(&packet),
188            Err(NnrpError::PacketLengthMismatch {
189                declared: 44,
190                actual: 42
191            })
192        );
193    }
194
195    #[test]
196    fn common_header_rejects_reserved_flags() {
197        let mut header = CommonHeader::new(MessageType::Ping, 0, 0);
198        header.flags = HeaderFlags(0x40);
199
200        assert_eq!(
201            header.to_bytes(),
202            Err(NnrpError::ReservedBitsSet {
203                value: 0x40,
204                allowed: HeaderFlags::KNOWN_MASK as u64
205            })
206        );
207    }
208
209    fn hex_to_bytes(hex: &str) -> Vec<u8> {
210        assert_eq!(hex.len() % 2, 0);
211        (0..hex.len())
212            .step_by(2)
213            .map(|index| u8::from_str_radix(&hex[index..index + 2], 16).unwrap())
214            .collect()
215    }
216}