Skip to main content

kraken_api_client/rate_limit/
keyed.rs

1//! Per-key rate limiting.
2//!
3//! This module provides rate limiting that can be applied on a per-key basis,
4//! such as per trading pair for order book requests.
5//!
6//! # Example
7//!
8//! ```rust
9//! use std::time::Duration;
10//! use kraken_api_client::rate_limit::KeyedRateLimiter;
11//!
12//! let mut limiter = KeyedRateLimiter::new(
13//!     Duration::from_secs(1),  // Window size
14//!     5,                        // Max requests per window
15//! );
16//!
17//! // Check if we can make a request for a specific key
18//! assert!(limiter.try_acquire("BTC/USD").is_ok());
19//! ```
20
21use std::collections::HashMap;
22use std::hash::Hash;
23use std::time::{Duration, Instant};
24
25/// Per-key rate limiter using a sliding window algorithm.
26///
27/// Each key (e.g., trading pair) has its own rate limit tracking.
28/// Useful for endpoints like order book that have per-pair limits.
29#[derive(Debug)]
30pub struct KeyedRateLimiter<K> {
31    /// Rate limits per key
32    limiters: HashMap<K, SlidingWindow>,
33    /// Window duration
34    window: Duration,
35    /// Maximum requests per window
36    max_requests: u32,
37}
38
39impl<K> KeyedRateLimiter<K>
40where
41    K: Hash + Eq + Clone,
42{
43    /// Create a new per-key rate limiter.
44    ///
45    /// # Arguments
46    ///
47    /// * `window` - The sliding window duration
48    /// * `max_requests` - Maximum number of requests allowed per window
49    pub fn new(window: Duration, max_requests: u32) -> Self {
50        Self {
51            limiters: HashMap::new(),
52            window,
53            max_requests,
54        }
55    }
56
57    /// Try to acquire a permit for the given key.
58    ///
59    /// Returns `Ok(())` if the request is allowed, or `Err(wait_time)` if
60    /// the rate limit has been exceeded and you need to wait.
61    pub fn try_acquire(&mut self, key: K) -> Result<(), Duration> {
62        let limiter = self
63            .limiters
64            .entry(key)
65            .or_insert_with(|| SlidingWindow::new(self.window, self.max_requests));
66
67        limiter.try_acquire()
68    }
69
70    /// Check if a request for the given key would be allowed without consuming a permit.
71    pub fn would_allow(&self, key: &K) -> bool {
72        self.limiters
73            .get(key)
74            .is_none_or(|limiter| limiter.would_allow())
75    }
76
77    /// Get the remaining permits for a key.
78    pub fn remaining(&self, key: &K) -> u32 {
79        self.limiters
80            .get(key)
81            .map_or(self.max_requests, |limiter| limiter.remaining())
82    }
83
84    /// Get the time until the next permit is available for a key.
85    pub fn time_until_available(&self, key: &K) -> Option<Duration> {
86        self.limiters
87            .get(key)
88            .and_then(|limiter| limiter.time_until_available())
89    }
90
91    /// Remove all rate limit tracking for a specific key.
92    pub fn remove(&mut self, key: &K) {
93        self.limiters.remove(key);
94    }
95
96    /// Clean up limiters that haven't been used recently.
97    ///
98    /// Removes limiters where all requests have expired from the window.
99    pub fn cleanup(&mut self) {
100        self.limiters.retain(|_, limiter| !limiter.is_empty());
101    }
102
103    /// Get the number of keys being tracked.
104    pub fn tracked_keys(&self) -> usize {
105        self.limiters.len()
106    }
107
108    /// Clear all rate limit tracking.
109    pub fn clear(&mut self) {
110        self.limiters.clear();
111    }
112}
113
114impl<K> Default for KeyedRateLimiter<K>
115where
116    K: Hash + Eq + Clone,
117{
118    fn default() -> Self {
119        // Default: 1 request per second per key
120        Self::new(Duration::from_secs(1), 1)
121    }
122}
123
124/// A sliding window rate limiter.
125///
126/// Tracks request timestamps within a sliding window and enforces a maximum
127/// number of requests within that window.
128#[derive(Debug)]
129pub struct SlidingWindow {
130    /// Request timestamps
131    requests: Vec<Instant>,
132    /// Window duration
133    window: Duration,
134    /// Maximum requests per window
135    max_requests: u32,
136}
137
138impl SlidingWindow {
139    /// Create a new sliding window rate limiter.
140    pub fn new(window: Duration, max_requests: u32) -> Self {
141        Self {
142            requests: Vec::with_capacity(max_requests as usize),
143            window,
144            max_requests,
145        }
146    }
147
148    /// Try to acquire a permit.
149    ///
150    /// Returns `Ok(())` if allowed, `Err(wait_time)` if rate limited.
151    pub fn try_acquire(&mut self) -> Result<(), Duration> {
152        self.cleanup_old();
153
154        if (self.requests.len() as u32) < self.max_requests {
155            self.requests.push(Instant::now());
156            Ok(())
157        } else {
158            // Find when the oldest request will expire.
159            let wait_time = self
160                .requests
161                .first()
162                .map(|oldest| self.window.saturating_sub(oldest.elapsed()))
163                .unwrap_or_default();
164            Err(wait_time)
165        }
166    }
167
168    /// Check if a request would be allowed without consuming a permit.
169    pub fn would_allow(&self) -> bool {
170        let count = self
171            .requests
172            .iter()
173            .filter(|ts| ts.elapsed() < self.window)
174            .count();
175        (count as u32) < self.max_requests
176    }
177
178    /// Get the number of remaining permits.
179    pub fn remaining(&self) -> u32 {
180        let count = self
181            .requests
182            .iter()
183            .filter(|ts| ts.elapsed() < self.window)
184            .count() as u32;
185        self.max_requests.saturating_sub(count)
186    }
187
188    /// Get the time until the next permit is available.
189    ///
190    /// Returns `None` if a permit is available now.
191    pub fn time_until_available(&self) -> Option<Duration> {
192        self.cleanup_check();
193
194        let count = self
195            .requests
196            .iter()
197            .filter(|ts| ts.elapsed() < self.window)
198            .count();
199
200        if (count as u32) < self.max_requests {
201            None
202        } else {
203            // Find the oldest request still in the window
204            self.requests
205                .iter()
206                .find(|ts| ts.elapsed() < self.window)
207                .map(|oldest| self.window.saturating_sub(oldest.elapsed()))
208        }
209    }
210
211    /// Check if the window has no active requests.
212    pub fn is_empty(&self) -> bool {
213        self.requests.iter().all(|ts| ts.elapsed() >= self.window)
214    }
215
216    /// Remove requests that are outside the window.
217    fn cleanup_old(&mut self) {
218        let window = self.window;
219        self.requests.retain(|ts| ts.elapsed() < window);
220    }
221
222    /// Internal cleanup check (immutable).
223    fn cleanup_check(&self) -> usize {
224        self.requests
225            .iter()
226            .filter(|ts| ts.elapsed() < self.window)
227            .count()
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::thread;
235
236    #[test]
237    fn test_sliding_window_allows_within_limit() {
238        let mut limiter = SlidingWindow::new(Duration::from_secs(1), 3);
239
240        assert!(limiter.try_acquire().is_ok());
241        assert!(limiter.try_acquire().is_ok());
242        assert!(limiter.try_acquire().is_ok());
243        assert!(limiter.try_acquire().is_err());
244    }
245
246    #[test]
247    fn test_sliding_window_resets_after_window() {
248        let mut limiter = SlidingWindow::new(Duration::from_millis(50), 2);
249
250        assert!(limiter.try_acquire().is_ok());
251        assert!(limiter.try_acquire().is_ok());
252        assert!(limiter.try_acquire().is_err());
253
254        thread::sleep(Duration::from_millis(60));
255
256        assert!(limiter.try_acquire().is_ok());
257    }
258
259    #[test]
260    fn test_remaining() {
261        let mut limiter = SlidingWindow::new(Duration::from_secs(1), 3);
262
263        assert_eq!(limiter.remaining(), 3);
264        limiter.try_acquire().ok();
265        assert_eq!(limiter.remaining(), 2);
266        limiter.try_acquire().ok();
267        assert_eq!(limiter.remaining(), 1);
268    }
269
270    #[test]
271    fn test_keyed_limiter() {
272        let mut limiter: KeyedRateLimiter<String> =
273            KeyedRateLimiter::new(Duration::from_secs(1), 2);
274
275        // Different keys have independent limits
276        assert!(limiter.try_acquire("BTC/USD".to_string()).is_ok());
277        assert!(limiter.try_acquire("BTC/USD".to_string()).is_ok());
278        assert!(limiter.try_acquire("BTC/USD".to_string()).is_err());
279
280        // ETH/USD has its own limit
281        assert!(limiter.try_acquire("ETH/USD".to_string()).is_ok());
282        assert!(limiter.try_acquire("ETH/USD".to_string()).is_ok());
283        assert!(limiter.try_acquire("ETH/USD".to_string()).is_err());
284    }
285
286    #[test]
287    fn test_keyed_limiter_cleanup() {
288        let mut limiter: KeyedRateLimiter<String> =
289            KeyedRateLimiter::new(Duration::from_millis(50), 1);
290
291        limiter.try_acquire("key1".to_string()).ok();
292        limiter.try_acquire("key2".to_string()).ok();
293        assert_eq!(limiter.tracked_keys(), 2);
294
295        thread::sleep(Duration::from_millis(60));
296        limiter.cleanup();
297
298        assert_eq!(limiter.tracked_keys(), 0);
299    }
300}