Skip to main content

heliosdb_proxy/rate_limit/
sliding_window.rs

1//! Sliding Window Rate Limiter
2//!
3//! Implements a sliding window algorithm for rate limiting over
4//! rolling time periods (e.g., queries per minute).
5
6use std::collections::VecDeque;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::time::{Duration, Instant};
9
10use parking_lot::Mutex;
11
12/// Sliding window rate limiter
13///
14/// Tracks events over a rolling time window, allowing precise
15/// rate limiting over time periods like "100 queries per minute".
16#[derive(Debug)]
17pub struct SlidingWindow {
18    /// Window duration
19    window_size: Duration,
20
21    /// Maximum events allowed in window
22    max_events: u32,
23
24    /// Event timestamps (relative to epoch)
25    events: Mutex<VecDeque<u64>>,
26
27    /// Epoch for time calculations
28    epoch: Instant,
29
30    /// Total events processed (for metrics)
31    total_events: AtomicU64,
32
33    /// Events rejected (for metrics)
34    rejected_events: AtomicU64,
35}
36
37impl SlidingWindow {
38    /// Create a new sliding window
39    ///
40    /// # Arguments
41    /// * `window_size` - Duration of the sliding window
42    /// * `max_events` - Maximum events allowed within the window
43    pub fn new(window_size: Duration, max_events: u32) -> Self {
44        Self {
45            window_size,
46            max_events,
47            events: Mutex::new(VecDeque::with_capacity(max_events as usize)),
48            epoch: Instant::now(),
49            total_events: AtomicU64::new(0),
50            rejected_events: AtomicU64::new(0),
51        }
52    }
53
54    /// Create a sliding window for events per second
55    pub fn per_second(max_events: u32) -> Self {
56        Self::new(Duration::from_secs(1), max_events)
57    }
58
59    /// Create a sliding window for events per minute
60    pub fn per_minute(max_events: u32) -> Self {
61        Self::new(Duration::from_secs(60), max_events)
62    }
63
64    /// Create a sliding window for events per hour
65    pub fn per_hour(max_events: u32) -> Self {
66        Self::new(Duration::from_secs(3600), max_events)
67    }
68
69    /// Try to record an event
70    ///
71    /// Returns Ok(()) if event was recorded, Err with wait time if limit exceeded.
72    pub fn try_record(&self) -> Result<(), SlidingWindowExceeded> {
73        self.try_record_n(1)
74    }
75
76    /// Try to record multiple events
77    pub fn try_record_n(&self, count: u32) -> Result<(), SlidingWindowExceeded> {
78        let now = self.epoch.elapsed().as_nanos() as u64;
79        let window_nanos = self.window_size.as_nanos() as u64;
80        let cutoff = now.saturating_sub(window_nanos);
81
82        let mut events = self.events.lock();
83
84        // Remove expired events
85        while let Some(&front) = events.front() {
86            if front < cutoff {
87                events.pop_front();
88            } else {
89                break;
90            }
91        }
92
93        // Check if we have room
94        let current_count = events.len() as u32;
95        if current_count + count > self.max_events {
96            self.rejected_events.fetch_add(count as u64, Ordering::Relaxed);
97
98            let wait_time = if let Some(&oldest) = events.front() {
99                let expires_at = oldest + window_nanos;
100                if expires_at > now {
101                    Duration::from_nanos(expires_at - now)
102                } else {
103                    Duration::ZERO
104                }
105            } else {
106                Duration::ZERO
107            };
108
109            return Err(SlidingWindowExceeded {
110                retry_after: wait_time,
111                current_count,
112                max_count: self.max_events,
113                window_size: self.window_size,
114            });
115        }
116
117        // Record events
118        for _ in 0..count {
119            events.push_back(now);
120        }
121
122        self.total_events.fetch_add(count as u64, Ordering::Relaxed);
123        Ok(())
124    }
125
126    /// Record an event, blocking until allowed (with timeout)
127    pub fn record_blocking(&self, timeout: Duration) -> Result<(), SlidingWindowExceeded> {
128        let deadline = Instant::now() + timeout;
129
130        loop {
131            match self.try_record() {
132                Ok(()) => return Ok(()),
133                Err(exceeded) => {
134                    let now = Instant::now();
135                    if now >= deadline {
136                        return Err(exceeded);
137                    }
138
139                    let wait = exceeded.retry_after.min(deadline - now);
140                    std::thread::sleep(wait);
141                }
142            }
143        }
144    }
145
146    /// Get current event count in window
147    pub fn current_count(&self) -> u32 {
148        let now = self.epoch.elapsed().as_nanos() as u64;
149        let cutoff = now.saturating_sub(self.window_size.as_nanos() as u64);
150
151        let events = self.events.lock();
152        events.iter().filter(|&&t| t >= cutoff).count() as u32
153    }
154
155    /// Get remaining capacity
156    pub fn remaining_capacity(&self) -> u32 {
157        self.max_events.saturating_sub(self.current_count())
158    }
159
160    /// Get window size
161    pub fn window_size(&self) -> Duration {
162        self.window_size
163    }
164
165    /// Get max events
166    pub fn max_events(&self) -> u32 {
167        self.max_events
168    }
169
170    /// Get utilization ratio (0.0 - 1.0)
171    pub fn utilization(&self) -> f64 {
172        self.current_count() as f64 / self.max_events as f64
173    }
174
175    /// Get total events processed
176    pub fn total_events(&self) -> u64 {
177        self.total_events.load(Ordering::Relaxed)
178    }
179
180    /// Get total events rejected
181    pub fn rejected_events(&self) -> u64 {
182        self.rejected_events.load(Ordering::Relaxed)
183    }
184
185    /// Get rejection rate (0.0 - 1.0)
186    pub fn rejection_rate(&self) -> f64 {
187        let total = self.total_events();
188        let rejected = self.rejected_events();
189        let attempted = total + rejected;
190
191        if attempted == 0 {
192            0.0
193        } else {
194            rejected as f64 / attempted as f64
195        }
196    }
197
198    /// Reset the sliding window
199    pub fn reset(&self) {
200        self.events.lock().clear();
201        self.total_events.store(0, Ordering::Relaxed);
202        self.rejected_events.store(0, Ordering::Relaxed);
203    }
204
205    /// Get event rate (events per second)
206    pub fn current_rate(&self) -> f64 {
207        let count = self.current_count();
208        count as f64 / self.window_size.as_secs_f64()
209    }
210
211    /// Estimate time until an event can be recorded
212    pub fn time_until_available(&self) -> Duration {
213        if self.remaining_capacity() > 0 {
214            return Duration::ZERO;
215        }
216
217        let now = self.epoch.elapsed().as_nanos() as u64;
218        let window_nanos = self.window_size.as_nanos() as u64;
219
220        let events = self.events.lock();
221        if let Some(&oldest) = events.front() {
222            let expires_at = oldest + window_nanos;
223            if expires_at > now {
224                return Duration::from_nanos(expires_at - now);
225            }
226        }
227
228        Duration::ZERO
229    }
230
231    /// Update max events (for dynamic limits)
232    pub fn set_max_events(&mut self, max_events: u32) {
233        self.max_events = max_events;
234    }
235
236    /// Update window size (for dynamic limits)
237    pub fn set_window_size(&mut self, window_size: Duration) {
238        self.window_size = window_size;
239    }
240}
241
242impl Clone for SlidingWindow {
243    fn clone(&self) -> Self {
244        Self {
245            window_size: self.window_size,
246            max_events: self.max_events,
247            events: Mutex::new(self.events.lock().clone()),
248            epoch: self.epoch,
249            total_events: AtomicU64::new(self.total_events.load(Ordering::Relaxed)),
250            rejected_events: AtomicU64::new(self.rejected_events.load(Ordering::Relaxed)),
251        }
252    }
253}
254
255/// Error returned when sliding window limit is exceeded
256#[derive(Debug, Clone)]
257pub struct SlidingWindowExceeded {
258    /// Time until an event slot opens up
259    pub retry_after: Duration,
260
261    /// Current event count in window
262    pub current_count: u32,
263
264    /// Maximum events allowed
265    pub max_count: u32,
266
267    /// Window size
268    pub window_size: Duration,
269}
270
271impl std::fmt::Display for SlidingWindowExceeded {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        write!(
274            f,
275            "Sliding window exceeded: {}/{} events in {:?}, retry after {}ms",
276            self.current_count,
277            self.max_count,
278            self.window_size,
279            self.retry_after.as_millis()
280        )
281    }
282}
283
284impl std::error::Error for SlidingWindowExceeded {}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_window_creation() {
292        let window = SlidingWindow::new(Duration::from_secs(60), 100);
293        assert_eq!(window.window_size(), Duration::from_secs(60));
294        assert_eq!(window.max_events(), 100);
295        assert_eq!(window.current_count(), 0);
296    }
297
298    #[test]
299    fn test_per_second() {
300        let window = SlidingWindow::per_second(10);
301        assert_eq!(window.window_size(), Duration::from_secs(1));
302        assert_eq!(window.max_events(), 10);
303    }
304
305    #[test]
306    fn test_per_minute() {
307        let window = SlidingWindow::per_minute(100);
308        assert_eq!(window.window_size(), Duration::from_secs(60));
309        assert_eq!(window.max_events(), 100);
310    }
311
312    #[test]
313    fn test_record_success() {
314        let window = SlidingWindow::new(Duration::from_secs(60), 10);
315
316        for i in 0..10 {
317            assert!(window.try_record().is_ok(), "Failed on event {}", i);
318        }
319
320        assert_eq!(window.current_count(), 10);
321    }
322
323    #[test]
324    fn test_record_exceeded() {
325        let window = SlidingWindow::new(Duration::from_secs(60), 5);
326
327        for _ in 0..5 {
328            assert!(window.try_record().is_ok());
329        }
330
331        let result = window.try_record();
332        assert!(result.is_err());
333
334        let err = result.unwrap_err();
335        assert_eq!(err.current_count, 5);
336        assert_eq!(err.max_count, 5);
337    }
338
339    #[test]
340    fn test_record_n() {
341        let window = SlidingWindow::new(Duration::from_secs(60), 10);
342
343        assert!(window.try_record_n(5).is_ok());
344        assert_eq!(window.current_count(), 5);
345
346        assert!(window.try_record_n(5).is_ok());
347        assert_eq!(window.current_count(), 10);
348
349        // Should fail - would exceed
350        assert!(window.try_record_n(1).is_err());
351    }
352
353    #[test]
354    fn test_event_expiration() {
355        let window = SlidingWindow::new(Duration::from_millis(50), 5);
356
357        // Fill window
358        for _ in 0..5 {
359            assert!(window.try_record().is_ok());
360        }
361        assert_eq!(window.current_count(), 5);
362
363        // Should be full
364        assert!(window.try_record().is_err());
365
366        // Wait for events to expire
367        std::thread::sleep(Duration::from_millis(60));
368
369        // Should be able to record again
370        assert!(window.try_record().is_ok());
371        // Count should be 1 (only the new event, old ones expired)
372        assert!(window.current_count() <= 2); // Allow some timing variance
373    }
374
375    #[test]
376    fn test_remaining_capacity() {
377        let window = SlidingWindow::new(Duration::from_secs(60), 10);
378
379        assert_eq!(window.remaining_capacity(), 10);
380
381        assert!(window.try_record_n(3).is_ok());
382        assert_eq!(window.remaining_capacity(), 7);
383
384        assert!(window.try_record_n(7).is_ok());
385        assert_eq!(window.remaining_capacity(), 0);
386    }
387
388    #[test]
389    fn test_utilization() {
390        let window = SlidingWindow::new(Duration::from_secs(60), 10);
391
392        assert!((window.utilization() - 0.0).abs() < 0.01);
393
394        assert!(window.try_record_n(5).is_ok());
395        assert!((window.utilization() - 0.5).abs() < 0.01);
396
397        assert!(window.try_record_n(5).is_ok());
398        assert!((window.utilization() - 1.0).abs() < 0.01);
399    }
400
401    #[test]
402    fn test_total_and_rejected() {
403        let window = SlidingWindow::new(Duration::from_secs(60), 3);
404
405        assert!(window.try_record().is_ok());
406        assert!(window.try_record().is_ok());
407        assert!(window.try_record().is_ok());
408        assert!(window.try_record().is_err());
409        assert!(window.try_record().is_err());
410
411        assert_eq!(window.total_events(), 3);
412        assert_eq!(window.rejected_events(), 2);
413    }
414
415    #[test]
416    fn test_rejection_rate() {
417        let window = SlidingWindow::new(Duration::from_secs(60), 2);
418
419        assert!(window.try_record().is_ok()); // 1 success
420        assert!(window.try_record().is_ok()); // 2 success
421        assert!(window.try_record().is_err()); // 1 failure
422        assert!(window.try_record().is_err()); // 2 failures
423
424        // 2 rejected out of 4 attempts = 50%
425        assert!((window.rejection_rate() - 0.5).abs() < 0.01);
426    }
427
428    #[test]
429    fn test_reset() {
430        let window = SlidingWindow::new(Duration::from_secs(60), 10);
431
432        assert!(window.try_record_n(5).is_ok());
433        assert_eq!(window.current_count(), 5);
434
435        window.reset();
436
437        assert_eq!(window.current_count(), 0);
438        assert_eq!(window.total_events(), 0);
439        assert_eq!(window.rejected_events(), 0);
440    }
441
442    #[test]
443    fn test_current_rate() {
444        let window = SlidingWindow::new(Duration::from_secs(10), 100);
445
446        assert!(window.try_record_n(50).is_ok());
447
448        // 50 events in a 10 second window = 5 events/sec
449        let rate = window.current_rate();
450        assert!((rate - 5.0).abs() < 0.1);
451    }
452
453    #[test]
454    fn test_time_until_available() {
455        let window = SlidingWindow::new(Duration::from_millis(100), 1);
456
457        // Empty window should be immediately available
458        assert_eq!(window.time_until_available(), Duration::ZERO);
459
460        // Fill window
461        assert!(window.try_record().is_ok());
462
463        // Should need to wait for expiration
464        let wait = window.time_until_available();
465        assert!(wait.as_millis() > 0);
466        assert!(wait.as_millis() <= 100);
467    }
468
469    #[test]
470    fn test_clone() {
471        let window1 = SlidingWindow::new(Duration::from_secs(60), 10);
472        assert!(window1.try_record_n(5).is_ok());
473
474        let window2 = window1.clone();
475        assert_eq!(window2.current_count(), 5);
476        assert_eq!(window2.max_events(), 10);
477    }
478
479    #[test]
480    fn test_concurrent_access() {
481        use std::sync::Arc;
482        use std::thread;
483
484        let window = Arc::new(SlidingWindow::new(Duration::from_secs(60), 100));
485        let mut handles = vec![];
486
487        // Spawn 10 threads, each trying to record 20 events
488        for _ in 0..10 {
489            let window = Arc::clone(&window);
490            handles.push(thread::spawn(move || {
491                for _ in 0..20 {
492                    let _ = window.try_record();
493                }
494            }));
495        }
496
497        for handle in handles {
498            handle.join().unwrap();
499        }
500
501        // Should have exactly 100 events (limited by max)
502        assert_eq!(window.current_count(), 100);
503        // Should have 100 rejected (200 attempts - 100 success)
504        assert_eq!(window.rejected_events(), 100);
505    }
506
507    #[test]
508    fn test_record_blocking() {
509        let window = SlidingWindow::new(Duration::from_millis(20), 1);
510
511        // Record first event
512        assert!(window.try_record().is_ok());
513
514        // Should succeed after waiting
515        let result = window.record_blocking(Duration::from_millis(50));
516        assert!(result.is_ok());
517    }
518
519    #[test]
520    fn test_record_blocking_timeout() {
521        let window = SlidingWindow::new(Duration::from_secs(60), 1);
522
523        // Fill window
524        assert!(window.try_record().is_ok());
525
526        // Should timeout
527        let result = window.record_blocking(Duration::from_millis(10));
528        assert!(result.is_err());
529    }
530}