use std::collections::HashMap;
use std::sync::Arc;
use crate::algorithm::Algorithm;
use crate::decision::Decision;
use crate::error::Result;
use crate::key::Key;
use crate::quota::Quota;
use crate::storage::Storage;
#[derive(Debug, Clone)]
pub struct RouteConfig {
pub quota: Quota,
pub key_suffix: Option<String>,
}
impl RouteConfig {
pub fn new(quota: Quota) -> Self {
Self {
quota,
key_suffix: None,
}
}
pub fn with_key_suffix(mut self, suffix: impl Into<String>) -> Self {
self.key_suffix = Some(suffix.into());
self
}
}
impl From<Quota> for RouteConfig {
fn from(quota: Quota) -> Self {
Self::new(quota)
}
}
pub struct RateLimitManager<A, S, K> {
algorithm: A,
storage: Arc<S>,
key_extractor: K,
default_quota: Option<Quota>,
routes: HashMap<String, RouteConfig>,
patterns: Vec<(String, RouteConfig)>,
}
impl<A, S, K> RateLimitManager<A, S, K>
where
A: Algorithm,
S: Storage,
{
pub fn builder() -> RateLimitManagerBuilder<K> {
RateLimitManagerBuilder::new()
}
pub async fn check_and_record<R>(&self, path: &str, request: &R) -> Result<Decision>
where
K: Key<R>,
{
let config = self.get_config(path);
let Some(quota) = config.map(|c| &c.quota).or(self.default_quota.as_ref()) else {
return Ok(Decision::allowed(crate::decision::RateLimitInfo::new(
u64::MAX,
u64::MAX,
std::time::Instant::now() + std::time::Duration::from_secs(3600),
std::time::Instant::now(),
)));
};
let base_key = self.key_extractor.extract(request).unwrap_or_else(|| "unknown".to_string());
let key = if let Some(suffix) = config.and_then(|c| c.key_suffix.as_ref()) {
format!("{}:{}", base_key, suffix)
} else {
format!("{}:{}", base_key, path)
};
self.algorithm
.check_and_record(&*self.storage, &key, quota)
.await
}
pub async fn check<R>(&self, path: &str, request: &R) -> Result<Decision>
where
K: Key<R>,
{
let config = self.get_config(path);
let Some(quota) = config.map(|c| &c.quota).or(self.default_quota.as_ref()) else {
return Ok(Decision::allowed(crate::decision::RateLimitInfo::new(
u64::MAX,
u64::MAX,
std::time::Instant::now() + std::time::Duration::from_secs(3600),
std::time::Instant::now(),
)));
};
let base_key = self.key_extractor.extract(request).unwrap_or_else(|| "unknown".to_string());
let key = if let Some(suffix) = config.and_then(|c| c.key_suffix.as_ref()) {
format!("{}:{}", base_key, suffix)
} else {
format!("{}:{}", base_key, path)
};
self.algorithm.check(&*self.storage, &key, quota).await
}
fn get_config(&self, path: &str) -> Option<&RouteConfig> {
if let Some(config) = self.routes.get(path) {
return Some(config);
}
for (pattern, config) in &self.patterns {
if pattern_matches(pattern, path) {
return Some(config);
}
}
None
}
pub async fn reset(&self, key: &str) -> Result<()> {
self.algorithm.reset(&*self.storage, key).await
}
}
fn pattern_matches(pattern: &str, path: &str) -> bool {
let pattern_parts: Vec<&str> = pattern.split('/').filter(|s| !s.is_empty()).collect();
let path_parts: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
let mut pi = 0; let mut pa = 0;
while pi < pattern_parts.len() && pa < path_parts.len() {
let p = pattern_parts[pi];
if p == "**" {
return true;
} else if p == "*" {
pi += 1;
pa += 1;
} else if p == path_parts[pa] {
pi += 1;
pa += 1;
} else {
return false;
}
}
pi == pattern_parts.len() && pa == path_parts.len()
}
pub struct RateLimitManagerBuilder<K> {
default_quota: Option<Quota>,
routes: HashMap<String, RouteConfig>,
patterns: Vec<(String, RouteConfig)>,
key_extractor: Option<K>,
}
impl<K> Default for RateLimitManagerBuilder<K> {
fn default() -> Self {
Self::new()
}
}
impl<K> RateLimitManagerBuilder<K> {
pub fn new() -> Self {
Self {
default_quota: None,
routes: HashMap::new(),
patterns: Vec::new(),
key_extractor: None,
}
}
pub fn default_quota(mut self, quota: Quota) -> Self {
self.default_quota = Some(quota);
self
}
pub fn route(mut self, path: impl Into<String>, config: impl Into<RouteConfig>) -> Self {
self.routes.insert(path.into(), config.into());
self
}
pub fn route_pattern(
mut self,
pattern: impl Into<String>,
config: impl Into<RouteConfig>,
) -> Self {
self.patterns.push((pattern.into(), config.into()));
self
}
pub fn key_extractor(mut self, extractor: K) -> Self {
self.key_extractor = Some(extractor);
self
}
pub fn build<A, S>(self, algorithm: A, storage: S) -> RateLimitManager<A, S, K>
where
K: Default,
{
RateLimitManager {
algorithm,
storage: Arc::new(storage),
key_extractor: self.key_extractor.unwrap_or_default(),
default_quota: self.default_quota,
routes: self.routes,
patterns: self.patterns,
}
}
pub fn build_with_key<A, S>(
self,
algorithm: A,
storage: S,
key_extractor: K,
) -> RateLimitManager<A, S, K> {
RateLimitManager {
algorithm,
storage: Arc::new(storage),
key_extractor,
default_quota: self.default_quota,
routes: self.routes,
patterns: self.patterns,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_matches_exact() {
assert!(pattern_matches("/api/users", "/api/users"));
assert!(!pattern_matches("/api/users", "/api/posts"));
}
#[test]
fn test_pattern_matches_single_wildcard() {
assert!(pattern_matches("/api/*/posts", "/api/users/posts"));
assert!(pattern_matches("/api/*/posts", "/api/admins/posts"));
assert!(!pattern_matches("/api/*/posts", "/api/users/comments"));
}
#[test]
fn test_pattern_matches_double_wildcard() {
assert!(pattern_matches("/api/**", "/api/users"));
assert!(pattern_matches("/api/**", "/api/users/123/posts"));
assert!(!pattern_matches("/api/**", "/v2/api/users"));
}
#[test]
fn test_route_config_from_quota() {
let config: RouteConfig = Quota::per_minute(60).into();
assert_eq!(config.quota.max_requests(), 60);
assert!(config.key_suffix.is_none());
}
}