titanium_http/
ratelimit.rs1use dashmap::DashMap;
6use parking_lot::Mutex;
7use std::sync::Arc;
8use std::time::{Duration, Instant};
9use tokio::sync::Semaphore;
10use tokio::time::sleep;
11
12pub struct RateLimiter {
14 buckets: DashMap<String, Arc<Bucket>>,
16 #[allow(dead_code)]
18 global: Arc<Semaphore>,
19 global_until: Mutex<Option<Instant>>,
21}
22
23struct Bucket {
25 remaining: Mutex<u32>,
27 reset_at: Mutex<Instant>,
29 semaphore: Semaphore,
31}
32
33impl RateLimiter {
34 pub fn new() -> Self {
36 Self {
37 buckets: DashMap::new(),
38 global: Arc::new(Semaphore::new(50)), global_until: Mutex::new(None),
40 }
41 }
42
43 pub async fn acquire(&self, route: &str) -> Result<(), crate::HttpError> {
45 let until = { *self.global_until.lock() };
47 if let Some(until) = until {
48 if Instant::now() < until {
49 sleep(until - Instant::now()).await;
50 }
51 }
52
53 let bucket = self
55 .buckets
56 .entry(route.to_string())
57 .or_insert_with(|| {
58 Arc::new(Bucket {
59 remaining: Mutex::new(1),
60 reset_at: Mutex::new(Instant::now()),
61 semaphore: Semaphore::new(1),
62 })
63 })
64 .clone();
65
66 let _permit = bucket.semaphore.acquire().await.map_err(|_| {
68 crate::HttpError::ClientError("Rate limit semaphore closed".to_string())
69 })?;
70
71 let wait = {
73 let remaining = *bucket.remaining.lock();
74 if remaining == 0 {
75 let reset_at = *bucket.reset_at.lock();
76 if Instant::now() < reset_at {
77 Some(reset_at - Instant::now())
78 } else {
79 None
80 }
81 } else {
82 None
83 }
84 };
85
86 if let Some(duration) = wait {
87 sleep(duration).await;
88 }
89
90 Ok(())
91 }
92
93 pub fn update(&self, route: &str, remaining: u32, reset_after_ms: u64) {
95 if let Some(bucket) = self.buckets.get(route) {
96 *bucket.remaining.lock() = remaining;
97 *bucket.reset_at.lock() = Instant::now() + Duration::from_millis(reset_after_ms);
98 }
99 }
100
101 pub fn set_global(&self, retry_after_ms: u64) {
103 *self.global_until.lock() = Some(Instant::now() + Duration::from_millis(retry_after_ms));
104 }
105}
106
107impl Default for RateLimiter {
108 fn default() -> Self {
109 Self::new()
110 }
111}