1use std::path::PathBuf;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use thiserror::Error;
7use tokio::fs::OpenOptions;
8use tokio::io::AsyncWriteExt;
9use tokio::sync::Mutex;
10
11use crate::entry::AuditEntry;
12
13#[derive(Debug, Error)]
15pub enum SinkError {
16 #[error("JSON serialization failed: {0}")]
17 Serialize(#[from] serde_json::Error),
18
19 #[error("file I/O failed: {0}")]
20 Io(#[from] std::io::Error),
21}
22
23#[derive(Debug, Clone)]
25pub struct AuditSinkConfig {
26 pub write_stdout: bool,
28
29 pub file_path: Option<PathBuf>,
31
32 pub max_file_size_bytes: u64,
36}
37
38const DEFAULT_MAX_AUDIT_FILE_SIZE: u64 = 100 * 1024 * 1024;
40
41impl Default for AuditSinkConfig {
42 fn default() -> Self {
43 Self {
44 write_stdout: true,
45 file_path: None,
46 max_file_size_bytes: DEFAULT_MAX_AUDIT_FILE_SIZE,
47 }
48 }
49}
50
51struct ChainState {
58 sequence: u64,
60 prev_hash: String,
62}
63
64pub struct AuditSink {
65 config: AuditSinkConfig,
66 stats: crate::stats::AuditStats,
67 write_failures: AtomicU64,
69 total_write_failures: AtomicU64,
71 recovery_successes: AtomicU64,
75 chain: Mutex<ChainState>,
77 file: Option<Mutex<tokio::fs::File>>,
80}
81
82impl AuditSink {
83 pub fn new(config: AuditSinkConfig) -> Self {
85 Self {
86 config,
87 stats: crate::stats::AuditStats::new(),
88 write_failures: AtomicU64::new(0),
89 total_write_failures: AtomicU64::new(0),
90 recovery_successes: AtomicU64::new(0),
91 chain: Mutex::new(ChainState {
92 sequence: 0,
93 prev_hash: "genesis".into(),
94 }),
95 file: None,
96 }
97 }
98
99 pub async fn init_file(&mut self) -> Result<(), SinkError> {
102 if let Some(ref path) = self.config.file_path {
103 let file = OpenOptions::new()
104 .create(true)
105 .append(true)
106 .open(path)
107 .await?;
108 self.file = Some(Mutex::new(file));
109 }
110 Ok(())
111 }
112
113 pub fn stats(&self) -> &crate::stats::AuditStats {
115 &self.stats
116 }
117
118 const RECOVERY_THRESHOLD: u64 = 3;
121
122 pub fn is_degraded(&self) -> bool {
126 self.write_failures.load(Ordering::Relaxed) > 0
127 }
128
129 pub fn consecutive_failures(&self) -> u64 {
131 self.write_failures.load(Ordering::Relaxed)
132 }
133
134 pub fn total_failures(&self) -> u64 {
136 self.total_write_failures.load(Ordering::Relaxed)
137 }
138
139 pub async fn write(&self, entry: &AuditEntry) -> Result<(), SinkError> {
144 let mut chained_entry = entry.clone();
146 {
147 let mut chain = self.chain.lock().await;
148 chain.sequence += 1;
149 chained_entry.chain_sequence = Some(chain.sequence);
150 chained_entry.chain_prev_hash = Some(chain.prev_hash.clone());
151 chained_entry.chain_record_hash = None;
154 let pre_hash_json = serde_json::to_string(&chained_entry).unwrap_or_default();
155 let record_hash = blake3::hash(pre_hash_json.as_bytes()).to_hex().to_string();
156 chained_entry.chain_record_hash = Some(record_hash.clone());
157 chain.prev_hash = record_hash;
158 }
159
160 let json = serde_json::to_string(&chained_entry)?;
161
162 if self.config.write_stdout {
163 tracing::info!(target: "arbiter_audit", audit_entry = %json);
165 }
166
167 if let Some(path) = &self.config.file_path {
168 match self.write_to_file(path, &json).await {
169 Ok(()) => {
170 let prev_failures = self.write_failures.load(Ordering::Relaxed);
171 if prev_failures > 0 {
172 let successes = self.recovery_successes.fetch_add(1, Ordering::Relaxed) + 1;
174 if successes >= Self::RECOVERY_THRESHOLD {
175 self.write_failures.store(0, Ordering::Relaxed);
176 self.recovery_successes.store(0, Ordering::Relaxed);
177 tracing::info!(
178 threshold = Self::RECOVERY_THRESHOLD,
179 "audit sink recovered after {} consecutive successful writes",
180 successes
181 );
182 }
183 }
184 }
185 Err(e) => {
186 let consecutive = self.write_failures.fetch_add(1, Ordering::Relaxed) + 1;
187 self.total_write_failures.fetch_add(1, Ordering::Relaxed);
188 self.recovery_successes.store(0, Ordering::Relaxed);
189 tracing::error!(
190 error = %e,
191 consecutive_failures = consecutive,
192 "audit file write failed; audit data may be lost"
193 );
194 return Err(e);
195 }
196 }
197 }
198
199 self.stats.record(entry).await;
203
204 Ok(())
205 }
206
207 async fn write_to_file(&self, path: &PathBuf, json: &str) -> Result<(), SinkError> {
208 if let Some(ref file_mutex) = self.file {
211 let mut file = file_mutex.lock().await;
212 file.write_all(json.as_bytes()).await?;
213 file.write_all(b"\n").await?;
214 file.flush().await?;
215 file.sync_all().await?;
216 return Ok(());
217 }
218
219 let mut file = OpenOptions::new()
221 .create(true)
222 .append(true)
223 .open(path)
224 .await?;
225 file.write_all(json.as_bytes()).await?;
226 file.write_all(b"\n").await?;
227 file.flush().await?;
228 file.sync_all().await?;
229 Ok(())
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use uuid::Uuid;
237
238 #[tokio::test]
239 async fn write_to_file() {
240 let dir = std::env::temp_dir().join(format!("arbiter-audit-test-{}", Uuid::new_v4()));
241 let file_path = dir.join("audit.jsonl");
242 tokio::fs::create_dir_all(&dir).await.unwrap();
243
244 let sink = AuditSink::new(AuditSinkConfig {
245 write_stdout: false,
246 file_path: Some(file_path.clone()),
247 ..Default::default()
248 });
249
250 let mut entry = AuditEntry::new(Uuid::new_v4());
251 entry.agent_id = "test-agent".into();
252 entry.tool_called = "test_tool".into();
253 entry.latency_ms = 10;
254
255 sink.write(&entry).await.unwrap();
256 sink.write(&entry).await.unwrap();
257
258 let contents = tokio::fs::read_to_string(&file_path).await.unwrap();
259 let lines: Vec<&str> = contents.trim().lines().collect();
260 assert_eq!(lines.len(), 2);
261
262 let parsed: AuditEntry = serde_json::from_str(lines[0]).unwrap();
264 assert_eq!(parsed.agent_id, "test-agent");
265
266 let _ = tokio::fs::remove_dir_all(&dir).await;
268 }
269
270 #[tokio::test]
271 async fn tracks_write_failures() {
272 let sink = AuditSink::new(AuditSinkConfig {
274 write_stdout: false,
275 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
276 ..Default::default()
277 });
278
279 assert!(!sink.is_degraded());
280 assert_eq!(sink.consecutive_failures(), 0);
281
282 let mut entry = AuditEntry::new(Uuid::new_v4());
283 entry.tool_called = "test".into();
284
285 assert!(sink.write(&entry).await.is_err());
287 assert!(sink.is_degraded());
288 assert_eq!(sink.consecutive_failures(), 1);
289 assert_eq!(sink.total_failures(), 1);
290
291 assert!(sink.write(&entry).await.is_err());
293 assert_eq!(sink.consecutive_failures(), 2);
294 assert_eq!(sink.total_failures(), 2);
295 }
296
297 #[tokio::test]
298 async fn resets_failures_on_success() {
299 let dir = std::env::temp_dir().join(format!("arbiter-audit-reset-{}", Uuid::new_v4()));
300 let file_path = dir.join("audit.jsonl");
301
302 let sink = AuditSink::new(AuditSinkConfig {
304 write_stdout: false,
305 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
306 ..Default::default()
307 });
308
309 let mut entry = AuditEntry::new(Uuid::new_v4());
310 entry.tool_called = "test".into();
311
312 let _ = sink.write(&entry).await;
314 assert!(sink.is_degraded());
315
316 tokio::fs::create_dir_all(&dir).await.unwrap();
319 let recovered_sink = AuditSink::new(AuditSinkConfig {
320 write_stdout: false,
321 file_path: Some(file_path.clone()),
322 ..Default::default()
323 });
324 recovered_sink.write_failures.store(3, Ordering::Relaxed);
326 assert!(recovered_sink.is_degraded());
327
328 for i in 1..AuditSink::RECOVERY_THRESHOLD {
330 recovered_sink.write(&entry).await.unwrap();
331 assert!(
332 recovered_sink.is_degraded(),
333 "should still be degraded after {i} successful write(s)"
334 );
335 }
336 recovered_sink.write(&entry).await.unwrap();
338 assert!(!recovered_sink.is_degraded());
339 assert_eq!(recovered_sink.consecutive_failures(), 0);
340
341 let _ = tokio::fs::remove_dir_all(&dir).await;
342 }
343
344 #[test]
345 fn serialization_produces_valid_json() {
346 let mut entry = AuditEntry::new(Uuid::new_v4());
347 entry.agent_id = "test-agent".into();
348 entry.tool_called = "dangerous_tool".into();
349 entry.authorization_decision = "deny".into();
350 entry.policy_matched = Some("block-dangerous".into());
351 entry.anomaly_flags = vec!["scope_violation".into(), "unusual_hour".into()];
352 entry.latency_ms = 7;
353 entry.upstream_status = Some(403);
354
355 let json = serde_json::to_string(&entry).unwrap();
356
357 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
359 assert_eq!(parsed.agent_id, "test-agent");
360 assert_eq!(parsed.authorization_decision, "deny");
361 assert_eq!(parsed.anomaly_flags.len(), 2);
362 assert_eq!(parsed.upstream_status, Some(403));
363
364 assert!(!json.contains('\n'), "JSON must be a single line");
366 }
367}