use std::collections::VecDeque;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
pub entities_per_second: f64,
pub burst_size: u32,
pub backpressure: RateLimitBackpressure,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
entities_per_second: 1000.0,
burst_size: 100,
backpressure: RateLimitBackpressure::Block,
enabled: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RateLimitBackpressure {
#[default]
Block,
Drop,
Buffer {
max_buffered: usize,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum RateLimitAction {
Proceed,
Dropped,
Buffered {
position: usize,
},
Waited {
wait_time_ms: u64,
},
}
#[derive(Debug, Clone, Default)]
pub struct RateLimiterStats {
pub total_acquisitions: u64,
pub immediate_proceeds: u64,
pub waits: u64,
pub drops: u64,
pub buffers: u64,
pub total_wait_time_ms: u64,
pub current_tokens: f64,
pub buffer_size: usize,
}
pub struct RateLimiter {
config: RateLimitConfig,
tokens: f64,
last_refill: Instant,
buffer: VecDeque<Instant>,
stats: RateLimiterStats,
}
impl RateLimiter {
pub fn new(config: RateLimitConfig) -> Self {
Self {
tokens: config.burst_size as f64,
last_refill: Instant::now(),
buffer: VecDeque::new(),
stats: RateLimiterStats {
current_tokens: config.burst_size as f64,
..Default::default()
},
config,
}
}
pub fn with_rate(entities_per_second: f64) -> Self {
Self::new(RateLimitConfig {
entities_per_second,
..Default::default()
})
}
pub fn disabled() -> Self {
Self::new(RateLimitConfig {
enabled: false,
..Default::default()
})
}
pub fn acquire(&mut self) -> RateLimitAction {
if !self.config.enabled {
self.stats.total_acquisitions += 1;
self.stats.immediate_proceeds += 1;
return RateLimitAction::Proceed;
}
self.stats.total_acquisitions += 1;
self.refill_tokens();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
self.stats.current_tokens = self.tokens;
self.stats.immediate_proceeds += 1;
return RateLimitAction::Proceed;
}
match self.config.backpressure {
RateLimitBackpressure::Block => {
let wait_time = self.wait_for_token();
self.stats.waits += 1;
self.stats.total_wait_time_ms += wait_time;
RateLimitAction::Waited {
wait_time_ms: wait_time,
}
}
RateLimitBackpressure::Drop => {
self.stats.drops += 1;
RateLimitAction::Dropped
}
RateLimitBackpressure::Buffer { max_buffered } => {
if self.buffer.len() < max_buffered {
self.buffer.push_back(Instant::now());
self.stats.buffers += 1;
self.stats.buffer_size = self.buffer.len();
RateLimitAction::Buffered {
position: self.buffer.len(),
}
} else {
let wait_time = self.wait_for_token();
self.stats.waits += 1;
self.stats.total_wait_time_ms += wait_time;
RateLimitAction::Waited {
wait_time_ms: wait_time,
}
}
}
}
}
pub fn try_acquire(&mut self) -> Option<RateLimitAction> {
if !self.config.enabled {
self.stats.total_acquisitions += 1;
self.stats.immediate_proceeds += 1;
return Some(RateLimitAction::Proceed);
}
self.refill_tokens();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
self.stats.current_tokens = self.tokens;
self.stats.total_acquisitions += 1;
self.stats.immediate_proceeds += 1;
Some(RateLimitAction::Proceed)
} else {
None
}
}
pub fn acquire_timeout(&mut self, timeout: Duration) -> Option<RateLimitAction> {
if !self.config.enabled {
self.stats.total_acquisitions += 1;
self.stats.immediate_proceeds += 1;
return Some(RateLimitAction::Proceed);
}
self.stats.total_acquisitions += 1;
self.refill_tokens();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
self.stats.current_tokens = self.tokens;
self.stats.immediate_proceeds += 1;
return Some(RateLimitAction::Proceed);
}
let tokens_needed = 1.0 - self.tokens;
let time_needed = Duration::from_secs_f64(tokens_needed / self.config.entities_per_second);
if time_needed > timeout {
match self.config.backpressure {
RateLimitBackpressure::Drop => {
self.stats.drops += 1;
Some(RateLimitAction::Dropped)
}
_ => None,
}
} else {
std::thread::sleep(time_needed);
self.refill_tokens();
self.tokens -= 1.0;
self.stats.current_tokens = self.tokens;
self.stats.waits += 1;
self.stats.total_wait_time_ms += time_needed.as_millis() as u64;
Some(RateLimitAction::Waited {
wait_time_ms: time_needed.as_millis() as u64,
})
}
}
pub fn stats(&self) -> RateLimiterStats {
let mut stats = self.stats.clone();
stats.current_tokens = self.tokens;
stats.buffer_size = self.buffer.len();
stats
}
pub fn reset(&mut self) {
self.tokens = self.config.burst_size as f64;
self.last_refill = Instant::now();
self.buffer.clear();
self.stats = RateLimiterStats {
current_tokens: self.tokens,
..Default::default()
};
}
pub fn available_tokens(&self) -> f64 {
self.tokens
}
pub fn config(&self) -> &RateLimitConfig {
&self.config
}
pub fn set_rate(&mut self, entities_per_second: f64) {
self.config.entities_per_second = entities_per_second;
}
pub fn set_enabled(&mut self, enabled: bool) {
self.config.enabled = enabled;
}
fn refill_tokens(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill);
let new_tokens = elapsed.as_secs_f64() * self.config.entities_per_second;
self.tokens = (self.tokens + new_tokens).min(self.config.burst_size as f64);
self.last_refill = now;
}
fn wait_for_token(&mut self) -> u64 {
let tokens_needed = 1.0 - self.tokens;
let wait_secs = tokens_needed / self.config.entities_per_second;
let wait_duration = Duration::from_secs_f64(wait_secs);
std::thread::sleep(wait_duration);
self.refill_tokens();
self.tokens -= 1.0;
self.stats.current_tokens = self.tokens;
wait_duration.as_millis() as u64
}
pub fn process_buffer(&mut self) -> Vec<Duration> {
self.refill_tokens();
let mut wait_times = Vec::new();
while !self.buffer.is_empty() && self.tokens >= 1.0 {
if let Some(enqueue_time) = self.buffer.pop_front() {
let wait_time = enqueue_time.elapsed();
wait_times.push(wait_time);
self.tokens -= 1.0;
}
}
self.stats.buffer_size = self.buffer.len();
self.stats.current_tokens = self.tokens;
wait_times
}
}
pub struct RateLimitedIterator<I> {
inner: I,
limiter: RateLimiter,
}
impl<I> RateLimitedIterator<I> {
pub fn new(inner: I, limiter: RateLimiter) -> Self {
Self { inner, limiter }
}
pub fn with_rate(inner: I, entities_per_second: f64) -> Self {
Self::new(inner, RateLimiter::with_rate(entities_per_second))
}
pub fn stats(&self) -> RateLimiterStats {
self.limiter.stats()
}
}
impl<I: Iterator> Iterator for RateLimitedIterator<I> {
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
self.limiter.acquire();
self.inner.next()
}
}
pub trait RateLimitExt: Iterator + Sized {
fn rate_limit(self, entities_per_second: f64) -> RateLimitedIterator<Self> {
RateLimitedIterator::with_rate(self, entities_per_second)
}
fn rate_limit_with(self, config: RateLimitConfig) -> RateLimitedIterator<Self> {
RateLimitedIterator::new(self, RateLimiter::new(config))
}
}
impl<I: Iterator> RateLimitExt for I {}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_rate_limiter_immediate_proceed() {
let config = RateLimitConfig {
entities_per_second: 1000.0,
burst_size: 10,
..Default::default()
};
let mut limiter = RateLimiter::new(config);
for _ in 0..10 {
let action = limiter.acquire();
assert_eq!(action, RateLimitAction::Proceed);
}
let stats = limiter.stats();
assert_eq!(stats.total_acquisitions, 10);
assert_eq!(stats.immediate_proceeds, 10);
}
#[test]
fn test_rate_limiter_blocking() {
let config = RateLimitConfig {
entities_per_second: 1000.0,
burst_size: 1,
backpressure: RateLimitBackpressure::Block,
..Default::default()
};
let mut limiter = RateLimiter::new(config);
let action1 = limiter.acquire();
assert_eq!(action1, RateLimitAction::Proceed);
let action2 = limiter.acquire();
assert!(matches!(action2, RateLimitAction::Waited { .. }));
}
#[test]
fn test_rate_limiter_drop() {
let config = RateLimitConfig {
entities_per_second: 10.0,
burst_size: 1,
backpressure: RateLimitBackpressure::Drop,
..Default::default()
};
let mut limiter = RateLimiter::new(config);
let action1 = limiter.acquire();
assert_eq!(action1, RateLimitAction::Proceed);
let action2 = limiter.acquire();
assert_eq!(action2, RateLimitAction::Dropped);
let stats = limiter.stats();
assert_eq!(stats.drops, 1);
}
#[test]
fn test_rate_limiter_buffer() {
let config = RateLimitConfig {
entities_per_second: 10.0,
burst_size: 1,
backpressure: RateLimitBackpressure::Buffer { max_buffered: 5 },
..Default::default()
};
let mut limiter = RateLimiter::new(config);
let action1 = limiter.acquire();
assert_eq!(action1, RateLimitAction::Proceed);
let action2 = limiter.acquire();
assert!(matches!(action2, RateLimitAction::Buffered { position: 1 }));
let stats = limiter.stats();
assert_eq!(stats.buffers, 1);
assert_eq!(stats.buffer_size, 1);
}
#[test]
fn test_rate_limiter_try_acquire() {
let config = RateLimitConfig {
entities_per_second: 10.0,
burst_size: 1,
..Default::default()
};
let mut limiter = RateLimiter::new(config);
assert!(limiter.try_acquire().is_some());
assert!(limiter.try_acquire().is_none());
}
#[test]
fn test_rate_limiter_disabled() {
let mut limiter = RateLimiter::disabled();
for _ in 0..100 {
let action = limiter.acquire();
assert_eq!(action, RateLimitAction::Proceed);
}
}
#[test]
fn test_rate_limiter_reset() {
let config = RateLimitConfig {
entities_per_second: 10.0,
burst_size: 5,
..Default::default()
};
let mut limiter = RateLimiter::new(config);
for _ in 0..5 {
limiter.acquire();
}
assert!(limiter.available_tokens() < 1.0);
limiter.reset();
assert_eq!(limiter.available_tokens(), 5.0);
}
#[test]
fn test_rate_limited_iterator() {
let items = vec![1, 2, 3, 4, 5];
let rate_limited: Vec<_> = items
.into_iter()
.rate_limit_with(RateLimitConfig {
entities_per_second: 10000.0,
burst_size: 100,
..Default::default()
})
.collect();
assert_eq!(rate_limited, vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_rate_limiter_refill() {
let config = RateLimitConfig {
entities_per_second: 100.0, burst_size: 10,
..Default::default()
};
let mut limiter = RateLimiter::new(config);
for _ in 0..10 {
limiter.try_acquire();
}
assert!(limiter.available_tokens() < 1.0);
std::thread::sleep(Duration::from_millis(25));
assert!(limiter.try_acquire().is_some());
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert!(config.enabled);
assert_eq!(config.entities_per_second, 1000.0);
assert_eq!(config.burst_size, 100);
}
}