use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimit {
pub utilization: f64,
pub resets_at: Option<String>,
pub remaining: Option<u32>,
pub limit: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub requests_per_minute: u32,
pub tokens_per_minute: u32,
pub burst: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: 60,
tokens_per_minute: 100000,
burst: true,
}
}
}
#[derive(Debug)]
pub struct TokenBucket {
capacity: u64,
tokens: u64,
refill_rate: f64, last_refill: Instant,
}
impl TokenBucket {
pub fn new(capacity: u64, refill_per_second: f64) -> Self {
let refill_rate = refill_per_second / 1000.0; Self {
capacity,
tokens: capacity,
refill_rate,
last_refill: Instant::now(),
}
}
pub fn try_consume(&mut self, tokens: u64) -> bool {
self.refill();
if self.tokens >= tokens {
self.tokens -= tokens;
true
} else {
false
}
}
fn refill(&mut self) {
let elapsed = self.last_refill.elapsed().as_millis() as f64;
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens as u64).min(self.capacity);
self.last_refill = Instant::now();
}
pub fn available(&self) -> u64 {
self.tokens
}
pub fn reset(&mut self) {
self.tokens = self.capacity;
self.last_refill = Instant::now();
}
}
#[derive(Debug)]
pub struct SlidingWindow {
max_requests: u32,
window_ms: u64,
requests: Vec<Instant>,
}
impl SlidingWindow {
pub fn new(max_requests: u32, window_duration: Duration) -> Self {
Self {
max_requests,
window_ms: window_duration.as_millis() as u64,
requests: Vec::new(),
}
}
pub fn try_acquire(&mut self) -> bool {
let now = Instant::now();
let window_start = now
.checked_sub(Duration::from_millis(self.window_ms))
.unwrap_or(now);
self.requests.retain(|&t| t > window_start);
if self.requests.len() < self.max_requests as usize {
self.requests.push(now);
true
} else {
false
}
}
pub fn time_until_available(&self) -> Option<Duration> {
if self.requests.len() < self.max_requests as usize {
return None;
}
let oldest = self.requests.iter().min()?;
let window_end = oldest
.checked_add(Duration::from_millis(self.window_ms))
.unwrap_or(*oldest);
let now = Instant::now();
if window_end > now {
Some(window_end.duration_since(now))
} else {
Some(Duration::ZERO)
}
}
pub fn current_count(&self) -> u32 {
let now = Instant::now();
let window_start = now
.checked_sub(Duration::from_millis(self.window_ms))
.unwrap_or(now);
self.requests.iter().filter(|&&t| t > window_start).count() as u32
}
pub fn reset(&mut self) {
self.requests.clear();
}
}
#[derive(Debug)]
pub struct RateLimiter {
request_limiter: SlidingWindow,
token_limiter: TokenBucket,
}
impl RateLimiter {
pub fn new(config: &RateLimitConfig) -> Self {
let request_limiter = SlidingWindow::new(
config.requests_per_minute,
Duration::from_secs(60),
);
let token_limiter = TokenBucket::new(
config.tokens_per_minute as u64,
config.tokens_per_minute as f64 / 60.0,
);
Self {
request_limiter,
token_limiter,
}
}
pub fn try_acquire(&mut self, token_count: u64) -> bool {
self.request_limiter.try_acquire()
&& self.token_limiter.try_consume(token_count)
}
pub async fn acquire(&mut self, token_count: u64) {
while !self.try_acquire(token_count) {
let request_wait = self.request_limiter.time_until_available();
let token_wait = if self.token_limiter.available() < token_count {
let deficit = token_count - self.token_limiter.available();
let refill_rate = 1000.0 / 60.0; Some(Duration::from_millis((deficit as f64 / refill_rate) as u64))
} else {
None
};
let wait_time = match (request_wait, token_wait) {
(Some(a), Some(b)) => std::cmp::min(a, b),
(Some(a), None) => a,
(None, Some(b)) => b,
(None, None) => Duration::from_millis(100),
};
tokio::time::sleep(wait_time).await;
}
}
pub fn status(&self) -> RateLimitStatus {
RateLimitStatus {
requests_remaining: self.request_limiter.max_requests
- self.request_limiter.current_count(),
tokens_remaining: self.token_limiter.available() as u32,
}
}
pub fn reset(&mut self) {
self.request_limiter.reset();
self.token_limiter.reset();
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitStatus {
pub requests_remaining: u32,
pub tokens_remaining: u32,
}
pub struct RateLimiterBuilder {
config: RateLimitConfig,
}
impl RateLimiterBuilder {
pub fn new() -> Self {
Self {
config: RateLimitConfig::default(),
}
}
pub fn requests_per_minute(mut self, rpm: u32) -> Self {
self.config.requests_per_minute = rpm;
self
}
pub fn tokens_per_minute(mut self, tpm: u32) -> Self {
self.config.tokens_per_minute = tpm;
self
}
pub fn burst(mut self, enable: bool) -> Self {
self.config.burst = enable;
self
}
pub fn build(self) -> RateLimiter {
RateLimiter::new(&self.config)
}
}
impl Default for RateLimiterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_bucket() {
let mut bucket = TokenBucket::new(10, 2.0);
assert!(bucket.try_consume(5));
assert!(bucket.try_consume(5));
assert!(!bucket.try_consume(1));
std::thread::sleep(Duration::from_millis(600));
assert!(bucket.try_consume(1)); }
#[test]
fn test_sliding_window() {
let mut window = SlidingWindow::new(3, Duration::from_millis(100));
assert!(window.try_acquire());
assert!(window.try_acquire());
assert!(window.try_acquire());
assert!(!window.try_acquire());
std::thread::sleep(Duration::from_millis(150));
assert!(window.try_acquire());
}
#[test]
fn test_sliding_window_count() {
let mut window = SlidingWindow::new(5, Duration::from_secs(1));
assert_eq!(window.current_count(), 0);
window.try_acquire();
window.try_acquire();
assert_eq!(window.current_count(), 2);
}
#[test]
fn test_rate_limiter_builder() {
let limiter = RateLimiterBuilder::new()
.requests_per_minute(100)
.tokens_per_minute(50000)
.build();
let status = limiter.status();
assert_eq!(status.requests_remaining, 100);
}
#[tokio::test]
async fn test_rate_limiter_acquire() {
let mut limiter = RateLimiterBuilder::new()
.requests_per_minute(10)
.tokens_per_minute(1000)
.build();
limiter.acquire(100).await;
let status = limiter.status();
assert!(status.requests_remaining < 10);
}
}