1use crate::{HeaderFlags, MessageType, NnrpError, CURRENT_WIRE_FORMAT};
2
3pub const COMMON_HEADER_LEN: usize = 40;
4pub const CURRENT_VERSION_MAJOR: u8 = 1;
5pub const ALPN: &str = "nnrp/1";
6const MAGIC: [u8; 4] = *b"NNRP";
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub struct CommonHeader {
10 pub version_major: u8,
11 pub wire_format: u8,
12 pub message_type: MessageType,
13 pub header_len: u8,
14 pub flags: HeaderFlags,
15 pub meta_len: u32,
16 pub body_len: u32,
17 pub session_id: u32,
18 pub frame_id: u32,
19 pub view_id: u16,
20 pub route_id: u16,
21 pub trace_id: u64,
22}
23
24impl CommonHeader {
25 pub fn new(message_type: MessageType, meta_len: u32, body_len: u32) -> Self {
26 Self {
27 version_major: CURRENT_VERSION_MAJOR,
28 wire_format: CURRENT_WIRE_FORMAT,
29 message_type,
30 header_len: COMMON_HEADER_LEN as u8,
31 flags: HeaderFlags::NONE,
32 meta_len,
33 body_len,
34 session_id: 0,
35 frame_id: 0,
36 view_id: 0,
37 route_id: 0,
38 trace_id: 0,
39 }
40 }
41
42 pub fn packet_len(&self) -> Result<usize, NnrpError> {
43 let payload_len = self
44 .meta_len
45 .checked_add(self.body_len)
46 .ok_or(NnrpError::MessageLengthOverflow)? as usize;
47 COMMON_HEADER_LEN
48 .checked_add(payload_len)
49 .ok_or(NnrpError::MessageLengthOverflow)
50 }
51
52 pub fn write(&self, destination: &mut [u8]) -> Result<(), NnrpError> {
53 if destination.len() < COMMON_HEADER_LEN {
54 return Err(NnrpError::DestinationTooShort {
55 expected: COMMON_HEADER_LEN,
56 actual: destination.len(),
57 });
58 }
59
60 if self.header_len != COMMON_HEADER_LEN as u8 {
61 return Err(NnrpError::InvalidHeaderLength(self.header_len));
62 }
63
64 self.flags.validate_known()?;
65
66 destination[0..4].copy_from_slice(&MAGIC);
67 destination[4] = self.version_major;
68 destination[5] = self.wire_format;
69 destination[6] = self.message_type as u8;
70 destination[7] = self.header_len;
71 destination[8..12].copy_from_slice(&self.flags.0.to_le_bytes());
72 destination[12..16].copy_from_slice(&self.meta_len.to_le_bytes());
73 destination[16..20].copy_from_slice(&self.body_len.to_le_bytes());
74 destination[20..24].copy_from_slice(&self.session_id.to_le_bytes());
75 destination[24..28].copy_from_slice(&self.frame_id.to_le_bytes());
76 destination[28..30].copy_from_slice(&self.view_id.to_le_bytes());
77 destination[30..32].copy_from_slice(&self.route_id.to_le_bytes());
78 destination[32..40].copy_from_slice(&self.trace_id.to_le_bytes());
79 Ok(())
80 }
81
82 pub fn to_bytes(&self) -> Result<[u8; COMMON_HEADER_LEN], NnrpError> {
83 let mut bytes = [0u8; COMMON_HEADER_LEN];
84 self.write(&mut bytes)?;
85 Ok(bytes)
86 }
87
88 pub fn parse(source: &[u8]) -> Result<Self, NnrpError> {
89 if source.len() < COMMON_HEADER_LEN {
90 return Err(NnrpError::SourceTooShort {
91 expected: COMMON_HEADER_LEN,
92 actual: source.len(),
93 });
94 }
95
96 if source[0..4] != MAGIC {
97 return Err(NnrpError::InvalidMagic);
98 }
99
100 let version_major = source[4];
101 if version_major != CURRENT_VERSION_MAJOR {
102 return Err(NnrpError::UnsupportedVersionMajor(version_major));
103 }
104
105 let wire_format = source[5];
106 if wire_format != CURRENT_WIRE_FORMAT {
107 return Err(NnrpError::UnsupportedWireFormat(wire_format));
108 }
109
110 let header_len = source[7];
111 if header_len != COMMON_HEADER_LEN as u8 {
112 return Err(NnrpError::InvalidHeaderLength(header_len));
113 }
114
115 let flags = HeaderFlags(u32::from_le_bytes(
116 source[8..12].try_into().expect("slice length"),
117 ));
118 flags.validate_known()?;
119
120 Ok(Self {
121 version_major,
122 wire_format,
123 message_type: MessageType::try_from_u8(source[6])?,
124 header_len,
125 flags,
126 meta_len: u32::from_le_bytes(source[12..16].try_into().expect("slice length")),
127 body_len: u32::from_le_bytes(source[16..20].try_into().expect("slice length")),
128 session_id: u32::from_le_bytes(source[20..24].try_into().expect("slice length")),
129 frame_id: u32::from_le_bytes(source[24..28].try_into().expect("slice length")),
130 view_id: u16::from_le_bytes(source[28..30].try_into().expect("slice length")),
131 route_id: u16::from_le_bytes(source[30..32].try_into().expect("slice length")),
132 trace_id: u64::from_le_bytes(source[32..40].try_into().expect("slice length")),
133 })
134 }
135
136 pub fn parse_packet(source: &[u8]) -> Result<(Self, &[u8], &[u8]), NnrpError> {
137 let header = Self::parse(source)?;
138 let declared = header.packet_len()?;
139 if declared != source.len() {
140 return Err(NnrpError::PacketLengthMismatch {
141 declared,
142 actual: source.len(),
143 });
144 }
145
146 let meta_start = COMMON_HEADER_LEN;
147 let meta_end = meta_start + header.meta_len as usize;
148 Ok((header, &source[meta_start..meta_end], &source[meta_end..]))
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::{CommonHeader, COMMON_HEADER_LEN};
155 use crate::{HeaderFlags, MessageType, NnrpError};
156
157 #[test]
158 fn common_header_round_trips_flow_update_vector_header() {
159 let packet = hex_to_bytes("4e4e5250010017280000000020000000000000002a000000000000000000090088776655443322110104020000000200000000000000000000000000780000000700000003000000");
160
161 let header = CommonHeader::parse(&packet).expect("header should parse");
162
163 assert_eq!(header.version_major, 1);
164 assert_eq!(header.message_type, MessageType::FlowUpdate);
165 assert_eq!(header.header_len, COMMON_HEADER_LEN as u8);
166 assert_eq!(header.meta_len, 32);
167 assert_eq!(header.body_len, 0);
168 assert_eq!(header.session_id, 42);
169 assert_eq!(header.view_id, 0);
170 assert_eq!(header.route_id, 9);
171 assert_eq!(header.trace_id, 0x1122_3344_5566_7788);
172 assert_eq!(
173 header.to_bytes().unwrap().as_slice(),
174 &packet[..COMMON_HEADER_LEN]
175 );
176 }
177
178 #[test]
179 fn common_header_rejects_length_mismatch() {
180 let mut packet = CommonHeader::new(MessageType::Ping, 4, 0)
181 .to_bytes()
182 .expect("header writes")
183 .to_vec();
184 packet.extend_from_slice(&[1, 2]);
185
186 assert_eq!(
187 CommonHeader::parse_packet(&packet),
188 Err(NnrpError::PacketLengthMismatch {
189 declared: 44,
190 actual: 42
191 })
192 );
193 }
194
195 #[test]
196 fn common_header_rejects_reserved_flags() {
197 let mut header = CommonHeader::new(MessageType::Ping, 0, 0);
198 header.flags = HeaderFlags(0x40);
199
200 assert_eq!(
201 header.to_bytes(),
202 Err(NnrpError::ReservedBitsSet {
203 value: 0x40,
204 allowed: HeaderFlags::KNOWN_MASK as u64
205 })
206 );
207 }
208
209 fn hex_to_bytes(hex: &str) -> Vec<u8> {
210 assert_eq!(hex.len() % 2, 0);
211 (0..hex.len())
212 .step_by(2)
213 .map(|index| u8::from_str_radix(&hex[index..index + 2], 16).unwrap())
214 .collect()
215 }
216}