1use std::net::IpAddr;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::extract::connect_info::ConnectInfo;
8use http::Request;
9use tower::{Layer, Service};
10
11use super::client_ip::ClientIp;
12use super::extract::extract_client_ip;
13
14pub struct ClientIpLayer {
21 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
22}
23
24impl Clone for ClientIpLayer {
25 fn clone(&self) -> Self {
26 Self {
27 trusted_proxies: self.trusted_proxies.clone(),
28 }
29 }
30}
31
32impl ClientIpLayer {
33 pub fn new() -> Self {
36 Self {
37 trusted_proxies: Arc::new(Vec::new()),
38 }
39 }
40
41 pub fn with_trusted_proxies(proxies: Vec<ipnet::IpNet>) -> Self {
43 Self {
44 trusted_proxies: Arc::new(proxies),
45 }
46 }
47}
48
49impl Default for ClientIpLayer {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl<S> Layer<S> for ClientIpLayer {
56 type Service = ClientIpMiddleware<S>;
57
58 fn layer(&self, inner: S) -> Self::Service {
59 ClientIpMiddleware {
60 inner,
61 trusted_proxies: self.trusted_proxies.clone(),
62 }
63 }
64}
65
66pub struct ClientIpMiddleware<S> {
71 inner: S,
72 trusted_proxies: Arc<Vec<ipnet::IpNet>>,
73}
74
75impl<S: Clone> Clone for ClientIpMiddleware<S> {
76 fn clone(&self) -> Self {
77 Self {
78 inner: self.inner.clone(),
79 trusted_proxies: self.trusted_proxies.clone(),
80 }
81 }
82}
83
84impl<S, ReqBody> Service<Request<ReqBody>> for ClientIpMiddleware<S>
85where
86 S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
87 S::Future: Send + 'static,
88 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
89 ReqBody: Send + 'static,
90{
91 type Response = http::Response<Body>;
92 type Error = S::Error;
93 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
94
95 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96 self.inner.poll_ready(cx)
97 }
98
99 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
100 let trusted_proxies = self.trusted_proxies.clone();
101 let mut inner = self.inner.clone();
102 std::mem::swap(&mut self.inner, &mut inner);
103
104 Box::pin(async move {
105 let connect_ip: Option<IpAddr> = request
106 .extensions()
107 .get::<ConnectInfo<std::net::SocketAddr>>()
108 .map(|ci| ci.0.ip());
109
110 let ip = extract_client_ip(request.headers(), &trusted_proxies, connect_ip);
111 request.extensions_mut().insert(ClientIp(ip));
112
113 inner.call(request).await
114 })
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use axum::body::Body;
122 use http::{Request, Response, StatusCode};
123 use std::convert::Infallible;
124 use tower::ServiceExt;
125
126 async fn echo_ip(req: Request<Body>) -> Result<Response<Body>, Infallible> {
127 let ip = req
128 .extensions()
129 .get::<ClientIp>()
130 .map(|c| c.0.to_string())
131 .unwrap_or_else(|| "missing".to_string());
132 Ok(Response::new(Body::from(ip)))
133 }
134
135 #[tokio::test]
136 async fn inserts_client_ip_from_xff() {
137 let layer = ClientIpLayer::new();
138 let svc = layer.layer(tower::service_fn(echo_ip));
139
140 let req = Request::builder()
141 .header("x-forwarded-for", "8.8.8.8")
142 .body(Body::empty())
143 .unwrap();
144 let resp = svc.oneshot(req).await.unwrap();
145 assert_eq!(resp.status(), StatusCode::OK);
146
147 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
148 .await
149 .unwrap();
150 assert_eq!(body.as_ref(), b"8.8.8.8");
151 }
152
153 #[tokio::test]
154 async fn falls_back_to_localhost_when_no_info() {
155 let layer = ClientIpLayer::new();
156 let svc = layer.layer(tower::service_fn(echo_ip));
157
158 let req = Request::builder().body(Body::empty()).unwrap();
159 let resp = svc.oneshot(req).await.unwrap();
160
161 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
162 .await
163 .unwrap();
164 assert_eq!(body.as_ref(), b"127.0.0.1");
165 }
166
167 #[tokio::test]
168 async fn respects_trusted_proxies() {
169 let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
170 let layer = ClientIpLayer::with_trusted_proxies(trusted);
171 let svc = layer.layer(tower::service_fn(echo_ip));
172
173 let mut req = Request::builder()
174 .header("x-forwarded-for", "1.2.3.4")
175 .body(Body::empty())
176 .unwrap();
177 req.extensions_mut()
178 .insert(ConnectInfo(std::net::SocketAddr::from((
179 [10, 0, 0, 1],
180 1234,
181 ))));
182
183 let resp = svc.oneshot(req).await.unwrap();
184 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
185 .await
186 .unwrap();
187 assert_eq!(body.as_ref(), b"1.2.3.4");
188 }
189
190 #[tokio::test]
191 async fn untrusted_source_ignores_xff() {
192 let trusted: Vec<ipnet::IpNet> = vec!["10.0.0.0/8".parse().unwrap()];
193 let layer = ClientIpLayer::with_trusted_proxies(trusted);
194 let svc = layer.layer(tower::service_fn(echo_ip));
195
196 let mut req = Request::builder()
197 .header("x-forwarded-for", "1.2.3.4")
198 .body(Body::empty())
199 .unwrap();
200 req.extensions_mut()
201 .insert(ConnectInfo(std::net::SocketAddr::from((
202 [203, 0, 113, 5],
203 1234,
204 ))));
205
206 let resp = svc.oneshot(req).await.unwrap();
207 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
208 .await
209 .unwrap();
210 assert_eq!(body.as_ref(), b"203.0.113.5");
211 }
212}