use std::{collections::BTreeMap, sync::Arc};
use anyhow::bail;
use ipnet::IpNet;
use scion_proto::address::{IsdAsn, ScionAddr};
use crate::network::local::receivers::Receiver;
#[derive(Default, Debug, Clone)]
pub struct NetworkReceiverRegistry {
receivers: BTreeMap<IsdAsn, LocalNetworkReceivers>,
}
impl NetworkReceiverRegistry {
pub fn new() -> Self {
Self {
receivers: BTreeMap::new(),
}
}
pub fn add_wildcard_receiver(
&mut self,
ias: IsdAsn,
receiver: Arc<dyn Receiver>,
) -> anyhow::Result<()> {
if self.receivers.contains_key(&ias) {
bail!("A Receiver for ISD-AS {} already exists", ias);
}
self.receivers.insert(
ias,
LocalNetworkReceivers::WildcardReceiver {
receivers: receiver,
},
);
Ok(())
}
pub fn add_receiver(
&mut self,
ias: IsdAsn,
ipnet: IpNet,
receiver: Arc<dyn Receiver>,
) -> anyhow::Result<()> {
let recvs = self.receivers.entry(ias).or_insert_with(|| {
LocalNetworkReceivers::ByAddressRanges {
receivers: Vec::new(),
}
});
let LocalNetworkReceivers::ByAddressRanges { receivers } = recvs else {
bail!("Receiver for ISD-AS {} is already a wildcard receiver", ias);
};
if let Some((overlap_net, _)) = receivers.iter().find(|(net, _)| net.contains(&ipnet)) {
bail!(
"ISD-AS {ias} has a receiver with overlapping IP range. existing: {overlap_net} overlaps with {ipnet}",
);
};
receivers.push((ipnet, receiver));
Ok(())
}
pub fn by_addr(&self, address: ScionAddr) -> Option<&Arc<dyn Receiver>> {
let dest_ip = address.local_address()?;
let ias = address.isd_asn();
self.receivers.get(&ias).and_then(|registration| {
match registration {
LocalNetworkReceivers::WildcardReceiver {
receivers: receiver,
} => Some(receiver),
LocalNetworkReceivers::ByAddressRanges { receivers } => {
receivers.iter().find_map(|(ipnet, receiver)| {
if ipnet.contains(&dest_ip) {
Some(receiver)
} else {
None
}
})
}
}
})
}
}
#[derive(Clone)]
enum LocalNetworkReceivers {
ByAddressRanges {
receivers: Vec<(IpNet, Arc<dyn Receiver>)>,
},
WildcardReceiver { receivers: Arc<dyn Receiver> },
}
impl std::fmt::Debug for LocalNetworkReceivers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ByAddressRanges { receivers } => {
f.debug_struct("ByAddressRanges")
.field("receivers", &receivers.len())
.finish()
}
Self::WildcardReceiver { .. } => f.debug_struct("Wildcard").finish(),
}
}
}
#[cfg(test)]
mod tests {
use std::{net::Ipv4Addr, str::FromStr};
use scion_proto::packet::ScionPacketRaw;
use super::*;
#[test]
fn should_get_wildcard_receiver_by_isd_as() {
let mut receivers = NetworkReceiverRegistry::new();
let ias = IsdAsn::from_str("1-2").unwrap();
let receiver: Arc<dyn Receiver> = Arc::new(MockReceiver);
receivers
.add_wildcard_receiver(ias, receiver.clone())
.unwrap();
let addr = ScionAddr::new(ias, Ipv4Addr::from_str("10.0.0.1").unwrap().into());
let found = receivers.by_addr(addr).unwrap();
assert!(Arc::ptr_eq(found, &receiver));
}
#[test]
fn should_get_receiver_by_ip() {
let mut receivers = NetworkReceiverRegistry::new();
let ias = IsdAsn::from_str("1-2").unwrap();
let ipnet: IpNet = "10.0.0.0/24".parse().unwrap();
let receiver: Arc<dyn Receiver> = Arc::new(MockReceiver);
receivers
.add_receiver(ias, ipnet, receiver.clone())
.unwrap();
let addr = ScionAddr::new(ias, Ipv4Addr::from_str("10.0.0.42").unwrap().into());
let found = receivers.by_addr(addr).unwrap();
assert!(Arc::ptr_eq(found, &receiver));
}
#[test]
fn should_get_receiver_by_ip_multiple_ranges() {
let mut receivers = NetworkReceiverRegistry::new();
let ias = IsdAsn::from_str("1-2").unwrap();
let ipnet1: IpNet = "10.0.0.0/24".parse().unwrap();
let ipnet2: IpNet = "10.0.1.0/24".parse().unwrap();
let receiver1: Arc<dyn Receiver> = Arc::new(MockReceiver);
let receiver2: Arc<dyn Receiver> = Arc::new(MockReceiver);
receivers
.add_receiver(ias, ipnet1, receiver1.clone())
.unwrap();
receivers
.add_receiver(ias, ipnet2, receiver2.clone())
.unwrap();
let addr1 = ScionAddr::new(ias, Ipv4Addr::from_str("10.0.0.42").unwrap().into());
let addr2 = ScionAddr::new(ias, Ipv4Addr::from_str("10.0.1.99").unwrap().into());
let found1 = receivers.by_addr(addr1).unwrap();
let found2 = receivers.by_addr(addr2).unwrap();
assert!(Arc::ptr_eq(found1, &receiver1));
assert!(Arc::ptr_eq(found2, &receiver2));
}
#[test]
fn should_return_none_for_ip_with_no_receiver() {
let mut receivers = NetworkReceiverRegistry::new();
let ias = IsdAsn::from_str("1-2").unwrap();
let ipnet: IpNet = "10.0.0.0/24".parse().unwrap();
let receiver: Arc<dyn Receiver> = Arc::new(MockReceiver);
receivers.add_receiver(ias, ipnet, receiver).unwrap();
let addr = ScionAddr::new(ias, Ipv4Addr::from_str("10.0.1.42").unwrap().into()); let found = receivers.by_addr(addr);
assert!(found.is_none());
}
#[test]
fn should_fail_to_add_receiver_with_overlapping_ip_ranges() {
let mut receivers = NetworkReceiverRegistry::new();
let ias = IsdAsn::from_str("1-2").unwrap();
let ipnet1: IpNet = "10.0.0.0/24".parse().unwrap();
let ipnet2: IpNet = "10.0.0.128/25".parse().unwrap();
let receiver1: Arc<dyn Receiver> = Arc::new(MockReceiver);
let receiver2: Arc<dyn Receiver> = Arc::new(MockReceiver);
receivers.add_receiver(ias, ipnet1, receiver1).unwrap();
let result = receivers.add_receiver(ias, ipnet2, receiver2);
assert!(result.is_err());
}
#[test]
fn should_fail_to_add_receiver_if_wildcard_receiver_exists() {
let mut receivers = NetworkReceiverRegistry::new();
let ias = IsdAsn::from_str("1-2").unwrap();
let receiver1: Arc<dyn Receiver> = Arc::new(MockReceiver);
let receiver2: Arc<dyn Receiver> = Arc::new(MockReceiver);
receivers.add_wildcard_receiver(ias, receiver1).unwrap();
let result = receivers.add_wildcard_receiver(ias, receiver2.clone());
assert!(result.is_err());
let ipnet1: IpNet = "10.0.0.0/24".parse().unwrap();
let result = receivers.add_receiver(ias, ipnet1, receiver2);
assert!(result.is_err());
}
#[derive(Default)]
struct MockReceiver;
impl Receiver for MockReceiver {
fn receive_packet(&self, _packet: ScionPacketRaw) {
}
}
}