1use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use tokio::sync::RwLock;
11
12#[derive(Debug, Default)]
14struct SessionCounters {
15 denied: AtomicU64,
16 anomalies: AtomicU64,
17}
18
19#[derive(Debug, Clone, Default)]
21pub struct SessionAuditStats {
22 pub denied_count: u64,
23 pub anomaly_count: u64,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct AggregateAuditStats {
29 pub total_denied: u64,
30 pub total_anomalies: u64,
31 pub sessions_with_anomalies: u64,
32 pub sessions_with_denials: u64,
33 pub per_session: Vec<(String, SessionAuditStats)>,
35}
36
37#[derive(Clone, Default)]
40pub struct AuditStats {
41 sessions: Arc<RwLock<HashMap<String, Arc<SessionCounters>>>>,
42}
43
44impl AuditStats {
45 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub async fn record(&self, entry: &crate::entry::AuditEntry) {
51 if entry.task_session_id.is_empty() {
52 return;
53 }
54
55 let counters = {
56 let sessions = self.sessions.read().await;
57 sessions.get(&entry.task_session_id).cloned()
58 };
59
60 let counters = match counters {
61 Some(c) => c,
62 None => {
63 let mut sessions = self.sessions.write().await;
64 sessions
65 .entry(entry.task_session_id.clone())
66 .or_insert_with(|| Arc::new(SessionCounters::default()))
67 .clone()
68 }
69 };
70
71 if entry.authorization_decision == "deny" {
72 counters.denied.fetch_add(1, Ordering::Relaxed);
73 }
74
75 if !entry.anomaly_flags.is_empty() {
76 counters.anomalies.fetch_add(1, Ordering::Relaxed);
77 }
78 }
79
80 pub async fn stats_for_session(&self, session_id: &str) -> SessionAuditStats {
82 let sessions = self.sessions.read().await;
83 match sessions.get(session_id) {
84 Some(counters) => SessionAuditStats {
85 denied_count: counters.denied.load(Ordering::Relaxed),
86 anomaly_count: counters.anomalies.load(Ordering::Relaxed),
87 },
88 None => SessionAuditStats::default(),
89 }
90 }
91
92 pub async fn remove_session(&self, session_id: &str) {
94 let mut sessions = self.sessions.write().await;
95 sessions.remove(session_id);
96 }
97
98 pub async fn aggregate(&self) -> AggregateAuditStats {
101 let sessions = self.sessions.read().await;
102 let mut total_denied: u64 = 0;
103 let mut total_anomalies: u64 = 0;
104 let mut sessions_with_anomalies: u64 = 0;
105 let mut sessions_with_denials: u64 = 0;
106 let mut per_session = Vec::with_capacity(sessions.len());
107
108 for (session_id, counters) in sessions.iter() {
109 let denied = counters.denied.load(Ordering::Relaxed);
110 let anomalies = counters.anomalies.load(Ordering::Relaxed);
111 total_denied += denied;
112 total_anomalies += anomalies;
113 if anomalies > 0 {
114 sessions_with_anomalies += 1;
115 }
116 if denied > 0 {
117 sessions_with_denials += 1;
118 }
119 per_session.push((
120 session_id.clone(),
121 SessionAuditStats {
122 denied_count: denied,
123 anomaly_count: anomalies,
124 },
125 ));
126 }
127
128 AggregateAuditStats {
129 total_denied,
130 total_anomalies,
131 sessions_with_anomalies,
132 sessions_with_denials,
133 per_session,
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::entry::AuditEntry;
142 use uuid::Uuid;
143
144 #[tokio::test]
145 async fn tracks_denied_and_anomalies() {
146 let stats = AuditStats::new();
147 let session_id = Uuid::new_v4().to_string();
148
149 let mut entry = AuditEntry::new(Uuid::new_v4());
151 entry.task_session_id = session_id.clone();
152 entry.authorization_decision = "allow".into();
153 stats.record(&entry).await;
154
155 let mut entry = AuditEntry::new(Uuid::new_v4());
157 entry.task_session_id = session_id.clone();
158 entry.authorization_decision = "deny".into();
159 stats.record(&entry).await;
160
161 let mut entry = AuditEntry::new(Uuid::new_v4());
163 entry.task_session_id = session_id.clone();
164 entry.authorization_decision = "deny".into();
165 entry.anomaly_flags = vec!["suspicious".into()];
166 stats.record(&entry).await;
167
168 let result = stats.stats_for_session(&session_id).await;
169 assert_eq!(result.denied_count, 2);
170 assert_eq!(result.anomaly_count, 1);
171 }
172
173 #[tokio::test]
174 async fn unknown_session_returns_zero() {
175 let stats = AuditStats::new();
176 let result = stats.stats_for_session("nonexistent").await;
177 assert_eq!(result.denied_count, 0);
178 assert_eq!(result.anomaly_count, 0);
179 }
180
181 #[tokio::test]
182 async fn remove_session_cleans_up() {
183 let stats = AuditStats::new();
184 let session_id = Uuid::new_v4().to_string();
185
186 let mut entry = AuditEntry::new(Uuid::new_v4());
187 entry.task_session_id = session_id.clone();
188 entry.authorization_decision = "deny".into();
189 stats.record(&entry).await;
190
191 stats.remove_session(&session_id).await;
192 let result = stats.stats_for_session(&session_id).await;
193 assert_eq!(result.denied_count, 0);
194 }
195
196 #[tokio::test]
205 async fn stats_growth_with_many_sessions() {
206 let stats = AuditStats::new();
207 let session_count = 1000;
208
209 let mut session_ids = Vec::with_capacity(session_count);
210 for i in 0..session_count {
211 let session_id = format!("session-{}", i);
212 session_ids.push(session_id.clone());
213
214 let mut entry = AuditEntry::new(Uuid::new_v4());
215 entry.task_session_id = session_id;
216 entry.authorization_decision = "deny".into();
217 entry.anomaly_flags = if i % 3 == 0 {
218 vec!["anomaly".into()]
219 } else {
220 vec![]
221 };
222 stats.record(&entry).await;
223 }
224
225 let agg = stats.aggregate().await;
227 assert_eq!(
228 agg.per_session.len(),
229 session_count,
230 "all {} sessions must be tracked",
231 session_count
232 );
233 assert_eq!(
234 agg.total_denied, session_count as u64,
235 "every session had one denied entry"
236 );
237 assert_eq!(
238 agg.sessions_with_denials, session_count as u64,
239 "all sessions have denials"
240 );
241 let expected_anomaly_sessions = (0..session_count).filter(|i| i % 3 == 0).count() as u64;
244 assert_eq!(
245 agg.sessions_with_anomalies, expected_anomaly_sessions,
246 "every 3rd session should have anomaly flag"
247 );
248 assert_eq!(
249 agg.total_anomalies, expected_anomaly_sessions,
250 "one anomaly entry per anomaly session"
251 );
252
253 for session_id in &session_ids[..500] {
255 stats.remove_session(session_id).await;
256 }
257 let agg_after = stats.aggregate().await;
258 assert_eq!(
259 agg_after.per_session.len(),
260 500,
261 "after removing 500 sessions, 500 must remain"
262 );
263 }
264}