Skip to main content

tower_real_ip/
lib.rs

1//! # Tower Real IP
2//!
3//! A robust middleware for extracting the real client IP address from HTTP requests,
4//! designed for environments behind trusted proxies (Load Balancers, CDNs, Nginx).
5//!
6//! ## Features
7//! - Supports `X-Forwarded-For` parsing (Right-to-Left security traversal).
8//! - Supports CIDR ranges (IPv4 & IPv6).
9//! - Auto-configuration from Environment Variables (split by `;`).
10//! - Axum 0.8 Extractor support.
11
12use axum::extract::{ConnectInfo, FromRequestParts};
13use http::{request::Parts, Request, Response};
14use ipnetwork::IpNetwork;
15use std::{
16    env,
17    future::Future,
18    net::{IpAddr, SocketAddr},
19    pin::Pin,
20    str::FromStr,
21    sync::Arc,
22    task::{Context, Poll},
23};
24use tower::{Layer, Service};
25use tracing::{debug, warn};
26
27// ============================================================================
28//  1. Configuration Logic
29// ============================================================================
30
31/// Configuration holding the list of trusted networks.
32#[derive(Clone, Debug)]
33pub struct TrustedProxyConfig {
34    trusted_networks: Arc<Vec<IpNetwork>>,
35}
36
37impl TrustedProxyConfig {
38    /// Creates a new config from a list of IP networks.
39    pub fn new(networks: Vec<IpNetwork>) -> Self {
40        Self {
41            trusted_networks: Arc::new(networks),
42        }
43    }
44
45    /// Loads configuration from an environment variable.
46    ///
47    /// Expected format: "127.0.0.1;10.0.0.0/8;::1"
48    pub fn from_env(env_key: &str) -> Result<Self, String> {
49        let val =
50            env::var(env_key).map_err(|_| format!("Environment variable {} not found", env_key))?;
51        Self::parse_str(&val)
52    }
53
54    /// Parses a string separated by `;` into trusted networks.
55    pub fn parse_str(input: &str) -> Result<Self, String> {
56        let mut networks = Vec::new();
57        for part in input.split(';') {
58            let part = part.trim();
59            if part.is_empty() {
60                continue;
61            }
62
63            // Try parsing as CIDR first, then as single IP
64            match part.parse::<IpNetwork>() {
65                Ok(net) => networks.push(net),
66                Err(_) => match part.parse::<IpAddr>() {
67                    Ok(ip) => networks.push(IpNetwork::from(ip)),
68                    Err(_) => return Err(format!("Invalid IP or CIDR: {}", part)),
69                },
70            }
71        }
72
73        debug!("Loaded {} trusted proxy networks", networks.len());
74        Ok(Self::new(networks))
75    }
76
77    /// Checks if an IP is trusted.
78    pub fn is_trusted(&self, ip: &IpAddr) -> bool {
79        self.trusted_networks.iter().any(|net| net.contains(*ip))
80    }
81}
82
83// ============================================================================
84//  2. The Result Struct (What the user gets)
85// ============================================================================
86
87/// The resolved real IP address of the client.
88#[derive(Debug, Clone, Copy, PartialEq, Eq)]
89pub struct RealIp(pub IpAddr);
90
91// ============================================================================
92//  3. Tower Middleware Implementation
93// ============================================================================
94
95#[derive(Clone)]
96pub struct RealIpLayer {
97    config: TrustedProxyConfig,
98}
99
100impl RealIpLayer {
101    pub fn new(config: TrustedProxyConfig) -> Self {
102        Self { config }
103    }
104}
105
106impl<S> Layer<S> for RealIpLayer {
107    type Service = RealIpService<S>;
108
109    fn layer(&self, inner: S) -> Self::Service {
110        RealIpService {
111            inner,
112            config: self.config.clone(),
113        }
114    }
115}
116
117#[derive(Clone)]
118pub struct RealIpService<S> {
119    inner: S,
120    config: TrustedProxyConfig,
121}
122
123impl<S, B> Service<Request<B>> for RealIpService<S>
124where
125    S: Service<Request<B>, Response = Response<B>> + Send + Clone + 'static,
126    S::Future: Send + 'static,
127    B: Send + 'static,
128{
129    type Response = S::Response;
130    type Error = S::Error;
131    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
132
133    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134        self.inner.poll_ready(cx)
135    }
136
137    fn call(&mut self, mut req: Request<B>) -> Self::Future {
138        // 1. Extract the direct connection IP (Peer Address)
139        // Axum/Tower usually provides this via ConnectInfo extension
140        let remote_addr = req
141            .extensions()
142            .get::<ConnectInfo<SocketAddr>>()
143            .map(|ci| ci.0.ip());
144
145        let config = self.config.clone();
146        let headers = req.headers().clone(); // Clone headers to use in async block
147
148        let mut inner = self.inner.clone();
149
150        Box::pin(async move {
151            let mut resolved_ip = remote_addr.unwrap_or_else(|| {
152                // Fallback if no underlying TCP info is present (shouldn't happen in normal HTTP serving)
153                IpAddr::from([0, 0, 0, 0])
154            });
155
156            // 2. The Core Algorithm: Trusted Proxy Traversal
157            if let Some(peer_ip) = remote_addr {
158                // Only attempt to parse headers if the direct peer is trusted
159                if config.is_trusted(&peer_ip)
160                    && let Some(xff_val) = headers.get("x-forwarded-for")
161                    && let Ok(xff_str) = xff_val.to_str()
162                {
163                    // Parse the comma-separated list
164                    // List: Client, Proxy1, Proxy2
165                    // We reverse iterate: Proxy2 -> Proxy1 -> Client
166                    let ips: Vec<&str> = xff_str.split(',').map(|s| s.trim()).collect();
167
168                    for ip_str in ips.iter().rev() {
169                        if let Ok(ip) = IpAddr::from_str(ip_str) {
170                            if !config.is_trusted(&ip) {
171                                // Found the first untrusted IP (looking backwards)
172                                // This is the Client.
173                                resolved_ip = ip;
174                                break;
175                            }
176                            // If trusted, continue strictly to the left
177                        } else {
178                            warn!("Skipping invalid IP in X-Forwarded-For: {}", ip_str);
179                        }
180                    }
181                    // Edge case: If all IPs in header are trusted, the loop finishes.
182                    // The `resolved_ip` remains the last trusted one (or peer),
183                    // but technically if strictly all are trusted, the request originates
184                    // from your internal network. We keep the peer or last logic.
185                }
186            }
187
188            // 3. Inject the result into extensions
189            req.extensions_mut().insert(RealIp(resolved_ip));
190
191            // 4. Forward request
192            inner.call(req).await
193        })
194    }
195}
196
197// ============================================================================
198//  4. Axum Extractor Support
199// ============================================================================
200
201/// Allows using `RealIp` directly in Axum handlers arguments.
202///
203/// Example:
204/// `async fn handler(RealIp(ip): RealIp) -> ...`
205impl<S> FromRequestParts<S> for RealIp
206where
207    S: Send + Sync,
208{
209    type Rejection = (http::StatusCode, &'static str);
210
211    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
212        parts.extensions.get::<RealIp>().cloned().ok_or((
213            http::StatusCode::INTERNAL_SERVER_ERROR,
214            "RealIp middleware is not configured correctly. Missing RealIp extension.",
215        ))
216    }
217}