1use bytes::{Buf, BufMut};
2
3use crate::constants::{self, HEADER_LEN, MARKER, MARKER_LEN, MIN_MESSAGE_LEN};
4use crate::error::DecodeError;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8#[repr(u8)]
9pub enum MessageType {
10 Open = 1,
12 Update = 2,
14 Notification = 3,
16 Keepalive = 4,
18 RouteRefresh = 5,
20}
21
22impl MessageType {
23 #[must_use]
25 pub fn from_u8(value: u8) -> Option<Self> {
26 match value {
27 constants::message_type::OPEN => Some(Self::Open),
28 constants::message_type::UPDATE => Some(Self::Update),
29 constants::message_type::NOTIFICATION => Some(Self::Notification),
30 constants::message_type::KEEPALIVE => Some(Self::Keepalive),
31 constants::message_type::ROUTE_REFRESH => Some(Self::RouteRefresh),
32 _ => None,
33 }
34 }
35
36 #[must_use]
38 pub fn as_u8(self) -> u8 {
39 self as u8
40 }
41}
42
43impl std::fmt::Display for MessageType {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 match self {
46 Self::Open => write!(f, "OPEN"),
47 Self::Update => write!(f, "UPDATE"),
48 Self::Notification => write!(f, "NOTIFICATION"),
49 Self::Keepalive => write!(f, "KEEPALIVE"),
50 Self::RouteRefresh => write!(f, "ROUTE-REFRESH"),
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub struct BgpHeader {
58 pub length: u16,
60 pub message_type: MessageType,
62}
63
64impl BgpHeader {
65 pub fn decode(buf: &mut impl Buf, max_message_len: u16) -> Result<Self, DecodeError> {
77 if buf.remaining() < HEADER_LEN {
78 return Err(DecodeError::Incomplete {
79 needed: HEADER_LEN,
80 available: buf.remaining(),
81 });
82 }
83
84 let mut marker = [0u8; MARKER_LEN];
86 buf.copy_to_slice(&mut marker);
87 if marker != MARKER {
88 return Err(DecodeError::InvalidMarker);
89 }
90
91 let length = buf.get_u16();
92 if !(MIN_MESSAGE_LEN..=max_message_len).contains(&length) {
93 return Err(DecodeError::InvalidLength { length });
94 }
95
96 let type_byte = buf.get_u8();
97 let message_type =
98 MessageType::from_u8(type_byte).ok_or(DecodeError::UnknownMessageType(type_byte))?;
99
100 Ok(Self {
101 length,
102 message_type,
103 })
104 }
105
106 pub fn encode(&self, buf: &mut impl BufMut) {
108 buf.put_slice(&MARKER);
109 buf.put_u16(self.length);
110 buf.put_u8(self.message_type.as_u8());
111 }
112}
113
114pub fn peek_message_length(buf: &[u8], max_message_len: u16) -> Result<Option<u16>, DecodeError> {
130 if buf.len() < HEADER_LEN {
131 return Ok(None);
132 }
133
134 if buf[..MARKER_LEN] != MARKER {
136 return Err(DecodeError::InvalidMarker);
137 }
138
139 let length = u16::from_be_bytes([buf[16], buf[17]]);
140 if !(MIN_MESSAGE_LEN..=max_message_len).contains(&length) {
141 return Err(DecodeError::InvalidLength { length });
142 }
143
144 let type_byte = buf[18];
145 if MessageType::from_u8(type_byte).is_none() {
146 return Err(DecodeError::UnknownMessageType(type_byte));
147 }
148
149 Ok(Some(length))
150}
151
152#[cfg(test)]
153mod tests {
154 use bytes::BytesMut;
155
156 use super::*;
157 use crate::constants::{EXTENDED_MAX_MESSAGE_LEN, MAX_MESSAGE_LEN};
158
159 fn make_header(length: u16, msg_type: u8) -> BytesMut {
160 let mut buf = BytesMut::with_capacity(HEADER_LEN);
161 buf.put_slice(&MARKER);
162 buf.put_u16(length);
163 buf.put_u8(msg_type);
164 buf
165 }
166
167 #[test]
168 fn decode_valid_keepalive_header() {
169 let mut buf = make_header(19, 4).freeze();
170 let hdr = BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN).unwrap();
171 assert_eq!(hdr.length, 19);
172 assert_eq!(hdr.message_type, MessageType::Keepalive);
173 assert_eq!(buf.remaining(), 0);
174 }
175
176 #[test]
177 fn decode_valid_open_header() {
178 let mut buf = make_header(29, 1).freeze();
179 let hdr = BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN).unwrap();
180 assert_eq!(hdr.message_type, MessageType::Open);
181 }
182
183 #[test]
184 fn reject_invalid_marker() {
185 let mut data = make_header(19, 4);
186 data[0] = 0x00; let mut buf = data.freeze();
188 assert!(matches!(
189 BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
190 Err(DecodeError::InvalidMarker)
191 ));
192 }
193
194 #[test]
195 fn reject_length_too_small() {
196 let mut buf = make_header(18, 4).freeze();
197 assert!(matches!(
198 BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
199 Err(DecodeError::InvalidLength { length: 18 })
200 ));
201 }
202
203 #[test]
204 fn reject_length_too_large() {
205 let mut buf = make_header(4097, 4).freeze();
206 assert!(matches!(
207 BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
208 Err(DecodeError::InvalidLength { length: 4097 })
209 ));
210 }
211
212 #[test]
213 fn reject_unknown_type() {
214 let mut buf = make_header(19, 99).freeze();
215 assert!(matches!(
216 BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
217 Err(DecodeError::UnknownMessageType(99))
218 ));
219 }
220
221 #[test]
222 fn reject_incomplete_buffer() {
223 let mut buf = bytes::Bytes::from_static(&[0xFF; 10]);
224 assert!(matches!(
225 BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
226 Err(DecodeError::Incomplete { .. })
227 ));
228 }
229
230 #[test]
231 fn roundtrip_header() {
232 let original = BgpHeader {
233 length: 100,
234 message_type: MessageType::Update,
235 };
236 let mut encoded = BytesMut::with_capacity(HEADER_LEN);
237 original.encode(&mut encoded);
238 let mut buf = encoded.freeze();
239 let decoded = BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN).unwrap();
240 assert_eq!(original, decoded);
241 }
242
243 #[test]
244 fn peek_returns_none_for_short_buffer() {
245 assert_eq!(
246 peek_message_length(&[0xFF; 10], MAX_MESSAGE_LEN).unwrap(),
247 None
248 );
249 }
250
251 #[test]
252 fn peek_returns_length_for_valid_header() {
253 let buf = make_header(42, 1);
254 assert_eq!(
255 peek_message_length(&buf, MAX_MESSAGE_LEN).unwrap(),
256 Some(42)
257 );
258 }
259
260 #[test]
261 fn peek_rejects_bad_marker() {
262 let mut data = make_header(19, 4);
263 data[15] = 0x00;
264 assert!(matches!(
265 peek_message_length(&data, MAX_MESSAGE_LEN),
266 Err(DecodeError::InvalidMarker)
267 ));
268 }
269
270 #[test]
271 fn extended_accepts_4097() {
272 let mut buf = make_header(4097, 2).freeze();
273 let hdr = BgpHeader::decode(&mut buf, EXTENDED_MAX_MESSAGE_LEN).unwrap();
274 assert_eq!(hdr.length, 4097);
275 }
276
277 #[test]
278 fn standard_rejects_4097() {
279 let mut buf = make_header(4097, 2).freeze();
280 assert!(matches!(
281 BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
282 Err(DecodeError::InvalidLength { length: 4097 })
283 ));
284 }
285
286 #[test]
287 fn peek_extended_accepts_large() {
288 let buf = make_header(5000, 2);
289 assert_eq!(
290 peek_message_length(&buf, EXTENDED_MAX_MESSAGE_LEN).unwrap(),
291 Some(5000)
292 );
293 }
294
295 #[test]
296 fn peek_standard_rejects_large() {
297 let buf = make_header(5000, 2);
298 assert!(matches!(
299 peek_message_length(&buf, MAX_MESSAGE_LEN),
300 Err(DecodeError::InvalidLength { length: 5000 })
301 ));
302 }
303}