1#![doc = include_str!("../readme.md")]
2
3use std::collections::HashMap;
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9use axum::http::Request;
10use axum::response::Response;
11use futures_util::future::BoxFuture;
12use maxminddb::geoip2;
13use tower::Service;
14
15pub use maxminddb::Reader;
16
17#[derive(Debug)]
22pub struct Config {
23 tariffs: HashMap<Box<str>, Duration>,
25 reader: Reader<Vec<u8>>,
27}
28
29impl Config {
30 pub fn new(reader: Reader<Vec<u8>>) -> Self {
43 Self {
44 tariffs: Default::default(),
45 reader,
46 }
47 }
48
49 pub fn with(mut self, code: &str, delay: Duration) -> Self {
67 self.tariffs.insert(Box::from(code.to_uppercase()), delay);
68 self
69 }
70
71 pub fn into_layer(self) -> TariffLayer {
91 TariffLayer {
92 config: Arc::new(self),
93 }
94 }
95
96 fn get_delay_for_ip(&self, ip: IpAddr) -> Option<Duration> {
102 self.reader
103 .lookup::<geoip2::Country>(ip)
104 .ok()
105 .flatten()
106 .and_then(|geo| geo.country)
107 .and_then(|country| country.iso_code)
108 .and_then(|code| self.tariffs.get(code.to_uppercase().as_str()))
109 .cloned()
110 }
111}
112
113#[derive(Clone)]
117pub struct TariffLayer {
118 config: Arc<Config>,
119}
120
121impl<S> tower::Layer<S> for TariffLayer {
122 type Service = TariffService<S>;
123
124 fn layer(&self, inner: S) -> Self::Service {
125 TariffService {
126 inner,
127 config: self.config.clone(),
128 }
129 }
130}
131
132#[derive(Clone)]
137pub struct TariffService<S> {
138 inner: S,
139 config: Arc<Config>,
140}
141
142impl<S, B> Service<Request<B>> for TariffService<S>
143where
144 B: Send + 'static,
145 S: Clone,
146 S: Service<Request<B>, Response = Response<B>> + Send + 'static,
147 S::Future: Send + 'static,
148{
149 type Response = S::Response;
150 type Error = S::Error;
151 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
152
153 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154 self.inner.poll_ready(cx)
155 }
156
157 fn call(&mut self, req: Request<B>) -> Self::Future {
158 let mut inner = self.inner.clone();
159 let config = Arc::clone(&self.config);
160 let client_ip = extract_client_ip(&req);
161
162 Box::pin(async move {
163 if let Some(delay) = client_ip.and_then(|ip| config.get_delay_for_ip(ip)) {
164 tokio::time::sleep(delay).await;
165 }
166
167 inner.call(req).await
168 })
169 }
170}
171
172fn extract_client_ip<B>(req: &Request<B>) -> Option<IpAddr> {
176 if let Some(header) = req.headers().get("x-forwarded-for") {
177 if let Ok(ip_str) = header.to_str() {
178 if let Some(ip_str) = ip_str.split(',').next() {
179 return ip_str.trim().parse().ok();
180 }
181 }
182 }
183
184 req.extensions()
185 .get::<axum::extract::connect_info::ConnectInfo<SocketAddr>>()
186 .map(|info| info.0.ip())
187}
188
189#[cfg(test)]
190mod tests {
191 use std::net::SocketAddr;
192 use std::time::Instant;
193
194 use axum::Router;
195 use axum::body::Body;
196 use axum::extract::connect_info::ConnectInfo;
197 use axum::http::Request;
198 use axum::routing::get;
199 use tower::ServiceExt;
200
201 use super::*; const IP_REGION: &str = "GB";
204 const IP_TEST: &str = "2.125.160.218";
205
206 fn test_reader() -> Reader<Vec<u8>> {
207 Reader::open_readfile("assets/GeoLite2-Country-Test.mmdb")
208 .expect("You need the test MaxMind DB at assets/GeoLite2-Country-Test.mmdb")
209 }
210
211 #[tokio::test]
212 async fn test_tariff_config_basic_mapping() {
213 let config = Config::new(test_reader()).with(IP_REGION, Duration::from_millis(1234));
214
215 let ip: IpAddr = IP_TEST.parse().unwrap();
216 let delay = config.get_delay_for_ip(ip);
217
218 assert_eq!(delay, Some(Duration::from_millis(1234)));
219 }
220
221 #[tokio::test]
222 async fn test_middleware_applies_delay() {
223 let layer = Config::new(test_reader())
224 .with(IP_REGION, Duration::from_millis(200))
225 .into_layer();
226
227 let app = Router::new()
228 .route("/", get(|| async { "ok" }))
229 .layer(layer)
230 .with_state(());
231
232 let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
233 let req = Request::builder()
234 .uri("/")
235 .header("x-forwarded-for", IP_TEST) .extension(ConnectInfo(addr))
237 .body(Body::empty())
238 .unwrap();
239
240 let start = Instant::now();
241 let response = app.clone().oneshot(req).await.unwrap();
242 let elapsed = start.elapsed();
243
244 assert_eq!(response.status(), http::StatusCode::OK);
245 assert!(elapsed >= Duration::from_millis(180)); }
247
248 #[tokio::test]
249 async fn test_extract_ip_header_and_fallback() {
250 let req = Request::builder()
252 .header("x-forwarded-for", "8.8.8.8")
253 .body(())
254 .unwrap();
255
256 assert_eq!(
257 extract_client_ip(&req),
258 Some("8.8.8.8".parse::<IpAddr>().unwrap())
259 );
260
261 let mut req = Request::builder().body(()).unwrap();
263 let addr: SocketAddr = "192.168.1.1:1234".parse().unwrap();
264 req.extensions_mut().insert(ConnectInfo(addr));
265
266 let ip = extract_client_ip(&req);
267 assert_eq!(ip, Some(addr.ip()));
268 }
269}