use crate::relay::{RelayError, RelayResult};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
pub trait RateLimiter: Send + Sync {
fn check_rate_limit(&self, addr: &SocketAddr) -> RelayResult<()>;
fn reset(&self, addr: &SocketAddr);
fn cleanup_expired(&self);
}
#[derive(Debug)]
pub struct TokenBucket {
tokens_per_second: u32,
max_tokens: u32,
buckets: Arc<Mutex<HashMap<SocketAddr, BucketState>>>,
}
#[derive(Debug, Clone)]
struct BucketState {
tokens: f64,
last_update: Instant,
}
impl TokenBucket {
pub fn new(tokens_per_second: u32, max_tokens: u32) -> RelayResult<Self> {
if tokens_per_second == 0 {
return Err(RelayError::ConfigurationError {
parameter: "tokens_per_second".to_string(),
reason: "must be greater than 0".to_string(),
});
}
if max_tokens == 0 {
return Err(RelayError::ConfigurationError {
parameter: "max_tokens".to_string(),
reason: "must be greater than 0".to_string(),
});
}
Ok(Self {
tokens_per_second,
max_tokens,
buckets: Arc::new(Mutex::new(HashMap::new())),
})
}
#[allow(clippy::unwrap_used)]
fn try_consume_token(&self, addr: &SocketAddr) -> RelayResult<()> {
let mut buckets = self.buckets.lock().unwrap();
let now = Instant::now();
let state = buckets.entry(*addr).or_insert(BucketState {
tokens: self.max_tokens as f64,
last_update: now,
});
let elapsed_seconds = now.duration_since(state.last_update).as_secs_f64();
state.tokens = (state.tokens + elapsed_seconds * self.tokens_per_second as f64)
.min(self.max_tokens as f64);
state.last_update = now;
if state.tokens >= 1.0 {
state.tokens -= 1.0;
Ok(())
} else {
let tokens_needed = 1.0 - state.tokens;
let retry_after_ms = (tokens_needed / self.tokens_per_second as f64 * 1000.0) as u64;
Err(RelayError::RateLimitExceeded { retry_after_ms })
}
}
}
impl RateLimiter for TokenBucket {
fn check_rate_limit(&self, addr: &SocketAddr) -> RelayResult<()> {
self.try_consume_token(addr)
}
#[allow(clippy::unwrap_used)]
fn reset(&self, addr: &SocketAddr) {
let mut buckets = self.buckets.lock().unwrap();
buckets.remove(addr);
}
#[allow(clippy::unwrap_used)]
fn cleanup_expired(&self) {
let mut buckets = self.buckets.lock().unwrap();
let now = Instant::now();
let cleanup_threshold = Duration::from_secs(300);
buckets.retain(|_, state| now.duration_since(state.last_update) < cleanup_threshold);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
use std::thread;
use std::time::Duration;
fn test_addr() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
}
#[test]
fn test_token_bucket_creation() {
let bucket = TokenBucket::new(10, 100).unwrap();
assert_eq!(bucket.tokens_per_second, 10);
assert_eq!(bucket.max_tokens, 100);
}
#[test]
fn test_token_bucket_invalid_config() {
assert!(TokenBucket::new(0, 100).is_err());
assert!(TokenBucket::new(10, 0).is_err());
}
#[test]
fn test_rate_limiting_allows_initial_requests() {
let bucket = TokenBucket::new(10, 100).unwrap();
let addr = test_addr();
for _ in 0..100 {
assert!(bucket.check_rate_limit(&addr).is_ok());
}
assert!(bucket.check_rate_limit(&addr).is_err());
}
#[test]
fn test_token_replenishment() {
let bucket = TokenBucket::new(10, 10).unwrap();
let addr = test_addr();
for _ in 0..10 {
assert!(bucket.check_rate_limit(&addr).is_ok());
}
assert!(bucket.check_rate_limit(&addr).is_err());
thread::sleep(Duration::from_millis(100));
assert!(bucket.check_rate_limit(&addr).is_ok());
}
#[test]
fn test_per_address_isolation() {
let bucket = TokenBucket::new(1, 1).unwrap();
let addr1 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
let addr2 = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)), 8080);
assert!(bucket.check_rate_limit(&addr1).is_ok());
assert!(bucket.check_rate_limit(&addr1).is_err());
assert!(bucket.check_rate_limit(&addr2).is_ok());
}
#[test]
fn test_reset_functionality() {
let bucket = TokenBucket::new(1, 1).unwrap();
let addr = test_addr();
assert!(bucket.check_rate_limit(&addr).is_ok());
assert!(bucket.check_rate_limit(&addr).is_err());
bucket.reset(&addr);
assert!(bucket.check_rate_limit(&addr).is_ok());
}
#[test]
fn test_cleanup_expired() {
let bucket = TokenBucket::new(10, 10).unwrap();
let addr = test_addr();
assert!(bucket.check_rate_limit(&addr).is_ok());
{
let buckets = bucket.buckets.lock().unwrap();
assert!(buckets.contains_key(&addr));
}
bucket.cleanup_expired();
{
let buckets = bucket.buckets.lock().unwrap();
assert!(buckets.contains_key(&addr));
}
}
#[test]
fn test_rate_limit_error_retry_calculation() {
let bucket = TokenBucket::new(2, 1).unwrap(); let addr = test_addr();
assert!(bucket.check_rate_limit(&addr).is_ok());
match bucket.check_rate_limit(&addr) {
Err(RelayError::RateLimitExceeded { retry_after_ms }) => {
assert!((400..=600).contains(&retry_after_ms));
}
_ => panic!("Expected RateLimitExceeded error"),
}
}
}