use std::collections::HashSet;
use std::net::{IpAddr, SocketAddr};
use ipnet::IpNet;
pub trait ConnPolicy: Send + Sync + 'static {
fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum PolicyDecision {
Require,
Use,
Ignore,
Reject,
}
#[derive(Debug, Clone)]
pub struct AcceptAll;
impl ConnPolicy for AcceptAll {
fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
PolicyDecision::Require
}
}
#[derive(Debug, Clone)]
pub struct TrustedProxies {
exact: HashSet<IpAddr>,
cidrs: Vec<IpNet>,
}
impl TrustedProxies {
pub fn new(addrs: impl IntoIterator<Item = IpAddr>) -> Self {
Self {
exact: addrs.into_iter().collect(),
cidrs: Vec::new(),
}
}
pub fn with_cidrs(
addrs: impl IntoIterator<Item = IpAddr>,
cidrs: impl IntoIterator<Item = IpNet>,
) -> Self {
Self {
exact: addrs.into_iter().collect(),
cidrs: cidrs.into_iter().collect(),
}
}
pub fn from_ipnets(nets: impl IntoIterator<Item = IpNet>) -> Self {
let mut exact = HashSet::new();
let mut cidrs = Vec::new();
for net in nets {
if net.prefix_len() == net.max_prefix_len() {
exact.insert(net.addr());
} else {
cidrs.push(net);
}
}
Self { exact, cidrs }
}
pub fn contains(&self, ip: IpAddr) -> bool {
self.exact.contains(&ip) || self.cidrs.iter().any(|net| net.contains(&ip))
}
}
impl ConnPolicy for TrustedProxies {
fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
if self.contains(peer_addr.ip()) {
PolicyDecision::Require
} else {
PolicyDecision::Reject
}
}
}
#[derive(Debug, Clone)]
pub struct MixedMode {
trusted: TrustedProxies,
}
impl MixedMode {
pub fn new(trusted: TrustedProxies) -> Self {
Self { trusted }
}
}
impl From<TrustedProxies> for MixedMode {
fn from(trusted: TrustedProxies) -> Self {
Self::new(trusted)
}
}
impl ConnPolicy for MixedMode {
fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
if self.trusted.contains(peer_addr.ip()) {
PolicyDecision::Require
} else {
PolicyDecision::Ignore
}
}
}
#[derive(Debug, Clone)]
pub struct OptionalProxy;
impl ConnPolicy for OptionalProxy {
fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
PolicyDecision::Use
}
}
impl<F> ConnPolicy for F
where
F: Fn(SocketAddr) -> PolicyDecision + Send + Sync + 'static,
{
fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
self(peer_addr)
}
}