#![allow(dead_code)]
use std::net::IpAddr;
use std::sync::Mutex;
static POLICY: Mutex<String> = Mutex::new(String::new());
pub fn set(value: &str) {
*POLICY.lock().unwrap() = value.to_string();
}
pub fn get() -> String {
let s = POLICY.lock().unwrap();
if s.is_empty() {
"allow_all".to_string()
} else {
s.clone()
}
}
pub fn check_addr(addr: std::net::SocketAddr) -> Result<(), String> {
let ip = addr.ip();
let policy = get();
if policy == "allow_all" {
return Ok(());
}
let private = is_private(ip);
if policy == "deny_private" {
return if private {
Err(format!("policy: {ip} is in deny_private range"))
} else {
Ok(())
};
}
if let Some(cidrs) = policy.strip_prefix("denylist:") {
if private {
return Err(format!("policy: {ip} is in deny_private range"));
}
for cidr in cidrs.split(',') {
if cidr_contains(cidr.trim(), ip) {
return Err(format!("policy: {ip} matched denylist {cidr}"));
}
}
return Ok(());
}
if let Some(cidrs) = policy.strip_prefix("allowlist:") {
for cidr in cidrs.split(',') {
if cidr_contains(cidr.trim(), ip) {
return Ok(());
}
}
return Err(format!("policy: {ip} not in allowlist"));
}
Err(format!("policy: unknown {policy:?}"))
}
fn is_private(ip: IpAddr) -> bool {
match ip {
IpAddr::V4(v4) => {
let o = v4.octets();
v4.is_loopback() || v4.is_private() || v4.is_link_local()
|| v4.is_multicast() || v4.is_broadcast()
|| o[0] == 0
|| (o[0] == 100 && (o[1] & 0xc0) == 64) || (o[0] == 192 && o[1] == 0 && o[2] == 0)
|| (o[0] == 192 && o[1] == 0 && o[2] == 2)
|| (o[0] == 198 && (o[1] == 18 || o[1] == 19))
|| (o[0] == 198 && o[1] == 51 && o[2] == 100)
|| (o[0] == 203 && o[1] == 0 && o[2] == 113)
|| o[0] >= 240
}
IpAddr::V6(v6) => {
v6.is_loopback() || v6.is_unspecified() || v6.is_multicast() || {
let s = v6.segments();
(s[0] & 0xfe00) == 0xfc00 || (s[0] & 0xffc0) == 0xfe80
}
}
}
}
fn cidr_contains(cidr: &str, ip: IpAddr) -> bool {
let (net_str, prefix) = match cidr.split_once('/') {
Some((n, p)) => (n, p.parse::<u32>().unwrap_or(u32::MAX)),
None => (cidr, u32::MAX),
};
let net: IpAddr = match net_str.parse() {
Ok(n) => n,
Err(_) => return false,
};
match (net, ip) {
(IpAddr::V4(n4), IpAddr::V4(i4)) => {
let pfx = prefix.min(32);
let mask: u32 = if pfx == 0 { 0 } else { !0u32 << (32 - pfx) };
let n = u32::from_be_bytes(n4.octets());
let i = u32::from_be_bytes(i4.octets());
(n & mask) == (i & mask)
}
(IpAddr::V6(n6), IpAddr::V6(i6)) => {
let pfx = prefix.min(128);
let n = u128::from_be_bytes(n6.octets());
let i = u128::from_be_bytes(i6.octets());
let mask: u128 = if pfx == 0 { 0 } else { !0u128 << (128 - pfx) };
(n & mask) == (i & mask)
}
_ => false,
}
}