use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use super::{PolicyError, PortRange};
#[derive(Clone, Debug, Default)]
pub struct AddrPolicy {
rules: Vec<AddrPolicyRule>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[allow(clippy::exhaustive_enums)]
pub enum RuleKind {
Accept,
Reject,
}
impl AddrPolicy {
pub fn allows(&self, addr: &IpAddr, port: u16) -> Option<RuleKind> {
self.rules
.iter()
.find(|rule| rule.pattern.matches(addr, port))
.map(|AddrPolicyRule { kind, .. }| *kind)
}
pub fn allows_sockaddr(&self, addr: &SocketAddr) -> Option<RuleKind> {
self.allows(&addr.ip(), addr.port())
}
pub fn new() -> Self {
AddrPolicy::default()
}
pub fn push(&mut self, kind: RuleKind, pattern: AddrPortPattern) {
self.rules.push(AddrPolicyRule { kind, pattern });
}
}
#[derive(Clone, Debug)]
struct AddrPolicyRule {
kind: RuleKind,
pattern: AddrPortPattern,
}
#[derive(
Clone, Debug, Eq, PartialEq, serde_with::SerializeDisplay, serde_with::DeserializeFromStr,
)]
pub struct AddrPortPattern {
pattern: IpPattern,
ports: PortRange,
}
impl AddrPortPattern {
pub fn new_all() -> Self {
Self {
pattern: IpPattern::Star,
ports: PortRange::new_all(),
}
}
pub fn matches(&self, addr: &IpAddr, port: u16) -> bool {
self.pattern.matches(addr) && self.ports.contains(port)
}
pub fn matches_sockaddr(&self, addr: &SocketAddr) -> bool {
self.matches(&addr.ip(), addr.port())
}
}
impl Display for AddrPortPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.ports.is_all() {
write!(f, "{}:*", self.pattern)
} else {
write!(f, "{}:{}", self.pattern, self.ports)
}
}
}
impl FromStr for AddrPortPattern {
type Err = PolicyError;
fn from_str(s: &str) -> Result<Self, PolicyError> {
let last_colon = s.rfind(':').ok_or(PolicyError::InvalidPolicy)?;
let pattern: IpPattern = s[..last_colon].parse()?;
let ports_s = &s[last_colon + 1..];
let ports: PortRange = if ports_s == "*" {
PortRange::new_all()
} else {
ports_s.parse()?
};
Ok(AddrPortPattern { pattern, ports })
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
enum IpPattern {
Star,
V4Star,
V6Star,
V4(Ipv4Addr, u8),
V6(Ipv6Addr, u8),
}
impl IpPattern {
fn from_addr_and_mask(addr: IpAddr, mask: u8) -> Result<Self, PolicyError> {
match (addr, mask) {
(IpAddr::V4(_), 0) => Ok(IpPattern::V4Star),
(IpAddr::V6(_), 0) => Ok(IpPattern::V6Star),
(IpAddr::V4(a), m) if m <= 32 => Ok(IpPattern::V4(a, m)),
(IpAddr::V6(a), m) if m <= 128 => Ok(IpPattern::V6(a, m)),
(_, _) => Err(PolicyError::InvalidMask),
}
}
fn matches(&self, addr: &IpAddr) -> bool {
match (self, addr) {
(IpPattern::Star, _) => true,
(IpPattern::V4Star, IpAddr::V4(_)) => true,
(IpPattern::V6Star, IpAddr::V6(_)) => true,
(IpPattern::V4(pat, mask), IpAddr::V4(addr)) => {
let p1 = u32::from_be_bytes(pat.octets());
let p2 = u32::from_be_bytes(addr.octets());
let shift = 32 - mask;
(p1 >> shift) == (p2 >> shift)
}
(IpPattern::V6(pat, mask), IpAddr::V6(addr)) => {
let p1 = u128::from_be_bytes(pat.octets());
let p2 = u128::from_be_bytes(addr.octets());
let shift = 128 - mask;
(p1 >> shift) == (p2 >> shift)
}
(_, _) => false,
}
}
}
impl Display for IpPattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use IpPattern::*;
match self {
Star | V4Star | V6Star => write!(f, "*"),
V4(a, 32) => write!(f, "{}", a),
V4(a, m) => write!(f, "{}/{}", a, m),
V6(a, 128) => write!(f, "[{}]", a),
V6(a, m) => write!(f, "[{}]/{}", a, m),
}
}
}
fn parse_addr(mut s: &str) -> Result<IpAddr, PolicyError> {
let bracketed = s.starts_with('[') && s.ends_with(']');
if bracketed {
s = &s[1..s.len() - 1];
}
let addr: IpAddr = s.parse().map_err(|_| PolicyError::InvalidAddress)?;
if addr.is_ipv6() != bracketed {
return Err(PolicyError::InvalidAddress);
}
Ok(addr)
}
impl FromStr for IpPattern {
type Err = PolicyError;
fn from_str(s: &str) -> Result<Self, PolicyError> {
let (ip_s, mask_s) = match s.find('/') {
Some(slash_idx) => (&s[..slash_idx], Some(&s[slash_idx + 1..])),
None => (s, None),
};
match (ip_s, mask_s) {
("*", Some(_)) => Err(PolicyError::MaskWithStar),
("*", None) => Ok(IpPattern::Star),
(s, Some(m)) => {
let a: IpAddr = parse_addr(s)?;
let m: u8 = m.parse().map_err(|_| PolicyError::InvalidMask)?;
IpPattern::from_addr_and_mask(a, m)
}
(s, None) => {
let a: IpAddr = parse_addr(s)?;
let m = if a.is_ipv4() { 32 } else { 128 };
IpPattern::from_addr_and_mask(a, m)
}
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_roundtrip_rules() {
fn check(inp: &str, outp: &str) {
let policy = inp.parse::<AddrPortPattern>().unwrap();
assert_eq!(format!("{}", policy), outp);
}
check("127.0.0.2/32:77-10000", "127.0.0.2:77-10000");
check("127.0.0.2/32:*", "127.0.0.2:*");
check("127.0.0.0/16:9-100", "127.0.0.0/16:9-100");
check("127.0.0.0/0:443", "*:443");
check("*:443", "*:443");
check("[::1]:443", "[::1]:443");
check("[ffaa::]/16:80", "[ffaa::]/16:80");
check("[ffaa::77]/128:80", "[ffaa::77]:80");
}
#[test]
fn test_bad_rules() {
fn check(s: &str) {
assert!(s.parse::<AddrPortPattern>().is_err());
}
check("marzipan:80");
check("1.2.3.4:90-80");
check("1.2.3.4/100:8888");
check("[1.2.3.4]/16:80");
check("[::1]/130:8888");
}
#[test]
fn test_rule_matches() {
fn check(addr: &str, yes: &[&str], no: &[&str]) {
use std::net::SocketAddr;
let policy = addr.parse::<AddrPortPattern>().unwrap();
for s in yes {
let sa = s.parse::<SocketAddr>().unwrap();
assert!(policy.matches_sockaddr(&sa));
}
for s in no {
let sa = s.parse::<SocketAddr>().unwrap();
assert!(!policy.matches_sockaddr(&sa));
}
}
check(
"1.2.3.4/16:80",
&["1.2.3.4:80", "1.2.44.55:80"],
&["9.9.9.9:80", "1.3.3.4:80", "1.2.3.4:81"],
);
check(
"*:443-8000",
&["1.2.3.4:443", "[::1]:500"],
&["9.0.0.0:80", "[::1]:80"],
);
check(
"[face::]/8:80",
&["[fab0::7]:80"],
&["[dd00::]:80", "[face::7]:443"],
);
check("0.0.0.0/0:*", &["127.0.0.1:80"], &["[f00b::]:80"]);
check("[::]/0:*", &["[f00b::]:80"], &["127.0.0.1:80"]);
}
#[test]
fn test_policy_matches() -> Result<(), PolicyError> {
let mut policy = AddrPolicy::default();
policy.push(RuleKind::Accept, "*:443".parse()?);
policy.push(RuleKind::Accept, "[::1]:80".parse()?);
policy.push(RuleKind::Reject, "*:80".parse()?);
let policy = policy; assert_eq!(
policy.allows_sockaddr(&"[::6]:443".parse().unwrap()),
Some(RuleKind::Accept)
);
assert_eq!(
policy.allows_sockaddr(&"127.0.0.1:443".parse().unwrap()),
Some(RuleKind::Accept)
);
assert_eq!(
policy.allows_sockaddr(&"[::1]:80".parse().unwrap()),
Some(RuleKind::Accept)
);
assert_eq!(
policy.allows_sockaddr(&"[::2]:80".parse().unwrap()),
Some(RuleKind::Reject)
);
assert_eq!(
policy.allows_sockaddr(&"127.0.0.1:80".parse().unwrap()),
Some(RuleKind::Reject)
);
assert_eq!(
policy.allows_sockaddr(&"127.0.0.1:66".parse().unwrap()),
None
);
Ok(())
}
#[test]
fn serde() {
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize, Eq, PartialEq)]
struct X {
p1: AddrPortPattern,
p2: AddrPortPattern,
}
let x = X {
p1: "127.0.0.1/8:9-10".parse().unwrap(),
p2: "*:80".parse().unwrap(),
};
let encoded = serde_json::to_string(&x).unwrap();
let expected = r#"{"p1":"127.0.0.1/8:9-10","p2":"*:80"}"#;
let x2: X = serde_json::from_str(&encoded).unwrap();
let x3: X = serde_json::from_str(expected).unwrap();
assert_eq!(&x2, &x3);
assert_eq!(&x2, &x);
}
}