agent_diva_core/security/
rate_limit.rs1use parking_lot::Mutex;
4use std::time::{Duration, Instant};
5
6#[derive(Debug)]
8pub struct ActionTracker {
9 actions: Mutex<Vec<Instant>>,
11 window_secs: u64,
13}
14
15impl ActionTracker {
16 pub fn new() -> Self {
18 Self {
19 actions: Mutex::new(Vec::new()),
20 window_secs: 3600,
21 }
22 }
23
24 pub fn with_window(window_secs: u64) -> Self {
26 Self {
27 actions: Mutex::new(Vec::new()),
28 window_secs,
29 }
30 }
31
32 pub fn record(&self) -> usize {
34 let mut actions = self.actions.lock();
35 self.cleanup(&mut actions);
36 actions.push(Instant::now());
37 actions.len()
38 }
39
40 pub fn count(&self) -> usize {
42 let mut actions = self.actions.lock();
43 self.cleanup(&mut actions);
44 actions.len()
45 }
46
47 pub fn is_rate_limited(&self, max_actions: u32) -> bool {
49 self.count() >= max_actions as usize
50 }
51
52 pub fn try_record(&self, max_actions: u32) -> bool {
54 let mut actions = self.actions.lock();
55 self.cleanup(&mut actions);
56
57 if actions.len() >= max_actions as usize {
58 false
59 } else {
60 actions.push(Instant::now());
61 true
62 }
63 }
64
65 fn cleanup(&self, actions: &mut Vec<Instant>) {
67 let Some(cutoff) = Instant::now().checked_sub(Duration::from_secs(self.window_secs)) else {
69 return;
70 };
71 actions.retain(|t| *t > cutoff);
72 }
73
74 pub fn window_duration(&self) -> Duration {
76 Duration::from_secs(self.window_secs)
77 }
78
79 pub fn reset(&self) {
81 let mut actions = self.actions.lock();
82 actions.clear();
83 }
84}
85
86impl Default for ActionTracker {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl Clone for ActionTracker {
93 fn clone(&self) -> Self {
94 let actions = self.actions.lock();
95 Self {
96 actions: Mutex::new(actions.clone()),
97 window_secs: self.window_secs,
98 }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use std::thread;
106
107 #[test]
108 fn test_action_tracker_basic() {
109 let tracker = ActionTracker::with_window(1); assert_eq!(tracker.record(), 1);
113 assert_eq!(tracker.record(), 2);
114 assert_eq!(tracker.record(), 3);
115
116 assert_eq!(tracker.count(), 3);
118
119 thread::sleep(Duration::from_secs(2));
121
122 assert_eq!(tracker.count(), 0);
124 }
125
126 #[test]
127 fn test_rate_limiting() {
128 let tracker = ActionTracker::with_window(3600);
129
130 for i in 0..5 {
132 assert!(
133 !tracker.is_rate_limited(5),
134 "Should not be rate limited at action {}",
135 i
136 );
137 tracker.record();
138 }
139
140 assert!(
142 tracker.is_rate_limited(5),
143 "Should be rate limited after 5 actions with limit of 5"
144 );
145 assert_eq!(tracker.count(), 5, "Count should be 5");
146 assert!(
147 !tracker.try_record(5),
148 "Should not be able to record when rate limited"
149 );
150 assert!(
151 tracker.try_record(6),
152 "Should be able to record when limit is 6"
153 );
154 }
155
156 #[test]
157 fn test_clone() {
158 let tracker = ActionTracker::with_window(3600);
159 tracker.record();
160 tracker.record();
161
162 assert_eq!(tracker.count(), 2, "Original tracker should have 2 actions");
163
164 let cloned = tracker.clone();
165 assert_eq!(cloned.count(), 2, "Cloned tracker should have 2 actions");
166
167 cloned.record();
169 assert_eq!(
170 cloned.count(),
171 3,
172 "Cloned tracker should have 3 actions after record"
173 );
174 assert_eq!(
175 tracker.count(),
176 2,
177 "Original tracker should still have 2 actions"
178 );
179 }
180}