idprova_core/policy/
rate.rs1use std::collections::{HashMap, VecDeque};
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10pub struct RateTracker {
16 inner: Mutex<RateTrackerInner>,
17 hour_window: Duration,
18 day_window: Duration,
19}
20
21struct RateTrackerInner {
22 actions: HashMap<String, VecDeque<Instant>>,
24 concurrent: HashMap<String, u64>,
26}
27
28impl RateTracker {
29 pub fn new() -> Self {
31 Self {
32 inner: Mutex::new(RateTrackerInner {
33 actions: HashMap::new(),
34 concurrent: HashMap::new(),
35 }),
36 hour_window: Duration::from_secs(3600),
37 day_window: Duration::from_secs(86400),
38 }
39 }
40
41 pub fn record_action(&self, agent_did: &str) {
43 let mut inner = self.inner.lock().unwrap();
44 let timestamps = inner.actions.entry(agent_did.to_string()).or_default();
45 timestamps.push_back(Instant::now());
46 }
47
48 pub fn get_counts(&self, agent_did: &str) -> (u64, u64, u64) {
50 let mut inner = self.inner.lock().unwrap();
51 let now = Instant::now();
52
53 let (hourly, daily) = if let Some(timestamps) = inner.actions.get_mut(agent_did) {
54 while timestamps
56 .front()
57 .is_some_and(|t| now.duration_since(*t) > self.day_window)
58 {
59 timestamps.pop_front();
60 }
61
62 let daily = timestamps.len() as u64;
63 let hourly = timestamps
64 .iter()
65 .filter(|t| now.duration_since(**t) <= self.hour_window)
66 .count() as u64;
67
68 (hourly, daily)
69 } else {
70 (0, 0)
71 };
72
73 let concurrent = inner.concurrent.get(agent_did).copied().unwrap_or(0);
74 (hourly, daily, concurrent)
75 }
76
77 pub fn acquire_concurrent(&self, agent_did: &str) {
79 let mut inner = self.inner.lock().unwrap();
80 let count = inner.concurrent.entry(agent_did.to_string()).or_insert(0);
81 *count += 1;
82 }
83
84 pub fn release_concurrent(&self, agent_did: &str) {
86 let mut inner = self.inner.lock().unwrap();
87 if let Some(count) = inner.concurrent.get_mut(agent_did) {
88 *count = count.saturating_sub(1);
89 }
90 }
91}
92
93impl Default for RateTracker {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use std::thread;
103
104 #[test]
105 fn test_record_and_count() {
106 let tracker = RateTracker::new();
107 let did = "did:idprova:test:agent1";
108
109 assert_eq!(tracker.get_counts(did), (0, 0, 0));
110
111 tracker.record_action(did);
112 tracker.record_action(did);
113 tracker.record_action(did);
114
115 let (hourly, daily, concurrent) = tracker.get_counts(did);
116 assert_eq!(hourly, 3);
117 assert_eq!(daily, 3);
118 assert_eq!(concurrent, 0);
119 }
120
121 #[test]
122 fn test_concurrent_tracking() {
123 let tracker = RateTracker::new();
124 let did = "did:idprova:test:agent1";
125
126 tracker.acquire_concurrent(did);
127 tracker.acquire_concurrent(did);
128 assert_eq!(tracker.get_counts(did).2, 2);
129
130 tracker.release_concurrent(did);
131 assert_eq!(tracker.get_counts(did).2, 1);
132
133 tracker.release_concurrent(did);
134 assert_eq!(tracker.get_counts(did).2, 0);
135
136 tracker.release_concurrent(did);
138 assert_eq!(tracker.get_counts(did).2, 0);
139 }
140
141 #[test]
142 fn test_separate_agents() {
143 let tracker = RateTracker::new();
144 let agent1 = "did:idprova:test:agent1";
145 let agent2 = "did:idprova:test:agent2";
146
147 tracker.record_action(agent1);
148 tracker.record_action(agent1);
149 tracker.record_action(agent2);
150
151 assert_eq!(tracker.get_counts(agent1).0, 2);
152 assert_eq!(tracker.get_counts(agent2).0, 1);
153 }
154
155 #[test]
156 fn test_thread_safety() {
157 let tracker = std::sync::Arc::new(RateTracker::new());
158 let did = "did:idprova:test:agent1";
159 let mut handles = vec![];
160
161 for _ in 0..10 {
162 let t = tracker.clone();
163 let d = did.to_string();
164 handles.push(thread::spawn(move || {
165 for _ in 0..100 {
166 t.record_action(&d);
167 }
168 }));
169 }
170
171 for h in handles {
172 h.join().unwrap();
173 }
174
175 let (hourly, daily, _) = tracker.get_counts(did);
176 assert_eq!(hourly, 1000);
177 assert_eq!(daily, 1000);
178 }
179}