use dashmap::DashMap;
use http::Method;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::middleware::RateLimitMiddleware;
use crate::types::{RateLimit, Route, ThrottleBehavior};
#[derive(Debug, Default, Clone)]
pub struct RateLimitBuilder {
pub(crate) routes: Vec<Route>,
}
impl RateLimitBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn route<F>(mut self, configure: F) -> Self
where
F: FnOnce(RouteBuilder) -> RouteBuilder,
{
let builder = RouteBuilder::new();
let configured = configure(builder);
self.routes.push(configured.into_route());
self
}
#[must_use]
pub fn host<F>(mut self, host: impl Into<String>, configure: F) -> Self
where
F: FnOnce(HostBuilder) -> HostBuilder,
{
let host_str = host.into();
let host_builder = HostBuilder::new(host_str);
let configured = configure(host_builder);
self.routes.extend(configured.routes);
self
}
#[must_use]
pub fn add_route(mut self, route: Route) -> Self {
self.routes.push(route);
self
}
#[must_use]
pub fn build(self) -> RateLimitMiddleware {
#[cfg(feature = "tracing")]
self.warn_catch_all_route_order();
RateLimitMiddleware {
routes: Arc::new(self.routes),
state: Arc::new(DashMap::new()),
start_instant: Instant::now(),
}
}
#[cfg(feature = "tracing")]
fn warn_catch_all_route_order(&self) {
let catch_all_indices: Vec<usize> = self
.routes
.iter()
.enumerate()
.filter(|(_, route)| route.is_catch_all())
.map(|(i, _)| i)
.collect();
for &catch_all_index in &catch_all_indices {
if let Some((specific_index, _)) = self
.routes
.iter()
.enumerate()
.skip(catch_all_index + 1)
.find(|(_, route)| !route.is_catch_all())
{
tracing::warn!(
catch_all_route_index = catch_all_index,
specific_route_index = specific_index,
"Catch-all route (index {}) precedes more specific route (index {}). \
All matching routes' limits are applied, so the catch-all will affect \
requests intended for the specific route. Consider reordering routes \
or using host-scoped builders.",
catch_all_index,
specific_index
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct HostBuilder {
host: String,
routes: Vec<Route>,
}
impl HostBuilder {
fn new(host: String) -> Self {
Self {
host,
routes: Vec::new(),
}
}
#[must_use]
pub fn route<F>(mut self, configure: F) -> Self
where
F: FnOnce(HostRouteBuilder) -> HostRouteBuilder,
{
let builder = HostRouteBuilder::new();
let configured = configure(builder);
assert!(
!configured.limits.is_empty(),
"route must have at least one limit configured via .limit()"
);
let route = Route {
host: Some(self.host.clone()),
method: configured.method,
path_prefix: configured.path_prefix,
limits: configured.limits,
on_limit: configured.on_limit,
};
self.routes.push(route);
self
}
}
#[derive(Debug, Default, Clone)]
pub struct HostRouteBuilder {
method: Option<Method>,
path_prefix: String,
limits: Vec<RateLimit>,
on_limit: ThrottleBehavior,
}
impl HostRouteBuilder {
fn new() -> Self {
Self::default()
}
#[must_use]
pub fn method(mut self, method: Method) -> Self {
self.method = Some(method);
self
}
#[must_use]
pub fn path(mut self, path_prefix: impl Into<String>) -> Self {
self.path_prefix = path_prefix.into();
self
}
#[must_use]
pub fn limit(mut self, requests: u32, window: Duration) -> Self {
self.limits.push(RateLimit::new(requests, window));
self
}
#[must_use]
pub fn on_limit(mut self, behavior: ThrottleBehavior) -> Self {
self.on_limit = behavior;
self
}
}
#[derive(Debug, Default, Clone)]
pub struct RouteBuilder {
host: Option<String>,
method: Option<Method>,
path_prefix: String,
limits: Vec<RateLimit>,
on_limit: ThrottleBehavior,
}
impl RouteBuilder {
fn new() -> Self {
Self::default()
}
fn into_route(self) -> Route {
assert!(
!self.limits.is_empty(),
"route must have at least one limit configured via .limit()"
);
Route {
host: self.host,
method: self.method,
path_prefix: self.path_prefix,
limits: self.limits,
on_limit: self.on_limit,
}
}
#[must_use]
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = Some(host.into());
self
}
#[must_use]
pub fn method(mut self, method: Method) -> Self {
self.method = Some(method);
self
}
#[must_use]
pub fn path(mut self, path_prefix: impl Into<String>) -> Self {
self.path_prefix = path_prefix.into();
self
}
#[must_use]
pub fn limit(mut self, requests: u32, window: Duration) -> Self {
self.limits.push(RateLimit::new(requests, window));
self
}
#[must_use]
pub fn on_limit(mut self, behavior: ThrottleBehavior) -> Self {
self.on_limit = behavior;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_api() {
let middleware = RateLimitMiddleware::builder()
.route(|r| {
r.host("api.example.com")
.method(Method::POST)
.path("/order")
.limit(100, Duration::from_secs(10))
.limit(1000, Duration::from_secs(60))
.on_limit(ThrottleBehavior::Delay)
})
.route(|r| {
r.path("/data")
.limit(50, Duration::from_secs(10))
.on_limit(ThrottleBehavior::Error)
})
.build();
assert_eq!(middleware.routes.len(), 2);
assert_eq!(middleware.routes[0].limits.len(), 2);
assert_eq!(middleware.routes[1].limits.len(), 1);
}
#[test]
fn test_host_scoped_builder() {
let middleware = RateLimitMiddleware::builder()
.host("clob.polymarket.com", |host| {
host.route(|r| r.limit(9000, Duration::from_secs(10)))
.route(|r| r.path("/book").limit(1500, Duration::from_secs(10)))
.route(|r| r.path("/price").limit(1500, Duration::from_secs(10)))
.route(|r| {
r.method(Method::POST)
.path("/order")
.limit(3500, Duration::from_secs(10))
.limit(36000, Duration::from_secs(600))
.on_limit(ThrottleBehavior::Delay)
})
})
.host("data-api.polymarket.com", |host| {
host.route(|r| r.limit(1000, Duration::from_secs(10)))
.route(|r| r.path("/trades").limit(200, Duration::from_secs(10)))
})
.build();
assert_eq!(middleware.routes.len(), 6);
for i in 0..4 {
assert_eq!(
middleware.routes[i].host.as_deref(),
Some("clob.polymarket.com")
);
}
for i in 4..6 {
assert_eq!(
middleware.routes[i].host.as_deref(),
Some("data-api.polymarket.com")
);
}
assert_eq!(middleware.routes[3].path_prefix, "/order");
assert_eq!(middleware.routes[3].method, Some(Method::POST));
assert_eq!(middleware.routes[3].limits.len(), 2);
}
#[test]
fn test_mixed_builder_styles() {
let middleware = RateLimitMiddleware::builder()
.route(|r| r.limit(15000, Duration::from_secs(10)))
.host("api.example.com", |host| {
host.route(|r| r.path("/data").limit(100, Duration::from_secs(10)))
})
.build();
assert_eq!(middleware.routes.len(), 2);
assert!(middleware.routes[0].host.is_none()); assert_eq!(
middleware.routes[1].host.as_deref(),
Some("api.example.com")
);
}
#[test]
fn test_single_line_routes() {
let middleware = RateLimitMiddleware::builder()
.host("api.example.com", |host| {
host.route(|r| r.path("/a").limit(100, Duration::from_secs(10)))
.route(|r| r.path("/b").limit(200, Duration::from_secs(10)))
.route(|r| r.path("/c").limit(300, Duration::from_secs(10)))
})
.build();
assert_eq!(middleware.routes.len(), 3);
}
#[test]
#[should_panic(expected = "route must have at least one limit")]
fn test_route_without_limit_panics() {
let _middleware = RateLimitMiddleware::builder()
.route(|r| r.path("/test"))
.build();
}
#[test]
#[should_panic(expected = "route must have at least one limit")]
fn test_host_route_without_limit_panics() {
let _middleware = RateLimitMiddleware::builder()
.host("api.example.com", |host| host.route(|r| r.path("/test")))
.build();
}
}