use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use crate::error::DecodeError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PmsiTunnelType {
NoTunnelInfo,
RsvpTeP2mp,
MldpP2mp,
PimSsm,
PimSm,
BidirPim,
IngressReplication,
MldpMp2mp,
Other(u8),
}
impl PmsiTunnelType {
#[must_use]
pub fn as_u8(self) -> u8 {
match self {
Self::NoTunnelInfo => 0,
Self::RsvpTeP2mp => 1,
Self::MldpP2mp => 2,
Self::PimSsm => 3,
Self::PimSm => 4,
Self::BidirPim => 5,
Self::IngressReplication => 6,
Self::MldpMp2mp => 7,
Self::Other(v) => v,
}
}
#[must_use]
pub fn from_u8(v: u8) -> Self {
match v {
0 => Self::NoTunnelInfo,
1 => Self::RsvpTeP2mp,
2 => Self::MldpP2mp,
3 => Self::PimSsm,
4 => Self::PimSm,
5 => Self::BidirPim,
6 => Self::IngressReplication,
7 => Self::MldpMp2mp,
other => Self::Other(other),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PmsiTunnelIdentifier {
Empty,
Ipv4(Ipv4Addr),
Ipv6(Ipv6Addr),
Raw(Vec<u8>),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PmsiTunnel {
pub flags: u8,
pub tunnel_type: PmsiTunnelType,
pub mpls_label: u32,
pub tunnel_identifier: PmsiTunnelIdentifier,
}
impl PmsiTunnel {
#[must_use]
pub fn for_evpn_ingress_replication(vni: u32, originator: IpAddr) -> Self {
let tunnel_identifier = match originator {
IpAddr::V4(v4) => PmsiTunnelIdentifier::Ipv4(v4),
IpAddr::V6(v6) => PmsiTunnelIdentifier::Ipv6(v6),
};
Self {
flags: 0,
tunnel_type: PmsiTunnelType::IngressReplication,
mpls_label: vni & 0x00FF_FFFF,
tunnel_identifier,
}
}
pub fn encode(&self, buf: &mut Vec<u8>) {
buf.push(self.flags);
buf.push(self.tunnel_type.as_u8());
let label = self.mpls_label & 0x00FF_FFFF;
buf.push(((label >> 16) & 0xff) as u8);
buf.push(((label >> 8) & 0xff) as u8);
buf.push((label & 0xff) as u8);
match &self.tunnel_identifier {
PmsiTunnelIdentifier::Empty => {}
PmsiTunnelIdentifier::Ipv4(v4) => buf.extend_from_slice(&v4.octets()),
PmsiTunnelIdentifier::Ipv6(v6) => buf.extend_from_slice(&v6.octets()),
PmsiTunnelIdentifier::Raw(bytes) => buf.extend_from_slice(bytes),
}
}
pub fn decode(value: &[u8]) -> Result<Self, DecodeError> {
if value.len() < 5 {
return Err(DecodeError::MalformedField {
message_type: "UPDATE",
detail: format!(
"PMSI Tunnel attribute truncated: need ≥5 bytes (flags+type+label), got {}",
value.len()
),
});
}
let flags = value[0];
let tunnel_type = PmsiTunnelType::from_u8(value[1]);
let label = (u32::from(value[2]) << 16) | (u32::from(value[3]) << 8) | u32::from(value[4]);
let rest = &value[5..];
let tunnel_identifier = match (tunnel_type, rest.len()) {
(_, 0) => PmsiTunnelIdentifier::Empty,
(PmsiTunnelType::IngressReplication, 4) => {
let mut o = [0u8; 4];
o.copy_from_slice(rest);
PmsiTunnelIdentifier::Ipv4(Ipv4Addr::from(o))
}
(PmsiTunnelType::IngressReplication, 16) => {
let mut o = [0u8; 16];
o.copy_from_slice(rest);
PmsiTunnelIdentifier::Ipv6(Ipv6Addr::from(o))
}
_ => PmsiTunnelIdentifier::Raw(rest.to_vec()),
};
Ok(Self {
flags,
tunnel_type,
mpls_label: label,
tunnel_identifier,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip(t: &PmsiTunnel) {
let mut buf = Vec::new();
t.encode(&mut buf);
let decoded = PmsiTunnel::decode(&buf).expect("decode");
assert_eq!(&decoded, t);
}
#[test]
fn ingress_replication_ipv4_roundtrip() {
let t = PmsiTunnel::for_evpn_ingress_replication(100, "10.0.0.1".parse().unwrap());
roundtrip(&t);
assert_eq!(t.mpls_label, 100);
assert_eq!(t.tunnel_type, PmsiTunnelType::IngressReplication);
assert_eq!(
t.tunnel_identifier,
PmsiTunnelIdentifier::Ipv4(Ipv4Addr::new(10, 0, 0, 1))
);
}
#[test]
fn ingress_replication_ipv6_roundtrip() {
let t = PmsiTunnel::for_evpn_ingress_replication(50, "2001:db8::1".parse().unwrap());
roundtrip(&t);
assert_eq!(t.mpls_label, 50);
}
#[test]
fn ingress_replication_ipv4_wire_bytes_match_rfc_8365() {
let t = PmsiTunnel::for_evpn_ingress_replication(100, "10.0.0.1".parse().unwrap());
let mut buf = Vec::new();
t.encode(&mut buf);
assert_eq!(
buf,
vec![
0x00, 0x06, 0x00, 0x00, 0x64, 10, 0, 0, 1, ]
);
}
#[test]
fn no_tunnel_info_with_no_identifier_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::NoTunnelInfo,
mpls_label: 0,
tunnel_identifier: PmsiTunnelIdentifier::Empty,
};
roundtrip(&t);
}
#[test]
fn rsvp_te_p2mp_with_opaque_id_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::RsvpTeP2mp,
mpls_label: 0x1234 << 4,
tunnel_identifier: PmsiTunnelIdentifier::Raw(vec![1, 2, 3, 4, 5, 6, 7, 8]),
};
roundtrip(&t);
}
#[test]
fn mldp_p2mp_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::MldpP2mp,
mpls_label: 42 << 4,
tunnel_identifier: PmsiTunnelIdentifier::Raw(vec![0xaa; 12]),
};
roundtrip(&t);
}
#[test]
fn pim_ssm_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::PimSsm,
mpls_label: 0,
tunnel_identifier: PmsiTunnelIdentifier::Raw(vec![10, 0, 0, 1, 224, 0, 0, 1]),
};
roundtrip(&t);
}
#[test]
fn pim_sm_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::PimSm,
mpls_label: 0,
tunnel_identifier: PmsiTunnelIdentifier::Raw(vec![0xff; 8]),
};
roundtrip(&t);
}
#[test]
fn bidir_pim_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::BidirPim,
mpls_label: 0,
tunnel_identifier: PmsiTunnelIdentifier::Raw(vec![1, 2, 3]),
};
roundtrip(&t);
}
#[test]
fn mldp_mp2mp_roundtrip() {
let t = PmsiTunnel {
flags: 0,
tunnel_type: PmsiTunnelType::MldpMp2mp,
mpls_label: 0,
tunnel_identifier: PmsiTunnelIdentifier::Empty,
};
roundtrip(&t);
}
#[test]
fn unknown_tunnel_type_round_trips_without_loss() {
let t = PmsiTunnel {
flags: 0x01, tunnel_type: PmsiTunnelType::Other(99),
mpls_label: 0x00ab_cdef,
tunnel_identifier: PmsiTunnelIdentifier::Raw(vec![0xde, 0xad, 0xbe, 0xef]),
};
roundtrip(&t);
}
#[test]
fn decode_rejects_truncated_value() {
let buf = [0u8, 6u8, 0u8, 0u8]; let err = PmsiTunnel::decode(&buf).unwrap_err();
assert!(matches!(err, DecodeError::MalformedField { .. }));
}
#[test]
fn decode_zero_length_tunnel_id_after_label_yields_empty() {
let buf = [0u8, 1u8, 0u8, 0u8, 0x10]; let t = PmsiTunnel::decode(&buf).unwrap();
assert_eq!(t.tunnel_identifier, PmsiTunnelIdentifier::Empty);
}
#[test]
fn ingress_replication_with_8_byte_id_treated_as_raw() {
let buf = [
0u8, 6u8, 0u8, 0u8, 0u8, 1, 2, 3, 4, 5, 6, 7, 8, ];
let t = PmsiTunnel::decode(&buf).unwrap();
assert!(matches!(t.tunnel_identifier, PmsiTunnelIdentifier::Raw(_)));
}
#[test]
fn for_evpn_ingress_replication_masks_vni_at_24_bits() {
let t = PmsiTunnel::for_evpn_ingress_replication(0xFF00_1234, "10.0.0.1".parse().unwrap());
assert_eq!(t.mpls_label, 0x0000_1234);
}
}