route_ratelimit/
lib.rs

1//! Route-based rate limiting middleware for reqwest.
2//!
3//! This crate provides a [`RateLimitMiddleware`] that can be used with
4//! [`reqwest_middleware`] to enforce rate limits based on endpoint matching.
5//!
6//! # Features
7//!
8//! - **Endpoint matching**: Match requests by host, HTTP method, and path prefix
9//! - **Multiple rate limits**: Stack burst and sustained limits on the same endpoint
10//! - **Configurable behavior**: Choose to delay requests or return errors per endpoint
11//! - **Lock-free performance**: Uses GCRA algorithm with atomic operations
12//! - **Shared state**: Rate limits are tracked across all client clones
13//!
14//! # Route Matching Behavior
15//!
16//! Routes are checked in the order they are defined, and **all matching routes'
17//! limits are applied**. This means you can layer general limits with specific ones:
18//!
19//! ```rust,no_run
20//! use route_ratelimit::RateLimitMiddleware;
21//! use std::time::Duration;
22//!
23//! let middleware = RateLimitMiddleware::builder()
24//!     // General limit: 9000 requests per 10 seconds for all endpoints
25//!     .host("api.example.com", |host| {
26//!         host.route(|r| r.limit(9000, Duration::from_secs(10)))
27//!             // Specific limit: /book endpoints also limited to 1500/10s
28//!             // Both limits are enforced - a request to /book must pass BOTH
29//!             .route(|r| r.path("/book").limit(1500, Duration::from_secs(10)))
30//!     })
31//!     .build();
32//! ```
33//!
34//! # Host Matching
35//!
36//! Host matching uses only the hostname portion of the URL, **excluding the port**.
37//! For example, `host("api.example.com")` will match `https://api.example.com:8443/path`.
38//!
39//! # Path Matching
40//!
41//! Path matching uses **segment boundaries**, not simple prefix matching:
42//! - `/order` matches `/order`, `/order/`, and `/order/123`
43//! - `/order` does **NOT** match `/orders` or `/order-test`
44//!
45//! # Example
46//!
47//! ```rust,no_run
48//! use route_ratelimit::{RateLimitMiddleware, ThrottleBehavior};
49//! use reqwest_middleware::ClientBuilder;
50//! use std::time::Duration;
51//! use http::Method;
52//!
53//! # async fn example() {
54//! let middleware = RateLimitMiddleware::builder()
55//!     // Configure routes by host for clean organization
56//!     .host("clob.polymarket.com", |host| {
57//!         host.route(|r| r.limit(9000, Duration::from_secs(10)))  // General limit
58//!             .route(|r| r.path("/book").limit(1500, Duration::from_secs(10)))
59//!             .route(|r| r.path("/price").limit(1500, Duration::from_secs(10)))
60//!             .route(|r| {
61//!                 r.method(Method::POST)
62//!                     .path("/order")
63//!                     .limit(3500, Duration::from_secs(10))   // Burst
64//!                     .limit(36000, Duration::from_secs(600)) // Sustained
65//!             })
66//!     })
67//!     .build();
68//!
69//! let client = ClientBuilder::new(reqwest::Client::new())
70//!     .with(middleware)
71//!     .build();
72//!
73//! // Requests will be automatically rate-limited
74//! client.get("https://clob.polymarket.com/book").send().await.unwrap();
75//! # }
76//! ```
77
78mod builder;
79mod error;
80mod gcra;
81mod middleware;
82mod types;
83
84// Public re-exports
85pub use builder::{HostBuilder, HostRouteBuilder, RateLimitBuilder, RouteBuilder};
86pub use error::RateLimitError;
87pub use middleware::RateLimitMiddleware;
88pub use types::{RateLimit, Route, ThrottleBehavior};
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use http::Method;
94    use std::time::Duration;
95
96    #[test]
97    fn test_route_matching_all() {
98        let route = Route {
99            host: None,
100            method: None,
101            path_prefix: String::new(),
102            limits: vec![],
103            on_limit: ThrottleBehavior::Delay,
104        };
105
106        let req = reqwest::Client::new()
107            .get("https://example.com/test")
108            .build()
109            .unwrap();
110
111        assert!(route.matches(&req));
112    }
113
114    #[test]
115    fn test_route_matching_host() {
116        let route = Route {
117            host: Some("api.example.com".to_string()),
118            method: None,
119            path_prefix: String::new(),
120            limits: vec![],
121            on_limit: ThrottleBehavior::Delay,
122        };
123
124        let req_match = reqwest::Client::new()
125            .get("https://api.example.com/test")
126            .build()
127            .unwrap();
128        let req_no_match = reqwest::Client::new()
129            .get("https://other.example.com/test")
130            .build()
131            .unwrap();
132
133        assert!(route.matches(&req_match));
134        assert!(!route.matches(&req_no_match));
135    }
136
137    #[test]
138    fn test_route_matching_method() {
139        let route = Route {
140            host: None,
141            method: Some(Method::POST),
142            path_prefix: String::new(),
143            limits: vec![],
144            on_limit: ThrottleBehavior::Delay,
145        };
146
147        let req_match = reqwest::Client::new()
148            .post("https://example.com/test")
149            .build()
150            .unwrap();
151        let req_no_match = reqwest::Client::new()
152            .get("https://example.com/test")
153            .build()
154            .unwrap();
155
156        assert!(route.matches(&req_match));
157        assert!(!route.matches(&req_no_match));
158    }
159
160    #[test]
161    fn test_route_matching_path_prefix() {
162        let route = Route {
163            host: None,
164            method: None,
165            path_prefix: "/api/v1".to_string(),
166            limits: vec![],
167            on_limit: ThrottleBehavior::Delay,
168        };
169
170        let req_match = reqwest::Client::new()
171            .get("https://example.com/api/v1/users")
172            .build()
173            .unwrap();
174        let req_no_match = reqwest::Client::new()
175            .get("https://example.com/api/v2/users")
176            .build()
177            .unwrap();
178
179        assert!(route.matches(&req_match));
180        assert!(!route.matches(&req_no_match));
181    }
182
183    #[test]
184    fn test_route_matching_path_segment_boundary() {
185        let route = Route {
186            host: None,
187            method: None,
188            path_prefix: "/order".to_string(),
189            limits: vec![],
190            on_limit: ThrottleBehavior::Delay,
191        };
192
193        // Should match: exact, with trailing slash, with sub-path
194        let req_exact = reqwest::Client::new()
195            .get("https://example.com/order")
196            .build()
197            .unwrap();
198        let req_trailing = reqwest::Client::new()
199            .get("https://example.com/order/")
200            .build()
201            .unwrap();
202        let req_subpath = reqwest::Client::new()
203            .get("https://example.com/order/123")
204            .build()
205            .unwrap();
206
207        assert!(route.matches(&req_exact), "/order should match /order");
208        assert!(route.matches(&req_trailing), "/order should match /order/");
209        assert!(
210            route.matches(&req_subpath),
211            "/order should match /order/123"
212        );
213
214        // Should NOT match: different path that starts with same chars
215        let req_orders = reqwest::Client::new()
216            .get("https://example.com/orders")
217            .build()
218            .unwrap();
219        let req_order_dash = reqwest::Client::new()
220            .get("https://example.com/order-test")
221            .build()
222            .unwrap();
223
224        assert!(
225            !route.matches(&req_orders),
226            "/order should NOT match /orders"
227        );
228        assert!(
229            !route.matches(&req_order_dash),
230            "/order should NOT match /order-test"
231        );
232    }
233
234    #[test]
235    fn test_emission_interval() {
236        let limit = RateLimit::new(100, Duration::from_secs(10));
237        assert_eq!(limit.emission_interval(), Duration::from_millis(100));
238
239        let limit = RateLimit::new(1000, Duration::from_secs(60));
240        assert_eq!(limit.emission_interval(), Duration::from_millis(60));
241    }
242
243    #[test]
244    #[should_panic(expected = "requests must be greater than 0")]
245    fn test_zero_requests_panics() {
246        RateLimit::new(0, Duration::from_secs(10));
247    }
248
249    #[test]
250    #[should_panic(expected = "window must be greater than 0")]
251    fn test_zero_window_panics() {
252        RateLimit::new(100, Duration::ZERO);
253    }
254
255    #[test]
256    #[should_panic(expected = "window must not exceed u64::MAX nanoseconds")]
257    fn test_overflow_window_panics() {
258        // u64::MAX nanoseconds is ~585 years, so 600 years should overflow
259        RateLimit::new(100, Duration::from_secs(600 * 365 * 24 * 60 * 60));
260    }
261}