use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
pub mod limits {
pub const PUBLIC_REST_RATE: f64 = 10.0;
pub const PRIVATE_REST_RATE: f64 = 30.0;
pub const PUBLIC_WS_RATE: f64 = 8.0;
pub const PRIVATE_WS_RATE: f64 = 750.0;
}
#[derive(Debug, Clone)]
pub struct TokenBucket {
max_tokens: f64,
refill_rate: f64,
tokens: f64,
last_update: Instant,
}
impl TokenBucket {
pub fn new(max_tokens: f64, refill_rate: f64) -> Self {
Self {
max_tokens,
refill_rate,
tokens: max_tokens,
last_update: Instant::now(),
}
}
pub fn for_public_rest() -> Self {
Self::new(limits::PUBLIC_REST_RATE, limits::PUBLIC_REST_RATE)
}
pub fn for_private_rest() -> Self {
Self::new(limits::PRIVATE_REST_RATE, limits::PRIVATE_REST_RATE)
}
pub fn for_public_ws() -> Self {
Self::new(limits::PUBLIC_WS_RATE, limits::PUBLIC_WS_RATE)
}
pub fn for_private_ws() -> Self {
Self::new(limits::PRIVATE_WS_RATE, limits::PRIVATE_WS_RATE)
}
fn refill(&mut self) {
let now = Instant::now();
let elapsed = now.duration_since(self.last_update).as_secs_f64();
let new_tokens = elapsed * self.refill_rate;
self.tokens = (self.tokens + new_tokens).min(self.max_tokens);
self.last_update = now;
}
pub fn try_consume(&mut self) -> bool {
self.refill();
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
pub fn time_until_available(&self) -> Duration {
if self.tokens >= 1.0 {
Duration::ZERO
} else {
Duration::from_secs_f64((1.0 - self.tokens) / self.refill_rate)
}
}
pub async fn wait_and_consume(&mut self) {
while !self.try_consume() {
let wait_time = self.time_until_available();
tokio::time::sleep(wait_time).await;
}
}
pub fn available_tokens(&self) -> f64 {
self.tokens
}
}
#[derive(Clone)]
pub struct RateLimiter {
bucket: Arc<Mutex<TokenBucket>>,
}
impl RateLimiter {
pub fn new(bucket: TokenBucket) -> Self {
Self {
bucket: Arc::new(Mutex::new(bucket)),
}
}
pub fn for_public_rest() -> Self {
Self::new(TokenBucket::for_public_rest())
}
pub fn for_private_rest() -> Self {
Self::new(TokenBucket::for_private_rest())
}
pub async fn try_acquire(&self) -> bool {
let mut bucket = self.bucket.lock().await;
bucket.try_consume()
}
pub async fn acquire(&self) {
let mut bucket = self.bucket.lock().await;
bucket.wait_and_consume().await;
}
pub async fn available(&self) -> f64 {
let mut bucket = self.bucket.lock().await;
bucket.refill();
bucket.available_tokens()
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub enabled: bool,
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: true,
max_retries: 3,
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(60),
}
}
}
impl RateLimitConfig {
pub fn new() -> Self {
Self::default()
}
pub fn disabled() -> Self {
Self {
enabled: false,
..Default::default()
}
}
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_backoff(mut self, duration: Duration) -> Self {
self.initial_backoff = duration;
self
}
pub fn with_max_backoff(mut self, duration: Duration) -> Self {
self.max_backoff = duration;
self
}
}
#[derive(Debug, Clone)]
pub struct RateLimitInfo {
pub limit: Option<u32>,
pub remaining: Option<u32>,
pub reset: Option<u64>,
}
impl RateLimitInfo {
pub fn from_headers(headers: &reqwest::header::HeaderMap) -> Self {
Self {
limit: headers
.get("x-ratelimit-limit")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok()),
remaining: headers
.get("x-ratelimit-remaining")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok()),
reset: headers
.get("x-ratelimit-reset")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse().ok()),
}
}
pub fn is_exhausted(&self) -> bool {
self.remaining == Some(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_bucket_new() {
let bucket = TokenBucket::new(10.0, 5.0);
assert_eq!(bucket.max_tokens, 10.0);
assert_eq!(bucket.refill_rate, 5.0);
assert_eq!(bucket.tokens, 10.0);
}
#[test]
fn test_token_bucket_consume() {
let mut bucket = TokenBucket::new(5.0, 1.0);
for _ in 0..5 {
assert!(bucket.try_consume());
}
assert!(!bucket.try_consume());
}
#[test]
fn test_rate_limit_config_default() {
let config = RateLimitConfig::default();
assert!(config.enabled);
assert_eq!(config.max_retries, 3);
}
#[test]
fn test_rate_limit_config_disabled() {
let config = RateLimitConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_rate_limit_info_parse() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("x-ratelimit-limit", "100".parse().unwrap());
headers.insert("x-ratelimit-remaining", "50".parse().unwrap());
headers.insert("x-ratelimit-reset", "1234567890".parse().unwrap());
let info = RateLimitInfo::from_headers(&headers);
assert_eq!(info.limit, Some(100));
assert_eq!(info.remaining, Some(50));
assert_eq!(info.reset, Some(1234567890));
}
#[test]
fn test_rate_limit_info_exhausted() {
let info = RateLimitInfo {
limit: Some(100),
remaining: Some(0),
reset: Some(1234567890),
};
assert!(info.is_exhausted());
let info2 = RateLimitInfo {
limit: Some(100),
remaining: Some(50),
reset: Some(1234567890),
};
assert!(!info2.is_exhausted());
}
#[tokio::test]
async fn test_rate_limiter_acquire() {
let limiter = RateLimiter::new(TokenBucket::new(2.0, 10.0));
assert!(limiter.try_acquire().await);
assert!(limiter.try_acquire().await);
assert!(!limiter.try_acquire().await);
}
}