use std::sync::Mutex;
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct RateLimiter {
state: Mutex<RateLimiterState>,
rate: f64,
burst: u32,
}
#[derive(Debug)]
struct RateLimiterState {
tokens: f64,
last_update: Instant,
}
impl RateLimiter {
pub fn new(rate: f64, burst: u32) -> Self {
assert!(rate > 0.0, "rate must be positive");
assert!(burst > 0, "burst must be at least 1");
Self {
state: Mutex::new(RateLimiterState {
tokens: burst as f64,
last_update: Instant::now(),
}),
rate,
burst,
}
}
pub fn default_http() -> Self {
Self::new(10.0, 5)
}
pub fn for_schema_fetch() -> Self {
Self::new(2.0, 3)
}
pub fn for_telemetry() -> Self {
Self::new(1.0, 2)
}
pub fn try_acquire(&self) -> bool {
let mut state = self.state.lock().expect("rate limiter lock poisoned");
self.refill_tokens(&mut state);
if state.tokens >= 1.0 {
state.tokens -= 1.0;
true
} else {
false
}
}
pub fn acquire(&self) {
loop {
{
let mut state = self.state.lock().expect("rate limiter lock poisoned");
self.refill_tokens(&mut state);
if state.tokens >= 1.0 {
state.tokens -= 1.0;
return;
}
}
let sleep_duration = Duration::from_secs_f64(1.0 / self.rate);
std::thread::sleep(sleep_duration);
}
}
pub fn acquire_timeout(&self, timeout: Duration) -> bool {
let deadline = Instant::now() + timeout;
loop {
{
let mut state = self.state.lock().expect("rate limiter lock poisoned");
self.refill_tokens(&mut state);
if state.tokens >= 1.0 {
state.tokens -= 1.0;
return true;
}
}
if Instant::now() >= deadline {
return false;
}
let sleep_duration =
Duration::from_secs_f64(1.0 / self.rate).min(Duration::from_millis(100));
let remaining = deadline.saturating_duration_since(Instant::now());
std::thread::sleep(sleep_duration.min(remaining));
}
}
pub fn available_tokens(&self) -> f64 {
let mut state = self.state.lock().expect("rate limiter lock poisoned");
self.refill_tokens(&mut state);
state.tokens
}
pub fn rate(&self) -> f64 {
self.rate
}
pub fn burst(&self) -> u32 {
self.burst
}
fn refill_tokens(&self, state: &mut RateLimiterState) {
let now = Instant::now();
let elapsed = now.duration_since(state.last_update).as_secs_f64();
let tokens_to_add = elapsed * self.rate;
state.tokens = (state.tokens + tokens_to_add).min(self.burst as f64);
state.last_update = now;
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::default_http()
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_second: f64,
pub burst_size: u32,
pub enabled: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: 10.0,
burst_size: 5,
enabled: true,
}
}
}
impl RateLimitConfig {
pub fn build(&self) -> Option<RateLimiter> {
if self.enabled {
Some(RateLimiter::new(self.requests_per_second, self.burst_size))
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rate_limiter_creation() {
let limiter = RateLimiter::new(10.0, 5);
assert_eq!(limiter.rate(), 10.0);
assert_eq!(limiter.burst(), 5);
}
#[test]
fn test_initial_burst_available() {
let limiter = RateLimiter::new(10.0, 5);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
}
#[test]
fn test_default_constructors() {
let http = RateLimiter::default_http();
assert_eq!(http.rate(), 10.0);
assert_eq!(http.burst(), 5);
let schema = RateLimiter::for_schema_fetch();
assert_eq!(schema.rate(), 2.0);
assert_eq!(schema.burst(), 3);
let telemetry = RateLimiter::for_telemetry();
assert_eq!(telemetry.rate(), 1.0);
assert_eq!(telemetry.burst(), 2);
}
#[test]
fn test_refill_over_time() {
let limiter = RateLimiter::new(100.0, 5);
for _ in 0..5 {
assert!(limiter.try_acquire());
}
assert!(!limiter.try_acquire());
std::thread::sleep(Duration::from_millis(15));
assert!(limiter.try_acquire());
}
#[test]
fn test_acquire_blocking() {
let limiter = RateLimiter::new(100.0, 1);
assert!(limiter.try_acquire());
let start = Instant::now();
limiter.acquire();
let elapsed = start.elapsed();
assert!(elapsed.as_millis() >= 5);
}
#[test]
fn test_acquire_timeout_success() {
let limiter = RateLimiter::new(100.0, 1);
assert!(limiter.try_acquire());
assert!(limiter.acquire_timeout(Duration::from_millis(50)));
}
#[test]
fn test_acquire_timeout_failure() {
let limiter = RateLimiter::new(1.0, 1); assert!(limiter.try_acquire());
assert!(!limiter.acquire_timeout(Duration::from_millis(10)));
}
#[test]
fn test_config_build() {
let config = RateLimitConfig {
requests_per_second: 5.0,
burst_size: 3,
enabled: true,
};
let limiter = config.build();
assert!(limiter.is_some());
let limiter = limiter.unwrap();
assert_eq!(limiter.rate(), 5.0);
assert_eq!(limiter.burst(), 3);
}
#[test]
fn test_config_disabled() {
let config = RateLimitConfig {
requests_per_second: 5.0,
burst_size: 3,
enabled: false,
};
assert!(config.build().is_none());
}
#[test]
#[should_panic(expected = "rate must be positive")]
fn test_invalid_rate() {
RateLimiter::new(0.0, 5);
}
#[test]
#[should_panic(expected = "burst must be at least 1")]
fn test_invalid_burst() {
RateLimiter::new(10.0, 0);
}
}