mod addrpolicy;
mod portpolicy;
use std::fmt::Display;
use std::str::FromStr;
use thiserror::Error;
pub use addrpolicy::{AddrPolicy, AddrPortPattern, RuleKind};
pub use portpolicy::PortPolicy;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum PolicyError {
#[error("Invalid port")]
InvalidPort,
#[error("Invalid port range")]
InvalidRange,
#[error("Invalid address")]
InvalidAddress,
#[error("mask with star")]
MaskWithStar,
#[error("invalid mask")]
InvalidMask,
#[error("Invalid policy")]
InvalidPolicy,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[allow(clippy::exhaustive_structs)]
pub struct PortRange {
pub lo: u16,
pub hi: u16,
}
impl PortRange {
fn new_unchecked(lo: u16, hi: u16) -> Self {
assert!(lo != 0);
assert!(lo <= hi);
PortRange { lo, hi }
}
pub fn new_all() -> Self {
PortRange::new_unchecked(1, 65535)
}
pub fn new(lo: u16, hi: u16) -> Option<Self> {
if lo != 0 && lo <= hi {
Some(PortRange { lo, hi })
} else {
None
}
}
pub fn contains(&self, port: u16) -> bool {
self.lo <= port && port <= self.hi
}
pub fn is_all(&self) -> bool {
self.lo == 1 && self.hi == 65535
}
fn compare_to_port(&self, port: u16) -> std::cmp::Ordering {
use std::cmp::Ordering::*;
if port < self.lo {
Greater
} else if port <= self.hi {
Equal
} else {
Less
}
}
}
impl Display for PortRange {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.lo == self.hi {
write!(f, "{}", self.lo)
} else {
write!(f, "{}-{}", self.lo, self.hi)
}
}
}
impl FromStr for PortRange {
type Err = PolicyError;
fn from_str(s: &str) -> Result<Self, PolicyError> {
let idx = s.find('-');
let (lo, hi) = if let Some(pos) = idx {
(
s[..pos]
.parse::<u16>()
.map_err(|_| PolicyError::InvalidPort)?,
s[pos + 1..]
.parse::<u16>()
.map_err(|_| PolicyError::InvalidPort)?,
)
} else {
let v = s.parse::<u16>().map_err(|_| PolicyError::InvalidPort)?;
(v, v)
};
PortRange::new(lo, hi).ok_or(PolicyError::InvalidRange)
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use crate::Result;
#[test]
fn parse_portrange() -> Result<()> {
assert_eq!(
"1-100".parse::<PortRange>()?,
PortRange::new(1, 100).unwrap()
);
assert_eq!(
"01-100".parse::<PortRange>()?,
PortRange::new(1, 100).unwrap()
);
assert_eq!("1-65535".parse::<PortRange>()?, PortRange::new_all());
assert_eq!(
"10-30".parse::<PortRange>()?,
PortRange::new(10, 30).unwrap()
);
assert_eq!(
"9001".parse::<PortRange>()?,
PortRange::new(9001, 9001).unwrap()
);
assert_eq!(
"9001-9001".parse::<PortRange>()?,
PortRange::new(9001, 9001).unwrap()
);
assert!("hello".parse::<PortRange>().is_err());
assert!("0".parse::<PortRange>().is_err());
assert!("65536".parse::<PortRange>().is_err());
assert!("65537".parse::<PortRange>().is_err());
assert!("1-2-3".parse::<PortRange>().is_err());
assert!("10-5".parse::<PortRange>().is_err());
assert!("1-".parse::<PortRange>().is_err());
assert!("-2".parse::<PortRange>().is_err());
assert!("-".parse::<PortRange>().is_err());
assert!("*".parse::<PortRange>().is_err());
Ok(())
}
#[test]
fn pr_manip() {
assert!(PortRange::new_all().is_all());
assert!(!PortRange::new(2, 65535).unwrap().is_all());
assert!(PortRange::new_all().contains(1));
assert!(PortRange::new_all().contains(65535));
assert!(PortRange::new_all().contains(7777));
assert!(PortRange::new(20, 30).unwrap().contains(20));
assert!(PortRange::new(20, 30).unwrap().contains(25));
assert!(PortRange::new(20, 30).unwrap().contains(30));
assert!(!PortRange::new(20, 30).unwrap().contains(19));
assert!(!PortRange::new(20, 30).unwrap().contains(31));
use std::cmp::Ordering::*;
assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(7), Greater);
assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(20), Equal);
assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(25), Equal);
assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(30), Equal);
assert_eq!(PortRange::new(20, 30).unwrap().compare_to_port(100), Less);
}
#[test]
fn pr_fmt() {
fn chk(a: u16, b: u16, s: &str) {
let pr = PortRange::new(a, b).unwrap();
assert_eq!(format!("{}", pr), s);
}
chk(1, 65535, "1-65535");
chk(10, 20, "10-20");
chk(20, 20, "20");
}
}