1use std::path::PathBuf;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use thiserror::Error;
7use tokio::fs::OpenOptions;
8use tokio::io::AsyncWriteExt;
9use tokio::sync::Mutex;
10
11use crate::entry::AuditEntry;
12
13#[derive(Debug, Error)]
15pub enum SinkError {
16 #[error("JSON serialization failed: {0}")]
17 Serialize(#[from] serde_json::Error),
18
19 #[error("file I/O failed: {0}")]
20 Io(#[from] std::io::Error),
21}
22
23#[derive(Debug, Clone)]
25pub struct AuditSinkConfig {
26 pub write_stdout: bool,
28
29 pub file_path: Option<PathBuf>,
31
32 pub max_file_size_bytes: u64,
36
37 pub hash_chain: bool,
40}
41
42const DEFAULT_MAX_AUDIT_FILE_SIZE: u64 = 100 * 1024 * 1024;
44
45impl Default for AuditSinkConfig {
46 fn default() -> Self {
47 Self {
48 write_stdout: true,
49 file_path: None,
50 max_file_size_bytes: DEFAULT_MAX_AUDIT_FILE_SIZE,
51 hash_chain: true,
52 }
53 }
54}
55
56struct ChainState {
63 sequence: u64,
65 prev_hash: String,
67}
68
69pub struct AuditSink {
70 config: AuditSinkConfig,
71 stats: crate::stats::AuditStats,
72 write_failures: AtomicU64,
74 total_write_failures: AtomicU64,
76 recovery_successes: AtomicU64,
80 chain: Mutex<ChainState>,
82 file: Option<Mutex<tokio::fs::File>>,
85}
86
87impl AuditSink {
88 pub fn new(config: AuditSinkConfig) -> Self {
90 Self {
91 config,
92 stats: crate::stats::AuditStats::new(),
93 write_failures: AtomicU64::new(0),
94 total_write_failures: AtomicU64::new(0),
95 recovery_successes: AtomicU64::new(0),
96 chain: Mutex::new(ChainState {
97 sequence: 0,
98 prev_hash: "genesis".into(),
99 }),
100 file: None,
101 }
102 }
103
104 pub async fn init_file(&mut self) -> Result<(), SinkError> {
107 if let Some(ref path) = self.config.file_path {
108 let file = OpenOptions::new()
109 .create(true)
110 .append(true)
111 .open(path)
112 .await?;
113 self.file = Some(Mutex::new(file));
114 }
115 Ok(())
116 }
117
118 pub fn stats(&self) -> &crate::stats::AuditStats {
120 &self.stats
121 }
122
123 const RECOVERY_THRESHOLD: u64 = 3;
126
127 pub fn is_degraded(&self) -> bool {
131 self.write_failures.load(Ordering::Relaxed) > 0
132 }
133
134 pub fn consecutive_failures(&self) -> u64 {
136 self.write_failures.load(Ordering::Relaxed)
137 }
138
139 pub fn total_failures(&self) -> u64 {
141 self.total_write_failures.load(Ordering::Relaxed)
142 }
143
144 pub async fn write(&self, entry: &AuditEntry) -> Result<(), SinkError> {
154 let mut chain_guard = self.chain.lock().await;
161
162 let mut chained_entry = entry.clone();
163 if self.config.hash_chain {
164 chain_guard.sequence += 1;
165 chained_entry.chain_sequence = Some(chain_guard.sequence);
166 chained_entry.chain_prev_hash = Some(chain_guard.prev_hash.clone());
167 chained_entry.chain_record_hash = None;
170 let pre_hash_json = serde_json::to_string(&chained_entry).unwrap_or_default();
171 let record_hash = blake3::hash(pre_hash_json.as_bytes()).to_hex().to_string();
172 chained_entry.chain_record_hash = Some(record_hash.clone());
173 chain_guard.prev_hash = record_hash;
174 }
175
176 let json = serde_json::to_string(&chained_entry)?;
177
178 if self.config.write_stdout {
179 tracing::info!(target: "arbiter_audit", audit_entry = %json);
181 }
182
183 if let Some(path) = &self.config.file_path {
184 match self.write_to_file(path, &json).await {
185 Ok(()) => {
186 let prev_failures = self.write_failures.load(Ordering::Relaxed);
187 if prev_failures > 0 {
188 let successes = self.recovery_successes.fetch_add(1, Ordering::Relaxed) + 1;
190 if successes >= Self::RECOVERY_THRESHOLD {
191 self.write_failures.store(0, Ordering::Relaxed);
192 self.recovery_successes.store(0, Ordering::Relaxed);
193 tracing::info!(
194 threshold = Self::RECOVERY_THRESHOLD,
195 "audit sink recovered after {} consecutive successful writes",
196 successes
197 );
198 }
199 }
200 }
201 Err(e) => {
202 let consecutive = self.write_failures.fetch_add(1, Ordering::Relaxed) + 1;
203 self.total_write_failures.fetch_add(1, Ordering::Relaxed);
204 self.recovery_successes.store(0, Ordering::Relaxed);
205 tracing::error!(
206 error = %e,
207 consecutive_failures = consecutive,
208 "audit file write failed; audit data may be lost"
209 );
210 return Err(e);
211 }
212 }
213 }
214
215 self.stats.record(entry).await;
219
220 Ok(())
221 }
222
223 async fn write_to_file(&self, path: &PathBuf, json: &str) -> Result<(), SinkError> {
224 if let Some(ref file_mutex) = self.file {
227 let mut file = file_mutex.lock().await;
228 file.write_all(json.as_bytes()).await?;
229 file.write_all(b"\n").await?;
230 file.flush().await?;
231 file.sync_all().await?;
232 return Ok(());
233 }
234
235 let mut file = OpenOptions::new()
237 .create(true)
238 .append(true)
239 .open(path)
240 .await?;
241 file.write_all(json.as_bytes()).await?;
242 file.write_all(b"\n").await?;
243 file.flush().await?;
244 file.sync_all().await?;
245 Ok(())
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use uuid::Uuid;
253
254 #[tokio::test]
255 async fn write_to_file() {
256 let dir = std::env::temp_dir().join(format!("arbiter-audit-test-{}", Uuid::new_v4()));
257 let file_path = dir.join("audit.jsonl");
258 tokio::fs::create_dir_all(&dir).await.unwrap();
259
260 let sink = AuditSink::new(AuditSinkConfig {
261 write_stdout: false,
262 file_path: Some(file_path.clone()),
263 ..Default::default()
264 });
265
266 let mut entry = AuditEntry::new(Uuid::new_v4());
267 entry.agent_id = "test-agent".into();
268 entry.tool_called = "test_tool".into();
269 entry.latency_ms = 10;
270
271 sink.write(&entry).await.unwrap();
272 sink.write(&entry).await.unwrap();
273
274 let contents = tokio::fs::read_to_string(&file_path).await.unwrap();
275 let lines: Vec<&str> = contents.trim().lines().collect();
276 assert_eq!(lines.len(), 2);
277
278 let parsed: AuditEntry = serde_json::from_str(lines[0]).unwrap();
280 assert_eq!(parsed.agent_id, "test-agent");
281
282 let _ = tokio::fs::remove_dir_all(&dir).await;
284 }
285
286 #[tokio::test]
287 async fn tracks_write_failures() {
288 let sink = AuditSink::new(AuditSinkConfig {
290 write_stdout: false,
291 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
292 ..Default::default()
293 });
294
295 assert!(!sink.is_degraded());
296 assert_eq!(sink.consecutive_failures(), 0);
297
298 let mut entry = AuditEntry::new(Uuid::new_v4());
299 entry.tool_called = "test".into();
300
301 assert!(sink.write(&entry).await.is_err());
303 assert!(sink.is_degraded());
304 assert_eq!(sink.consecutive_failures(), 1);
305 assert_eq!(sink.total_failures(), 1);
306
307 assert!(sink.write(&entry).await.is_err());
309 assert_eq!(sink.consecutive_failures(), 2);
310 assert_eq!(sink.total_failures(), 2);
311 }
312
313 #[tokio::test]
314 async fn resets_failures_on_success() {
315 let dir = std::env::temp_dir().join(format!("arbiter-audit-reset-{}", Uuid::new_v4()));
316 let file_path = dir.join("audit.jsonl");
317
318 let sink = AuditSink::new(AuditSinkConfig {
320 write_stdout: false,
321 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
322 ..Default::default()
323 });
324
325 let mut entry = AuditEntry::new(Uuid::new_v4());
326 entry.tool_called = "test".into();
327
328 let _ = sink.write(&entry).await;
330 assert!(sink.is_degraded());
331
332 tokio::fs::create_dir_all(&dir).await.unwrap();
335 let recovered_sink = AuditSink::new(AuditSinkConfig {
336 write_stdout: false,
337 file_path: Some(file_path.clone()),
338 ..Default::default()
339 });
340 recovered_sink.write_failures.store(3, Ordering::Relaxed);
342 assert!(recovered_sink.is_degraded());
343
344 for i in 1..AuditSink::RECOVERY_THRESHOLD {
346 recovered_sink.write(&entry).await.unwrap();
347 assert!(
348 recovered_sink.is_degraded(),
349 "should still be degraded after {i} successful write(s)"
350 );
351 }
352 recovered_sink.write(&entry).await.unwrap();
354 assert!(!recovered_sink.is_degraded());
355 assert_eq!(recovered_sink.consecutive_failures(), 0);
356
357 let _ = tokio::fs::remove_dir_all(&dir).await;
358 }
359
360 #[test]
361 fn serialization_produces_valid_json() {
362 let mut entry = AuditEntry::new(Uuid::new_v4());
363 entry.agent_id = "test-agent".into();
364 entry.tool_called = "dangerous_tool".into();
365 entry.authorization_decision = "deny".into();
366 entry.policy_matched = Some("block-dangerous".into());
367 entry.anomaly_flags = vec!["scope_violation".into(), "unusual_hour".into()];
368 entry.latency_ms = 7;
369 entry.upstream_status = Some(403);
370
371 let json = serde_json::to_string(&entry).unwrap();
372
373 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
375 assert_eq!(parsed.agent_id, "test-agent");
376 assert_eq!(parsed.authorization_decision, "deny");
377 assert_eq!(parsed.anomaly_flags.len(), 2);
378 assert_eq!(parsed.upstream_status, Some(403));
379
380 assert!(!json.contains('\n'), "JSON must be a single line");
382 }
383}