Skip to main content

arbiter_audit/
stats.rs

1//! In-memory audit statistics tracker.
2//!
3//! Maintains lightweight per-session counters for denied requests and
4//! anomaly detections. Updated on each audit entry write. Queryable
5//! by session ID for session close summaries.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use tokio::sync::RwLock;
11
12/// Per-session audit statistics.
13#[derive(Debug, Default)]
14struct SessionCounters {
15    denied: AtomicU64,
16    anomalies: AtomicU64,
17}
18
19/// Queryable audit statistics returned to callers.
20#[derive(Debug, Clone, Default)]
21pub struct SessionAuditStats {
22    pub denied_count: u64,
23    pub anomaly_count: u64,
24}
25
26/// Aggregated audit statistics across all tracked sessions.
27#[derive(Debug, Clone, Default)]
28pub struct AggregateAuditStats {
29    pub total_denied: u64,
30    pub total_anomalies: u64,
31    pub sessions_with_anomalies: u64,
32    pub sessions_with_denials: u64,
33    /// Per-session stats keyed by session ID.
34    pub per_session: Vec<(String, SessionAuditStats)>,
35}
36
37/// Thread-safe audit stats tracker, shared between the audit write path
38/// and the session close query path.
39#[derive(Clone, Default)]
40pub struct AuditStats {
41    sessions: Arc<RwLock<HashMap<String, Arc<SessionCounters>>>>,
42}
43
44impl AuditStats {
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Record an audit entry. Inspects the entry to update per-session counters.
50    pub async fn record(&self, entry: &crate::entry::AuditEntry) {
51        if entry.task_session_id.is_empty() {
52            return;
53        }
54
55        let counters = {
56            let sessions = self.sessions.read().await;
57            sessions.get(&entry.task_session_id).cloned()
58        };
59
60        let counters = match counters {
61            Some(c) => c,
62            None => {
63                let mut sessions = self.sessions.write().await;
64                sessions
65                    .entry(entry.task_session_id.clone())
66                    .or_insert_with(|| Arc::new(SessionCounters::default()))
67                    .clone()
68            }
69        };
70
71        if entry.authorization_decision == "deny" {
72            counters.denied.fetch_add(1, Ordering::Relaxed);
73        }
74
75        if !entry.anomaly_flags.is_empty() {
76            counters.anomalies.fetch_add(1, Ordering::Relaxed);
77        }
78    }
79
80    /// Query stats for a specific session.
81    pub async fn stats_for_session(&self, session_id: &str) -> SessionAuditStats {
82        let sessions = self.sessions.read().await;
83        match sessions.get(session_id) {
84            Some(counters) => SessionAuditStats {
85                denied_count: counters.denied.load(Ordering::Relaxed),
86                anomaly_count: counters.anomalies.load(Ordering::Relaxed),
87            },
88            None => SessionAuditStats::default(),
89        }
90    }
91
92    /// Remove stats for a session (called after session close to prevent unbounded growth).
93    pub async fn remove_session(&self, session_id: &str) {
94        let mut sessions = self.sessions.write().await;
95        sessions.remove(session_id);
96    }
97
98    /// Aggregate stats across all tracked sessions. Returns totals and counts
99    /// of sessions that have anomalies or denials.
100    pub async fn aggregate(&self) -> AggregateAuditStats {
101        let sessions = self.sessions.read().await;
102        let mut total_denied: u64 = 0;
103        let mut total_anomalies: u64 = 0;
104        let mut sessions_with_anomalies: u64 = 0;
105        let mut sessions_with_denials: u64 = 0;
106        let mut per_session = Vec::with_capacity(sessions.len());
107
108        for (session_id, counters) in sessions.iter() {
109            let denied = counters.denied.load(Ordering::Relaxed);
110            let anomalies = counters.anomalies.load(Ordering::Relaxed);
111            total_denied += denied;
112            total_anomalies += anomalies;
113            if anomalies > 0 {
114                sessions_with_anomalies += 1;
115            }
116            if denied > 0 {
117                sessions_with_denials += 1;
118            }
119            per_session.push((
120                session_id.clone(),
121                SessionAuditStats {
122                    denied_count: denied,
123                    anomaly_count: anomalies,
124                },
125            ));
126        }
127
128        AggregateAuditStats {
129            total_denied,
130            total_anomalies,
131            sessions_with_anomalies,
132            sessions_with_denials,
133            per_session,
134        }
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::entry::AuditEntry;
142    use uuid::Uuid;
143
144    #[tokio::test]
145    async fn tracks_denied_and_anomalies() {
146        let stats = AuditStats::new();
147        let session_id = Uuid::new_v4().to_string();
148
149        // Allowed request, no counters increment.
150        let mut entry = AuditEntry::new(Uuid::new_v4());
151        entry.task_session_id = session_id.clone();
152        entry.authorization_decision = "allow".into();
153        stats.record(&entry).await;
154
155        // Denied request.
156        let mut entry = AuditEntry::new(Uuid::new_v4());
157        entry.task_session_id = session_id.clone();
158        entry.authorization_decision = "deny".into();
159        stats.record(&entry).await;
160
161        // Another denied request with anomaly flag.
162        let mut entry = AuditEntry::new(Uuid::new_v4());
163        entry.task_session_id = session_id.clone();
164        entry.authorization_decision = "deny".into();
165        entry.anomaly_flags = vec!["suspicious".into()];
166        stats.record(&entry).await;
167
168        let result = stats.stats_for_session(&session_id).await;
169        assert_eq!(result.denied_count, 2);
170        assert_eq!(result.anomaly_count, 1);
171    }
172
173    #[tokio::test]
174    async fn unknown_session_returns_zero() {
175        let stats = AuditStats::new();
176        let result = stats.stats_for_session("nonexistent").await;
177        assert_eq!(result.denied_count, 0);
178        assert_eq!(result.anomaly_count, 0);
179    }
180
181    #[tokio::test]
182    async fn remove_session_cleans_up() {
183        let stats = AuditStats::new();
184        let session_id = Uuid::new_v4().to_string();
185
186        let mut entry = AuditEntry::new(Uuid::new_v4());
187        entry.task_session_id = session_id.clone();
188        entry.authorization_decision = "deny".into();
189        stats.record(&entry).await;
190
191        stats.remove_session(&session_id).await;
192        let result = stats.stats_for_session(&session_id).await;
193        assert_eq!(result.denied_count, 0);
194    }
195
196    // -----------------------------------------------------------------------
197    // Unbounded memory growth in audit stats map
198    // -----------------------------------------------------------------------
199
200    /// Create 1000 unique session IDs and record entries for each. Verify
201    /// that aggregate() returns correct totals. This documents the growth
202    /// behavior: without remove_session(), the map grows monotonically.
203    /// Also verify that remove_session() can reclaim entries.
204    #[tokio::test]
205    async fn stats_growth_with_many_sessions() {
206        let stats = AuditStats::new();
207        let session_count = 1000;
208
209        let mut session_ids = Vec::with_capacity(session_count);
210        for i in 0..session_count {
211            let session_id = format!("session-{}", i);
212            session_ids.push(session_id.clone());
213
214            let mut entry = AuditEntry::new(Uuid::new_v4());
215            entry.task_session_id = session_id;
216            entry.authorization_decision = "deny".into();
217            entry.anomaly_flags = if i % 3 == 0 {
218                vec!["anomaly".into()]
219            } else {
220                vec![]
221            };
222            stats.record(&entry).await;
223        }
224
225        // Verify aggregate returns correct totals.
226        let agg = stats.aggregate().await;
227        assert_eq!(
228            agg.per_session.len(),
229            session_count,
230            "all {} sessions must be tracked",
231            session_count
232        );
233        assert_eq!(
234            agg.total_denied, session_count as u64,
235            "every session had one denied entry"
236        );
237        assert_eq!(
238            agg.sessions_with_denials, session_count as u64,
239            "all sessions have denials"
240        );
241        // Every 3rd session (i % 3 == 0) has an anomaly: 0, 3, 6, ..., 999
242        // That's ceil(1000/3) = 334 sessions.
243        let expected_anomaly_sessions = (0..session_count).filter(|i| i % 3 == 0).count() as u64;
244        assert_eq!(
245            agg.sessions_with_anomalies, expected_anomaly_sessions,
246            "every 3rd session should have anomaly flag"
247        );
248        assert_eq!(
249            agg.total_anomalies, expected_anomaly_sessions,
250            "one anomaly entry per anomaly session"
251        );
252
253        // Verify that removing sessions reduces the map size.
254        for session_id in &session_ids[..500] {
255            stats.remove_session(session_id).await;
256        }
257        let agg_after = stats.aggregate().await;
258        assert_eq!(
259            agg_after.per_session.len(),
260            500,
261            "after removing 500 sessions, 500 must remain"
262        );
263    }
264}