axum_tariff/
lib.rs

1#![doc = include_str!("../readme.md")]
2
3use std::collections::HashMap;
4use std::net::{IpAddr, SocketAddr};
5use std::sync::Arc;
6use std::task::{Context, Poll};
7use std::time::Duration;
8
9use axum::http::Request;
10use axum::response::Response;
11use futures_util::future::BoxFuture;
12use maxminddb::geoip2;
13use tower::Service;
14
15pub use maxminddb::Reader;
16
17/// Configuration for applying request delays (tariffs) based on IP country.
18///
19/// This struct maps ISO country codes to delay durations,
20/// and uses a MaxMind DB to determine the country for a given IP address.
21#[derive(Debug)]
22pub struct Config {
23    // Mapping of ISO country codes (e.g., "US", "FR") to delay durations
24    tariffs: HashMap<Box<str>, Duration>,
25    // MaxMind database reader used to look up IP address locations
26    reader: Reader<Vec<u8>>,
27}
28
29impl Config {
30    /// Create a new `Config` with an empty tariff map and a provided MaxMind DB reader.
31    ///
32    /// # Arguments
33    ///
34    /// * `reader` - A MaxMind DB reader, e.g., from GeoLite2-Country.mmdb
35    ///
36    /// # Example
37    ///
38    /// ```
39    /// let reader = axum_tariff::Reader::open_readfile("assets/GeoLite2-Country-Test.mmdb").unwrap();
40    /// let config = axum_tariff::Config::new(reader);
41    /// ```
42    pub fn new(reader: Reader<Vec<u8>>) -> Self {
43        Self {
44            tariffs: Default::default(),
45            reader,
46        }
47    }
48
49    /// Add a country code and associated delay to the tariff configuration.
50    ///
51    /// This uses the ISO alpha-2 country code (e.g., "US", "DE", "IN").
52    ///
53    /// # Arguments
54    ///
55    /// * `code` - A 2-letter ISO country code.
56    /// * `delay` - A duration representing how long to delay requests from that country.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// let reader = axum_tariff::Reader::open_readfile("assets/GeoLite2-Country-Test.mmdb").unwrap();
62    /// let config = axum_tariff::Config::new(reader)
63    ///     .with("US", tokio::time::Duration::from_secs(2))  // Delay US traffic by 2 seconds
64    ///     .with("CN", tokio::time::Duration::from_millis(500)); // Delay CN traffic by 500ms
65    /// ```
66    pub fn with(mut self, code: &str, delay: Duration) -> Self {
67        self.tariffs.insert(Box::from(code.to_uppercase()), delay);
68        self
69    }
70
71    /// Convert the configuration into a middleware `TariffLayer`
72    /// that can be applied to an Axum router.
73    ///
74    /// # Example
75    ///
76    /// ```
77    /// let reader = axum_tariff::Reader::open_readfile("assets/GeoLite2-Country-Test.mmdb").unwrap();
78    /// let layer = axum_tariff::Config::new(reader)
79    ///     .with("FR", tokio::time::Duration::from_secs(1))
80    ///     .into_layer();
81    ///
82    /// async fn handler() -> axum::http::StatusCode {
83    ///     axum::http::StatusCode::NO_CONTENT
84    /// }
85    ///
86    /// let app: axum::Router<()> = axum::Router::new()
87    ///     .route("/", axum::routing::get(handler))
88    ///     .layer(layer);
89    /// ```
90    pub fn into_layer(self) -> TariffLayer {
91        TariffLayer {
92            config: Arc::new(self),
93        }
94    }
95
96    /// Get the configured delay duration for a given IP address,
97    /// based on its resolved country code.
98    ///
99    /// Returns `Some(duration)` if the country has a configured tariff,
100    /// otherwise returns `None`.
101    fn get_delay_for_ip(&self, ip: IpAddr) -> Option<Duration> {
102        self.reader
103            .lookup::<geoip2::Country>(ip)
104            .ok()
105            .flatten()
106            .and_then(|geo| geo.country)
107            .and_then(|country| country.iso_code)
108            .and_then(|code| self.tariffs.get(code.to_uppercase().as_str()))
109            .cloned()
110    }
111}
112
113/// A `tower::Layer` that wraps services to apply country-based request delays.
114///
115/// Can be applied to an Axum router using `.layer(...)`.
116#[derive(Clone)]
117pub struct TariffLayer {
118    config: Arc<Config>,
119}
120
121impl<S> tower::Layer<S> for TariffLayer {
122    type Service = TariffService<S>;
123
124    fn layer(&self, inner: S) -> Self::Service {
125        TariffService {
126            inner,
127            config: self.config.clone(),
128        }
129    }
130}
131
132/// A `tower::Service` that introduces delay based on the client IP address's country.
133///
134/// It uses the MaxMind GeoIP database to look up the country, and delays the request
135/// if the country has a configured tariff.
136#[derive(Clone)]
137pub struct TariffService<S> {
138    inner: S,
139    config: Arc<Config>,
140}
141
142impl<S, B> Service<Request<B>> for TariffService<S>
143where
144    B: Send + 'static,
145    S: Clone,
146    S: Service<Request<B>, Response = Response<B>> + Send + 'static,
147    S::Future: Send + 'static,
148{
149    type Response = S::Response;
150    type Error = S::Error;
151    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
152
153    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154        self.inner.poll_ready(cx)
155    }
156
157    fn call(&mut self, req: Request<B>) -> Self::Future {
158        let mut inner = self.inner.clone();
159        let config = Arc::clone(&self.config);
160        let client_ip = extract_client_ip(&req);
161
162        Box::pin(async move {
163            if let Some(delay) = client_ip.and_then(|ip| config.get_delay_for_ip(ip)) {
164                tokio::time::sleep(delay).await;
165            }
166
167            inner.call(req).await
168        })
169    }
170}
171
172/// Extract the client's IP address from headers or socket address.
173///
174/// Tries `X-Forwarded-For` header first, then falls back to `ConnectInfo`.
175fn extract_client_ip<B>(req: &Request<B>) -> Option<IpAddr> {
176    if let Some(header) = req.headers().get("x-forwarded-for") {
177        if let Ok(ip_str) = header.to_str() {
178            if let Some(ip_str) = ip_str.split(',').next() {
179                return ip_str.trim().parse().ok();
180            }
181        }
182    }
183
184    req.extensions()
185        .get::<axum::extract::connect_info::ConnectInfo<SocketAddr>>()
186        .map(|info| info.0.ip())
187}
188
189#[cfg(test)]
190mod tests {
191    use std::net::SocketAddr;
192    use std::time::Instant;
193
194    use axum::Router;
195    use axum::body::Body;
196    use axum::extract::connect_info::ConnectInfo;
197    use axum::http::Request;
198    use axum::routing::get;
199    use tower::ServiceExt;
200
201    use super::*; // for `oneshot`
202
203    const IP_REGION: &str = "GB";
204    const IP_TEST: &str = "2.125.160.218";
205
206    fn test_reader() -> Reader<Vec<u8>> {
207        Reader::open_readfile("assets/GeoLite2-Country-Test.mmdb")
208            .expect("You need the test MaxMind DB at assets/GeoLite2-Country-Test.mmdb")
209    }
210
211    #[tokio::test]
212    async fn test_tariff_config_basic_mapping() {
213        let config = Config::new(test_reader()).with(IP_REGION, Duration::from_millis(1234));
214
215        let ip: IpAddr = IP_TEST.parse().unwrap();
216        let delay = config.get_delay_for_ip(ip);
217
218        assert_eq!(delay, Some(Duration::from_millis(1234)));
219    }
220
221    #[tokio::test]
222    async fn test_middleware_applies_delay() {
223        let layer = Config::new(test_reader())
224            .with(IP_REGION, Duration::from_millis(200))
225            .into_layer();
226
227        let app = Router::new()
228            .route("/", get(|| async { "ok" }))
229            .layer(layer)
230            .with_state(());
231
232        let addr: SocketAddr = "127.0.0.1:12345".parse().unwrap();
233        let req = Request::builder()
234            .uri("/")
235            .header("x-forwarded-for", IP_TEST) // FR IP
236            .extension(ConnectInfo(addr))
237            .body(Body::empty())
238            .unwrap();
239
240        let start = Instant::now();
241        let response = app.clone().oneshot(req).await.unwrap();
242        let elapsed = start.elapsed();
243
244        assert_eq!(response.status(), http::StatusCode::OK);
245        assert!(elapsed >= Duration::from_millis(180)); // Allow for small overhead
246    }
247
248    #[tokio::test]
249    async fn test_extract_ip_header_and_fallback() {
250        // Header parsing
251        let req = Request::builder()
252            .header("x-forwarded-for", "8.8.8.8")
253            .body(())
254            .unwrap();
255
256        assert_eq!(
257            extract_client_ip(&req),
258            Some("8.8.8.8".parse::<IpAddr>().unwrap())
259        );
260
261        // Fallback to ConnectInfo
262        let mut req = Request::builder().body(()).unwrap();
263        let addr: SocketAddr = "192.168.1.1:1234".parse().unwrap();
264        req.extensions_mut().insert(ConnectInfo(addr));
265
266        let ip = extract_client_ip(&req);
267        assert_eq!(ip, Some(addr.ip()));
268    }
269}