kraken_api_client/rate_limit/
keyed.rs1use std::collections::HashMap;
22use std::hash::Hash;
23use std::time::{Duration, Instant};
24
25#[derive(Debug)]
30pub struct KeyedRateLimiter<K> {
31 limiters: HashMap<K, SlidingWindow>,
33 window: Duration,
35 max_requests: u32,
37}
38
39impl<K> KeyedRateLimiter<K>
40where
41 K: Hash + Eq + Clone,
42{
43 pub fn new(window: Duration, max_requests: u32) -> Self {
50 Self {
51 limiters: HashMap::new(),
52 window,
53 max_requests,
54 }
55 }
56
57 pub fn try_acquire(&mut self, key: K) -> Result<(), Duration> {
62 let limiter = self
63 .limiters
64 .entry(key)
65 .or_insert_with(|| SlidingWindow::new(self.window, self.max_requests));
66
67 limiter.try_acquire()
68 }
69
70 pub fn would_allow(&self, key: &K) -> bool {
72 self.limiters
73 .get(key)
74 .is_none_or(|limiter| limiter.would_allow())
75 }
76
77 pub fn remaining(&self, key: &K) -> u32 {
79 self.limiters
80 .get(key)
81 .map_or(self.max_requests, |limiter| limiter.remaining())
82 }
83
84 pub fn time_until_available(&self, key: &K) -> Option<Duration> {
86 self.limiters
87 .get(key)
88 .and_then(|limiter| limiter.time_until_available())
89 }
90
91 pub fn remove(&mut self, key: &K) {
93 self.limiters.remove(key);
94 }
95
96 pub fn cleanup(&mut self) {
100 self.limiters.retain(|_, limiter| !limiter.is_empty());
101 }
102
103 pub fn tracked_keys(&self) -> usize {
105 self.limiters.len()
106 }
107
108 pub fn clear(&mut self) {
110 self.limiters.clear();
111 }
112}
113
114impl<K> Default for KeyedRateLimiter<K>
115where
116 K: Hash + Eq + Clone,
117{
118 fn default() -> Self {
119 Self::new(Duration::from_secs(1), 1)
121 }
122}
123
124#[derive(Debug)]
129pub struct SlidingWindow {
130 requests: Vec<Instant>,
132 window: Duration,
134 max_requests: u32,
136}
137
138impl SlidingWindow {
139 pub fn new(window: Duration, max_requests: u32) -> Self {
141 Self {
142 requests: Vec::with_capacity(max_requests as usize),
143 window,
144 max_requests,
145 }
146 }
147
148 pub fn try_acquire(&mut self) -> Result<(), Duration> {
152 self.cleanup_old();
153
154 if (self.requests.len() as u32) < self.max_requests {
155 self.requests.push(Instant::now());
156 Ok(())
157 } else {
158 let wait_time = self
160 .requests
161 .first()
162 .map(|oldest| self.window.saturating_sub(oldest.elapsed()))
163 .unwrap_or_default();
164 Err(wait_time)
165 }
166 }
167
168 pub fn would_allow(&self) -> bool {
170 let count = self
171 .requests
172 .iter()
173 .filter(|ts| ts.elapsed() < self.window)
174 .count();
175 (count as u32) < self.max_requests
176 }
177
178 pub fn remaining(&self) -> u32 {
180 let count = self
181 .requests
182 .iter()
183 .filter(|ts| ts.elapsed() < self.window)
184 .count() as u32;
185 self.max_requests.saturating_sub(count)
186 }
187
188 pub fn time_until_available(&self) -> Option<Duration> {
192 self.cleanup_check();
193
194 let count = self
195 .requests
196 .iter()
197 .filter(|ts| ts.elapsed() < self.window)
198 .count();
199
200 if (count as u32) < self.max_requests {
201 None
202 } else {
203 self.requests
205 .iter()
206 .find(|ts| ts.elapsed() < self.window)
207 .map(|oldest| self.window.saturating_sub(oldest.elapsed()))
208 }
209 }
210
211 pub fn is_empty(&self) -> bool {
213 self.requests.iter().all(|ts| ts.elapsed() >= self.window)
214 }
215
216 fn cleanup_old(&mut self) {
218 let window = self.window;
219 self.requests.retain(|ts| ts.elapsed() < window);
220 }
221
222 fn cleanup_check(&self) -> usize {
224 self.requests
225 .iter()
226 .filter(|ts| ts.elapsed() < self.window)
227 .count()
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::thread;
235
236 #[test]
237 fn test_sliding_window_allows_within_limit() {
238 let mut limiter = SlidingWindow::new(Duration::from_secs(1), 3);
239
240 assert!(limiter.try_acquire().is_ok());
241 assert!(limiter.try_acquire().is_ok());
242 assert!(limiter.try_acquire().is_ok());
243 assert!(limiter.try_acquire().is_err());
244 }
245
246 #[test]
247 fn test_sliding_window_resets_after_window() {
248 let mut limiter = SlidingWindow::new(Duration::from_millis(50), 2);
249
250 assert!(limiter.try_acquire().is_ok());
251 assert!(limiter.try_acquire().is_ok());
252 assert!(limiter.try_acquire().is_err());
253
254 thread::sleep(Duration::from_millis(60));
255
256 assert!(limiter.try_acquire().is_ok());
257 }
258
259 #[test]
260 fn test_remaining() {
261 let mut limiter = SlidingWindow::new(Duration::from_secs(1), 3);
262
263 assert_eq!(limiter.remaining(), 3);
264 limiter.try_acquire().ok();
265 assert_eq!(limiter.remaining(), 2);
266 limiter.try_acquire().ok();
267 assert_eq!(limiter.remaining(), 1);
268 }
269
270 #[test]
271 fn test_keyed_limiter() {
272 let mut limiter: KeyedRateLimiter<String> =
273 KeyedRateLimiter::new(Duration::from_secs(1), 2);
274
275 assert!(limiter.try_acquire("BTC/USD".to_string()).is_ok());
277 assert!(limiter.try_acquire("BTC/USD".to_string()).is_ok());
278 assert!(limiter.try_acquire("BTC/USD".to_string()).is_err());
279
280 assert!(limiter.try_acquire("ETH/USD".to_string()).is_ok());
282 assert!(limiter.try_acquire("ETH/USD".to_string()).is_ok());
283 assert!(limiter.try_acquire("ETH/USD".to_string()).is_err());
284 }
285
286 #[test]
287 fn test_keyed_limiter_cleanup() {
288 let mut limiter: KeyedRateLimiter<String> =
289 KeyedRateLimiter::new(Duration::from_millis(50), 1);
290
291 limiter.try_acquire("key1".to_string()).ok();
292 limiter.try_acquire("key2".to_string()).ok();
293 assert_eq!(limiter.tracked_keys(), 2);
294
295 thread::sleep(Duration::from_millis(60));
296 limiter.cleanup();
297
298 assert_eq!(limiter.tracked_keys(), 0);
299 }
300}