1use std::path::Path;
5
6use crate::config::AuditConfig;
7
8#[derive(Debug)]
9pub struct AuditLogger {
10 destination: AuditDestination,
11}
12
13#[derive(Debug)]
14enum AuditDestination {
15 Stdout,
16 File(tokio::sync::Mutex<tokio::fs::File>),
17}
18
19#[derive(serde::Serialize)]
20pub struct AuditEntry {
21 pub timestamp: String,
22 pub tool: String,
23 pub command: String,
24 pub result: AuditResult,
25 pub duration_ms: u64,
26 #[serde(skip_serializing_if = "Option::is_none")]
28 pub error_category: Option<String>,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub error_domain: Option<String>,
32 #[serde(skip_serializing_if = "Option::is_none")]
35 pub error_phase: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
38 pub claim_source: Option<crate::executor::ClaimSource>,
39 #[serde(skip_serializing_if = "Option::is_none")]
41 pub mcp_server_id: Option<String>,
42 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
44 pub injection_flagged: bool,
45 #[serde(default, skip_serializing_if = "std::ops::Not::not")]
48 pub embedding_anomalous: bool,
49}
50
51#[derive(serde::Serialize)]
52#[serde(tag = "type")]
53pub enum AuditResult {
54 #[serde(rename = "success")]
55 Success,
56 #[serde(rename = "blocked")]
57 Blocked { reason: String },
58 #[serde(rename = "error")]
59 Error { message: String },
60 #[serde(rename = "timeout")]
61 Timeout,
62}
63
64impl AuditLogger {
65 pub async fn from_config(config: &AuditConfig) -> Result<Self, std::io::Error> {
71 let destination = if config.destination == "stdout" {
72 AuditDestination::Stdout
73 } else {
74 let file = tokio::fs::OpenOptions::new()
75 .create(true)
76 .append(true)
77 .open(Path::new(&config.destination))
78 .await?;
79 AuditDestination::File(tokio::sync::Mutex::new(file))
80 };
81
82 Ok(Self { destination })
83 }
84
85 pub async fn log(&self, entry: &AuditEntry) {
86 let Ok(json) = serde_json::to_string(entry) else {
87 return;
88 };
89
90 match &self.destination {
91 AuditDestination::Stdout => {
92 tracing::info!(target: "audit", "{json}");
93 }
94 AuditDestination::File(file) => {
95 use tokio::io::AsyncWriteExt;
96 let mut f = file.lock().await;
97 let line = format!("{json}\n");
98 if let Err(e) = f.write_all(line.as_bytes()).await {
99 tracing::error!("failed to write audit log: {e}");
100 } else if let Err(e) = f.flush().await {
101 tracing::error!("failed to flush audit log: {e}");
102 }
103 }
104 }
105 }
106}
107
108#[must_use]
109pub fn chrono_now() -> String {
110 use std::time::{SystemTime, UNIX_EPOCH};
111 let secs = SystemTime::now()
112 .duration_since(UNIX_EPOCH)
113 .unwrap_or_default()
114 .as_secs();
115 format!("{secs}")
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn audit_entry_serialization() {
124 let entry = AuditEntry {
125 timestamp: "1234567890".into(),
126 tool: "shell".into(),
127 command: "echo hello".into(),
128 result: AuditResult::Success,
129 duration_ms: 42,
130 error_category: None,
131 error_domain: None,
132 error_phase: None,
133 claim_source: None,
134 mcp_server_id: None,
135 injection_flagged: false,
136 embedding_anomalous: false,
137 };
138 let json = serde_json::to_string(&entry).unwrap();
139 assert!(json.contains("\"type\":\"success\""));
140 assert!(json.contains("\"tool\":\"shell\""));
141 assert!(json.contains("\"duration_ms\":42"));
142 }
143
144 #[test]
145 fn audit_result_blocked_serialization() {
146 let entry = AuditEntry {
147 timestamp: "0".into(),
148 tool: "shell".into(),
149 command: "sudo rm".into(),
150 result: AuditResult::Blocked {
151 reason: "blocked command: sudo".into(),
152 },
153 duration_ms: 0,
154 error_category: Some("policy_blocked".to_owned()),
155 error_domain: Some("action".to_owned()),
156 error_phase: None,
157 claim_source: None,
158 mcp_server_id: None,
159 injection_flagged: false,
160 embedding_anomalous: false,
161 };
162 let json = serde_json::to_string(&entry).unwrap();
163 assert!(json.contains("\"type\":\"blocked\""));
164 assert!(json.contains("\"reason\""));
165 }
166
167 #[test]
168 fn audit_result_error_serialization() {
169 let entry = AuditEntry {
170 timestamp: "0".into(),
171 tool: "shell".into(),
172 command: "bad".into(),
173 result: AuditResult::Error {
174 message: "exec failed".into(),
175 },
176 duration_ms: 0,
177 error_category: None,
178 error_domain: None,
179 error_phase: None,
180 claim_source: None,
181 mcp_server_id: None,
182 injection_flagged: false,
183 embedding_anomalous: false,
184 };
185 let json = serde_json::to_string(&entry).unwrap();
186 assert!(json.contains("\"type\":\"error\""));
187 }
188
189 #[test]
190 fn audit_result_timeout_serialization() {
191 let entry = AuditEntry {
192 timestamp: "0".into(),
193 tool: "shell".into(),
194 command: "sleep 999".into(),
195 result: AuditResult::Timeout,
196 duration_ms: 30000,
197 error_category: Some("timeout".to_owned()),
198 error_domain: Some("system".to_owned()),
199 error_phase: None,
200 claim_source: None,
201 mcp_server_id: None,
202 injection_flagged: false,
203 embedding_anomalous: false,
204 };
205 let json = serde_json::to_string(&entry).unwrap();
206 assert!(json.contains("\"type\":\"timeout\""));
207 }
208
209 #[tokio::test]
210 async fn audit_logger_stdout() {
211 let config = AuditConfig {
212 enabled: true,
213 destination: "stdout".into(),
214 };
215 let logger = AuditLogger::from_config(&config).await.unwrap();
216 let entry = AuditEntry {
217 timestamp: "0".into(),
218 tool: "shell".into(),
219 command: "echo test".into(),
220 result: AuditResult::Success,
221 duration_ms: 1,
222 error_category: None,
223 error_domain: None,
224 error_phase: None,
225 claim_source: None,
226 mcp_server_id: None,
227 injection_flagged: false,
228 embedding_anomalous: false,
229 };
230 logger.log(&entry).await;
231 }
232
233 #[tokio::test]
234 async fn audit_logger_file() {
235 let dir = tempfile::tempdir().unwrap();
236 let path = dir.path().join("audit.log");
237 let config = AuditConfig {
238 enabled: true,
239 destination: path.display().to_string(),
240 };
241 let logger = AuditLogger::from_config(&config).await.unwrap();
242 let entry = AuditEntry {
243 timestamp: "0".into(),
244 tool: "shell".into(),
245 command: "echo test".into(),
246 result: AuditResult::Success,
247 duration_ms: 1,
248 error_category: None,
249 error_domain: None,
250 error_phase: None,
251 claim_source: None,
252 mcp_server_id: None,
253 injection_flagged: false,
254 embedding_anomalous: false,
255 };
256 logger.log(&entry).await;
257
258 let content = tokio::fs::read_to_string(&path).await.unwrap();
259 assert!(content.contains("\"tool\":\"shell\""));
260 }
261
262 #[tokio::test]
263 async fn audit_logger_file_write_error_logged() {
264 let config = AuditConfig {
265 enabled: true,
266 destination: "/nonexistent/dir/audit.log".into(),
267 };
268 let result = AuditLogger::from_config(&config).await;
269 assert!(result.is_err());
270 }
271
272 #[test]
273 fn claim_source_serde_roundtrip() {
274 use crate::executor::ClaimSource;
275 let cases = [
276 (ClaimSource::Shell, "\"shell\""),
277 (ClaimSource::FileSystem, "\"file_system\""),
278 (ClaimSource::WebScrape, "\"web_scrape\""),
279 (ClaimSource::Mcp, "\"mcp\""),
280 (ClaimSource::A2a, "\"a2a\""),
281 (ClaimSource::CodeSearch, "\"code_search\""),
282 (ClaimSource::Diagnostics, "\"diagnostics\""),
283 (ClaimSource::Memory, "\"memory\""),
284 ];
285 for (variant, expected_json) in cases {
286 let serialized = serde_json::to_string(&variant).unwrap();
287 assert_eq!(serialized, expected_json, "serialize {variant:?}");
288 let deserialized: ClaimSource = serde_json::from_str(&serialized).unwrap();
289 assert_eq!(deserialized, variant, "deserialize {variant:?}");
290 }
291 }
292
293 #[test]
294 fn audit_entry_claim_source_none_omitted() {
295 let entry = AuditEntry {
296 timestamp: "0".into(),
297 tool: "shell".into(),
298 command: "echo".into(),
299 result: AuditResult::Success,
300 duration_ms: 1,
301 error_category: None,
302 error_domain: None,
303 error_phase: None,
304 claim_source: None,
305 mcp_server_id: None,
306 injection_flagged: false,
307 embedding_anomalous: false,
308 };
309 let json = serde_json::to_string(&entry).unwrap();
310 assert!(
311 !json.contains("claim_source"),
312 "claim_source must be omitted when None: {json}"
313 );
314 }
315
316 #[test]
317 fn audit_entry_claim_source_some_present() {
318 use crate::executor::ClaimSource;
319 let entry = AuditEntry {
320 timestamp: "0".into(),
321 tool: "shell".into(),
322 command: "echo".into(),
323 result: AuditResult::Success,
324 duration_ms: 1,
325 error_category: None,
326 error_domain: None,
327 error_phase: None,
328 claim_source: Some(ClaimSource::Shell),
329 mcp_server_id: None,
330 injection_flagged: false,
331 embedding_anomalous: false,
332 };
333 let json = serde_json::to_string(&entry).unwrap();
334 assert!(
335 json.contains("\"claim_source\":\"shell\""),
336 "expected claim_source=shell in JSON: {json}"
337 );
338 }
339
340 #[tokio::test]
341 async fn audit_logger_multiple_entries() {
342 let dir = tempfile::tempdir().unwrap();
343 let path = dir.path().join("audit.log");
344 let config = AuditConfig {
345 enabled: true,
346 destination: path.display().to_string(),
347 };
348 let logger = AuditLogger::from_config(&config).await.unwrap();
349
350 for i in 0..5 {
351 let entry = AuditEntry {
352 timestamp: i.to_string(),
353 tool: "shell".into(),
354 command: format!("cmd{i}"),
355 result: AuditResult::Success,
356 duration_ms: i,
357 error_category: None,
358 error_domain: None,
359 error_phase: None,
360 claim_source: None,
361 mcp_server_id: None,
362 injection_flagged: false,
363 embedding_anomalous: false,
364 };
365 logger.log(&entry).await;
366 }
367
368 let content = tokio::fs::read_to_string(&path).await.unwrap();
369 assert_eq!(content.lines().count(), 5);
370 }
371}