use bytes::{Buf, BufMut};
use crate::constants::{self, HEADER_LEN, MARKER, MARKER_LEN, MIN_MESSAGE_LEN};
use crate::error::DecodeError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum MessageType {
Open = 1,
Update = 2,
Notification = 3,
Keepalive = 4,
RouteRefresh = 5,
}
impl MessageType {
#[must_use]
pub fn from_u8(value: u8) -> Option<Self> {
match value {
constants::message_type::OPEN => Some(Self::Open),
constants::message_type::UPDATE => Some(Self::Update),
constants::message_type::NOTIFICATION => Some(Self::Notification),
constants::message_type::KEEPALIVE => Some(Self::Keepalive),
constants::message_type::ROUTE_REFRESH => Some(Self::RouteRefresh),
_ => None,
}
}
#[must_use]
pub fn as_u8(self) -> u8 {
self as u8
}
}
impl std::fmt::Display for MessageType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Open => write!(f, "OPEN"),
Self::Update => write!(f, "UPDATE"),
Self::Notification => write!(f, "NOTIFICATION"),
Self::Keepalive => write!(f, "KEEPALIVE"),
Self::RouteRefresh => write!(f, "ROUTE-REFRESH"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BgpHeader {
pub length: u16,
pub message_type: MessageType,
}
impl BgpHeader {
pub fn decode(buf: &mut impl Buf, max_message_len: u16) -> Result<Self, DecodeError> {
if buf.remaining() < HEADER_LEN {
return Err(DecodeError::Incomplete {
needed: HEADER_LEN,
available: buf.remaining(),
});
}
let mut marker = [0u8; MARKER_LEN];
buf.copy_to_slice(&mut marker);
if marker != MARKER {
return Err(DecodeError::InvalidMarker);
}
let length = buf.get_u16();
if !(MIN_MESSAGE_LEN..=max_message_len).contains(&length) {
return Err(DecodeError::InvalidLength { length });
}
let type_byte = buf.get_u8();
let message_type =
MessageType::from_u8(type_byte).ok_or(DecodeError::UnknownMessageType(type_byte))?;
Ok(Self {
length,
message_type,
})
}
pub fn encode(&self, buf: &mut impl BufMut) {
buf.put_slice(&MARKER);
buf.put_u16(self.length);
buf.put_u8(self.message_type.as_u8());
}
}
pub fn peek_message_length(buf: &[u8], max_message_len: u16) -> Result<Option<u16>, DecodeError> {
if buf.len() < HEADER_LEN {
return Ok(None);
}
if buf[..MARKER_LEN] != MARKER {
return Err(DecodeError::InvalidMarker);
}
let length = u16::from_be_bytes([buf[16], buf[17]]);
if !(MIN_MESSAGE_LEN..=max_message_len).contains(&length) {
return Err(DecodeError::InvalidLength { length });
}
let type_byte = buf[18];
if MessageType::from_u8(type_byte).is_none() {
return Err(DecodeError::UnknownMessageType(type_byte));
}
Ok(Some(length))
}
#[cfg(test)]
mod tests {
use bytes::BytesMut;
use super::*;
use crate::constants::{EXTENDED_MAX_MESSAGE_LEN, MAX_MESSAGE_LEN};
fn make_header(length: u16, msg_type: u8) -> BytesMut {
let mut buf = BytesMut::with_capacity(HEADER_LEN);
buf.put_slice(&MARKER);
buf.put_u16(length);
buf.put_u8(msg_type);
buf
}
#[test]
fn decode_valid_keepalive_header() {
let mut buf = make_header(19, 4).freeze();
let hdr = BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN).unwrap();
assert_eq!(hdr.length, 19);
assert_eq!(hdr.message_type, MessageType::Keepalive);
assert_eq!(buf.remaining(), 0);
}
#[test]
fn decode_valid_open_header() {
let mut buf = make_header(29, 1).freeze();
let hdr = BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN).unwrap();
assert_eq!(hdr.message_type, MessageType::Open);
}
#[test]
fn reject_invalid_marker() {
let mut data = make_header(19, 4);
data[0] = 0x00; let mut buf = data.freeze();
assert!(matches!(
BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
Err(DecodeError::InvalidMarker)
));
}
#[test]
fn reject_length_too_small() {
let mut buf = make_header(18, 4).freeze();
assert!(matches!(
BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
Err(DecodeError::InvalidLength { length: 18 })
));
}
#[test]
fn reject_length_too_large() {
let mut buf = make_header(4097, 4).freeze();
assert!(matches!(
BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
Err(DecodeError::InvalidLength { length: 4097 })
));
}
#[test]
fn reject_unknown_type() {
let mut buf = make_header(19, 99).freeze();
assert!(matches!(
BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
Err(DecodeError::UnknownMessageType(99))
));
}
#[test]
fn reject_incomplete_buffer() {
let mut buf = bytes::Bytes::from_static(&[0xFF; 10]);
assert!(matches!(
BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
Err(DecodeError::Incomplete { .. })
));
}
#[test]
fn roundtrip_header() {
let original = BgpHeader {
length: 100,
message_type: MessageType::Update,
};
let mut encoded = BytesMut::with_capacity(HEADER_LEN);
original.encode(&mut encoded);
let mut buf = encoded.freeze();
let decoded = BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn peek_returns_none_for_short_buffer() {
assert_eq!(
peek_message_length(&[0xFF; 10], MAX_MESSAGE_LEN).unwrap(),
None
);
}
#[test]
fn peek_returns_length_for_valid_header() {
let buf = make_header(42, 1);
assert_eq!(
peek_message_length(&buf, MAX_MESSAGE_LEN).unwrap(),
Some(42)
);
}
#[test]
fn peek_rejects_bad_marker() {
let mut data = make_header(19, 4);
data[15] = 0x00;
assert!(matches!(
peek_message_length(&data, MAX_MESSAGE_LEN),
Err(DecodeError::InvalidMarker)
));
}
#[test]
fn extended_accepts_4097() {
let mut buf = make_header(4097, 2).freeze();
let hdr = BgpHeader::decode(&mut buf, EXTENDED_MAX_MESSAGE_LEN).unwrap();
assert_eq!(hdr.length, 4097);
}
#[test]
fn standard_rejects_4097() {
let mut buf = make_header(4097, 2).freeze();
assert!(matches!(
BgpHeader::decode(&mut buf, MAX_MESSAGE_LEN),
Err(DecodeError::InvalidLength { length: 4097 })
));
}
#[test]
fn peek_extended_accepts_large() {
let buf = make_header(5000, 2);
assert_eq!(
peek_message_length(&buf, EXTENDED_MAX_MESSAGE_LEN).unwrap(),
Some(5000)
);
}
#[test]
fn peek_standard_rejects_large() {
let buf = make_header(5000, 2);
assert!(matches!(
peek_message_length(&buf, MAX_MESSAGE_LEN),
Err(DecodeError::InvalidLength { length: 5000 })
));
}
}