use std::net::IpAddr;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use prefix_trie::{Prefix, PrefixSet};
#[derive(Default)]
pub(crate) struct AccessControl {
ipv4: InnerAccessControl<Ipv4Net>,
ipv6: InnerAccessControl<Ipv6Net>,
}
impl AccessControl {
pub(crate) fn insert_deny(&mut self, networks: impl IntoIterator<Item = IpNet>) {
for network in networks {
match network {
IpNet::V4(v4) => {
self.ipv4.deny.insert(v4);
}
IpNet::V6(v6) => {
self.ipv6.deny.insert(v6);
}
}
}
}
pub(crate) fn insert_allow(&mut self, networks: impl IntoIterator<Item = IpNet>) {
for network in networks {
match network {
IpNet::V4(v4) => {
self.ipv4.allow.insert(v4);
}
IpNet::V6(v6) => {
self.ipv6.allow.insert(v6);
}
}
}
}
#[must_use]
pub(crate) fn allow(&self, ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let v4 = Ipv4Net::from(v4);
self.ipv4.allow(&v4)
}
IpAddr::V6(v6) => {
let v6 = Ipv6Net::from(v6);
self.ipv6.allow(&v6)
}
}
}
}
#[derive(Default)]
struct InnerAccessControl<I: Prefix> {
allow: PrefixSet<I>,
deny: PrefixSet<I>,
}
impl<I: Prefix> InnerAccessControl<I> {
#[must_use]
fn allow(&self, ip: &I) -> bool {
match (self.deny.get_lpm(ip), self.allow.get_lpm(ip)) {
(Some(denied), Some(allowed)) => allowed.prefix_len() > denied.prefix_len(),
(Some(_denied), None) => false,
(None, Some(_allowed)) => true,
(None, None) => match (
self.deny.iter().next().is_some(),
self.allow.iter().next().is_some(),
) {
(true, _) => true, (false, true) => false, (false, false) => true, },
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_none() {
let access = AccessControl::default();
assert!(access.allow("192.168.1.1".parse().unwrap()));
assert!(access.allow("fd00::1".parse().unwrap()));
}
#[test]
fn test_v4() {
let mut access = AccessControl::default();
access.insert_allow(["192.168.1.0/24".parse().unwrap()]);
assert!(access.allow("192.168.1.1".parse().unwrap()));
assert!(access.allow("192.168.1.255".parse().unwrap()));
assert!(!access.allow("192.168.2.1".parse().unwrap()));
assert!(!access.allow("192.168.0.0".parse().unwrap()));
}
#[test]
fn test_v6() {
let mut access = AccessControl::default();
access.insert_allow(["fd00::/120".parse().unwrap()]);
assert!(access.allow("fd00::1".parse().unwrap()));
assert!(access.allow("fd00::00ff".parse().unwrap()));
assert!(!access.allow("fd00::ffff".parse().unwrap()));
assert!(!access.allow("fd00::1:1".parse().unwrap()));
}
#[test]
fn test_deny_v4() {
let mut access = AccessControl::default();
access.insert_deny(["192.168.1.0/24".parse().unwrap()]);
assert!(!access.allow("192.168.1.1".parse().unwrap()));
assert!(!access.allow("192.168.1.255".parse().unwrap()));
assert!(access.allow("192.168.2.1".parse().unwrap()));
assert!(access.allow("192.168.0.0".parse().unwrap()));
}
#[test]
fn test_deny_v6() {
let mut access = AccessControl::default();
access.insert_deny(["fd00::/120".parse().unwrap()]);
assert!(!access.allow("fd00::1".parse().unwrap()));
assert!(!access.allow("fd00::00ff".parse().unwrap()));
assert!(access.allow("fd00::ffff".parse().unwrap()));
assert!(access.allow("fd00::1:1".parse().unwrap()));
}
#[test]
fn test_deny_allow_v4() {
let mut access = AccessControl::default();
access.insert_deny(["192.168.0.0/16".parse().unwrap()]);
access.insert_allow(["192.168.1.0/24".parse().unwrap()]);
assert!(access.allow("192.168.1.1".parse().unwrap()));
assert!(access.allow("192.168.1.255".parse().unwrap()));
assert!(!access.allow("192.168.2.1".parse().unwrap()));
assert!(!access.allow("192.168.0.0".parse().unwrap()));
assert!(access.allow("10.0.0.1".parse().unwrap()));
}
}