modo/geolocation/
middleware.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use axum::body::Body;
5use http::Request;
6use tower::{Layer, Service};
7
8use crate::ip::ClientIp;
9
10use super::locator::GeoLocator;
11
12pub struct GeoLayer {
19 locator: GeoLocator,
20}
21
22impl Clone for GeoLayer {
23 fn clone(&self) -> Self {
24 Self {
25 locator: self.locator.clone(),
26 }
27 }
28}
29
30impl GeoLayer {
31 pub fn new(locator: GeoLocator) -> Self {
36 Self { locator }
37 }
38}
39
40impl<S> Layer<S> for GeoLayer {
41 type Service = GeoMiddleware<S>;
42
43 fn layer(&self, inner: S) -> Self::Service {
44 GeoMiddleware {
45 inner,
46 locator: self.locator.clone(),
47 }
48 }
49}
50
51pub struct GeoMiddleware<S> {
53 inner: S,
54 locator: GeoLocator,
55}
56
57impl<S: Clone> Clone for GeoMiddleware<S> {
58 fn clone(&self) -> Self {
59 Self {
60 inner: self.inner.clone(),
61 locator: self.locator.clone(),
62 }
63 }
64}
65
66impl<S, ReqBody> Service<Request<ReqBody>> for GeoMiddleware<S>
67where
68 S: Service<Request<ReqBody>, Response = http::Response<Body>> + Clone + Send + 'static,
69 S::Future: Send + 'static,
70 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
71 ReqBody: Send + 'static,
72{
73 type Response = http::Response<Body>;
74 type Error = S::Error;
75 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
76
77 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78 self.inner.poll_ready(cx)
79 }
80
81 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
82 let locator = self.locator.clone();
83 let mut inner = self.inner.clone();
84 std::mem::swap(&mut self.inner, &mut inner);
85
86 Box::pin(async move {
87 if let Some(client_ip) = request.extensions().get::<ClientIp>().copied() {
88 match locator.lookup(client_ip.0) {
89 Ok(location) => {
90 request.extensions_mut().insert(location);
91 }
92 Err(e) => {
93 tracing::warn!(
94 ip = %client_ip.0,
95 error = %e,
96 "geolocation lookup failed"
97 );
98 }
99 }
100 }
101
102 inner.call(request).await
103 })
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::geolocation::{GeolocationConfig, Location};
111 use axum::body::Body;
112 use http::{Request, Response, StatusCode};
113 use std::convert::Infallible;
114 use tower::ServiceExt;
115
116 fn test_locator() -> GeoLocator {
117 GeoLocator::from_config(&GeolocationConfig {
118 mmdb_path: "tests/fixtures/GeoIP2-City-Test.mmdb".to_string(),
119 })
120 .unwrap()
121 }
122
123 async fn check_location(req: Request<Body>) -> Result<Response<Body>, Infallible> {
124 let has_location = req.extensions().get::<Location>().is_some();
125 let body = if has_location {
126 "has-location"
127 } else {
128 "no-location"
129 };
130 Ok(Response::new(Body::from(body)))
131 }
132
133 #[tokio::test]
134 async fn inserts_location_when_client_ip_present() {
135 let layer = GeoLayer::new(test_locator());
136 let svc = layer.layer(tower::service_fn(check_location));
137
138 let ip: std::net::IpAddr = "81.2.69.142".parse().unwrap();
139 let mut req = Request::builder().body(Body::empty()).unwrap();
140 req.extensions_mut().insert(ClientIp(ip));
141
142 let resp = svc.oneshot(req).await.unwrap();
143 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
144 .await
145 .unwrap();
146 assert_eq!(body.as_ref(), b"has-location");
147 }
148
149 #[tokio::test]
150 async fn passes_through_when_no_client_ip() {
151 let layer = GeoLayer::new(test_locator());
152 let svc = layer.layer(tower::service_fn(check_location));
153
154 let req = Request::builder().body(Body::empty()).unwrap();
155 let resp = svc.oneshot(req).await.unwrap();
156 assert_eq!(resp.status(), StatusCode::OK);
157
158 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
159 .await
160 .unwrap();
161 assert_eq!(body.as_ref(), b"no-location");
162 }
163
164 #[tokio::test]
165 async fn private_ip_inserts_default_location() {
166 let layer = GeoLayer::new(test_locator());
167 let svc = layer.layer(tower::service_fn(|req: Request<Body>| async move {
168 let loc = req.extensions().get::<Location>().cloned().unwrap();
169 let has_data = loc.country_code.is_some();
170 let body = if has_data { "has-data" } else { "empty" };
171 Ok::<_, Infallible>(Response::new(Body::from(body)))
172 }));
173
174 let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
175 let mut req = Request::builder().body(Body::empty()).unwrap();
176 req.extensions_mut().insert(ClientIp(ip));
177
178 let resp = svc.oneshot(req).await.unwrap();
179 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
180 .await
181 .unwrap();
182 assert_eq!(body.as_ref(), b"empty");
183 }
184}