use anyhow::{Context, Result};
use ipnetwork::IpNetwork;
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AccessPolicy {
Allow,
Deny,
}
#[derive(Debug, Clone)]
pub struct IpAccessControl {
allowed: Vec<IpNetwork>,
blocked: Vec<IpNetwork>,
default_policy: AccessPolicy,
}
impl Default for IpAccessControl {
fn default() -> Self {
Self::new()
}
}
impl IpAccessControl {
pub fn new() -> Self {
Self {
allowed: Vec::new(),
blocked: Vec::new(),
default_policy: AccessPolicy::Allow,
}
}
pub fn from_config(allowed_ips: &[String], blocked_ips: &[String]) -> Result<Self> {
let mut ctrl = Self::new();
for cidr in allowed_ips {
let network: IpNetwork = cidr
.parse()
.with_context(|| format!("Invalid allowed_ips CIDR: {}", cidr))?;
ctrl.allowed.push(network);
}
for cidr in blocked_ips {
let network: IpNetwork = cidr
.parse()
.with_context(|| format!("Invalid blocked_ips CIDR: {}", cidr))?;
ctrl.blocked.push(network);
}
if !ctrl.allowed.is_empty() {
ctrl.default_policy = AccessPolicy::Deny;
}
tracing::info!(
allowed_count = ctrl.allowed.len(),
blocked_count = ctrl.blocked.len(),
default_policy = ?ctrl.default_policy,
"IP access control configured"
);
Ok(ctrl)
}
pub fn check(&self, ip: &IpAddr) -> AccessPolicy {
for network in &self.blocked {
if network.contains(*ip) {
tracing::debug!(
ip = %ip,
network = %network,
"IP blocked by rule"
);
return AccessPolicy::Deny;
}
}
if !self.allowed.is_empty() {
for network in &self.allowed {
if network.contains(*ip) {
tracing::trace!(
ip = %ip,
network = %network,
"IP allowed by rule"
);
return AccessPolicy::Allow;
}
}
tracing::debug!(
ip = %ip,
"IP not in allowed list"
);
return AccessPolicy::Deny;
}
self.default_policy
}
pub fn allow_cidr(&mut self, cidr: &str) -> Result<()> {
let network: IpNetwork = cidr
.parse()
.with_context(|| format!("Invalid CIDR: {}", cidr))?;
self.allow(network);
Ok(())
}
pub fn allow(&mut self, network: IpNetwork) {
if !self.allowed.contains(&network) {
self.allowed.push(network);
if self.default_policy == AccessPolicy::Allow {
self.default_policy = AccessPolicy::Deny;
}
tracing::info!(network = %network, "Added to allowed list");
}
}
pub fn block_cidr(&mut self, cidr: &str) -> Result<()> {
let network: IpNetwork = cidr
.parse()
.with_context(|| format!("Invalid CIDR: {}", cidr))?;
self.block(network);
Ok(())
}
pub fn block(&mut self, network: IpNetwork) {
if !self.blocked.contains(&network) {
self.blocked.push(network);
tracing::info!(network = %network, "Added to blocked list");
}
}
pub fn block_ip(&mut self, ip: IpAddr) {
let network = IpNetwork::from(ip);
self.block(network);
}
pub fn unblock_ip(&mut self, ip: IpAddr) {
let network = IpNetwork::from(ip);
let before = self.blocked.len();
self.blocked.retain(|n| n != &network);
if self.blocked.len() < before {
tracing::info!(ip = %ip, "Removed from blocked list");
}
}
pub fn remove_allowed(&mut self, network: &IpNetwork) {
let before = self.allowed.len();
self.allowed.retain(|n| n != network);
if self.allowed.len() < before {
tracing::info!(network = %network, "Removed from allowed list");
}
if self.allowed.is_empty() {
self.default_policy = AccessPolicy::Allow;
}
}
pub fn remove_blocked(&mut self, network: &IpNetwork) {
let before = self.blocked.len();
self.blocked.retain(|n| n != network);
if self.blocked.len() < before {
tracing::info!(network = %network, "Removed from blocked list");
}
}
pub fn reload(&mut self, allowed_ips: &[String], blocked_ips: &[String]) -> Result<()> {
let new_config = Self::from_config(allowed_ips, blocked_ips)?;
*self = new_config;
tracing::info!("IP access control reloaded");
Ok(())
}
pub fn allowed_count(&self) -> usize {
self.allowed.len()
}
pub fn blocked_count(&self) -> usize {
self.blocked.len()
}
pub fn default_policy(&self) -> AccessPolicy {
self.default_policy
}
pub fn is_whitelist_mode(&self) -> bool {
!self.allowed.is_empty()
}
pub fn allowed_networks(&self) -> Vec<IpNetwork> {
self.allowed.clone()
}
pub fn blocked_networks(&self) -> Vec<IpNetwork> {
self.blocked.clone()
}
}
#[derive(Clone)]
pub struct SharedIpAccessControl {
inner: Arc<RwLock<IpAccessControl>>,
}
impl SharedIpAccessControl {
pub fn new(access: IpAccessControl) -> Self {
Self {
inner: Arc::new(RwLock::new(access)),
}
}
pub async fn check(&self, ip: &IpAddr) -> AccessPolicy {
self.inner.read().await.check(ip)
}
pub fn check_sync(&self, ip: &IpAddr) -> AccessPolicy {
if let Ok(guard) = self.inner.try_read() {
return guard.check(ip);
}
tracing::warn!(
ip = %ip,
"Access control lock contended, denying for security"
);
AccessPolicy::Deny
}
pub async fn block_ip(&self, ip: IpAddr) {
self.inner.write().await.block_ip(ip);
}
pub async fn unblock_ip(&self, ip: IpAddr) {
self.inner.write().await.unblock_ip(ip);
}
pub async fn reload(&self, allowed_ips: &[String], blocked_ips: &[String]) -> Result<()> {
self.inner.write().await.reload(allowed_ips, blocked_ips)
}
pub async fn stats(&self) -> (usize, usize, AccessPolicy) {
let guard = self.inner.read().await;
(
guard.allowed_count(),
guard.blocked_count(),
guard.default_policy(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_default_allow() {
let access = IpAccessControl::new();
let ip: IpAddr = "192.168.1.100".parse().unwrap();
assert_eq!(access.check(&ip), AccessPolicy::Allow);
}
#[test]
fn test_cidr_matching() {
let mut access = IpAccessControl::new();
access.allow_cidr("192.168.0.0/16").unwrap();
let ip_in: IpAddr = "192.168.1.100".parse().unwrap();
assert_eq!(access.check(&ip_in), AccessPolicy::Allow);
let ip_out: IpAddr = "10.0.0.1".parse().unwrap();
assert_eq!(access.check(&ip_out), AccessPolicy::Deny);
}
#[test]
fn test_whitelist_mode() {
let access = IpAccessControl::from_config(
&["10.0.0.0/8".to_string(), "192.168.0.0/16".to_string()],
&[],
)
.unwrap();
assert!(access.is_whitelist_mode());
assert_eq!(access.default_policy(), AccessPolicy::Deny);
assert_eq!(
access.check(&"10.1.2.3".parse().unwrap()),
AccessPolicy::Allow
);
assert_eq!(
access.check(&"192.168.100.1".parse().unwrap()),
AccessPolicy::Allow
);
assert_eq!(
access.check(&"8.8.8.8".parse().unwrap()),
AccessPolicy::Deny
);
}
#[test]
fn test_blacklist_priority() {
let access = IpAccessControl::from_config(
&["192.168.0.0/16".to_string()],
&["192.168.100.0/24".to_string()],
)
.unwrap();
assert_eq!(
access.check(&"192.168.1.1".parse().unwrap()),
AccessPolicy::Allow
);
assert_eq!(
access.check(&"192.168.100.50".parse().unwrap()),
AccessPolicy::Deny
);
}
#[test]
fn test_single_ip_blocking() {
let mut access = IpAccessControl::new();
let ip: IpAddr = "10.10.10.10".parse().unwrap();
assert_eq!(access.check(&ip), AccessPolicy::Allow);
access.block_ip(ip);
assert_eq!(access.check(&ip), AccessPolicy::Deny);
assert_eq!(
access.check(&"10.10.10.11".parse().unwrap()),
AccessPolicy::Allow
);
access.unblock_ip(ip);
assert_eq!(access.check(&ip), AccessPolicy::Allow);
}
#[test]
fn test_ipv6_support() {
let mut access = IpAccessControl::new();
access.allow_cidr("2001:db8::/32").unwrap();
let ip_in: IpAddr = IpAddr::V6("2001:db8::1".parse::<Ipv6Addr>().unwrap());
assert_eq!(access.check(&ip_in), AccessPolicy::Allow);
let ip_out: IpAddr = IpAddr::V6("2001:db9::1".parse::<Ipv6Addr>().unwrap());
assert_eq!(access.check(&ip_out), AccessPolicy::Deny);
}
#[test]
fn test_mixed_ipv4_ipv6() {
let access = IpAccessControl::from_config(
&["192.168.0.0/16".to_string(), "2001:db8::/32".to_string()],
&[],
)
.unwrap();
assert_eq!(
access.check(&"192.168.1.1".parse().unwrap()),
AccessPolicy::Allow
);
let ipv6: IpAddr = IpAddr::V6("2001:db8::1".parse().unwrap());
assert_eq!(access.check(&ipv6), AccessPolicy::Allow);
assert_eq!(
access.check(&"10.0.0.1".parse().unwrap()),
AccessPolicy::Deny
);
}
#[test]
fn test_invalid_cidr() {
let result = IpAccessControl::from_config(&["not-a-cidr".to_string()], &[]);
assert!(result.is_err());
}
#[test]
fn test_reload() {
let mut access = IpAccessControl::from_config(&["10.0.0.0/8".to_string()], &[]).unwrap();
assert_eq!(
access.check(&"192.168.1.1".parse().unwrap()),
AccessPolicy::Deny
);
access.reload(&["192.168.0.0/16".to_string()], &[]).unwrap();
assert_eq!(
access.check(&"192.168.1.1".parse().unwrap()),
AccessPolicy::Allow
);
assert_eq!(
access.check(&"10.1.1.1".parse().unwrap()),
AccessPolicy::Deny
);
}
#[test]
fn test_remove_networks() {
let mut access = IpAccessControl::new();
access.allow_cidr("10.0.0.0/8").unwrap();
access.block_cidr("192.168.0.0/16").unwrap();
assert_eq!(access.allowed_count(), 1);
assert_eq!(access.blocked_count(), 1);
let allowed_net: IpNetwork = "10.0.0.0/8".parse().unwrap();
access.remove_allowed(&allowed_net);
assert_eq!(access.allowed_count(), 0);
assert_eq!(access.default_policy(), AccessPolicy::Allow);
let blocked_net: IpNetwork = "192.168.0.0/16".parse().unwrap();
access.remove_blocked(&blocked_net);
assert_eq!(access.blocked_count(), 0);
}
#[test]
fn test_localhost_allowed_by_default() {
let access = IpAccessControl::new();
let localhost_v4: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
assert_eq!(access.check(&localhost_v4), AccessPolicy::Allow);
let localhost_v6: IpAddr = IpAddr::V6(Ipv6Addr::LOCALHOST);
assert_eq!(access.check(&localhost_v6), AccessPolicy::Allow);
}
#[test]
fn test_empty_config() {
let access = IpAccessControl::from_config(&[], &[]).unwrap();
assert!(!access.is_whitelist_mode());
assert_eq!(access.default_policy(), AccessPolicy::Allow);
assert_eq!(access.allowed_count(), 0);
assert_eq!(access.blocked_count(), 0);
assert_eq!(
access.check(&"8.8.8.8".parse().unwrap()),
AccessPolicy::Allow
);
}
#[test]
fn test_get_networks() {
let access = IpAccessControl::from_config(
&["10.0.0.0/8".to_string()],
&["192.168.0.0/16".to_string()],
)
.unwrap();
let allowed = access.allowed_networks();
assert_eq!(allowed.len(), 1);
assert_eq!(allowed[0].to_string(), "10.0.0.0/8");
let blocked = access.blocked_networks();
assert_eq!(blocked.len(), 1);
assert_eq!(blocked[0].to_string(), "192.168.0.0/16");
}
#[tokio::test]
async fn test_shared_access_control() {
let access = IpAccessControl::from_config(&["10.0.0.0/8".to_string()], &[]).unwrap();
let shared = SharedIpAccessControl::new(access);
assert_eq!(
shared.check(&"10.1.2.3".parse().unwrap()).await,
AccessPolicy::Allow
);
assert_eq!(
shared.check(&"192.168.1.1".parse().unwrap()).await,
AccessPolicy::Deny
);
let ip: IpAddr = "10.100.100.100".parse().unwrap();
assert_eq!(shared.check(&ip).await, AccessPolicy::Allow);
shared.block_ip(ip).await;
assert_eq!(shared.check(&ip).await, AccessPolicy::Deny);
shared.unblock_ip(ip).await;
assert_eq!(shared.check(&ip).await, AccessPolicy::Allow);
let (allowed, blocked, policy) = shared.stats().await;
assert_eq!(allowed, 1);
assert_eq!(blocked, 0);
assert_eq!(policy, AccessPolicy::Deny);
}
}