use http::Method;
use reqwest::Request;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ThrottleBehavior {
#[default]
Delay,
Error,
}
#[derive(Debug, Clone)]
pub struct RateLimit {
pub requests: u32,
pub window: Duration,
}
impl RateLimit {
pub fn new(requests: u32, window: Duration) -> Self {
assert!(requests > 0, "requests must be greater than 0");
assert!(!window.is_zero(), "window must be greater than 0");
assert!(
window.as_nanos() <= u64::MAX as u128,
"window must not exceed u64::MAX nanoseconds (~585 years)"
);
Self { requests, window }
}
#[inline]
pub(crate) fn emission_interval(&self) -> Duration {
self.window / self.requests
}
}
#[derive(Debug, Clone)]
pub struct Route {
pub host: Option<String>,
pub method: Option<Method>,
pub path_prefix: String,
pub limits: Vec<RateLimit>,
pub on_limit: ThrottleBehavior,
}
impl Route {
#[cfg(feature = "tracing")]
#[inline]
pub(crate) fn is_catch_all(&self) -> bool {
self.host.is_none() && self.method.is_none() && self.path_prefix.is_empty()
}
#[inline]
pub(crate) fn matches(&self, req: &Request) -> bool {
if let Some(ref host) = self.host {
if let Some(req_host) = req.url().host_str() {
if req_host != host {
return false;
}
} else {
return false;
}
}
if let Some(ref method) = self.method {
if req.method() != method {
return false;
}
}
if !self.path_prefix.is_empty() {
let path = req.url().path();
if !path.starts_with(&self.path_prefix) {
return false;
}
let remaining = &path[self.path_prefix.len()..];
if !remaining.is_empty() && !remaining.starts_with('/') {
return false;
}
}
true
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct RouteKey {
pub route_index: usize,
pub limit_index: usize,
}