titanium_http/
ratelimit.rs

1//! HTTP rate limiting.
2//!
3//! Implements Discord's bucket-based rate limiting system.
4
5use dashmap::DashMap;
6use parking_lot::Mutex;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::Semaphore;
10use tokio::time::sleep;
11
12/// Rate limiter for Discord API requests.
13pub struct RateLimiter {
14    /// Per-route buckets.
15    buckets: DashMap<String, Arc<Bucket>>,
16    /// Global rate limit semaphore.
17    #[allow(dead_code)]
18    global: Arc<Semaphore>,
19    /// Global rate limit until timestamp.
20    global_until: Mutex<Option<Instant>>,
21}
22
23/// A rate limit bucket for a specific route.
24struct Bucket {
25    /// Remaining requests in this bucket.
26    remaining: Mutex<u32>,
27    /// When the bucket resets.
28    reset_at: Mutex<Instant>,
29    /// Semaphore to queue requests.
30    semaphore: Semaphore,
31}
32
33impl RateLimiter {
34    /// Create a new rate limiter.
35    pub fn new() -> Self {
36        Self {
37            buckets: DashMap::new(),
38            global: Arc::new(Semaphore::new(50)), // Discord allows 50 requests/second globally
39            global_until: Mutex::new(None),
40        }
41    }
42
43    /// Acquire permission to make a request to the given route.
44    pub async fn acquire(&self, route: &str) {
45        // Check global rate limit
46        let until = { *self.global_until.lock() };
47        if let Some(until) = until {
48            if Instant::now() < until {
49                sleep(until - Instant::now()).await;
50            }
51        }
52
53        // Get or create bucket for route
54        let bucket = self
55            .buckets
56            .entry(route.to_string())
57            .or_insert_with(|| {
58                Arc::new(Bucket {
59                    remaining: Mutex::new(1),
60                    reset_at: Mutex::new(Instant::now()),
61                    semaphore: Semaphore::new(1),
62                })
63            })
64            .clone();
65
66        // Acquire semaphore permit
67        let _permit = bucket.semaphore.acquire().await.expect("semaphore closed");
68
69        // Check if we need to wait for reset
70        let wait = {
71            let remaining = *bucket.remaining.lock();
72            if remaining == 0 {
73                let reset_at = *bucket.reset_at.lock();
74                if Instant::now() < reset_at {
75                    Some(reset_at - Instant::now())
76                } else {
77                    None
78                }
79            } else {
80                None
81            }
82        };
83
84        if let Some(duration) = wait {
85            sleep(duration).await;
86        }
87    }
88
89    /// Update rate limit info from response headers.
90    pub fn update(&self, route: &str, remaining: u32, reset_after_ms: u64) {
91        if let Some(bucket) = self.buckets.get(route) {
92            *bucket.remaining.lock() = remaining;
93            *bucket.reset_at.lock() = Instant::now() + Duration::from_millis(reset_after_ms);
94        }
95    }
96
97    /// Set global rate limit.
98    pub fn set_global(&self, retry_after_ms: u64) {
99        *self.global_until.lock() = Some(Instant::now() + Duration::from_millis(retry_after_ms));
100    }
101}
102
103impl Default for RateLimiter {
104    fn default() -> Self {
105        Self::new()
106    }
107}