Skip to main content

modo/geolocation/
middleware.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use axum::body::Body;
5use http::Request;
6use tower::{Layer, Service};
7
8use crate::ip::ClientIp;
9
10use super::locator::GeoLocator;
11
12/// Tower layer that performs geolocation lookup and inserts
13/// [`Location`](super::Location) into request extensions.
14///
15/// Apply this layer after [`ClientIpLayer`](crate::ip::ClientIpLayer) so that
16/// [`ClientIp`] is already present in extensions when `GeoLayer` runs.
17/// If `ClientIp` is absent the request passes through without modification.
18pub struct GeoLayer {
19    locator: GeoLocator,
20}
21
22impl Clone for GeoLayer {
23    fn clone(&self) -> Self {
24        Self {
25            locator: self.locator.clone(),
26        }
27    }
28}
29
30impl GeoLayer {
31    /// Create a new `GeoLayer` backed by `locator`.
32    ///
33    /// The layer clones `locator` per-request; cloning is cheap because
34    /// [`GeoLocator`] holds the mmdb reader behind an `Arc`.
35    pub fn new(locator: GeoLocator) -> Self {
36        Self { locator }
37    }
38}
39
40impl<S> Layer<S> for GeoLayer {
41    type Service = GeoMiddleware<S>;
42
43    fn layer(&self, inner: S) -> Self::Service {
44        GeoMiddleware {
45            inner,
46            locator: self.locator.clone(),
47        }
48    }
49}
50
51/// Tower service produced by [`GeoLayer`].
52pub struct GeoMiddleware<S> {
53    inner: S,
54    locator: GeoLocator,
55}
56
57impl<S: Clone> Clone for GeoMiddleware<S> {
58    fn clone(&self) -> Self {
59        Self {
60            inner: self.inner.clone(),
61            locator: self.locator.clone(),
62        }
63    }
64}
65
66impl<S, ReqBody> Service<Request<ReqBody>> for GeoMiddleware<S>
67where
68    S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
69    S::Future: Send + 'static,
70    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
71    ReqBody: Send + 'static,
72{
73    type Response = http::Response<Body>;
74    type Error = S::Error;
75    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
76
77    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78        self.inner.poll_ready(cx)
79    }
80
81    fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
82        let locator = self.locator.clone();
83        let mut inner = self.inner.clone();
84        std::mem::swap(&mut self.inner, &mut inner);
85
86        Box::pin(async move {
87            if let Some(client_ip) = request.extensions().get::<ClientIp>().copied() {
88                match locator.lookup(client_ip.0) {
89                    Ok(location) => {
90                        request.extensions_mut().insert(location);
91                    }
92                    Err(e) => {
93                        tracing::warn!(
94                            ip = %client_ip.0,
95                            error = %e,
96                            "geolocation lookup failed"
97                        );
98                    }
99                }
100            }
101
102            inner.call(request).await
103        })
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::geolocation::{GeolocationConfig, Location};
111    use axum::body::Body;
112    use http::{Request, Response, StatusCode};
113    use std::convert::Infallible;
114    use tower::ServiceExt;
115
116    fn test_locator() -> GeoLocator {
117        GeoLocator::from_config(&GeolocationConfig {
118            mmdb_path: "tests/fixtures/GeoIP2-City-Test.mmdb".to_string(),
119        })
120        .unwrap()
121    }
122
123    async fn check_location(req: Request<Body>) -> Result<Response<Body>, Infallible> {
124        let has_location = req.extensions().get::<Location>().is_some();
125        let body = if has_location {
126            "has-location"
127        } else {
128            "no-location"
129        };
130        Ok(Response::new(Body::from(body)))
131    }
132
133    #[tokio::test]
134    async fn inserts_location_when_client_ip_present() {
135        let layer = GeoLayer::new(test_locator());
136        let svc = layer.layer(tower::service_fn(check_location));
137
138        let ip: std::net::IpAddr = "81.2.69.142".parse().unwrap();
139        let mut req = Request::builder().body(Body::empty()).unwrap();
140        req.extensions_mut().insert(ClientIp(ip));
141
142        let resp = svc.oneshot(req).await.unwrap();
143        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
144            .await
145            .unwrap();
146        assert_eq!(body.as_ref(), b"has-location");
147    }
148
149    #[tokio::test]
150    async fn passes_through_when_no_client_ip() {
151        let layer = GeoLayer::new(test_locator());
152        let svc = layer.layer(tower::service_fn(check_location));
153
154        let req = Request::builder().body(Body::empty()).unwrap();
155        let resp = svc.oneshot(req).await.unwrap();
156        assert_eq!(resp.status(), StatusCode::OK);
157
158        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
159            .await
160            .unwrap();
161        assert_eq!(body.as_ref(), b"no-location");
162    }
163
164    #[tokio::test]
165    async fn private_ip_inserts_default_location() {
166        let layer = GeoLayer::new(test_locator());
167        let svc = layer.layer(tower::service_fn(|req: Request<Body>| async move {
168            let loc = req.extensions().get::<Location>().cloned().unwrap();
169            let has_data = loc.country_code.is_some();
170            let body = if has_data { "has-data" } else { "empty" };
171            Ok::<_, Infallible>(Response::new(Body::from(body)))
172        }));
173
174        let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
175        let mut req = Request::builder().body(Body::empty()).unwrap();
176        req.extensions_mut().insert(ClientIp(ip));
177
178        let resp = svc.oneshot(req).await.unwrap();
179        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
180            .await
181            .unwrap();
182        assert_eq!(body.as_ref(), b"empty");
183    }
184}