proxy_protocol_rs/
policy.rs1use std::collections::HashSet;
16use std::net::{IpAddr, SocketAddr};
17
18use ipnet::IpNet;
19
20pub trait ConnPolicy: Send + Sync + 'static {
22 fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision;
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum PolicyDecision {
28 Require,
30 Use,
32 Ignore,
34 Reject,
36}
37
38#[derive(Debug, Clone)]
40pub struct AcceptAll;
41
42impl ConnPolicy for AcceptAll {
43 fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
44 PolicyDecision::Require
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct TrustedProxies {
51 exact: HashSet<IpAddr>,
52 cidrs: Vec<IpNet>,
53}
54
55impl TrustedProxies {
56 pub fn new(addrs: impl IntoIterator<Item = IpAddr>) -> Self {
57 Self {
58 exact: addrs.into_iter().collect(),
59 cidrs: Vec::new(),
60 }
61 }
62
63 pub fn with_cidrs(
64 addrs: impl IntoIterator<Item = IpAddr>,
65 cidrs: impl IntoIterator<Item = IpNet>,
66 ) -> Self {
67 Self {
68 exact: addrs.into_iter().collect(),
69 cidrs: cidrs.into_iter().collect(),
70 }
71 }
72
73 pub fn from_ipnets(nets: impl IntoIterator<Item = IpNet>) -> Self {
77 let mut exact = HashSet::new();
78 let mut cidrs = Vec::new();
79 for net in nets {
80 if net.prefix_len() == net.max_prefix_len() {
81 exact.insert(net.addr());
82 } else {
83 cidrs.push(net);
84 }
85 }
86 Self { exact, cidrs }
87 }
88
89 pub fn contains(&self, ip: IpAddr) -> bool {
90 self.exact.contains(&ip) || self.cidrs.iter().any(|net| net.contains(&ip))
91 }
92}
93
94impl ConnPolicy for TrustedProxies {
95 fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
96 if self.contains(peer_addr.ip()) {
97 PolicyDecision::Require
98 } else {
99 PolicyDecision::Reject
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct MixedMode {
107 trusted: TrustedProxies,
108}
109
110impl MixedMode {
111 pub fn new(trusted: TrustedProxies) -> Self {
112 Self { trusted }
113 }
114}
115
116impl From<TrustedProxies> for MixedMode {
117 fn from(trusted: TrustedProxies) -> Self {
118 Self::new(trusted)
119 }
120}
121
122impl ConnPolicy for MixedMode {
123 fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
124 if self.trusted.contains(peer_addr.ip()) {
125 PolicyDecision::Require
126 } else {
127 PolicyDecision::Ignore
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
139pub struct OptionalProxy;
140
141impl ConnPolicy for OptionalProxy {
142 fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
143 PolicyDecision::Use
144 }
145}
146
147impl<F> ConnPolicy for F
149where
150 F: Fn(SocketAddr) -> PolicyDecision + Send + Sync + 'static,
151{
152 fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
153 self(peer_addr)
154 }
155}