modo/geolocation/
middleware.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use axum::body::Body;
5use http::Request;
6use tower::{Layer, Service};
7
8use crate::ip::ClientIp;
9
10use super::locator::GeoLocator;
11
12pub struct GeoLayer {
19 locator: GeoLocator,
20}
21
22impl Clone for GeoLayer {
23 fn clone(&self) -> Self {
24 Self {
25 locator: self.locator.clone(),
26 }
27 }
28}
29
30impl GeoLayer {
31 pub fn new(locator: GeoLocator) -> Self {
33 Self { locator }
34 }
35}
36
37impl<S> Layer<S> for GeoLayer {
38 type Service = GeoMiddleware<S>;
39
40 fn layer(&self, inner: S) -> Self::Service {
41 GeoMiddleware {
42 inner,
43 locator: self.locator.clone(),
44 }
45 }
46}
47
48pub struct GeoMiddleware<S> {
50 inner: S,
51 locator: GeoLocator,
52}
53
54impl<S: Clone> Clone for GeoMiddleware<S> {
55 fn clone(&self) -> Self {
56 Self {
57 inner: self.inner.clone(),
58 locator: self.locator.clone(),
59 }
60 }
61}
62
63impl<S, ReqBody> Service<Request<ReqBody>> for GeoMiddleware<S>
64where
65 S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
66 S::Future: Send + 'static,
67 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
68 ReqBody: Send + 'static,
69{
70 type Response = http::Response<Body>;
71 type Error = S::Error;
72 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
73
74 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
75 self.inner.poll_ready(cx)
76 }
77
78 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
79 let locator = self.locator.clone();
80 let mut inner = self.inner.clone();
81 std::mem::swap(&mut self.inner, &mut inner);
82
83 Box::pin(async move {
84 if let Some(client_ip) = request.extensions().get::<ClientIp>().copied() {
85 match locator.lookup(client_ip.0) {
86 Ok(location) => {
87 request.extensions_mut().insert(location);
88 }
89 Err(e) => {
90 tracing::warn!(
91 ip = %client_ip.0,
92 error = %e,
93 "geolocation lookup failed"
94 );
95 }
96 }
97 }
98
99 inner.call(request).await
100 })
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::geolocation::{GeolocationConfig, Location};
108 use axum::body::Body;
109 use http::{Request, Response, StatusCode};
110 use std::convert::Infallible;
111 use tower::ServiceExt;
112
113 fn test_locator() -> GeoLocator {
114 GeoLocator::from_config(&GeolocationConfig {
115 mmdb_path: "tests/fixtures/GeoIP2-City-Test.mmdb".to_string(),
116 })
117 .unwrap()
118 }
119
120 async fn check_location(req: Request<Body>) -> Result<Response<Body>, Infallible> {
121 let has_location = req.extensions().get::<Location>().is_some();
122 let body = if has_location {
123 "has-location"
124 } else {
125 "no-location"
126 };
127 Ok(Response::new(Body::from(body)))
128 }
129
130 #[tokio::test]
131 async fn inserts_location_when_client_ip_present() {
132 let layer = GeoLayer::new(test_locator());
133 let svc = layer.layer(tower::service_fn(check_location));
134
135 let ip: std::net::IpAddr = "81.2.69.142".parse().unwrap();
136 let mut req = Request::builder().body(Body::empty()).unwrap();
137 req.extensions_mut().insert(ClientIp(ip));
138
139 let resp = svc.oneshot(req).await.unwrap();
140 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
141 .await
142 .unwrap();
143 assert_eq!(body.as_ref(), b"has-location");
144 }
145
146 #[tokio::test]
147 async fn passes_through_when_no_client_ip() {
148 let layer = GeoLayer::new(test_locator());
149 let svc = layer.layer(tower::service_fn(check_location));
150
151 let req = Request::builder().body(Body::empty()).unwrap();
152 let resp = svc.oneshot(req).await.unwrap();
153 assert_eq!(resp.status(), StatusCode::OK);
154
155 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
156 .await
157 .unwrap();
158 assert_eq!(body.as_ref(), b"no-location");
159 }
160
161 #[tokio::test]
162 async fn private_ip_inserts_default_location() {
163 let layer = GeoLayer::new(test_locator());
164 let svc = layer.layer(tower::service_fn(|req: Request<Body>| async move {
165 let loc = req.extensions().get::<Location>().cloned().unwrap();
166 let has_data = loc.country_code.is_some();
167 let body = if has_data { "has-data" } else { "empty" };
168 Ok::<_, Infallible>(Response::new(Body::from(body)))
169 }));
170
171 let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
172 let mut req = Request::builder().body(Body::empty()).unwrap();
173 req.extensions_mut().insert(ClientIp(ip));
174
175 let resp = svc.oneshot(req).await.unwrap();
176 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
177 .await
178 .unwrap();
179 assert_eq!(body.as_ref(), b"empty");
180 }
181}