use std::collections::HashMap;
use std::collections::VecDeque;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct SimpleRateLimiter {
max_requests: u32,
window: Duration,
requests: VecDeque<Instant>,
}
impl SimpleRateLimiter {
pub fn new(max_requests: u32, window: Duration) -> Self {
Self {
max_requests,
window,
requests: VecDeque::with_capacity(max_requests as usize),
}
}
pub fn try_acquire(&mut self) -> bool {
self.cleanup();
if self.requests.len() < self.max_requests as usize {
self.requests.push_back(Instant::now());
true
} else {
false
}
}
pub fn time_until_ready(&self) -> Duration {
if self.requests.len() < self.max_requests as usize {
return Duration::ZERO;
}
if let Some(&oldest) = self.requests.front() {
let elapsed = oldest.elapsed();
if elapsed >= self.window {
Duration::ZERO
} else {
self.window - elapsed
}
} else {
Duration::ZERO
}
}
pub fn current_count(&mut self) -> u32 {
self.cleanup();
self.requests.len() as u32
}
pub fn max_requests(&self) -> u32 {
self.max_requests
}
pub fn remaining(&mut self) -> u32 {
self.cleanup();
self.max_requests.saturating_sub(self.requests.len() as u32)
}
pub fn update_from_server(&mut self, remaining: u32) {
let used = self.max_requests.saturating_sub(remaining);
self.requests.clear();
let now = Instant::now();
for _ in 0..used {
self.requests.push_back(now);
}
}
fn cleanup(&mut self) {
let now = Instant::now();
while let Some(&oldest) = self.requests.front() {
if now.duration_since(oldest) >= self.window {
self.requests.pop_front();
} else {
break;
}
}
}
}
#[derive(Debug, Clone)]
pub struct WeightRateLimiter {
max_weight: u32,
window: Duration,
entries: VecDeque<(Instant, u32)>,
last_server_used: Option<u32>,
last_server_update: Option<Instant>,
}
impl WeightRateLimiter {
pub fn new(max_weight: u32, window: Duration) -> Self {
Self {
max_weight,
window,
entries: VecDeque::new(),
last_server_used: None,
last_server_update: None,
}
}
pub fn try_acquire(&mut self, weight: u32) -> bool {
self.cleanup();
let current = self.current_weight();
if current + weight <= self.max_weight {
self.entries.push_back((Instant::now(), weight));
true
} else {
false
}
}
pub fn time_until_ready(&mut self, weight: u32) -> Duration {
let current = self.current_weight();
if current + weight <= self.max_weight {
return Duration::ZERO;
}
let needed = current + weight - self.max_weight;
let mut accumulated = 0;
for &(timestamp, entry_weight) in &self.entries {
accumulated += entry_weight;
if accumulated >= needed {
let elapsed = timestamp.elapsed();
if elapsed >= self.window {
return Duration::ZERO;
} else {
return self.window - elapsed;
}
}
}
Duration::ZERO
}
pub fn update_from_server(&mut self, used_weight: u32) {
self.last_server_used = Some(used_weight);
self.last_server_update = Some(Instant::now());
}
pub fn current_weight(&mut self) -> u32 {
self.cleanup();
if let (Some(server_weight), Some(server_time)) =
(self.last_server_used, self.last_server_update)
{
if server_time.elapsed() < self.window {
return server_weight;
}
}
self.entries.iter().map(|(_, weight)| weight).sum()
}
pub fn max_weight(&self) -> u32 {
self.max_weight
}
pub fn remaining(&mut self) -> u32 {
self.max_weight.saturating_sub(self.current_weight())
}
fn cleanup(&mut self) {
let now = Instant::now();
while let Some(&(timestamp, _)) = self.entries.front() {
if now.duration_since(timestamp) >= self.window {
self.entries.pop_front();
} else {
break;
}
}
if let Some(server_time) = self.last_server_update {
if now.duration_since(server_time) >= self.window {
self.last_server_used = None;
self.last_server_update = None;
}
}
}
}
#[derive(Debug, Clone)]
pub struct DecayingRateLimiter {
max_counter: f64,
decay_rate: f64,
counter: f64,
last_update: Instant,
}
impl DecayingRateLimiter {
pub fn new(max_counter: f64, decay_rate: f64) -> Self {
Self {
max_counter,
decay_rate,
counter: 0.0,
last_update: Instant::now(),
}
}
fn apply_decay(&mut self) {
let elapsed = self.last_update.elapsed().as_secs_f64();
self.counter = (self.counter - self.decay_rate * elapsed).max(0.0);
self.last_update = Instant::now();
}
pub fn try_acquire(&mut self, cost: f64) -> bool {
self.apply_decay();
if self.counter + cost <= self.max_counter {
self.counter += cost;
true
} else {
false
}
}
pub fn time_until_ready(&mut self, cost: f64) -> Duration {
self.apply_decay();
if self.counter + cost <= self.max_counter {
return Duration::ZERO;
}
let excess = self.counter + cost - self.max_counter;
let wait_secs = excess / self.decay_rate;
Duration::from_secs_f64(wait_secs)
}
pub fn current_level(&mut self) -> f64 {
self.apply_decay();
self.counter
}
pub fn max_level(&self) -> f64 {
self.max_counter
}
pub fn remaining(&mut self) -> f64 {
self.apply_decay();
(self.max_counter - self.counter).max(0.0)
}
}
#[derive(Debug, Clone)]
pub struct GroupRateLimiter {
groups: HashMap<&'static str, WeightRateLimiter>,
}
impl GroupRateLimiter {
pub fn new() -> Self {
Self {
groups: HashMap::new(),
}
}
pub fn add_group(&mut self, name: &'static str, max_weight: u32, window: Duration) {
self.groups
.insert(name, WeightRateLimiter::new(max_weight, window));
}
pub fn try_acquire(&mut self, group: &str, weight: u32) -> bool {
if let Some(limiter) = self.groups.get_mut(group) {
limiter.try_acquire(weight)
} else {
true }
}
pub fn time_until_ready(&mut self, group: &str, weight: u32) -> Duration {
if let Some(limiter) = self.groups.get_mut(group) {
limiter.time_until_ready(weight)
} else {
Duration::ZERO
}
}
pub fn update_from_server(&mut self, group: &str, used_weight: u32) {
if let Some(limiter) = self.groups.get_mut(group) {
limiter.update_from_server(used_weight);
}
}
pub fn group_stats(&mut self, group: &str) -> Option<(u32, u32)> {
self.groups
.get_mut(group)
.map(|l| (l.current_weight(), l.max_weight()))
}
pub fn all_stats(&mut self) -> Vec<(&str, u32, u32)> {
self.groups
.iter_mut()
.map(|(name, l)| (*name, l.current_weight(), l.max_weight()))
.collect()
}
pub fn primary_stats(&mut self) -> (u32, u32) {
self.groups
.values_mut()
.next()
.map(|l| (l.current_weight(), l.max_weight()))
.unwrap_or((0, 0))
}
}
impl Default for GroupRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_simple_rate_limiter_allows_under_limit() {
let mut limiter = SimpleRateLimiter::new(5, Duration::from_secs(1));
for i in 0..5 {
assert!(limiter.try_acquire(), "Request {} should be allowed", i + 1);
}
assert_eq!(limiter.current_count(), 5);
assert_eq!(limiter.remaining(), 0);
}
#[test]
fn test_simple_rate_limiter_blocks_over_limit() {
let mut limiter = SimpleRateLimiter::new(3, Duration::from_secs(1));
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire(), "4th request should be blocked");
assert_eq!(limiter.current_count(), 3);
}
#[test]
fn test_simple_rate_limiter_allows_after_window() {
let mut limiter = SimpleRateLimiter::new(2, Duration::from_millis(100));
assert!(limiter.try_acquire());
assert!(limiter.try_acquire());
assert!(!limiter.try_acquire());
thread::sleep(Duration::from_millis(110));
assert!(
limiter.try_acquire(),
"Request should be allowed after window expires"
);
}
#[test]
fn test_simple_rate_limiter_time_until_ready() {
let mut limiter = SimpleRateLimiter::new(1, Duration::from_secs(1));
assert!(limiter.try_acquire());
let wait = limiter.time_until_ready();
assert!(
wait > Duration::from_millis(900),
"Wait time should be close to 1 second"
);
assert!(
wait <= Duration::from_secs(1),
"Wait time should not exceed window"
);
}
#[test]
fn test_simple_rate_limiter_time_until_ready_when_available() {
let mut limiter = SimpleRateLimiter::new(5, Duration::from_secs(1));
assert!(limiter.try_acquire());
let wait = limiter.time_until_ready();
assert_eq!(
wait,
Duration::ZERO,
"Should return zero wait when capacity available"
);
}
#[test]
fn test_weight_rate_limiter_allows_under_limit() {
let mut limiter = WeightRateLimiter::new(100, Duration::from_secs(1));
assert!(limiter.try_acquire(10));
assert!(limiter.try_acquire(20));
assert!(limiter.try_acquire(30));
assert_eq!(limiter.current_weight(), 60);
assert_eq!(limiter.remaining(), 40);
}
#[test]
fn test_weight_rate_limiter_blocks_over_limit() {
let mut limiter = WeightRateLimiter::new(50, Duration::from_secs(1));
assert!(limiter.try_acquire(30));
assert!(limiter.try_acquire(15));
assert_eq!(limiter.current_weight(), 45);
assert!(
!limiter.try_acquire(10),
"Request should be blocked when it would exceed limit"
);
assert_eq!(
limiter.current_weight(),
45,
"Weight should not increase after blocked request"
);
}
#[test]
fn test_weight_rate_limiter_allows_after_window() {
let mut limiter = WeightRateLimiter::new(50, Duration::from_millis(100));
assert!(limiter.try_acquire(50));
assert!(!limiter.try_acquire(1));
thread::sleep(Duration::from_millis(110));
assert!(
limiter.try_acquire(50),
"Request should be allowed after window expires"
);
}
#[test]
fn test_weight_rate_limiter_time_until_ready() {
let mut limiter = WeightRateLimiter::new(100, Duration::from_secs(1));
assert!(limiter.try_acquire(100));
let wait = limiter.time_until_ready(1);
assert!(
wait > Duration::from_millis(900),
"Wait time should be close to 1 second"
);
assert!(
wait <= Duration::from_secs(1),
"Wait time should not exceed window"
);
}
#[test]
fn test_weight_rate_limiter_partial_expiry() {
let mut limiter = WeightRateLimiter::new(100, Duration::from_millis(100));
assert!(limiter.try_acquire(50));
thread::sleep(Duration::from_millis(60));
assert!(limiter.try_acquire(40));
thread::sleep(Duration::from_millis(50));
assert!(
limiter.try_acquire(50),
"Should allow request after partial expiry"
);
}
#[test]
fn test_weight_rate_limiter_server_update() {
let mut limiter = WeightRateLimiter::new(1000, Duration::from_secs(60));
assert!(limiter.try_acquire(100));
assert!(limiter.try_acquire(50));
assert_eq!(limiter.current_weight(), 150);
limiter.update_from_server(500);
assert_eq!(
limiter.current_weight(),
500,
"Should use server-reported weight"
);
assert_eq!(limiter.remaining(), 500);
}
#[test]
fn test_weight_rate_limiter_server_update_expires() {
let mut limiter = WeightRateLimiter::new(1000, Duration::from_millis(100));
limiter.update_from_server(500);
assert_eq!(limiter.current_weight(), 500);
thread::sleep(Duration::from_millis(110));
limiter.cleanup();
assert_eq!(
limiter.current_weight(),
0,
"Should revert to client tracking after server data expires"
);
}
#[test]
fn test_weight_rate_limiter_different_weights() {
let mut limiter = WeightRateLimiter::new(100, Duration::from_secs(1));
assert!(limiter.try_acquire(1)); assert!(limiter.try_acquire(1)); assert!(limiter.try_acquire(5)); assert!(limiter.try_acquire(10)); assert!(limiter.try_acquire(50));
assert_eq!(limiter.current_weight(), 67);
assert_eq!(limiter.remaining(), 33);
assert!(limiter.try_acquire(33));
assert!(!limiter.try_acquire(1), "Should be at capacity");
}
#[test]
fn test_simple_rate_limiter_update_from_server() {
let mut limiter = SimpleRateLimiter::new(10, Duration::from_secs(60));
limiter.update_from_server(3);
assert_eq!(limiter.remaining(), 3);
assert_eq!(limiter.current_count(), 7);
for _ in 0..3 {
assert!(
limiter.try_acquire(),
"Should allow request within remaining capacity"
);
}
assert!(
!limiter.try_acquire(),
"Should block when remaining exhausted"
);
}
#[test]
fn test_decaying_rate_limiter_allows_under_limit() {
let mut limiter = DecayingRateLimiter::new(15.0, 0.33);
for i in 0..15 {
assert!(
limiter.try_acquire(1.0),
"Request {} should be allowed",
i + 1
);
}
assert!(limiter.current_level() <= 15.0);
}
#[test]
fn test_decaying_rate_limiter_blocks_over_limit() {
let mut limiter = DecayingRateLimiter::new(10.0, 1.0);
assert!(
limiter.try_acquire(10.0),
"Should allow request at exactly max"
);
assert!(
!limiter.try_acquire(1.0),
"Should block when counter is at max"
);
}
#[test]
fn test_decaying_rate_limiter_decays_over_time() {
let mut limiter = DecayingRateLimiter::new(10.0, 10.0);
assert!(limiter.try_acquire(10.0));
let level_before = limiter.current_level();
assert!(
level_before > 9.0,
"Counter should be near 10 right after request"
);
thread::sleep(Duration::from_millis(200));
let level_after = limiter.current_level();
assert!(
level_after < level_before,
"Counter should decay over time: before={}, after={}",
level_before,
level_after
);
}
#[test]
fn test_decaying_rate_limiter_time_until_ready() {
let mut limiter = DecayingRateLimiter::new(10.0, 10.0);
assert!(limiter.try_acquire(10.0));
let wait = limiter.time_until_ready(5.0);
assert!(
wait > Duration::from_millis(400),
"Wait should be roughly 0.5s, got {:?}",
wait
);
assert!(
wait <= Duration::from_secs(1),
"Wait should not exceed 1s, got {:?}",
wait
);
}
#[test]
fn test_group_rate_limiter_independent_groups() {
let mut limiter = GroupRateLimiter::new();
limiter.add_group("public", 100, Duration::from_secs(10));
limiter.add_group("private", 20, Duration::from_secs(10));
assert!(limiter.try_acquire("private", 20));
assert!(
!limiter.try_acquire("private", 1),
"private group should be at capacity"
);
assert!(
limiter.try_acquire("public", 50),
"public group should be unaffected"
);
}
#[test]
fn test_group_rate_limiter_unknown_group_allows() {
let mut limiter = GroupRateLimiter::new();
limiter.add_group("public", 100, Duration::from_secs(10));
assert!(
limiter.try_acquire("nonexistent", 9999),
"Unknown group should return true"
);
assert_eq!(
limiter.time_until_ready("nonexistent", 9999),
Duration::ZERO,
"Unknown group wait time should be zero"
);
}
#[test]
fn test_group_rate_limiter_all_stats() {
let mut limiter = GroupRateLimiter::new();
limiter.add_group("spot", 50, Duration::from_secs(10));
limiter.add_group("futures", 200, Duration::from_secs(10));
limiter.try_acquire("spot", 10);
limiter.try_acquire("futures", 40);
let stats = limiter.all_stats();
assert_eq!(
stats.len(),
2,
"all_stats should return one entry per group"
);
for (name, current, max) in &stats {
match *name {
"spot" => {
assert_eq!(*max, 50);
assert_eq!(*current, 10);
}
"futures" => {
assert_eq!(*max, 200);
assert_eq!(*current, 40);
}
other => panic!("Unexpected group name: {}", other),
}
}
}
}