1use std::path::PathBuf;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use thiserror::Error;
7use tokio::fs::OpenOptions;
8use tokio::io::AsyncWriteExt;
9
10use crate::entry::AuditEntry;
11
12#[derive(Debug, Error)]
14pub enum SinkError {
15 #[error("JSON serialization failed: {0}")]
16 Serialize(#[from] serde_json::Error),
17
18 #[error("file I/O failed: {0}")]
19 Io(#[from] std::io::Error),
20}
21
22#[derive(Debug, Clone)]
24pub struct AuditSinkConfig {
25 pub write_stdout: bool,
27
28 pub file_path: Option<PathBuf>,
30}
31
32impl Default for AuditSinkConfig {
33 fn default() -> Self {
34 Self {
35 write_stdout: true,
36 file_path: None,
37 }
38 }
39}
40
41pub struct AuditSink {
47 config: AuditSinkConfig,
48 stats: crate::stats::AuditStats,
49 write_failures: AtomicU64,
51 total_write_failures: AtomicU64,
53}
54
55impl AuditSink {
56 pub fn new(config: AuditSinkConfig) -> Self {
58 Self {
59 config,
60 stats: crate::stats::AuditStats::new(),
61 write_failures: AtomicU64::new(0),
62 total_write_failures: AtomicU64::new(0),
63 }
64 }
65
66 pub fn stats(&self) -> &crate::stats::AuditStats {
68 &self.stats
69 }
70
71 pub fn is_degraded(&self) -> bool {
73 self.write_failures.load(Ordering::Relaxed) > 0
74 }
75
76 pub fn consecutive_failures(&self) -> u64 {
78 self.write_failures.load(Ordering::Relaxed)
79 }
80
81 pub fn total_failures(&self) -> u64 {
83 self.total_write_failures.load(Ordering::Relaxed)
84 }
85
86 pub async fn write(&self, entry: &AuditEntry) -> Result<(), SinkError> {
91 self.stats.record(entry).await;
93
94 let json = serde_json::to_string(entry)?;
95
96 if self.config.write_stdout {
97 tracing::info!(target: "arbiter_audit", audit_entry = %json);
99 }
100
101 if let Some(path) = &self.config.file_path {
102 match self.write_to_file(path, &json).await {
103 Ok(()) => {
104 self.write_failures.store(0, Ordering::Relaxed);
105 }
106 Err(e) => {
107 let consecutive = self.write_failures.fetch_add(1, Ordering::Relaxed) + 1;
108 self.total_write_failures.fetch_add(1, Ordering::Relaxed);
109 tracing::error!(
110 error = %e,
111 consecutive_failures = consecutive,
112 "audit file write failed; audit data may be lost"
113 );
114 return Err(e);
115 }
116 }
117 }
118
119 Ok(())
120 }
121
122 async fn write_to_file(&self, path: &PathBuf, json: &str) -> Result<(), SinkError> {
123 let mut file = OpenOptions::new()
124 .create(true)
125 .append(true)
126 .open(path)
127 .await?;
128 file.write_all(json.as_bytes()).await?;
129 file.write_all(b"\n").await?;
130 file.flush().await?;
131 Ok(())
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use uuid::Uuid;
139
140 #[tokio::test]
141 async fn write_to_file() {
142 let dir = std::env::temp_dir().join(format!("arbiter-audit-test-{}", Uuid::new_v4()));
143 let file_path = dir.join("audit.jsonl");
144 tokio::fs::create_dir_all(&dir).await.unwrap();
145
146 let sink = AuditSink::new(AuditSinkConfig {
147 write_stdout: false,
148 file_path: Some(file_path.clone()),
149 ..Default::default()
150 });
151
152 let mut entry = AuditEntry::new(Uuid::new_v4());
153 entry.agent_id = "test-agent".into();
154 entry.tool_called = "test_tool".into();
155 entry.latency_ms = 10;
156
157 sink.write(&entry).await.unwrap();
158 sink.write(&entry).await.unwrap();
159
160 let contents = tokio::fs::read_to_string(&file_path).await.unwrap();
161 let lines: Vec<&str> = contents.trim().lines().collect();
162 assert_eq!(lines.len(), 2);
163
164 let parsed: AuditEntry = serde_json::from_str(lines[0]).unwrap();
166 assert_eq!(parsed.agent_id, "test-agent");
167
168 let _ = tokio::fs::remove_dir_all(&dir).await;
170 }
171
172 #[tokio::test]
173 async fn tracks_write_failures() {
174 let sink = AuditSink::new(AuditSinkConfig {
176 write_stdout: false,
177 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
178 ..Default::default()
179 });
180
181 assert!(!sink.is_degraded());
182 assert_eq!(sink.consecutive_failures(), 0);
183
184 let mut entry = AuditEntry::new(Uuid::new_v4());
185 entry.tool_called = "test".into();
186
187 assert!(sink.write(&entry).await.is_err());
189 assert!(sink.is_degraded());
190 assert_eq!(sink.consecutive_failures(), 1);
191 assert_eq!(sink.total_failures(), 1);
192
193 assert!(sink.write(&entry).await.is_err());
195 assert_eq!(sink.consecutive_failures(), 2);
196 assert_eq!(sink.total_failures(), 2);
197 }
198
199 #[tokio::test]
200 async fn resets_failures_on_success() {
201 let dir = std::env::temp_dir().join(format!("arbiter-audit-reset-{}", Uuid::new_v4()));
202 let file_path = dir.join("audit.jsonl");
203
204 let sink = AuditSink::new(AuditSinkConfig {
206 write_stdout: false,
207 file_path: Some(PathBuf::from("/nonexistent/dir/audit.jsonl")),
208 ..Default::default()
209 });
210
211 let mut entry = AuditEntry::new(Uuid::new_v4());
212 entry.tool_called = "test".into();
213
214 let _ = sink.write(&entry).await;
216 assert!(sink.is_degraded());
217
218 tokio::fs::create_dir_all(&dir).await.unwrap();
221 let recovered_sink = AuditSink::new(AuditSinkConfig {
222 write_stdout: false,
223 file_path: Some(file_path.clone()),
224 ..Default::default()
225 });
226 recovered_sink.write_failures.store(3, Ordering::Relaxed);
228 assert!(recovered_sink.is_degraded());
229
230 recovered_sink.write(&entry).await.unwrap();
232 assert!(!recovered_sink.is_degraded());
233 assert_eq!(recovered_sink.consecutive_failures(), 0);
234
235 let _ = tokio::fs::remove_dir_all(&dir).await;
236 }
237
238 #[test]
239 fn serialization_produces_valid_json() {
240 let mut entry = AuditEntry::new(Uuid::new_v4());
241 entry.agent_id = "test-agent".into();
242 entry.tool_called = "dangerous_tool".into();
243 entry.authorization_decision = "deny".into();
244 entry.policy_matched = Some("block-dangerous".into());
245 entry.anomaly_flags = vec!["scope_violation".into(), "unusual_hour".into()];
246 entry.latency_ms = 7;
247 entry.upstream_status = Some(403);
248
249 let json = serde_json::to_string(&entry).unwrap();
250
251 let parsed: AuditEntry = serde_json::from_str(&json).unwrap();
253 assert_eq!(parsed.agent_id, "test-agent");
254 assert_eq!(parsed.authorization_decision, "deny");
255 assert_eq!(parsed.anomaly_flags.len(), 2);
256 assert_eq!(parsed.upstream_status, Some(403));
257
258 assert!(!json.contains('\n'), "JSON must be a single line");
260 }
261}