Skip to main content

proxy_protocol_rs/
policy.rs

1// Copyright (C) 2025-2026 Michael S. Klishin and Contributors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::HashSet;
16use std::net::{IpAddr, SocketAddr};
17
18use ipnet::IpNet;
19
20/// Policy applied to incoming connections before reading the PP header
21pub trait ConnPolicy: Send + Sync + 'static {
22    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision;
23}
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[non_exhaustive]
27pub enum PolicyDecision {
28    /// Read and parse the PP header; error if absent or malformed
29    Require,
30    /// Try to read a PP header; if the first bytes don't match, treat as direct
31    Use,
32    /// Ignore any PP header; pass through unmodified
33    Ignore,
34    /// Reject the connection immediately
35    Reject,
36}
37
38/// Require PP headers from all connections
39#[derive(Debug, Clone)]
40pub struct AcceptAll;
41
42impl ConnPolicy for AcceptAll {
43    fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
44        PolicyDecision::Require
45    }
46}
47
48/// Require PP headers only from connections in the trusted set; reject all others
49#[derive(Debug, Clone)]
50pub struct TrustedProxies {
51    exact: HashSet<IpAddr>,
52    cidrs: Vec<IpNet>,
53}
54
55impl TrustedProxies {
56    pub fn new(addrs: impl IntoIterator<Item = IpAddr>) -> Self {
57        Self {
58            exact: addrs.into_iter().collect(),
59            cidrs: Vec::new(),
60        }
61    }
62
63    pub fn with_cidrs(
64        addrs: impl IntoIterator<Item = IpAddr>,
65        cidrs: impl IntoIterator<Item = IpNet>,
66    ) -> Self {
67        Self {
68            exact: addrs.into_iter().collect(),
69            cidrs: cidrs.into_iter().collect(),
70        }
71    }
72
73    pub(crate) fn contains(&self, ip: IpAddr) -> bool {
74        self.exact.contains(&ip) || self.cidrs.iter().any(|net| net.contains(&ip))
75    }
76}
77
78impl ConnPolicy for TrustedProxies {
79    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
80        if self.contains(peer_addr.ip()) {
81            PolicyDecision::Require
82        } else {
83            PolicyDecision::Reject
84        }
85    }
86}
87
88/// Allow both proxied and direct connections
89#[derive(Debug, Clone)]
90pub struct MixedMode {
91    trusted: TrustedProxies,
92}
93
94impl MixedMode {
95    pub fn new(trusted: TrustedProxies) -> Self {
96        Self { trusted }
97    }
98}
99
100impl From<TrustedProxies> for MixedMode {
101    fn from(trusted: TrustedProxies) -> Self {
102        Self::new(trusted)
103    }
104}
105
106impl ConnPolicy for MixedMode {
107    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
108        if self.trusted.contains(peer_addr.ip()) {
109            PolicyDecision::Require
110        } else {
111            PolicyDecision::Ignore
112        }
113    }
114}
115
116/// Try to read a PP header from every connection, but treat it as optional
117///
118/// If the first bytes match a Proxy Protocol signature, the header is parsed.
119/// If not, the connection is passed through as a direct connection.
120/// Unlike `MixedMode`, this does not require a trusted proxy list:
121/// it accepts PP from any peer
122#[derive(Debug, Clone)]
123pub struct OptionalProxy;
124
125impl ConnPolicy for OptionalProxy {
126    fn evaluate(&self, _peer_addr: SocketAddr) -> PolicyDecision {
127        PolicyDecision::Use
128    }
129}
130
131/// Blanket implementation for closures
132impl<F> ConnPolicy for F
133where
134    F: Fn(SocketAddr) -> PolicyDecision + Send + Sync + 'static,
135{
136    fn evaluate(&self, peer_addr: SocketAddr) -> PolicyDecision {
137        self(peer_addr)
138    }
139}