Skip to main content

proxy_protocol_rs/
policy.rs

1// Copyright (C) 2025-2026 Michael S. Klishin and Contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::net::{IpAddr, SocketAddr};
17
18use ipnet::IpNet;
19
20/// Policy applied to incoming connections before reading the PP header
21pub trait ConnPolicy: Send + Sync + 'static {
22    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision;
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum PolicyDecision {
28    /// Read and parse the PP header; error if absent or malformed
29    Require,
30    /// Try to read a PP header; if the first bytes don't match, treat as direct
31    Use,
32    /// Ignore any PP header; pass through unmodified
33    Ignore,
34    /// Reject the connection immediately
35    Reject,
36}
37
38/// Require PP headers from all connections
39#[derive(Debug, Clone)]
40pub struct AcceptAll;
41
42impl ConnPolicy for AcceptAll {
43    fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
44        PolicyDecision::Require
45    }
46}
47
48/// Require PP headers only from connections in the trusted set; reject all others
49#[derive(Debug, Clone)]
50pub struct TrustedProxies {
51    exact: HashSet<IpAddr>,
52    cidrs: Vec<IpNet>,
53}
54
55impl TrustedProxies {
56    pub fn new(addrs: impl IntoIterator<Item = IpAddr>) -> Self {
57        Self {
58            exact: addrs.into_iter().collect(),
59            cidrs: Vec::new(),
60        }
61    }
62
63    pub fn with_cidrs(
64        addrs: impl IntoIterator<Item = IpAddr>,
65        cidrs: impl IntoIterator<Item = IpNet>,
66    ) -> Self {
67        Self {
68            exact: addrs.into_iter().collect(),
69            cidrs: cidrs.into_iter().collect(),
70        }
71    }
72
73    /// Builds from a mixed set of `IpNet`s, automatically splitting
74    /// host-length prefixes (/32 for IPv4, /128 for IPv6) into exact
75    /// matches and the rest into CIDR ranges.
76    pub fn from_ipnets(nets: impl IntoIterator<Item = IpNet>) -> Self {
77        let mut exact = HashSet::new();
78        let mut cidrs = Vec::new();
79        for net in nets {
80            if net.prefix_len() == net.max_prefix_len() {
81                exact.insert(net.addr());
82            } else {
83                cidrs.push(net);
84            }
85        }
86        Self { exact, cidrs }
87    }
88
89    pub fn contains(&self, ip: IpAddr) -> bool {
90        self.exact.contains(&ip) || self.cidrs.iter().any(|net| net.contains(&ip))
91    }
92}
93
94impl ConnPolicy for TrustedProxies {
95    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
96        if self.contains(peer_addr.ip()) {
97            PolicyDecision::Require
98        } else {
99            PolicyDecision::Reject
100        }
101    }
102}
103
104/// Allow both proxied and direct connections
105#[derive(Debug, Clone)]
106pub struct MixedMode {
107    trusted: TrustedProxies,
108}
109
110impl MixedMode {
111    pub fn new(trusted: TrustedProxies) -> Self {
112        Self { trusted }
113    }
114}
115
116impl From<TrustedProxies> for MixedMode {
117    fn from(trusted: TrustedProxies) -> Self {
118        Self::new(trusted)
119    }
120}
121
122impl ConnPolicy for MixedMode {
123    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
124        if self.trusted.contains(peer_addr.ip()) {
125            PolicyDecision::Require
126        } else {
127            PolicyDecision::Ignore
128        }
129    }
130}
131
132/// Try to read a PP header from every connection, but treat it as optional
133///
134/// If the first bytes match a Proxy Protocol signature, the header is parsed.
135/// If not, the connection is passed through as a direct connection.
136/// Unlike `MixedMode`, this does not require a trusted proxy list:
137/// it accepts PP from any peer
138#[derive(Debug, Clone)]
139pub struct OptionalProxy;
140
141impl ConnPolicy for OptionalProxy {
142    fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
143        PolicyDecision::Use
144    }
145}
146
147/// Blanket implementation for closures
148impl<F> ConnPolicy for F
149where
150    F: Fn(SocketAddr) -> PolicyDecision + Send + Sync + 'static,
151{
152    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
153        self(peer_addr)
154    }
155}