Skip to main content

fraiseql_auth/
proxy.rs

1//! Proxy and IP address extraction with security validation
2
3use std::net::IpAddr;
4
5/// Validate that a string is a valid IP address format
6///
7/// # SECURITY
8/// Prevents injection attacks where malformed IPs could bypass validation.
9/// Returns None for any invalid IP format.
10fn validate_ip_format(ip_str: &str) -> Option<IpAddr> {
11    ip_str.parse::<IpAddr>().ok()
12}
13
14/// Proxy configuration for X-Forwarded-For header validation
15#[derive(Debug, Clone)]
16pub struct ProxyConfig {
17    /// List of trusted proxy IPs (e.g., load balancer, Nginx, HAProxy IPs)
18    /// Only X-Forwarded-For headers from these IPs are trusted
19    pub trusted_proxies:       Vec<IpAddr>,
20    /// If true, require request to come from a trusted proxy to use X-Forwarded-For
21    pub require_trusted_proxy: bool,
22}
23
24impl ProxyConfig {
25    /// Create a new proxy configuration
26    pub const fn new(trusted_proxies: Vec<IpAddr>, require_trusted_proxy: bool) -> Self {
27        Self {
28            trusted_proxies,
29            require_trusted_proxy,
30        }
31    }
32
33    /// Create a proxy config that trusts all local proxies (127.0.0.1 only).
34    ///
35    /// # Panics
36    ///
37    /// Cannot panic — the IP literal `"127.0.0.1"` is always valid.
38    pub fn localhost_only() -> Self {
39        Self {
40            trusted_proxies:       vec!["127.0.0.1".parse().expect("valid IP")], /* Reason: "127.0.0.1" is a compile-time literal and always parses successfully */
41            require_trusted_proxy: true,
42        }
43    }
44
45    /// Create a proxy config with no trusted proxies
46    pub const fn none() -> Self {
47        Self {
48            trusted_proxies:       Vec::new(),
49            require_trusted_proxy: false,
50        }
51    }
52
53    /// Check if an IP address is a trusted proxy
54    ///
55    /// # SECURITY
56    /// Validates IP format before checking against trusted list.
57    /// Returns false for any invalid IP format, preventing bypass attempts.
58    pub fn is_trusted_proxy(&self, ip: &str) -> bool {
59        if self.trusted_proxies.is_empty() {
60            return false;
61        }
62
63        // Validate IP format and check against trusted list
64        match validate_ip_format(ip) {
65            Some(addr) => self.trusted_proxies.contains(&addr),
66            None => false, // Invalid IP format is not trusted
67        }
68    }
69
70    /// Extract client IP from headers with security validation
71    ///
72    /// # SECURITY
73    /// Only trusts X-Forwarded-For if the request comes from a trusted proxy.
74    /// Falls back to direct connection IP if X-Forwarded-For cannot be validated.
75    /// Validates all extracted IPs to ensure proper format.
76    ///
77    /// This prevents IP spoofing attacks where an attacker sends a malicious
78    /// X-Forwarded-For header to bypass rate limiting or access controls.
79    pub fn extract_client_ip(
80        &self,
81        headers: &axum::http::HeaderMap,
82        socket_addr: Option<std::net::SocketAddr>,
83    ) -> Option<String> {
84        let direct_ip = socket_addr.map(|addr| addr.ip().to_string());
85
86        // If no direct IP available, return early
87        let direct_ip_str = direct_ip.as_deref().unwrap_or("");
88
89        // Check X-Forwarded-For if proxy is trusted
90        if let Some(forwarded_for) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
91            if self.is_trusted_proxy(direct_ip_str) {
92                // Extract first IP from X-Forwarded-For (client IP in chain)
93                if let Some(ip_str) = forwarded_for.split(',').next().map(|ip| ip.trim()) {
94                    // SECURITY: Validate IP format before returning
95                    if validate_ip_format(ip_str).is_some() {
96                        return Some(ip_str.to_string());
97                    }
98                    // Invalid IP format - fall through to use direct IP
99                }
100            }
101            // X-Forwarded-For present but from untrusted proxy - ignore it and use direct IP
102            if let Some(ip) = direct_ip {
103                return Some(ip);
104            }
105        }
106
107        // Check X-Real-IP if proxy is trusted
108        if let Some(real_ip) = headers.get("x-real-ip").and_then(|v| v.to_str().ok()) {
109            if self.is_trusted_proxy(direct_ip_str) {
110                // SECURITY: Validate IP format before returning
111                if validate_ip_format(real_ip).is_some() {
112                    return Some(real_ip.to_string());
113                }
114                // Invalid IP format - fall through to use direct IP
115            }
116            // X-Real-IP present but from untrusted proxy - ignore it and use direct IP
117            if let Some(ip) = direct_ip {
118                return Some(ip);
119            }
120        }
121
122        // Fall back to direct connection IP (already validated by Axum)
123        direct_ip
124    }
125}
126
127#[allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
128#[cfg(test)]
129mod tests {
130    #[allow(clippy::wildcard_imports)]
131    // Reason: test module — wildcard keeps test boilerplate minimal
132    use super::*;
133
134    #[test]
135    fn test_proxy_config_localhost_only() {
136        let config = ProxyConfig::localhost_only();
137        assert!(config.is_trusted_proxy("127.0.0.1"));
138        assert!(!config.is_trusted_proxy("192.168.1.1"));
139    }
140
141    #[test]
142    fn test_proxy_config_is_trusted_proxy_valid_ip() {
143        let ip: IpAddr = "10.0.0.1".parse().unwrap();
144        let config = ProxyConfig::new(vec![ip], true);
145        assert!(config.is_trusted_proxy("10.0.0.1"));
146    }
147
148    #[test]
149    fn test_proxy_config_is_trusted_proxy_untrusted_ip() {
150        let ip: IpAddr = "10.0.0.1".parse().unwrap();
151        let config = ProxyConfig::new(vec![ip], true);
152        assert!(!config.is_trusted_proxy("192.168.1.1"));
153    }
154
155    #[test]
156    fn test_proxy_config_is_trusted_proxy_invalid_ip() {
157        let ip: IpAddr = "10.0.0.1".parse().unwrap();
158        let config = ProxyConfig::new(vec![ip], true);
159        assert!(!config.is_trusted_proxy("invalid_ip"));
160    }
161
162    #[test]
163    fn test_extract_client_ip_from_trusted_proxy_x_forwarded_for() {
164        let ip: IpAddr = "10.0.0.1".parse().unwrap();
165        let config = ProxyConfig::new(vec![ip], true);
166
167        let mut headers = axum::http::HeaderMap::new();
168        headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
169
170        let direct_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
171        let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
172
173        let result = config.extract_client_ip(&headers, socket);
174        assert_eq!(result, Some("192.0.2.1".to_string()));
175    }
176
177    #[test]
178    fn test_extract_client_ip_from_untrusted_proxy_x_forwarded_for() {
179        let ip: IpAddr = "10.0.0.1".parse().unwrap();
180        let config = ProxyConfig::new(vec![ip], true);
181
182        let mut headers = axum::http::HeaderMap::new();
183        headers.insert("x-forwarded-for", "192.0.2.1, 10.0.0.1".parse().unwrap());
184
185        let direct_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
186        let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
187
188        // Should ignore X-Forwarded-For and use direct IP
189        let result = config.extract_client_ip(&headers, socket);
190        assert_eq!(result, Some("192.168.1.100".to_string()));
191    }
192
193    #[test]
194    fn test_extract_client_ip_no_headers() {
195        let config = ProxyConfig::localhost_only();
196        let headers = axum::http::HeaderMap::new();
197
198        let direct_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
199        let socket = direct_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
200
201        let result = config.extract_client_ip(&headers, socket);
202        assert_eq!(result, Some("192.168.1.100".to_string()));
203    }
204
205    #[test]
206    fn test_extract_client_ip_empty_headers() {
207        let config = ProxyConfig::localhost_only();
208        let headers = axum::http::HeaderMap::new();
209
210        let result = config.extract_client_ip(&headers, None);
211        assert_eq!(result, None);
212    }
213
214    #[test]
215    fn test_extract_client_ip_spoofing_attempt() {
216        // Attacker tries to spoof IP from untrusted source
217        let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
218        let config = ProxyConfig::new(vec![trusted_ip], true);
219
220        let mut headers = axum::http::HeaderMap::new();
221        // Attacker sends malicious X-Forwarded-For header
222        headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
223
224        // Request comes from untrusted IP (attacker direct IP)
225        let attacker_ip = "192.168.1.100".parse::<std::net::IpAddr>().ok();
226        let socket = attacker_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
227
228        // Should use attacker's direct IP, not the spoofed X-Forwarded-For
229        let result = config.extract_client_ip(&headers, socket);
230        assert_eq!(result, Some("192.168.1.100".to_string()));
231    }
232
233    #[test]
234    fn test_extract_client_ip_invalid_format_x_forwarded_for() {
235        // SECURITY: Invalid IP format in X-Forwarded-For header should be rejected
236        let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
237        let config = ProxyConfig::new(vec![trusted_ip], true);
238
239        let mut headers = axum::http::HeaderMap::new();
240        // Attacker sends malformed IP that's not a valid IP
241        headers.insert("x-forwarded-for", "not-a-valid-ip-address, 10.0.0.1".parse().unwrap());
242
243        let trusted_source_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
244        let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
245
246        // Should reject invalid format and fall back to direct IP
247        let result = config.extract_client_ip(&headers, socket);
248        assert_eq!(result, Some("10.0.0.1".to_string()));
249    }
250
251    #[test]
252    fn test_extract_client_ip_invalid_format_x_real_ip() {
253        // SECURITY: Invalid IP format in X-Real-IP header should be rejected
254        let trusted_ip: IpAddr = "10.0.0.1".parse().unwrap();
255        let config = ProxyConfig::new(vec![trusted_ip], true);
256
257        let mut headers = axum::http::HeaderMap::new();
258        // Attacker sends malformed IP
259        headers.insert("x-real-ip", "256.256.256.256".parse().unwrap());
260
261        let trusted_source_ip = "10.0.0.1".parse::<std::net::IpAddr>().ok();
262        let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
263
264        // Should reject invalid format and fall back to direct IP
265        let result = config.extract_client_ip(&headers, socket);
266        assert_eq!(result, Some("10.0.0.1".to_string()));
267    }
268
269    #[test]
270    fn test_extract_client_ip_valid_ipv6() {
271        // Test with valid IPv6 address
272        let trusted_ip: IpAddr = "::1".parse().unwrap();
273        let config = ProxyConfig::new(vec![trusted_ip], true);
274
275        let mut headers = axum::http::HeaderMap::new();
276        headers.insert("x-forwarded-for", "2001:db8::1, ::1".parse().unwrap());
277
278        let trusted_source_ip = "::1".parse::<std::net::IpAddr>().ok();
279        let socket = trusted_source_ip.map(|ip| std::net::SocketAddr::new(ip, 8000));
280
281        let result = config.extract_client_ip(&headers, socket);
282        assert_eq!(result, Some("2001:db8::1".to_string()));
283    }
284}