use std::collections::HashMap;
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone)]
pub struct AdaptiveRateLimitConfig {
pub base_rate: u64,
pub base_window_secs: u64,
pub min_rate: u64,
pub max_rate: u64,
pub reputation_multiplier: f64,
pub burst_multiplier: f64,
pub cleanup_interval_secs: u64,
}
impl Default for AdaptiveRateLimitConfig {
fn default() -> Self {
Self {
base_rate: 100,
base_window_secs: 60,
min_rate: 10,
max_rate: 1000,
reputation_multiplier: 2.0,
burst_multiplier: 1.5,
cleanup_interval_secs: 300,
}
}
}
#[derive(Debug, Clone)]
struct RequestRecord {
timestamp: SystemTime,
count: u64,
}
#[derive(Debug, Clone)]
struct PeerRateLimit {
current_limit: u64,
requests: Vec<RequestRecord>,
last_reset: SystemTime,
total_requests: u64,
violations: u64,
}
pub struct AdaptiveRateLimiter {
config: AdaptiveRateLimitConfig,
peer_limits: HashMap<String, PeerRateLimit>,
last_cleanup: SystemTime,
}
impl AdaptiveRateLimiter {
#[must_use]
#[inline]
pub fn new(config: AdaptiveRateLimitConfig) -> Self {
Self {
config,
peer_limits: HashMap::new(),
last_cleanup: SystemTime::now(),
}
}
pub fn check_rate_limit(&mut self, peer_id: &str, reputation_score: f64) -> bool {
self.maybe_cleanup();
let now = SystemTime::now();
let calculated_limit = self.calculate_limit(reputation_score);
let state = self
.peer_limits
.entry(peer_id.to_string())
.or_insert_with(|| PeerRateLimit {
current_limit: calculated_limit,
requests: Vec::new(),
last_reset: now,
total_requests: 0,
violations: 0,
});
state.current_limit = calculated_limit;
let window = Duration::from_secs(self.config.base_window_secs);
state.requests.retain(|r| {
if let Ok(age) = now.duration_since(r.timestamp) {
age < window
} else {
false
}
});
let current_count: u64 = state.requests.iter().map(|r| r.count).sum();
let burst_limit = (state.current_limit as f64 * self.config.burst_multiplier) as u64;
if current_count < burst_limit {
state.requests.push(RequestRecord {
timestamp: now,
count: 1,
});
state.total_requests += 1;
true
} else {
state.violations += 1;
false
}
}
#[inline]
fn calculate_limit(&self, reputation_score: f64) -> u64 {
let reputation_score = reputation_score.clamp(0.0, 1.0);
let multiplier = 1.0 + (reputation_score * (self.config.reputation_multiplier - 1.0));
let limit = (self.config.base_rate as f64 * multiplier) as u64;
limit.clamp(self.config.min_rate, self.config.max_rate)
}
#[must_use]
#[inline]
pub fn get_limit(&mut self, peer_id: &str, reputation_score: f64) -> u64 {
let limit = self.calculate_limit(reputation_score);
if let Some(state) = self.peer_limits.get_mut(peer_id) {
state.current_limit = limit;
}
limit
}
#[must_use]
#[inline]
pub fn get_remaining(&mut self, peer_id: &str, reputation_score: f64) -> u64 {
let now = SystemTime::now();
let window = Duration::from_secs(self.config.base_window_secs);
let state = match self.peer_limits.get_mut(peer_id) {
Some(s) => s,
None => return self.calculate_limit(reputation_score),
};
state.requests.retain(|r| {
if let Ok(age) = now.duration_since(r.timestamp) {
age < window
} else {
false
}
});
let current_count: u64 = state.requests.iter().map(|r| r.count).sum();
let limit = self.calculate_limit(reputation_score);
limit.saturating_sub(current_count)
}
#[must_use]
#[inline]
pub fn get_reset_time(&self, peer_id: &str) -> Option<Duration> {
let state = self.peer_limits.get(peer_id)?;
let now = SystemTime::now();
let oldest = state.requests.iter().min_by_key(|r| r.timestamp)?;
let window = Duration::from_secs(self.config.base_window_secs);
let age = now.duration_since(oldest.timestamp).ok()?;
if age < window {
Some(window - age)
} else {
Some(Duration::from_secs(0))
}
}
#[inline]
pub fn reset_peer(&mut self, peer_id: &str) {
if let Some(state) = self.peer_limits.get_mut(peer_id) {
state.requests.clear();
state.last_reset = SystemTime::now();
}
}
#[must_use]
#[inline]
pub fn get_peer_stats(&self, peer_id: &str) -> Option<PeerRateLimitStats> {
let state = self.peer_limits.get(peer_id)?;
let current_count: u64 = state.requests.iter().map(|r| r.count).sum();
Some(PeerRateLimitStats {
current_limit: state.current_limit,
current_usage: current_count,
total_requests: state.total_requests,
violations: state.violations,
})
}
#[must_use]
#[inline]
pub fn get_global_stats(&self) -> GlobalRateLimitStats {
let total_peers = self.peer_limits.len();
let total_requests: u64 = self.peer_limits.values().map(|s| s.total_requests).sum();
let total_violations: u64 = self.peer_limits.values().map(|s| s.violations).sum();
GlobalRateLimitStats {
total_peers,
total_requests,
total_violations,
}
}
#[inline]
fn maybe_cleanup(&mut self) {
let now = SystemTime::now();
if let Ok(duration) = now.duration_since(self.last_cleanup) {
if duration.as_secs() < self.config.cleanup_interval_secs {
return;
}
}
let cleanup_threshold = Duration::from_secs(self.config.base_window_secs * 5);
self.peer_limits.retain(|_, state| {
if state.requests.is_empty() {
if let Ok(age) = now.duration_since(state.last_reset) {
age < cleanup_threshold
} else {
true
}
} else {
true
}
});
self.last_cleanup = now;
}
#[inline]
pub fn remove_peer(&mut self, peer_id: &str) {
self.peer_limits.remove(peer_id);
}
#[must_use]
#[inline]
pub fn peer_count(&self) -> usize {
self.peer_limits.len()
}
#[inline]
pub fn clear(&mut self) {
self.peer_limits.clear();
self.last_cleanup = SystemTime::now();
}
}
#[derive(Debug, Clone)]
pub struct PeerRateLimitStats {
pub current_limit: u64,
pub current_usage: u64,
pub total_requests: u64,
pub violations: u64,
}
#[derive(Debug, Clone)]
pub struct GlobalRateLimitStats {
pub total_peers: usize,
pub total_requests: u64,
pub total_violations: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_basic_rate_limiting() {
let config = AdaptiveRateLimitConfig {
base_rate: 10,
base_window_secs: 1,
burst_multiplier: 1.0,
reputation_multiplier: 1.0,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
for _ in 0..10 {
assert!(limiter.check_rate_limit("peer1", 0.5));
}
assert!(!limiter.check_rate_limit("peer1", 0.5));
}
#[test]
fn test_reputation_based_limits() {
let config = AdaptiveRateLimitConfig {
base_rate: 100,
reputation_multiplier: 3.0,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
let low_limit = limiter.get_limit("peer1", 0.1);
let high_limit = limiter.get_limit("peer2", 0.9);
assert!(high_limit > low_limit);
}
#[test]
fn test_window_expiration() {
let config = AdaptiveRateLimitConfig {
base_rate: 5,
base_window_secs: 1,
burst_multiplier: 1.0,
reputation_multiplier: 1.0,
min_rate: 1,
max_rate: 1000,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
for _ in 0..5 {
assert!(limiter.check_rate_limit("peer1", 0.5));
}
assert!(!limiter.check_rate_limit("peer1", 0.5));
thread::sleep(Duration::from_millis(1100));
assert!(limiter.check_rate_limit("peer1", 0.5));
}
#[test]
fn test_burst_allowance() {
let config = AdaptiveRateLimitConfig {
base_rate: 10,
burst_multiplier: 2.0,
reputation_multiplier: 1.0,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
for _ in 0..20 {
assert!(limiter.check_rate_limit("peer1", 0.5));
}
assert!(!limiter.check_rate_limit("peer1", 0.5));
}
#[test]
fn test_get_remaining() {
let config = AdaptiveRateLimitConfig {
base_rate: 10,
burst_multiplier: 1.0,
reputation_multiplier: 1.0,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
assert_eq!(limiter.get_remaining("peer1", 0.5), 10);
limiter.check_rate_limit("peer1", 0.5);
limiter.check_rate_limit("peer1", 0.5);
limiter.check_rate_limit("peer1", 0.5);
assert_eq!(limiter.get_remaining("peer1", 0.5), 7);
}
#[test]
fn test_reset_peer() {
let config = AdaptiveRateLimitConfig {
base_rate: 5,
burst_multiplier: 1.0,
reputation_multiplier: 1.0,
min_rate: 1,
max_rate: 1000,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
for _ in 0..5 {
assert!(limiter.check_rate_limit("peer1", 0.5));
}
assert_eq!(limiter.get_remaining("peer1", 0.5), 0);
limiter.reset_peer("peer1");
assert_eq!(limiter.get_remaining("peer1", 0.5), 5);
}
#[test]
fn test_peer_stats() {
let config = AdaptiveRateLimitConfig {
base_rate: 10,
burst_multiplier: 1.0,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
limiter.check_rate_limit("peer1", 0.5);
limiter.check_rate_limit("peer1", 0.5);
limiter.check_rate_limit("peer1", 0.5);
let stats = limiter.get_peer_stats("peer1").unwrap();
assert_eq!(stats.total_requests, 3);
assert_eq!(stats.current_usage, 3);
}
#[test]
fn test_violation_tracking() {
let config = AdaptiveRateLimitConfig {
base_rate: 2,
burst_multiplier: 1.0,
reputation_multiplier: 1.0,
min_rate: 1,
max_rate: 1000,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
assert!(limiter.check_rate_limit("peer1", 0.5));
assert!(limiter.check_rate_limit("peer1", 0.5));
assert!(!limiter.check_rate_limit("peer1", 0.5)); assert!(!limiter.check_rate_limit("peer1", 0.5));
let stats = limiter.get_peer_stats("peer1").unwrap();
assert_eq!(stats.violations, 2);
}
#[test]
fn test_global_stats() {
let config = AdaptiveRateLimitConfig::default();
let mut limiter = AdaptiveRateLimiter::new(config);
limiter.check_rate_limit("peer1", 0.5);
limiter.check_rate_limit("peer2", 0.5);
limiter.check_rate_limit("peer3", 0.5);
let stats = limiter.get_global_stats();
assert_eq!(stats.total_peers, 3);
assert_eq!(stats.total_requests, 3);
}
#[test]
fn test_min_max_limits() {
let config = AdaptiveRateLimitConfig {
base_rate: 100,
min_rate: 50,
max_rate: 200,
reputation_multiplier: 10.0,
..Default::default()
};
let mut limiter = AdaptiveRateLimiter::new(config);
let low_limit = limiter.get_limit("peer1", 0.0);
assert_eq!(low_limit, 100);
let high_limit = limiter.get_limit("peer2", 1.0);
assert_eq!(high_limit, 200);
}
#[test]
fn test_remove_peer() {
let config = AdaptiveRateLimitConfig::default();
let mut limiter = AdaptiveRateLimiter::new(config);
limiter.check_rate_limit("peer1", 0.5);
assert_eq!(limiter.peer_count(), 1);
limiter.remove_peer("peer1");
assert_eq!(limiter.peer_count(), 0);
}
#[test]
fn test_clear() {
let config = AdaptiveRateLimitConfig::default();
let mut limiter = AdaptiveRateLimiter::new(config);
limiter.check_rate_limit("peer1", 0.5);
limiter.check_rate_limit("peer2", 0.5);
limiter.check_rate_limit("peer3", 0.5);
assert_eq!(limiter.peer_count(), 3);
limiter.clear();
assert_eq!(limiter.peer_count(), 0);
}
}