Skip to main content

agent_diva_core/security/
rate_limit.rs

1//! Sliding-window rate limiting for security actions
2
3use parking_lot::Mutex;
4use std::time::{Duration, Instant};
5
6/// Tracks actions in a sliding window for rate limiting
7#[derive(Debug)]
8pub struct ActionTracker {
9    /// Recent action timestamps (within the window)
10    actions: Mutex<Vec<Instant>>,
11    /// Window size in seconds (default: 3600 = 1 hour)
12    window_secs: u64,
13}
14
15impl ActionTracker {
16    /// Create a new action tracker with default 1-hour window
17    pub fn new() -> Self {
18        Self {
19            actions: Mutex::new(Vec::new()),
20            window_secs: 3600,
21        }
22    }
23
24    /// Create a new action tracker with custom window size
25    pub fn with_window(window_secs: u64) -> Self {
26        Self {
27            actions: Mutex::new(Vec::new()),
28            window_secs,
29        }
30    }
31
32    /// Record an action and return the current count in the window
33    pub fn record(&self) -> usize {
34        let mut actions = self.actions.lock();
35        self.cleanup(&mut actions);
36        actions.push(Instant::now());
37        actions.len()
38    }
39
40    /// Get the current action count without recording
41    pub fn count(&self) -> usize {
42        let mut actions = self.actions.lock();
43        self.cleanup(&mut actions);
44        actions.len()
45    }
46
47    /// Check if the action count exceeds the limit
48    pub fn is_rate_limited(&self, max_actions: u32) -> bool {
49        self.count() >= max_actions as usize
50    }
51
52    /// Try to record an action, returning false if rate limited
53    pub fn try_record(&self, max_actions: u32) -> bool {
54        let mut actions = self.actions.lock();
55        self.cleanup(&mut actions);
56
57        if actions.len() >= max_actions as usize {
58            false
59        } else {
60            actions.push(Instant::now());
61            true
62        }
63    }
64
65    /// Clean up expired actions
66    fn cleanup(&self, actions: &mut Vec<Instant>) {
67        // If we can't subtract (program running less than window), nothing is expired
68        let Some(cutoff) = Instant::now().checked_sub(Duration::from_secs(self.window_secs)) else {
69            return;
70        };
71        actions.retain(|t| *t > cutoff);
72    }
73
74    /// Get the window duration
75    pub fn window_duration(&self) -> Duration {
76        Duration::from_secs(self.window_secs)
77    }
78
79    /// Reset all tracked actions
80    pub fn reset(&self) {
81        let mut actions = self.actions.lock();
82        actions.clear();
83    }
84}
85
86impl Default for ActionTracker {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl Clone for ActionTracker {
93    fn clone(&self) -> Self {
94        let actions = self.actions.lock();
95        Self {
96            actions: Mutex::new(actions.clone()),
97            window_secs: self.window_secs,
98        }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use std::thread;
106
107    #[test]
108    fn test_action_tracker_basic() {
109        let tracker = ActionTracker::with_window(1); // 1 second window for testing
110
111        // Record some actions
112        assert_eq!(tracker.record(), 1);
113        assert_eq!(tracker.record(), 2);
114        assert_eq!(tracker.record(), 3);
115
116        // Check count
117        assert_eq!(tracker.count(), 3);
118
119        // Wait for window to expire
120        thread::sleep(Duration::from_secs(2));
121
122        // Actions should be cleaned up
123        assert_eq!(tracker.count(), 0);
124    }
125
126    #[test]
127    fn test_rate_limiting() {
128        let tracker = ActionTracker::with_window(3600);
129
130        // Record actions
131        for i in 0..5 {
132            assert!(
133                !tracker.is_rate_limited(5),
134                "Should not be rate limited at action {}",
135                i
136            );
137            tracker.record();
138        }
139
140        // After 5 records with limit of 5, should be rate limited
141        assert!(
142            tracker.is_rate_limited(5),
143            "Should be rate limited after 5 actions with limit of 5"
144        );
145        assert_eq!(tracker.count(), 5, "Count should be 5");
146        assert!(
147            !tracker.try_record(5),
148            "Should not be able to record when rate limited"
149        );
150        assert!(
151            tracker.try_record(6),
152            "Should be able to record when limit is 6"
153        );
154    }
155
156    #[test]
157    fn test_clone() {
158        let tracker = ActionTracker::with_window(3600);
159        tracker.record();
160        tracker.record();
161
162        assert_eq!(tracker.count(), 2, "Original tracker should have 2 actions");
163
164        let cloned = tracker.clone();
165        assert_eq!(cloned.count(), 2, "Cloned tracker should have 2 actions");
166
167        // Recording on clone should not affect original
168        cloned.record();
169        assert_eq!(
170            cloned.count(),
171            3,
172            "Cloned tracker should have 3 actions after record"
173        );
174        assert_eq!(
175            tracker.count(),
176            2,
177            "Original tracker should still have 2 actions"
178        );
179    }
180}