Skip to main content

modo/ip/
middleware.rs

1use std::net::IpAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::extract::connect_info::ConnectInfo;
8use http::Request;
9use tower::{Layer, Service};
10
11use super::client_ip::ClientIp;
12use super::extract::extract_client_ip;
13
14/// Tower layer that extracts the client IP address and inserts
15/// [`ClientIp`] into request extensions.
16///
17/// Apply with `Router::layer()`. When trusted proxies are configured,
18/// `X-Forwarded-For` and `X-Real-IP` headers are only honoured for
19/// connections originating from a trusted CIDR range.
20pub struct ClientIpLayer {
21    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
22}
23
24impl Clone for ClientIpLayer {
25    fn clone(&self) -> Self {
26        Self {
27            trusted_proxies: self.trusted_proxies.clone(),
28        }
29    }
30}
31
32impl ClientIpLayer {
33    /// Create a layer with no trusted proxies.
34    /// Headers are trusted unconditionally; `ConnectInfo` is the final fallback.
35    pub fn new() -> Self {
36        Self {
37            trusted_proxies: Arc::new(Vec::new()),
38        }
39    }
40
41    /// Create a layer with pre-parsed trusted proxy CIDR ranges.
42    pub fn with_trusted_proxies(proxies: Vec<ipnet::IpNet>) -> Self {
43        Self {
44            trusted_proxies: Arc::new(proxies),
45        }
46    }
47}
48
49impl Default for ClientIpLayer {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl<S> Layer<S> for ClientIpLayer {
56    type Service = ClientIpMiddleware<S>;
57
58    fn layer(&self, inner: S) -> Self::Service {
59        ClientIpMiddleware {
60            inner,
61            trusted_proxies: self.trusted_proxies.clone(),
62        }
63    }
64}
65
66/// Tower service produced by [`ClientIpLayer`].
67///
68/// Resolves the client IP on every request and inserts it as a [`ClientIp`]
69/// extension before delegating to the inner service.
70pub struct ClientIpMiddleware<S> {
71    inner: S,
72    trusted_proxies: Arc<Vec<ipnet::IpNet>>,
73}
74
75impl<S: Clone> Clone for ClientIpMiddleware<S> {
76    fn clone(&self) -> Self {
77        Self {
78            inner: self.inner.clone(),
79            trusted_proxies: self.trusted_proxies.clone(),
80        }
81    }
82}
83
84impl<S, ReqBody> Service<Request<ReqBody>> for ClientIpMiddleware<S>
85where
86    S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
87    S::Future: Send + 'static,
88    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
89    ReqBody: Send + 'static,
90{
91    type Response = http::Response<Body>;
92    type Error = S::Error;
93    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
94
95    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96        self.inner.poll_ready(cx)
97    }
98
99    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
100        let trusted_proxies = self.trusted_proxies.clone();
101        let mut inner = self.inner.clone();
102        std::mem::swap(&mut self.inner, &mut inner);
103
104        Box::pin(async move {
105            let connect_ip: Option<IpAddr> = request
106                .extensions()
107                .get::<ConnectInfo<std::net::SocketAddr>>()
108                .map(|ci| ci.0.ip());
109
110            let ip = extract_client_ip(request.headers(), &trusted_proxies, connect_ip);
111            request.extensions_mut().insert(ClientIp(ip));
112
113            inner.call(request).await
114        })
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use axum::body::Body;
122    use http::{Request, Response, StatusCode};
123    use std::convert::Infallible;
124    use tower::ServiceExt;
125
126    async fn echo_ip(req: Request<Body>) -> Result<Response<Body>, Infallible> {
127        let ip = req
128            .extensions()
129            .get::<ClientIp>()
130            .map(|c| c.0.to_string())
131            .unwrap_or_else(|| "missing".to_string());
132        Ok(Response::new(Body::from(ip)))
133    }
134
135    #[tokio::test]
136    async fn inserts_client_ip_from_xff() {
137        let layer = ClientIpLayer::new();
138        let svc = layer.layer(tower::service_fn(echo_ip));
139
140        let req = Request::builder()
141            .header("x-forwarded-for", "8.8.8.8")
142            .body(Body::empty())
143            .unwrap();
144        let resp = svc.oneshot(req).await.unwrap();
145        assert_eq!(resp.status(), StatusCode::OK);
146
147        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
148            .await
149            .unwrap();
150        assert_eq!(body.as_ref(), b"8.8.8.8");
151    }
152
153    #[tokio::test]
154    async fn falls_back_to_localhost_when_no_info() {
155        let layer = ClientIpLayer::new();
156        let svc = layer.layer(tower::service_fn(echo_ip));
157
158        let req = Request::builder().body(Body::empty()).unwrap();
159        let resp = svc.oneshot(req).await.unwrap();
160
161        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
162            .await
163            .unwrap();
164        assert_eq!(body.as_ref(), b"127.0.0.1");
165    }
166
167    #[tokio::test]
168    async fn respects_trusted_proxies() {
169        let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
170        let layer = ClientIpLayer::with_trusted_proxies(trusted);
171        let svc = layer.layer(tower::service_fn(echo_ip));
172
173        let mut req = Request::builder()
174            .header("x-forwarded-for", "1.2.3.4")
175            .body(Body::empty())
176            .unwrap();
177        req.extensions_mut()
178            .insert(ConnectInfo(std::net::SocketAddr::from((
179                [10, 0, 0, 1],
180                1234,
181            ))));
182
183        let resp = svc.oneshot(req).await.unwrap();
184        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
185            .await
186            .unwrap();
187        assert_eq!(body.as_ref(), b"1.2.3.4");
188    }
189
190    #[tokio::test]
191    async fn untrusted_source_ignores_xff() {
192        let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
193        let layer = ClientIpLayer::with_trusted_proxies(trusted);
194        let svc = layer.layer(tower::service_fn(echo_ip));
195
196        let mut req = Request::builder()
197            .header("x-forwarded-for", "1.2.3.4")
198            .body(Body::empty())
199            .unwrap();
200        req.extensions_mut()
201            .insert(ConnectInfo(std::net::SocketAddr::from((
202                [203, 0, 113, 5],
203                1234,
204            ))));
205
206        let resp = svc.oneshot(req).await.unwrap();
207        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
208            .await
209            .unwrap();
210        assert_eq!(body.as_ref(), b"203.0.113.5");
211    }
212}