use std::collections::HashMap;
use std::sync::Mutex;
use std::time::{Duration, Instant};
struct TokenBucket {
tokens: f64,
capacity: f64,
refill_rate: f64, last_refill: Instant,
}
impl TokenBucket {
fn new(capacity: f64, refill_rate: f64) -> Self {
Self {
tokens: capacity,
capacity,
refill_rate,
last_refill: Instant::now(),
}
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed_secs = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + self.refill_rate * elapsed_secs).min(self.capacity);
self.last_refill = now;
}
fn try_consume(&mut self, n: f64) -> bool {
self.refill();
if self.tokens >= n {
self.tokens -= n;
true
} else {
false
}
}
#[allow(dead_code)]
fn available(&mut self) -> f64 {
self.refill();
self.tokens
}
fn ms_until_available(&self, n: f64) -> u64 {
if self.tokens >= n {
return 0;
}
let deficit = n - self.tokens;
let secs = deficit / self.refill_rate;
(secs * 1000.0).ceil() as u64
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub rps: f64,
pub burst: f64,
pub max_clients: usize,
pub client_ttl: Duration,
pub global_rps: Option<f64>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
rps: 10.0,
burst: 20.0,
max_clients: 10_000,
client_ttl: Duration::from_secs(300),
global_rps: None,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum RateLimitDecision {
Allow,
Deny {
retry_after_ms: u64,
},
}
impl RateLimitDecision {
pub fn is_allowed(&self) -> bool {
matches!(self, RateLimitDecision::Allow)
}
pub fn retry_after_ms(&self) -> Option<u64> {
match self {
RateLimitDecision::Deny { retry_after_ms } => Some(*retry_after_ms),
RateLimitDecision::Allow => None,
}
}
}
pub struct RateLimiter {
config: RateLimitConfig,
clients: Mutex<HashMap<String, (TokenBucket, Instant)>>,
global: Option<Mutex<TokenBucket>>,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
let global = config.global_rps.map(|rps| {
Mutex::new(TokenBucket::new(rps * 2.0, rps))
});
Self {
config,
clients: Mutex::new(HashMap::new()),
global,
}
}
pub fn check(&self, client_id: &str) -> RateLimitDecision {
if let Some(ref global_mutex) = self.global {
let global = global_mutex
.lock()
.expect("global rate limiter mutex poisoned");
if global.tokens < 1.0 {
let retry_ms = global.ms_until_available(1.0);
return RateLimitDecision::Deny {
retry_after_ms: retry_ms.max(1),
};
}
}
let mut clients = self
.clients
.lock()
.expect("client rate limiter mutex poisoned");
if let Some((bucket, _last_seen)) = clients.get_mut(client_id) {
bucket.refill();
if bucket.tokens < 1.0 {
let retry_ms = bucket.ms_until_available(1.0);
return RateLimitDecision::Deny {
retry_after_ms: retry_ms.max(1),
};
}
}
RateLimitDecision::Allow
}
pub fn check_and_consume(&self, client_id: &str) -> RateLimitDecision {
if let Some(ref global_mutex) = self.global {
let mut global = global_mutex
.lock()
.expect("global rate limiter mutex poisoned");
if !global.try_consume(1.0) {
let retry_ms = global.ms_until_available(1.0);
return RateLimitDecision::Deny {
retry_after_ms: retry_ms.max(1),
};
}
}
let mut clients = self
.clients
.lock()
.expect("client rate limiter mutex poisoned");
if clients.len() >= self.config.max_clients {
let ttl = self.config.client_ttl;
let now = Instant::now();
clients.retain(|_, (_, last_seen)| now.duration_since(*last_seen) < ttl);
}
let bucket = clients.entry(client_id.to_owned()).or_insert_with(|| {
(
TokenBucket::new(self.config.burst, self.config.rps),
Instant::now(),
)
});
let (token_bucket, last_seen) = bucket;
*last_seen = Instant::now();
if token_bucket.try_consume(1.0) {
RateLimitDecision::Allow
} else {
let retry_ms = token_bucket.ms_until_available(1.0);
RateLimitDecision::Deny {
retry_after_ms: retry_ms.max(1),
}
}
}
pub fn evict_stale(&self) {
let ttl = self.config.client_ttl;
let now = Instant::now();
let mut clients = self
.clients
.lock()
.expect("client rate limiter mutex poisoned");
clients.retain(|_, (_, last_seen)| now.duration_since(*last_seen) < ttl);
}
pub fn active_clients(&self) -> usize {
self.clients
.lock()
.expect("client rate limiter mutex poisoned")
.len()
}
pub fn reset_client(&self, client_id: &str) {
self.clients
.lock()
.expect("client rate limiter mutex poisoned")
.remove(client_id);
}
pub fn is_global_limited(&self) -> bool {
match &self.global {
None => false,
Some(global_mutex) => {
let global = global_mutex
.lock()
.expect("global rate limiter mutex poisoned");
global.tokens < 1.0
}
}
}
}
pub fn rate_limit_middleware(
limiter: std::sync::Arc<RateLimiter>,
client_id: &str,
) -> RateLimitDecision {
limiter.check_and_consume(client_id)
}
#[cfg(feature = "server")]
pub fn extract_client_id(headers: &axum::http::HeaderMap) -> String {
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(val) = xff.to_str() {
let first = val.split(',').next().unwrap_or("").trim();
if !first.is_empty() {
return first.to_owned();
}
}
}
if let Some(real_ip) = headers.get("x-real-ip") {
if let Ok(val) = real_ip.to_str() {
let trimmed = val.trim();
if !trimmed.is_empty() {
return trimmed.to_owned();
}
}
}
"unknown".to_owned()
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_token_bucket_initial_full() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert!((bucket.available() - 10.0).abs() < 1e-6);
}
#[test]
fn test_token_bucket_consume_success() {
let mut bucket = TokenBucket::new(10.0, 1.0);
assert!(bucket.try_consume(5.0));
let remaining = bucket.available();
assert!((4.9..=5.1).contains(&remaining), "remaining={remaining}");
}
#[test]
fn test_token_bucket_consume_fail_insufficient() {
let mut bucket = TokenBucket::new(3.0, 0.01); assert!(bucket.try_consume(3.0)); assert!(!bucket.try_consume(1.0)); }
#[test]
fn test_token_bucket_refills_over_time() {
let mut bucket = TokenBucket::new(10.0, 1000.0); assert!(bucket.try_consume(10.0)); thread::sleep(Duration::from_millis(20));
let available = bucket.available();
assert!(
available > 1.0,
"bucket should have refilled; got {available}"
);
}
#[test]
fn test_rate_limiter_allows_first_request() {
let config = RateLimitConfig {
rps: 10.0,
burst: 10.0,
..Default::default()
};
let limiter = RateLimiter::new(config);
let decision = limiter.check_and_consume("client-1");
assert_eq!(decision, RateLimitDecision::Allow);
}
#[test]
fn test_rate_limiter_denies_after_burst() {
let config = RateLimitConfig {
rps: 1.0,
burst: 3.0, ..Default::default()
};
let limiter = RateLimiter::new(config);
for i in 0..3 {
let d = limiter.check_and_consume("client-burst");
assert_eq!(d, RateLimitDecision::Allow, "request {i} should be allowed");
}
let denied = limiter.check_and_consume("client-burst");
assert!(
denied.retry_after_ms().is_some(),
"4th request should be denied"
);
}
#[test]
fn test_rate_limiter_different_clients_independent() {
let config = RateLimitConfig {
rps: 1.0,
burst: 1.0,
..Default::default()
};
let limiter = RateLimiter::new(config);
assert_eq!(
limiter.check_and_consume("client-a"),
RateLimitDecision::Allow
);
let denied = limiter.check_and_consume("client-a");
assert!(!denied.is_allowed());
assert_eq!(
limiter.check_and_consume("client-b"),
RateLimitDecision::Allow
);
}
#[test]
fn test_rate_limit_decision_is_allowed() {
assert!(RateLimitDecision::Allow.is_allowed());
assert_eq!(RateLimitDecision::Allow.retry_after_ms(), None);
let denied = RateLimitDecision::Deny {
retry_after_ms: 500,
};
assert!(!denied.is_allowed());
assert_eq!(denied.retry_after_ms(), Some(500));
}
#[test]
fn test_extract_client_id_x_forwarded_for() {
use axum::http::HeaderMap;
use axum::http::HeaderValue;
let mut headers = HeaderMap::new();
headers.insert(
"x-forwarded-for",
HeaderValue::from_static("203.0.113.42, 10.0.0.1"),
);
let id = extract_client_id(&headers);
assert_eq!(id, "203.0.113.42");
}
#[test]
fn test_extract_client_id_fallback() {
use axum::http::HeaderMap;
let headers = HeaderMap::new();
let id = extract_client_id(&headers);
assert_eq!(id, "unknown");
}
#[test]
fn test_rate_limiter_active_clients_tracked() {
let limiter = RateLimiter::new(RateLimitConfig::default());
limiter.check_and_consume("alpha");
limiter.check_and_consume("beta");
assert_eq!(limiter.active_clients(), 2);
limiter.reset_client("alpha");
assert_eq!(limiter.active_clients(), 1);
}
#[test]
fn test_rate_limiter_no_global_limit_by_default() {
let limiter = RateLimiter::new(RateLimitConfig::default());
assert!(!limiter.is_global_limited());
}
}