by_loco/controller/middleware/
remote_ip.rs1use std::{
12 fmt,
13 iter::Iterator,
14 net::{IpAddr, SocketAddr},
15 str::FromStr,
16 sync::OnceLock,
17 task::{Context, Poll},
18};
19
20use axum::{
21 body::Body,
22 extract::{ConnectInfo, FromRequestParts, Request},
23 http::{header::HeaderMap, request::Parts},
24 response::Response,
25 Router as AXRouter,
26};
27use futures_util::future::BoxFuture;
28use ipnetwork::IpNetwork;
29use serde::{Deserialize, Serialize};
30use tower::{Layer, Service};
31use tracing::error;
32
33use crate::{app::AppContext, controller::middleware::MiddlewareLayer, Error, Result};
34
35static LOCAL_TRUSTED_PROXIES: OnceLock<Vec<IpNetwork>> = OnceLock::new();
36
37fn get_local_trusted_proxies() -> &'static Vec<IpNetwork> {
38 LOCAL_TRUSTED_PROXIES.get_or_init(|| {
39 [
40 "127.0.0.0/8", "::1", "fc00::/7", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16",
46 ]
47 .iter()
48 .map(|ip| IpNetwork::from_str(ip).unwrap())
49 .collect()
50 })
51}
52
53const X_FORWARDED_FOR: &str = "X-Forwarded-For";
54
55#[derive(Default, Serialize, Deserialize, Debug, Clone)]
96pub struct RemoteIpMiddleware {
97 #[serde(default)]
98 pub enable: bool,
99 pub trusted_proxies: Option<Vec<String>>,
102}
103
104impl MiddlewareLayer for RemoteIpMiddleware {
105 fn name(&self) -> &'static str {
107 "remote_ip"
108 }
109
110 fn is_enabled(&self) -> bool {
112 self.enable
113 && (self.trusted_proxies.is_none()
114 || self.trusted_proxies.as_ref().is_some_and(|t| !t.is_empty()))
115 }
116
117 fn config(&self) -> serde_json::Result<serde_json::Value> {
118 serde_json::to_value(self)
119 }
120
121 fn apply(&self, app: AXRouter<AppContext>) -> Result<AXRouter<AppContext>> {
123 Ok(app.layer(RemoteIPLayer::new(self)?))
124 }
125}
126
127fn maybe_get_forwarded(
129 headers: &HeaderMap,
130 trusted_proxies: Option<&Vec<IpNetwork>>,
131) -> Option<IpAddr> {
132 let xffs = headers
140 .get_all(X_FORWARDED_FOR)
141 .iter()
142 .map(|hdr| hdr.to_str())
143 .filter_map(Result::ok)
144 .collect::<Vec<_>>();
145
146 if xffs.is_empty() {
147 return None;
148 }
149
150 let forwarded = xffs.join(",");
151
152 forwarded
153 .split(',')
154 .map(str::trim)
155 .map(str::parse)
156 .filter_map(Result::ok)
157 .filter(|ip| {
163 let proxies = trusted_proxies.unwrap_or_else(|| get_local_trusted_proxies());
165 !proxies
166 .iter()
167 .any(|trusted_proxy| trusted_proxy.contains(*ip))
168 })
169 .next_back()
179}
180
181#[derive(Copy, Clone, Debug)]
182pub enum RemoteIP {
183 Forwarded(IpAddr),
184 Socket(IpAddr),
185 None,
186}
187
188impl<S> FromRequestParts<S> for RemoteIP
189where
190 S: Send + Sync,
191{
192 type Rejection = ();
193
194 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
195 let ip = parts.extensions.get::<Self>();
196 Ok(*ip.unwrap_or(&Self::None))
197 }
198}
199
200impl fmt::Display for RemoteIP {
201 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202 match self {
203 Self::Forwarded(ip) => write!(f, "remote: {ip}"),
204 Self::Socket(ip) => write!(f, "socket: {ip}"),
205 Self::None => write!(f, "--"),
206 }
207 }
208}
209
210#[derive(Clone, Debug)]
211struct RemoteIPLayer {
212 trusted_proxies: Option<Vec<IpNetwork>>,
213}
214
215impl RemoteIPLayer {
216 pub fn new(config: &RemoteIpMiddleware) -> Result<Self> {
221 Ok(Self {
222 trusted_proxies: config
223 .trusted_proxies
224 .as_ref()
225 .map(|proxies| {
226 proxies
227 .iter()
228 .map(|proxy| {
229 IpNetwork::from_str(proxy).map_err(|err| {
230 Error::Message(format!(
231 "remote ip middleare cannot parse trusted proxy \
232 configuration: `{proxy}`, reason: `{err}`",
233 ))
234 })
235 })
236 .collect::<Result<Vec<_>>>()
237 })
238 .transpose()?,
239 })
240 }
241}
242
243impl<S> Layer<S> for RemoteIPLayer {
244 type Service = RemoteIPMiddleware<S>;
245
246 fn layer(&self, inner: S) -> Self::Service {
247 RemoteIPMiddleware {
248 inner,
249 layer: self.clone(),
250 }
251 }
252}
253
254#[derive(Clone, Debug)]
256#[must_use]
257pub struct RemoteIPMiddleware<S> {
258 inner: S,
259 layer: RemoteIPLayer,
260}
261
262impl<S> Service<Request<Body>> for RemoteIPMiddleware<S>
263where
264 S: Service<Request<Body>, Response = Response> + Send + 'static,
265 S::Future: Send + 'static,
266{
267 type Response = S::Response;
268 type Error = S::Error;
269 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
270
271 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272 self.inner.poll_ready(cx)
273 }
274
275 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
276 let layer = self.layer.clone();
277 let xff_ip = maybe_get_forwarded(req.headers(), layer.trusted_proxies.as_ref());
278 let remote_ip = xff_ip.map_or_else(
279 || {
280 let ip = req
281 .extensions()
282 .get::<ConnectInfo<SocketAddr>>()
283 .map_or_else(
284 || {
285 error!(
286 "remote ip middleware cannot get socket IP (not set in axum \
287 extensions): setting IP to `127.0.0.1`"
288 );
289 RemoteIP::None
290 },
291 |info| RemoteIP::Socket(info.ip()),
292 );
293 ip
294 },
295 RemoteIP::Forwarded,
296 );
297
298 req.extensions_mut().insert(remote_ip);
299
300 Box::pin(self.inner.call(req))
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use std::str::FromStr;
307
308 use axum::http::{HeaderMap, HeaderName, HeaderValue};
309 use insta::assert_debug_snapshot;
310 use ipnetwork::IpNetwork;
311
312 use super::maybe_get_forwarded;
313
314 fn xff(val: &str) -> HeaderMap {
315 let mut headers = HeaderMap::new();
316
317 headers.insert(
318 HeaderName::from_static("x-forwarded-for"),
319 HeaderValue::from_str(val).unwrap(),
320 );
321 headers
322 }
323
324 #[test]
325 pub fn test_parsing() {
326 let res = maybe_get_forwarded(&xff(""), None);
327 assert_debug_snapshot!(res);
328 let res = maybe_get_forwarded(&xff("foobar"), None);
329 assert_debug_snapshot!(res);
330 let res = maybe_get_forwarded(&xff("192.1.1.1"), None);
331 assert_debug_snapshot!(res);
332 let res = maybe_get_forwarded(&xff("51.50.51.50,10.0.0.1,192.168.1.1"), None);
333 assert_debug_snapshot!(res);
334 let res = maybe_get_forwarded(&xff("19.84.19.84,192.168.0.1"), None);
335 assert_debug_snapshot!(res);
336 let res = maybe_get_forwarded(&xff("b51.50.51.50b,/10.0.0.1-,192.168.1.1"), None);
337 assert_debug_snapshot!(res);
338 let res = maybe_get_forwarded(
339 &xff("51.50.51.50,192.1.1.1"),
340 Some(&vec![IpNetwork::from_str("192.1.1.1/8").unwrap()]),
341 );
342 assert_debug_snapshot!(res);
343
344 let res = maybe_get_forwarded(
347 &xff("51.50.51.50,192.168.1.1"),
348 Some(&vec![IpNetwork::from_str("192.1.1.1/16").unwrap()]),
349 );
350 assert_debug_snapshot!(res);
351 }
352}