use core::net::IpAddr;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use prefix_trie::PrefixSet;
use tracing::debug;
use crate::ProtoError;
pub struct AccessControlSetBuilder(AccessControlSet);
impl<'a> AccessControlSetBuilder {
pub fn new(name: &'static str) -> Self {
Self(AccessControlSet::new(name))
}
pub fn allow(mut self, allow: impl Iterator<Item = &'a IpNet>) -> Self {
for network in allow {
debug!(name = self.0.name, ?network, "appending to allow list");
match network {
IpNet::V4(network) => {
self.0.v4_allow.insert(*network);
}
IpNet::V6(network) => {
self.0.v6_allow.insert(*network);
}
}
}
self
}
pub fn deny(mut self, deny: impl Iterator<Item = &'a IpNet>) -> Self {
for network in deny {
debug!(name = self.0.name, ?network, "appending to deny list");
match network {
IpNet::V4(network) => {
self.0.v4_deny.insert(*network);
}
IpNet::V6(network) => {
self.0.v6_deny.insert(*network);
}
}
}
self
}
pub fn clear_allow(mut self) -> Self {
self.0.v4_allow.clear();
self.0.v6_allow.clear();
self
}
pub fn clear_deny(mut self) -> Self {
self.0.v4_deny.clear();
self.0.v6_deny.clear();
self
}
pub fn build(self) -> Result<AccessControlSet, ProtoError> {
let deny_empty = self.0.v4_deny.is_empty() && self.0.v6_deny.is_empty();
let allowed_count = self.0.v4_allow.iter().count() + self.0.v6_allow.iter().count();
if deny_empty && allowed_count != 0 {
return Err(format!(
"access control set {name:?} has {allowed_count} allowed overrides, but no denied networks to override",
name = self.0.name
).into());
}
Ok(self.0)
}
}
#[derive(Clone, Debug)]
pub struct AccessControlSet {
name: &'static str,
v4_allow: PrefixSet<Ipv4Net>,
v4_deny: PrefixSet<Ipv4Net>,
v6_allow: PrefixSet<Ipv6Net>,
v6_deny: PrefixSet<Ipv6Net>,
}
impl AccessControlSet {
fn new(name: &'static str) -> Self {
Self {
name,
v4_allow: PrefixSet::new(),
v4_deny: PrefixSet::new(),
v6_allow: PrefixSet::new(),
v6_deny: PrefixSet::new(),
}
}
pub fn empty(name: &'static str) -> Self {
Self::new(name)
}
pub fn allows_all(&self) -> bool {
self.v4_deny.is_empty() && self.v6_deny.is_empty()
}
pub fn denied(&self, ip: IpAddr) -> bool {
if self.allows_all() {
return false;
}
match ip {
IpAddr::V4(ip) => {
self.v4_allow.get_spm(&ip.into()).is_none()
&& self.v4_deny.get_spm(&ip.into()).is_some()
}
IpAddr::V6(ip) => {
self.v6_allow.get_spm(&ip.into()).is_none()
&& self.v6_deny.get_spm(&ip.into()).is_some()
}
}
}
}
#[cfg(test)]
mod tests {
use crate::access_control::{AccessControlSet, AccessControlSetBuilder};
#[test]
fn access_control_set_networks_test() {
let acs = AccessControlSetBuilder::new("test acs")
.deny(
[
"10.0.0.0/8".parse().unwrap(),
"172.16.0.0/12".parse().unwrap(),
"192.168.0.0/16".parse().unwrap(),
"fe80::/10".parse().unwrap(),
]
.iter(),
)
.allow(
[
"10.1.0.3/29".parse().unwrap(),
"192.168.1.10/32".parse().unwrap(),
"fe80::200/128".parse().unwrap(),
]
.iter(),
)
.build()
.unwrap();
assert!(acs.denied([10, 0, 254, 254].into()));
assert!(!acs.denied([10, 1, 0, 0].into()));
assert!(!acs.denied([10, 1, 0, 3].into()));
assert!(!acs.denied([10, 1, 0, 7].into()));
assert!(acs.denied([10, 1, 0, 8].into()));
assert!(acs.denied([192, 168, 1, 1].into()));
assert!(!acs.denied([192, 168, 1, 10].into()));
assert!(!acs.denied([0xfe80, 0, 0, 0, 0, 0, 0, 0x200].into()));
assert!(acs.denied([0xfe80, 0, 0, 0, 0, 0, 0, 1].into()));
}
#[test]
fn access_control_semantics_test() {
struct TestCase {
name: &'static str,
in_deny: bool,
in_allow: bool,
expected_build_err: bool,
expected_denied: bool,
}
let test_cases = [
TestCase {
name: "deny=true, allow=false -> denied",
in_deny: true,
in_allow: false,
expected_build_err: false,
expected_denied: true,
},
TestCase {
name: "deny=false, allow=false -> allowed",
in_deny: false,
in_allow: false,
expected_build_err: false,
expected_denied: false,
},
TestCase {
name: "deny=true, allow=true -> allowed",
in_deny: true,
in_allow: true,
expected_build_err: false,
expected_denied: false,
},
TestCase {
name: "deny=false, allow=true -> allowed",
in_deny: false,
in_allow: true,
expected_build_err: true,
expected_denied: false,
},
];
let test_v4 = [192, 0, 2, 1].into();
let test_v4_net = "192.0.2.0/24".parse().unwrap();
let test_v6 = [0x2001, 0xdb8, 0, 0, 0, 0, 0, 1].into();
let test_v6_net = "2001:db8::/32".parse().unwrap();
for tc in &test_cases {
let mut builder = AccessControlSetBuilder::new(tc.name);
if tc.in_deny {
builder = builder.deny([test_v4_net, test_v6_net].iter());
}
if tc.in_allow {
builder = builder.allow([test_v4_net, test_v6_net].iter());
}
let Ok(acs) = builder.build() else {
match tc.expected_build_err {
true => continue,
false => panic!("unexpected builder error"),
}
};
assert_eq!(
acs.denied(test_v4),
tc.expected_denied,
"IPv4 case '{}' failed",
tc.name
);
assert_eq!(
acs.denied(test_v6),
tc.expected_denied,
"IPv6 case '{}' failed",
tc.name
);
}
}
#[test]
fn allows_all_test() {
let empty = AccessControlSet::empty("empty");
assert!(empty.allows_all());
let v4_only = AccessControlSetBuilder::new("v4 only")
.deny(["10.0.0.0/8".parse().unwrap()].iter())
.build()
.unwrap();
assert!(!v4_only.allows_all());
let v6_only = AccessControlSetBuilder::new("v6 only")
.deny(["fe80::/10".parse().unwrap()].iter())
.build()
.unwrap();
assert!(!v6_only.allows_all());
let both = AccessControlSetBuilder::new("both")
.deny(["10.0.0.0/8".parse().unwrap(), "fe80::/10".parse().unwrap()].iter())
.build()
.unwrap();
assert!(!both.allows_all());
}
#[test]
fn v4_only_deny_test() {
let acs = AccessControlSetBuilder::new("v4 only deny")
.deny(["10.0.0.0/8".parse().unwrap()].iter())
.build()
.unwrap();
assert!(!acs.allows_all());
assert!(acs.denied([10, 0, 0, 1].into()));
assert!(acs.denied([10, 255, 255, 255].into()));
assert!(!acs.denied([11, 0, 0, 1].into()));
assert!(!acs.denied([0xfe80, 0, 0, 0, 0, 0, 0, 1].into()));
}
#[test]
fn v6_only_deny_test() {
let acs = AccessControlSetBuilder::new("v6 only deny")
.deny(["fe80::/10".parse().unwrap()].iter())
.build()
.unwrap();
assert!(!acs.allows_all());
assert!(acs.denied([0xfe80, 0, 0, 0, 0, 0, 0, 1].into()));
assert!(
acs.denied(
[
0xfebf, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff
]
.into()
)
);
assert!(!acs.denied([0xfec0, 0, 0, 0, 0, 0, 0, 1].into()));
assert!(!acs.denied([10, 0, 0, 1].into()));
}
}