1use dashmap::DashMap;
4use http::Method;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use crate::middleware::RateLimitMiddleware;
9use crate::types::{RateLimit, Route, ThrottleBehavior};
10
11#[derive(Debug, Default, Clone)]
13pub struct RateLimitBuilder {
14 pub(crate) routes: Vec<Route>,
15}
16
17impl RateLimitBuilder {
18 #[must_use]
20 pub fn new() -> Self {
21 Self::default()
22 }
23
24 #[must_use]
42 pub fn route<F>(mut self, configure: F) -> Self
43 where
44 F: FnOnce(RouteBuilder) -> RouteBuilder,
45 {
46 let builder = RouteBuilder::new();
47 let configured = configure(builder);
48 self.routes.push(configured.into_route());
49 self
50 }
51
52 #[must_use]
85 pub fn host<F>(mut self, host: impl Into<String>, configure: F) -> Self
86 where
87 F: FnOnce(HostBuilder) -> HostBuilder,
88 {
89 let host_str = host.into();
90 let host_builder = HostBuilder::new(host_str);
91 let configured = configure(host_builder);
92 self.routes.extend(configured.routes);
93 self
94 }
95
96 #[must_use]
98 pub fn add_route(mut self, route: Route) -> Self {
99 self.routes.push(route);
100 self
101 }
102
103 #[must_use]
112 pub fn build(self) -> RateLimitMiddleware {
113 #[cfg(feature = "tracing")]
114 self.warn_catch_all_route_order();
115
116 RateLimitMiddleware {
117 routes: Arc::new(self.routes),
118 state: Arc::new(DashMap::new()),
119 start_instant: Instant::now(),
120 }
121 }
122
123 #[cfg(feature = "tracing")]
125 fn warn_catch_all_route_order(&self) {
126 let catch_all_indices: Vec<usize> = self
128 .routes
129 .iter()
130 .enumerate()
131 .filter(|(_, route)| route.is_catch_all())
132 .map(|(i, _)| i)
133 .collect();
134
135 for &catch_all_index in &catch_all_indices {
137 if let Some((specific_index, _)) = self
139 .routes
140 .iter()
141 .enumerate()
142 .skip(catch_all_index + 1)
143 .find(|(_, route)| !route.is_catch_all())
144 {
145 tracing::warn!(
146 catch_all_route_index = catch_all_index,
147 specific_route_index = specific_index,
148 "Catch-all route (index {}) precedes more specific route (index {}). \
149 All matching routes' limits are applied, so the catch-all will affect \
150 requests intended for the specific route. Consider reordering routes \
151 or using host-scoped builders.",
152 catch_all_index,
153 specific_index
154 );
155 }
156 }
157 }
158}
159
160#[derive(Debug, Clone)]
165pub struct HostBuilder {
166 host: String,
167 routes: Vec<Route>,
168}
169
170impl HostBuilder {
171 fn new(host: String) -> Self {
172 Self {
173 host,
174 routes: Vec::new(),
175 }
176 }
177
178 #[must_use]
186 pub fn route<F>(mut self, configure: F) -> Self
187 where
188 F: FnOnce(HostRouteBuilder) -> HostRouteBuilder,
189 {
190 let builder = HostRouteBuilder::new();
191 let configured = configure(builder);
192 assert!(
193 !configured.limits.is_empty(),
194 "route must have at least one limit configured via .limit()"
195 );
196 let route = Route {
197 host: Some(self.host.clone()),
198 method: configured.method,
199 path_prefix: configured.path_prefix,
200 limits: configured.limits,
201 on_limit: configured.on_limit,
202 };
203 self.routes.push(route);
204 self
205 }
206}
207
208#[derive(Debug, Default, Clone)]
213pub struct HostRouteBuilder {
214 method: Option<Method>,
215 path_prefix: String,
216 limits: Vec<RateLimit>,
217 on_limit: ThrottleBehavior,
218}
219
220impl HostRouteBuilder {
221 fn new() -> Self {
222 Self::default()
223 }
224
225 #[must_use]
227 pub fn method(mut self, method: Method) -> Self {
228 self.method = Some(method);
229 self
230 }
231
232 #[must_use]
234 pub fn path(mut self, path_prefix: impl Into<String>) -> Self {
235 self.path_prefix = path_prefix.into();
236 self
237 }
238
239 #[must_use]
241 pub fn limit(mut self, requests: u32, window: Duration) -> Self {
242 self.limits.push(RateLimit::new(requests, window));
243 self
244 }
245
246 #[must_use]
248 pub fn on_limit(mut self, behavior: ThrottleBehavior) -> Self {
249 self.on_limit = behavior;
250 self
251 }
252}
253
254#[derive(Debug, Default, Clone)]
259pub struct RouteBuilder {
260 host: Option<String>,
261 method: Option<Method>,
262 path_prefix: String,
263 limits: Vec<RateLimit>,
264 on_limit: ThrottleBehavior,
265}
266
267impl RouteBuilder {
268 fn new() -> Self {
269 Self::default()
270 }
271
272 fn into_route(self) -> Route {
273 assert!(
274 !self.limits.is_empty(),
275 "route must have at least one limit configured via .limit()"
276 );
277 Route {
278 host: self.host,
279 method: self.method,
280 path_prefix: self.path_prefix,
281 limits: self.limits,
282 on_limit: self.on_limit,
283 }
284 }
285
286 #[must_use]
291 pub fn host(mut self, host: impl Into<String>) -> Self {
292 self.host = Some(host.into());
293 self
294 }
295
296 #[must_use]
298 pub fn method(mut self, method: Method) -> Self {
299 self.method = Some(method);
300 self
301 }
302
303 #[must_use]
305 pub fn path(mut self, path_prefix: impl Into<String>) -> Self {
306 self.path_prefix = path_prefix.into();
307 self
308 }
309
310 #[must_use]
312 pub fn limit(mut self, requests: u32, window: Duration) -> Self {
313 self.limits.push(RateLimit::new(requests, window));
314 self
315 }
316
317 #[must_use]
319 pub fn on_limit(mut self, behavior: ThrottleBehavior) -> Self {
320 self.on_limit = behavior;
321 self
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_builder_api() {
331 let middleware = RateLimitMiddleware::builder()
332 .route(|r| {
333 r.host("api.example.com")
334 .method(Method::POST)
335 .path("/order")
336 .limit(100, Duration::from_secs(10))
337 .limit(1000, Duration::from_secs(60))
338 .on_limit(ThrottleBehavior::Delay)
339 })
340 .route(|r| {
341 r.path("/data")
342 .limit(50, Duration::from_secs(10))
343 .on_limit(ThrottleBehavior::Error)
344 })
345 .build();
346
347 assert_eq!(middleware.routes.len(), 2);
348 assert_eq!(middleware.routes[0].limits.len(), 2);
349 assert_eq!(middleware.routes[1].limits.len(), 1);
350 }
351
352 #[test]
353 fn test_host_scoped_builder() {
354 let middleware = RateLimitMiddleware::builder()
355 .host("clob.polymarket.com", |host| {
356 host.route(|r| r.limit(9000, Duration::from_secs(10)))
357 .route(|r| r.path("/book").limit(1500, Duration::from_secs(10)))
358 .route(|r| r.path("/price").limit(1500, Duration::from_secs(10)))
359 .route(|r| {
360 r.method(Method::POST)
361 .path("/order")
362 .limit(3500, Duration::from_secs(10))
363 .limit(36000, Duration::from_secs(600))
364 .on_limit(ThrottleBehavior::Delay)
365 })
366 })
367 .host("data-api.polymarket.com", |host| {
368 host.route(|r| r.limit(1000, Duration::from_secs(10)))
369 .route(|r| r.path("/trades").limit(200, Duration::from_secs(10)))
370 })
371 .build();
372
373 assert_eq!(middleware.routes.len(), 6);
375
376 for i in 0..4 {
378 assert_eq!(
379 middleware.routes[i].host.as_deref(),
380 Some("clob.polymarket.com")
381 );
382 }
383
384 for i in 4..6 {
386 assert_eq!(
387 middleware.routes[i].host.as_deref(),
388 Some("data-api.polymarket.com")
389 );
390 }
391
392 assert_eq!(middleware.routes[3].path_prefix, "/order");
394 assert_eq!(middleware.routes[3].method, Some(Method::POST));
395 assert_eq!(middleware.routes[3].limits.len(), 2);
396 }
397
398 #[test]
399 fn test_mixed_builder_styles() {
400 let middleware = RateLimitMiddleware::builder()
402 .route(|r| r.limit(15000, Duration::from_secs(10)))
404 .host("api.example.com", |host| {
406 host.route(|r| r.path("/data").limit(100, Duration::from_secs(10)))
407 })
408 .build();
409
410 assert_eq!(middleware.routes.len(), 2);
411 assert!(middleware.routes[0].host.is_none()); assert_eq!(
413 middleware.routes[1].host.as_deref(),
414 Some("api.example.com")
415 );
416 }
417
418 #[test]
419 fn test_single_line_routes() {
420 let middleware = RateLimitMiddleware::builder()
422 .host("api.example.com", |host| {
423 host.route(|r| r.path("/a").limit(100, Duration::from_secs(10)))
424 .route(|r| r.path("/b").limit(200, Duration::from_secs(10)))
425 .route(|r| r.path("/c").limit(300, Duration::from_secs(10)))
426 })
427 .build();
428
429 assert_eq!(middleware.routes.len(), 3);
430 }
431
432 #[test]
433 #[should_panic(expected = "route must have at least one limit")]
434 fn test_route_without_limit_panics() {
435 let _middleware = RateLimitMiddleware::builder()
436 .route(|r| r.path("/test"))
437 .build();
438 }
439
440 #[test]
441 #[should_panic(expected = "route must have at least one limit")]
442 fn test_host_route_without_limit_panics() {
443 let _middleware = RateLimitMiddleware::builder()
444 .host("api.example.com", |host| host.route(|r| r.path("/test")))
445 .build();
446 }
447}