1use std::path::PathBuf;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use thiserror::Error;
7use tokio::fs::OpenOptions;
8use tokio::io::AsyncWriteExt;
9
10use crate::entry::AuditEntry;
11
12#[derive(Debug, Error)]
14pub enum SinkError {
15 #[error("JSON serialization failed: {0}")]
16 Serialize(#[from] serde_json::Error),
17
18 #[error("file I/O failed: {0}")]
19 Io(#[from] std::io::Error),
20}
21
22#[derive(Debug, Clone)]
24pub struct AuditSinkConfig {
25 pub write_stdout: bool,
27
28 pub file_path: Option<PathBuf>,
30}
31
32impl Default for AuditSinkConfig {
33 fn default() -> Self {
34 Self {
35 write_stdout: true,
36 file_path: None,
37 }
38 }
39}
40
41pub struct AuditSink {
47 config: AuditSinkConfig,
48 stats: crate::stats::AuditStats,
49 write_failures: AtomicU64,
51 total_write_failures: AtomicU64,
53 recovery_successes: AtomicU64,
57}
58
59impl AuditSink {
60 pub fn new(config: AuditSinkConfig) -> Self {
62 Self {
63 config,
64 stats: crate::stats::AuditStats::new(),
65 write_failures: AtomicU64::new(0),
66 total_write_failures: AtomicU64::new(0),
67 recovery_successes: AtomicU64::new(0),
68 }
69 }
70
71 pub fn stats(&self) -> &crate::stats::AuditStats {
73 &self.stats
74 }
75
76 const RECOVERY_THRESHOLD: u64 = 3;
79
80 pub fn is_degraded(&self) -> bool {
84 self.write_failures.load(Ordering::Relaxed) > 0
85 }
86
87 pub fn consecutive_failures(&self) -> u64 {
89 self.write_failures.load(Ordering::Relaxed)
90 }
91
92 pub fn total_failures(&self) -> u64 {
94 self.total_write_failures.load(Ordering::Relaxed)
95 }
96
97 pub async fn write(&self, entry: &AuditEntry) -> Result<(), SinkError> {
102 self.stats.record(entry).await;
104
105 let json = serde_json::to_string(entry)?;
106
107 if self.config.write_stdout {
108 tracing::info!(target: "arbiter_audit", audit_entry = %json);
110 }
111
112 if let Some(path) = &self.config.file_path {
113 match self.write_to_file(path, &json).await {
114 Ok(()) => {
115 let prev_failures = self.write_failures.load(Ordering::Relaxed);
116 if prev_failures > 0 {
117 let successes = self.recovery_successes.fetch_add(1, Ordering::Relaxed) + 1;
119 if successes >= Self::RECOVERY_THRESHOLD {
120 self.write_failures.store(0, Ordering::Relaxed);
121 self.recovery_successes.store(0, Ordering::Relaxed);
122 tracing::info!(
123 threshold = Self::RECOVERY_THRESHOLD,
124 "audit sink recovered after {} consecutive successful writes",
125 successes
126 );
127 }
128 }
129 }
130 Err(e) => {
131 let consecutive = self.write_failures.fetch_add(1, Ordering::Relaxed) + 1;
132 self.total_write_failures.fetch_add(1, Ordering::Relaxed);
133 self.recovery_successes.store(0, Ordering::Relaxed);
134 tracing::error!(
135 error = %e,
136 consecutive_failures = consecutive,
137 "audit file write failed; audit data may be lost"
138 );
139 return Err(e);
140 }
141 }
142 }
143
144 Ok(())
145 }
146
147 async fn write_to_file(&self, path: &PathBuf, json: &str) -> Result<(), SinkError> {
148 let mut file = OpenOptions::new()
149 .create(true)
150 .append(true)
151 .open(path)
152 .await?;
153 file.write_all(json.as_bytes()).await?;
154 file.write_all(b"\n").await?;
155 file.flush().await?;
156 Ok(())
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use uuid::Uuid;
164
165 #[tokio::test]
166 async fn write_to_file() {
167 let dir = std::env::temp_dir().join(format!("arbiter-audit-test-{}", Uuid::new_v4()));
168 let file_path = dir.join("audit.jsonl");
169 tokio::fs::create_dir_all(&dir).await.unwrap();
170
171 let sink = AuditSink::new(AuditSinkConfig {
172 write_stdout: false,
173 file_path: Some(file_path.clone()),
174 ..Default::default()
175 });
176
177 let mut entry = AuditEntry::new(Uuid::new_v4());
178 entry.agent_id = "test-agent".into();
179 entry.tool_called = "test_tool".into();
180 entry.latency_ms = 10;
181
182 sink.write(&entry).await.unwrap();
183 sink.write(&entry).await.unwrap();
184
185 let contents = tokio::fs::read_to_string(&file_path).await.unwrap();
186 let lines: Vec<&str> = contents.trim().lines().collect();
187 assert_eq!(lines.len(), 2);
188
189 let parsed: AuditEntry = serde_json::from_str(lines[0]).unwrap();
191 assert_eq!(parsed.agent_id, "test-agent");
192
193 let _ = tokio::fs::remove_dir_all(&dir).await;
195 }
196
197 #[tokio::test]
198 async fn tracks_write_failures() {
199 let sink = AuditSink::new(AuditSinkConfig {
201 write_stdout: false,
202 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
203 ..Default::default()
204 });
205
206 assert!(!sink.is_degraded());
207 assert_eq!(sink.consecutive_failures(), 0);
208
209 let mut entry = AuditEntry::new(Uuid::new_v4());
210 entry.tool_called = "test".into();
211
212 assert!(sink.write(&entry).await.is_err());
214 assert!(sink.is_degraded());
215 assert_eq!(sink.consecutive_failures(), 1);
216 assert_eq!(sink.total_failures(), 1);
217
218 assert!(sink.write(&entry).await.is_err());
220 assert_eq!(sink.consecutive_failures(), 2);
221 assert_eq!(sink.total_failures(), 2);
222 }
223
224 #[tokio::test]
225 async fn resets_failures_on_success() {
226 let dir = std::env::temp_dir().join(format!("arbiter-audit-reset-{}", Uuid::new_v4()));
227 let file_path = dir.join("audit.jsonl");
228
229 let sink = AuditSink::new(AuditSinkConfig {
231 write_stdout: false,
232 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
233 ..Default::default()
234 });
235
236 let mut entry = AuditEntry::new(Uuid::new_v4());
237 entry.tool_called = "test".into();
238
239 let _ = sink.write(&entry).await;
241 assert!(sink.is_degraded());
242
243 tokio::fs::create_dir_all(&dir).await.unwrap();
246 let recovered_sink = AuditSink::new(AuditSinkConfig {
247 write_stdout: false,
248 file_path: Some(file_path.clone()),
249 ..Default::default()
250 });
251 recovered_sink.write_failures.store(3, Ordering::Relaxed);
253 assert!(recovered_sink.is_degraded());
254
255 for i in 1..AuditSink::RECOVERY_THRESHOLD {
257 recovered_sink.write(&entry).await.unwrap();
258 assert!(
259 recovered_sink.is_degraded(),
260 "should still be degraded after {i} successful write(s)"
261 );
262 }
263 recovered_sink.write(&entry).await.unwrap();
265 assert!(!recovered_sink.is_degraded());
266 assert_eq!(recovered_sink.consecutive_failures(), 0);
267
268 let _ = tokio::fs::remove_dir_all(&dir).await;
269 }
270
271 #[test]
272 fn serialization_produces_valid_json() {
273 let mut entry = AuditEntry::new(Uuid::new_v4());
274 entry.agent_id = "test-agent".into();
275 entry.tool_called = "dangerous_tool".into();
276 entry.authorization_decision = "deny".into();
277 entry.policy_matched = Some("block-dangerous".into());
278 entry.anomaly_flags = vec!["scope_violation".into(), "unusual_hour".into()];
279 entry.latency_ms = 7;
280 entry.upstream_status = Some(403);
281
282 let json = serde_json::to_string(&entry).unwrap();
283
284 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
286 assert_eq!(parsed.agent_id, "test-agent");
287 assert_eq!(parsed.authorization_decision, "deny");
288 assert_eq!(parsed.anomaly_flags.len(), 2);
289 assert_eq!(parsed.upstream_status, Some(403));
290
291 assert!(!json.contains('\n'), "JSON must be a single line");
293 }
294}