1use std::net::IpAddr;
4
5fn validate_ip_format(ip_str: &str) -> Option<IpAddr> {
11 ip_str.parse::<IpAddr>().ok()
12}
13
14#[derive(Debug, Clone)]
16pub struct ProxyConfig {
17 pub trusted_proxies: Vec<IpAddr>,
20 pub require_trusted_proxy: bool,
22}
23
24impl ProxyConfig {
25 pub const fn new(trusted_proxies: Vec<IpAddr>, require_trusted_proxy: bool) -> Self {
27 Self {
28 trusted_proxies,
29 require_trusted_proxy,
30 }
31 }
32
33 pub fn localhost_only() -> Self {
39 Self {
40 trusted_proxies: vec!["127.0.0.1".parse().expect("valid IP")], require_trusted_proxy: true,
42 }
43 }
44
45 pub const fn none() -> Self {
47 Self {
48 trusted_proxies: Vec::new(),
49 require_trusted_proxy: false,
50 }
51 }
52
53 pub fn is_trusted_proxy(&self, ip: &str) -> bool {
59 if self.trusted_proxies.is_empty() {
60 return false;
61 }
62
63 match validate_ip_format(ip) {
65 Some(addr) => self.trusted_proxies.contains(&addr),
66 None => false, }
68 }
69
70 pub fn extract_client_ip(
80 &self,
81 headers: &axum::http::HeaderMap,
82 socket_addr: Option<std::net::SocketAddr>,
83 ) -> Option<String> {
84 let direct_ip = socket_addr.map(|addr| addr.ip().to_string());
85
86 let direct_ip_str = direct_ip.as_deref().unwrap_or("");
88
89 if let Some(forwarded_for) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
91 if self.is_trusted_proxy(direct_ip_str) {
92 if let Some(ip_str) = forwarded_for.split(',').next().map(|ip| ip.trim()) {
94 if validate_ip_format(ip_str).is_some() {
96 return Some(ip_str.to_string());
97 }
98 }
100 }
101 if let Some(ip) = direct_ip {
103 return Some(ip);
104 }
105 }
106
107 if let Some(real_ip) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
109 if self.is_trusted_proxy(direct_ip_str) {
110 if validate_ip_format(real_ip).is_some() {
112 return Some(real_ip.to_string());
113 }
114 }
116 if let Some(ip) = direct_ip {
118 return Some(ip);
119 }
120 }
121
122 direct_ip
124 }
125}
126
127#[allow(clippy::unwrap_used)] #[cfg(test)]
129mod tests {
130 #[allow(clippy::wildcard_imports)]
131 use super::*;
133
134 #[test]
135 fn test_proxy_config_localhost_only() {
136 let config = ProxyConfig::localhost_only();
137 assert!(config.is_trusted_proxy("127.0.0.1"));
138 assert!(!config.is_trusted_proxy("192.168.1.1"));
139 }
140
141 #[test]
142 fn test_proxy_config_is_trusted_proxy_valid_ip() {
143 let ip: IpAddr = "10.0.0.1".parse().unwrap();
144 let config = ProxyConfig::new(vec![ip], true);
145 assert!(config.is_trusted_proxy("10.0.0.1"));
146 }
147
148 #[test]
149 fn test_proxy_config_is_trusted_proxy_untrusted_ip() {
150 let ip: IpAddr = "10.0.0.1".parse().unwrap();
151 let config = ProxyConfig::new(vec![ip], true);
152 assert!(!config.is_trusted_proxy("192.168.1.1"));
153 }
154
155 #[test]
156 fn test_proxy_config_is_trusted_proxy_invalid_ip() {
157 let ip: IpAddr = "10.0.0.1".parse().unwrap();
158 let config = ProxyConfig::new(vec![ip], true);
159 assert!(!config.is_trusted_proxy("invalid_ip"));
160 }
161
162 #[test]
163 fn test_extract_client_ip_from_trusted_proxy_x_forwarded_for() {
164 let ip: IpAddr = "10.0.0.1".parse().unwrap();
165 let config = ProxyConfig::new(vec![ip], true);
166
167 let mut headers = axum::http::HeaderMap::new();
168 headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
169
170 let direct_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
171 let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
172
173 let result = config.extract_client_ip(&headers, socket);
174 assert_eq!(result, Some("192.0.2.1".to_string()));
175 }
176
177 #[test]
178 fn test_extract_client_ip_from_untrusted_proxy_x_forwarded_for() {
179 let ip: IpAddr = "10.0.0.1".parse().unwrap();
180 let config = ProxyConfig::new(vec![ip], true);
181
182 let mut headers = axum::http::HeaderMap::new();
183 headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
184
185 let direct_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
186 let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
187
188 let result = config.extract_client_ip(&headers, socket);
190 assert_eq!(result, Some("192.168.1.100".to_string()));
191 }
192
193 #[test]
194 fn test_extract_client_ip_no_headers() {
195 let config = ProxyConfig::localhost_only();
196 let headers = axum::http::HeaderMap::new();
197
198 let direct_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
199 let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
200
201 let result = config.extract_client_ip(&headers, socket);
202 assert_eq!(result, Some("192.168.1.100".to_string()));
203 }
204
205 #[test]
206 fn test_extract_client_ip_empty_headers() {
207 let config = ProxyConfig::localhost_only();
208 let headers = axum::http::HeaderMap::new();
209
210 let result = config.extract_client_ip(&headers, None);
211 assert_eq!(result, None);
212 }
213
214 #[test]
215 fn test_extract_client_ip_spoofing_attempt() {
216 let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
218 let config = ProxyConfig::new(vec![trusted_ip], true);
219
220 let mut headers = axum::http::HeaderMap::new();
221 headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
223
224 let attacker_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
226 let socket = attacker_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
227
228 let result = config.extract_client_ip(&headers, socket);
230 assert_eq!(result, Some("192.168.1.100".to_string()));
231 }
232
233 #[test]
234 fn test_extract_client_ip_invalid_format_x_forwarded_for() {
235 let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
237 let config = ProxyConfig::new(vec![trusted_ip], true);
238
239 let mut headers = axum::http::HeaderMap::new();
240 headers.insert("x-forwarded-for", "not-a-valid-ip-address, 10.0.0.1".parse().unwrap());
242
243 let trusted_source_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
244 let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
245
246 let result = config.extract_client_ip(&headers, socket);
248 assert_eq!(result, Some("10.0.0.1".to_string()));
249 }
250
251 #[test]
252 fn test_extract_client_ip_invalid_format_x_real_ip() {
253 let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
255 let config = ProxyConfig::new(vec![trusted_ip], true);
256
257 let mut headers = axum::http::HeaderMap::new();
258 headers.insert("x-real-ip", "256.256.256.256".parse().unwrap());
260
261 let trusted_source_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
262 let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
263
264 let result = config.extract_client_ip(&headers, socket);
266 assert_eq!(result, Some("10.0.0.1".to_string()));
267 }
268
269 #[test]
270 fn test_extract_client_ip_valid_ipv6() {
271 let trusted_ip: IpAddr = "::1".parse().unwrap();
273 let config = ProxyConfig::new(vec![trusted_ip], true);
274
275 let mut headers = axum::http::HeaderMap::new();
276 headers.insert("x-forwarded-for", "2001:db8::1, ::1".parse().unwrap());
277
278 let trusted_source_ip = "::1".parse::<std::net::IpAddr>().ok();
279 let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
280
281 let result = config.extract_client_ip(&headers, socket);
282 assert_eq!(result, Some("2001:db8::1".to_string()));
283 }
284}