Skip to main content

modo/geolocation/
middleware.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use axum::body::Body;
5use http::Request;
6use tower::{Layer, Service};
7
8use crate::ip::ClientIp;
9
10use super::locator::GeoLocator;
11
12/// Tower layer that performs geolocation lookup and inserts
13/// [`Location`](super::Location) into request extensions.
14///
15/// Apply this layer after [`ClientIpLayer`](crate::ip::ClientIpLayer) so that
16/// [`ClientIp`] is already present in extensions when `GeoLayer` runs.
17/// If `ClientIp` is absent the request passes through without modification.
18pub struct GeoLayer {
19    locator: GeoLocator,
20}
21
22impl Clone for GeoLayer {
23    fn clone(&self) -> Self {
24        Self {
25            locator: self.locator.clone(),
26        }
27    }
28}
29
30impl GeoLayer {
31    /// Create a new `GeoLayer` backed by `locator`.
32    pub fn new(locator: GeoLocator) -> Self {
33        Self { locator }
34    }
35}
36
37impl<S> Layer<S> for GeoLayer {
38    type Service = GeoMiddleware<S>;
39
40    fn layer(&self, inner: S) -> Self::Service {
41        GeoMiddleware {
42            inner,
43            locator: self.locator.clone(),
44        }
45    }
46}
47
48/// Tower service produced by [`GeoLayer`].
49pub struct GeoMiddleware<S> {
50    inner: S,
51    locator: GeoLocator,
52}
53
54impl<S: Clone> Clone for GeoMiddleware<S> {
55    fn clone(&self) -> Self {
56        Self {
57            inner: self.inner.clone(),
58            locator: self.locator.clone(),
59        }
60    }
61}
62
63impl<S, ReqBody> Service<Request<ReqBody>> for GeoMiddleware<S>
64where
65    S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
66    S::Future: Send + 'static,
67    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
68    ReqBody: Send + 'static,
69{
70    type Response = http::Response<Body>;
71    type Error = S::Error;
72    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
73
74    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75        self.inner.poll_ready(cx)
76    }
77
78    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
79        let locator = self.locator.clone();
80        let mut inner = self.inner.clone();
81        std::mem::swap(&mut self.inner, &mut inner);
82
83        Box::pin(async move {
84            if let Some(client_ip) = request.extensions().get::<ClientIp>().copied() {
85                match locator.lookup(client_ip.0) {
86                    Ok(location) => {
87                        request.extensions_mut().insert(location);
88                    }
89                    Err(e) => {
90                        tracing::warn!(
91                            ip = %client_ip.0,
92                            error = %e,
93                            "geolocation lookup failed"
94                        );
95                    }
96                }
97            }
98
99            inner.call(request).await
100        })
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::geolocation::{GeolocationConfig, Location};
108    use axum::body::Body;
109    use http::{Request, Response, StatusCode};
110    use std::convert::Infallible;
111    use tower::ServiceExt;
112
113    fn test_locator() -> GeoLocator {
114        GeoLocator::from_config(&GeolocationConfig {
115            mmdb_path: "tests/fixtures/GeoIP2-City-Test.mmdb".to_string(),
116        })
117        .unwrap()
118    }
119
120    async fn check_location(req: Request<Body>) -> Result<Response<Body>, Infallible> {
121        let has_location = req.extensions().get::<Location>().is_some();
122        let body = if has_location {
123            "has-location"
124        } else {
125            "no-location"
126        };
127        Ok(Response::new(Body::from(body)))
128    }
129
130    #[tokio::test]
131    async fn inserts_location_when_client_ip_present() {
132        let layer = GeoLayer::new(test_locator());
133        let svc = layer.layer(tower::service_fn(check_location));
134
135        let ip: std::net::IpAddr = "81.2.69.142".parse().unwrap();
136        let mut req = Request::builder().body(Body::empty()).unwrap();
137        req.extensions_mut().insert(ClientIp(ip));
138
139        let resp = svc.oneshot(req).await.unwrap();
140        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
141            .await
142            .unwrap();
143        assert_eq!(body.as_ref(), b"has-location");
144    }
145
146    #[tokio::test]
147    async fn passes_through_when_no_client_ip() {
148        let layer = GeoLayer::new(test_locator());
149        let svc = layer.layer(tower::service_fn(check_location));
150
151        let req = Request::builder().body(Body::empty()).unwrap();
152        let resp = svc.oneshot(req).await.unwrap();
153        assert_eq!(resp.status(), StatusCode::OK);
154
155        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
156            .await
157            .unwrap();
158        assert_eq!(body.as_ref(), b"no-location");
159    }
160
161    #[tokio::test]
162    async fn private_ip_inserts_default_location() {
163        let layer = GeoLayer::new(test_locator());
164        let svc = layer.layer(tower::service_fn(|req: Request<Body>| async move {
165            let loc = req.extensions().get::<Location>().cloned().unwrap();
166            let has_data = loc.country_code.is_some();
167            let body = if has_data { "has-data" } else { "empty" };
168            Ok::<_, Infallible>(Response::new(Body::from(body)))
169        }));
170
171        let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
172        let mut req = Request::builder().body(Body::empty()).unwrap();
173        req.extensions_mut().insert(ClientIp(ip));
174
175        let resp = svc.oneshot(req).await.unwrap();
176        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
177            .await
178            .unwrap();
179        assert_eq!(body.as_ref(), b"empty");
180    }
181}