Skip to main content

idprova_core/policy/
rate.rs

1//! In-memory rate tracking for constraint enforcement.
2//!
3//! `RateTracker` maintains sliding-window action counts per agent DID,
4//! used to populate `EvaluationContext` fields for rate limit constraints.
5
6use std::collections::{HashMap, VecDeque};
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10/// Thread-safe in-memory rate tracker.
11///
12/// Tracks action timestamps per agent DID using sliding windows.
13/// Not persistent — resets on process restart (by design; DATs
14/// are short-lived and rate limits are best-effort).
15pub struct RateTracker {
16    inner: Mutex<RateTrackerInner>,
17    hour_window: Duration,
18    day_window: Duration,
19}
20
21struct RateTrackerInner {
22    /// Sliding-window timestamps per agent DID.
23    actions: HashMap<String, VecDeque<Instant>>,
24    /// Active concurrent operation counts per agent DID.
25    concurrent: HashMap<String, u64>,
26}
27
28impl RateTracker {
29    /// Create a new rate tracker with standard 1-hour and 24-hour windows.
30    pub fn new() -> Self {
31        Self {
32            inner: Mutex::new(RateTrackerInner {
33                actions: HashMap::new(),
34                concurrent: HashMap::new(),
35            }),
36            hour_window: Duration::from_secs(3600),
37            day_window: Duration::from_secs(86400),
38        }
39    }
40
41    /// Record an action for the given agent.
42    pub fn record_action(&self, agent_did: &str) {
43        let mut inner = self.inner.lock().unwrap();
44        let timestamps = inner.actions.entry(agent_did.to_string()).or_default();
45        timestamps.push_back(Instant::now());
46    }
47
48    /// Get current rate counts for an agent: (hourly, daily, concurrent).
49    pub fn get_counts(&self, agent_did: &str) -> (u64, u64, u64) {
50        let mut inner = self.inner.lock().unwrap();
51        let now = Instant::now();
52
53        let (hourly, daily) = if let Some(timestamps) = inner.actions.get_mut(agent_did) {
54            // Evict entries older than the day window
55            while timestamps
56                .front()
57                .is_some_and(|t| now.duration_since(*t) > self.day_window)
58            {
59                timestamps.pop_front();
60            }
61
62            let daily = timestamps.len() as u64;
63            let hourly = timestamps
64                .iter()
65                .filter(|t| now.duration_since(**t) <= self.hour_window)
66                .count() as u64;
67
68            (hourly, daily)
69        } else {
70            (0, 0)
71        };
72
73        let concurrent = inner.concurrent.get(agent_did).copied().unwrap_or(0);
74        (hourly, daily, concurrent)
75    }
76
77    /// Increment the concurrent operation count for an agent.
78    pub fn acquire_concurrent(&self, agent_did: &str) {
79        let mut inner = self.inner.lock().unwrap();
80        let count = inner.concurrent.entry(agent_did.to_string()).or_insert(0);
81        *count += 1;
82    }
83
84    /// Decrement the concurrent operation count for an agent.
85    pub fn release_concurrent(&self, agent_did: &str) {
86        let mut inner = self.inner.lock().unwrap();
87        if let Some(count) = inner.concurrent.get_mut(agent_did) {
88            *count = count.saturating_sub(1);
89        }
90    }
91}
92
93impl Default for RateTracker {
94    fn default() -> Self {
95        Self::new()
96    }
97}
98
99#[cfg(test)]
100mod tests {
101    use super::*;
102    use std::thread;
103
104    #[test]
105    fn test_record_and_count() {
106        let tracker = RateTracker::new();
107        let did = "did:idprova:test:agent1";
108
109        assert_eq!(tracker.get_counts(did), (0, 0, 0));
110
111        tracker.record_action(did);
112        tracker.record_action(did);
113        tracker.record_action(did);
114
115        let (hourly, daily, concurrent) = tracker.get_counts(did);
116        assert_eq!(hourly, 3);
117        assert_eq!(daily, 3);
118        assert_eq!(concurrent, 0);
119    }
120
121    #[test]
122    fn test_concurrent_tracking() {
123        let tracker = RateTracker::new();
124        let did = "did:idprova:test:agent1";
125
126        tracker.acquire_concurrent(did);
127        tracker.acquire_concurrent(did);
128        assert_eq!(tracker.get_counts(did).2, 2);
129
130        tracker.release_concurrent(did);
131        assert_eq!(tracker.get_counts(did).2, 1);
132
133        tracker.release_concurrent(did);
134        assert_eq!(tracker.get_counts(did).2, 0);
135
136        // Release below zero should saturate at 0
137        tracker.release_concurrent(did);
138        assert_eq!(tracker.get_counts(did).2, 0);
139    }
140
141    #[test]
142    fn test_separate_agents() {
143        let tracker = RateTracker::new();
144        let agent1 = "did:idprova:test:agent1";
145        let agent2 = "did:idprova:test:agent2";
146
147        tracker.record_action(agent1);
148        tracker.record_action(agent1);
149        tracker.record_action(agent2);
150
151        assert_eq!(tracker.get_counts(agent1).0, 2);
152        assert_eq!(tracker.get_counts(agent2).0, 1);
153    }
154
155    #[test]
156    fn test_thread_safety() {
157        let tracker = std::sync::Arc::new(RateTracker::new());
158        let did = "did:idprova:test:agent1";
159        let mut handles = vec![];
160
161        for _ in 0..10 {
162            let t = tracker.clone();
163            let d = did.to_string();
164            handles.push(thread::spawn(move || {
165                for _ in 0..100 {
166                    t.record_action(&d);
167                }
168            }));
169        }
170
171        for h in handles {
172            h.join().unwrap();
173        }
174
175        let (hourly, daily, _) = tracker.get_counts(did);
176        assert_eq!(hourly, 1000);
177        assert_eq!(daily, 1000);
178    }
179}