1use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::fs::OpenOptions;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15use super::{ToolContext, ToolResult, ToolError};
16use super::permission::RiskLevel;
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct AuditLogEntry {
21 pub entry_id: String,
22 pub timestamp: DateTime<Utc>,
23 pub session_id: String,
24 pub message_id: String,
25 pub tool_id: String,
26 pub operation_type: OperationType,
27 pub status: ExecutionStatus,
28 pub risk_level: Option<RiskLevel>,
29 pub parameters: Value,
30 pub result_metadata: Option<Value>,
31 pub error_details: Option<String>,
32 pub execution_time_ms: Option<u64>,
33 pub user_context: HashMap<String, Value>,
34 pub system_context: SystemContext,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum OperationType {
40 FileRead,
41 FileWrite,
42 FileEdit,
43 FileDelete,
44 CommandExecution,
45 NetworkRequest,
46 SystemQuery,
47 ProcessSpawn,
48 Other(String),
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum ExecutionStatus {
54 Started,
55 Completed,
56 Failed,
57 Aborted,
58 PermissionDenied,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct SystemContext {
64 pub working_directory: PathBuf,
65 pub platform: String,
66 pub hostname: Option<String>,
67 pub process_id: u32,
68 pub environment_hash: Option<String>,
69}
70
71pub struct AuditLogger {
73 log_file_path: Option<PathBuf>,
74 in_memory_logs: Arc<RwLock<Vec<AuditLogEntry>>>,
75 max_memory_entries: usize,
76 enabled: bool,
77}
78
79impl Default for AuditLogger {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl AuditLogger {
86 pub fn new() -> Self {
88 Self {
89 log_file_path: None,
90 in_memory_logs: Arc::new(RwLock::new(Vec::new())),
91 max_memory_entries: 10000,
92 enabled: true,
93 }
94 }
95
96 pub fn with_file(log_file_path: PathBuf) -> Self {
98 Self {
99 log_file_path: Some(log_file_path),
100 in_memory_logs: Arc::new(RwLock::new(Vec::new())),
101 max_memory_entries: 10000,
102 enabled: true,
103 }
104 }
105
106 pub fn set_enabled(&mut self, enabled: bool) {
108 self.enabled = enabled;
109 }
110
111 pub async fn log_tool_start(
113 &self,
114 tool_id: &str,
115 operation_type: OperationType,
116 ctx: &ToolContext,
117 parameters: Value,
118 risk_level: Option<RiskLevel>,
119 ) -> Result<String, ToolError> {
120 if !self.enabled {
121 return Ok(String::new());
122 }
123
124 let entry_id = Uuid::new_v4().to_string();
125 let entry = AuditLogEntry {
126 entry_id: entry_id.clone(),
127 timestamp: Utc::now(),
128 session_id: ctx.session_id.clone(),
129 message_id: ctx.message_id.clone(),
130 tool_id: tool_id.to_string(),
131 operation_type,
132 status: ExecutionStatus::Started,
133 risk_level,
134 parameters,
135 result_metadata: None,
136 error_details: None,
137 execution_time_ms: None,
138 user_context: HashMap::new(),
139 system_context: self.create_system_context(ctx).await,
140 };
141
142 self.write_log_entry(&entry).await?;
143 Ok(entry_id)
144 }
145
146 pub async fn log_tool_completion(
148 &self,
149 entry_id: &str,
150 result: &ToolResult,
151 execution_time_ms: u64,
152 ) -> Result<(), ToolError> {
153 if !self.enabled {
154 return Ok(());
155 }
156
157 self.update_log_entry(
158 entry_id,
159 ExecutionStatus::Completed,
160 Some(result.metadata.clone()),
161 None,
162 Some(execution_time_ms),
163 ).await
164 }
165
166 pub async fn log_tool_failure(
168 &self,
169 entry_id: &str,
170 error: &ToolError,
171 execution_time_ms: u64,
172 ) -> Result<(), ToolError> {
173 if !self.enabled {
174 return Ok(());
175 }
176
177 let status = match error {
178 ToolError::Aborted => ExecutionStatus::Aborted,
179 ToolError::PermissionDenied(_) => ExecutionStatus::PermissionDenied,
180 _ => ExecutionStatus::Failed,
181 };
182
183 self.update_log_entry(
184 entry_id,
185 status,
186 None,
187 Some(error.to_string()),
188 Some(execution_time_ms),
189 ).await
190 }
191
192 pub async fn get_logs(
194 &self,
195 session_id: Option<&str>,
196 tool_id: Option<&str>,
197 start_time: Option<DateTime<Utc>>,
198 end_time: Option<DateTime<Utc>>,
199 limit: Option<usize>,
200 ) -> Vec<AuditLogEntry> {
201 let logs = self.in_memory_logs.read().await;
202
203 logs.iter()
204 .filter(|entry| {
205 if let Some(sid) = session_id {
206 if entry.session_id != sid {
207 return false;
208 }
209 }
210
211 if let Some(tid) = tool_id {
212 if entry.tool_id != tid {
213 return false;
214 }
215 }
216
217 if let Some(start) = start_time {
218 if entry.timestamp < start {
219 return false;
220 }
221 }
222
223 if let Some(end) = end_time {
224 if entry.timestamp > end {
225 return false;
226 }
227 }
228
229 true
230 })
231 .take(limit.unwrap_or(usize::MAX))
232 .cloned()
233 .collect()
234 }
235
236 pub async fn get_statistics(&self) -> AuditStatistics {
238 let logs = self.in_memory_logs.read().await;
239
240 let mut stats = AuditStatistics {
241 total_entries: logs.len(),
242 by_tool: HashMap::new(),
243 by_status: HashMap::new(),
244 by_risk_level: HashMap::new(),
245 average_execution_time_ms: 0.0,
246 total_execution_time_ms: 0,
247 };
248
249 let mut total_time = 0u64;
250 let mut completed_count = 0;
251
252 for entry in logs.iter() {
253 *stats.by_tool.entry(entry.tool_id.clone()).or_insert(0) += 1;
255
256 let status_key = format!("{:?}", entry.status);
258 *stats.by_status.entry(status_key).or_insert(0) += 1;
259
260 if let Some(risk) = &entry.risk_level {
262 let risk_key = format!("{:?}", risk);
263 *stats.by_risk_level.entry(risk_key).or_insert(0) += 1;
264 }
265
266 if let Some(time) = entry.execution_time_ms {
268 total_time += time;
269 completed_count += 1;
270 }
271 }
272
273 stats.total_execution_time_ms = total_time;
274 if completed_count > 0 {
275 stats.average_execution_time_ms = total_time as f64 / completed_count as f64;
276 }
277
278 stats
279 }
280
281 pub async fn cleanup_old_logs(&self, older_than: DateTime<Utc>) -> usize {
283 let mut logs = self.in_memory_logs.write().await;
284 let original_count = logs.len();
285
286 logs.retain(|entry| entry.timestamp >= older_than);
287
288 original_count - logs.len()
289 }
290
291 async fn create_system_context(&self, ctx: &ToolContext) -> SystemContext {
293 SystemContext {
294 working_directory: ctx.working_directory.clone(),
295 platform: std::env::consts::OS.to_string(),
296 hostname: hostname::get().ok().and_then(|h| h.into_string().ok()),
297 process_id: std::process::id(),
298 environment_hash: self.hash_environment(),
299 }
300 }
301
302 fn hash_environment(&self) -> Option<String> {
304 use std::collections::BTreeMap;
305 use sha2::{Sha256, Digest};
306
307 let relevant_vars = ["PATH", "HOME", "USER", "USERNAME", "SHELL"];
308 let mut env_map = BTreeMap::new();
309
310 for var in &relevant_vars {
311 if let Ok(value) = std::env::var(var) {
312 env_map.insert(*var, value);
313 }
314 }
315
316 if env_map.is_empty() {
317 return None;
318 }
319
320 let serialized = serde_json::to_string(&env_map).ok()?;
321 let mut hasher = Sha256::new();
322 hasher.update(serialized.as_bytes());
323 Some(format!("{:x}", hasher.finalize()))
324 }
325
326 async fn write_log_entry(&self, entry: &AuditLogEntry) -> Result<(), ToolError> {
328 {
330 let mut logs = self.in_memory_logs.write().await;
331 logs.push(entry.clone());
332
333 if logs.len() > self.max_memory_entries {
335 logs.remove(0);
336 }
337 }
338
339 if let Some(log_path) = &self.log_file_path {
341 let log_line = serde_json::to_string(entry)
342 .map_err(|e| ToolError::ExecutionFailed(format!("Failed to serialize log entry: {}", e)))?;
343
344 let mut file = OpenOptions::new()
345 .create(true)
346 .append(true)
347 .open(log_path)
348 .await
349 .map_err(|e| ToolError::ExecutionFailed(format!("Failed to open audit log file: {}", e)))?;
350
351 file.write_all(format!("{}\n", log_line).as_bytes())
352 .await
353 .map_err(|e| ToolError::ExecutionFailed(format!("Failed to write to audit log: {}", e)))?;
354
355 file.flush().await
356 .map_err(|e| ToolError::ExecutionFailed(format!("Failed to flush audit log: {}", e)))?;
357 }
358
359 Ok(())
360 }
361
362 async fn update_log_entry(
364 &self,
365 entry_id: &str,
366 status: ExecutionStatus,
367 result_metadata: Option<Value>,
368 error_details: Option<String>,
369 execution_time_ms: Option<u64>,
370 ) -> Result<(), ToolError> {
371 let mut logs = self.in_memory_logs.write().await;
372
373 if let Some(entry) = logs.iter_mut().find(|e| e.entry_id == entry_id) {
374 entry.status = status;
375 entry.result_metadata = result_metadata;
376 entry.error_details = error_details;
377 entry.execution_time_ms = execution_time_ms;
378
379 if self.log_file_path.is_some() {
381 self.write_log_entry(entry).await?;
382 }
383 }
384
385 Ok(())
386 }
387}
388
389#[derive(Debug, Clone, Serialize)]
391pub struct AuditStatistics {
392 pub total_entries: usize,
393 pub by_tool: HashMap<String, usize>,
394 pub by_status: HashMap<String, usize>,
395 pub by_risk_level: HashMap<String, usize>,
396 pub average_execution_time_ms: f64,
397 pub total_execution_time_ms: u64,
398}
399
400pub fn operation_type_from_tool(tool_id: &str) -> OperationType {
402 match tool_id {
403 "read" => OperationType::FileRead,
404 "write" => OperationType::FileWrite,
405 "edit" | "multiedit" => OperationType::FileEdit,
406 "bash" => OperationType::CommandExecution,
407 "web_fetch" | "web_search" => OperationType::NetworkRequest,
408 "grep" | "glob" => OperationType::SystemQuery,
409 _ => OperationType::Other(tool_id.to_string()),
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416 use tempfile::NamedTempFile;
417
418 #[tokio::test]
419 async fn test_audit_logger() {
420 let logger = AuditLogger::new();
421
422 let ctx = ToolContext {
423 session_id: "test_session".to_string(),
424 message_id: "test_message".to_string(),
425 abort_signal: tokio::sync::watch::channel(false).1,
426 working_directory: std::env::current_dir().unwrap(),
427 };
428
429 let entry_id = logger.log_tool_start(
431 "test_tool",
432 OperationType::FileRead,
433 &ctx,
434 serde_json::json!({"test": "value"}),
435 Some(RiskLevel::Low),
436 ).await.unwrap();
437
438 let result = ToolResult {
440 title: "Test".to_string(),
441 metadata: serde_json::json!({"result": "success"}),
442 output: "Test output".to_string(),
443 };
444
445 logger.log_tool_completion(&entry_id, &result, 100).await.unwrap();
446
447 let logs = logger.get_logs(Some("test_session"), None, None, None, None).await;
449 assert_eq!(logs.len(), 1);
450 assert_eq!(logs[0].tool_id, "test_tool");
451 assert!(matches!(logs[0].status, ExecutionStatus::Completed));
452 }
453
454 #[tokio::test]
455 async fn test_audit_statistics() {
456 let logger = AuditLogger::new();
457
458 let ctx = ToolContext {
459 session_id: "test_session".to_string(),
460 message_id: "test_message".to_string(),
461 abort_signal: tokio::sync::watch::channel(false).1,
462 working_directory: std::env::current_dir().unwrap(),
463 };
464
465 for i in 0..3 {
467 let entry_id = logger.log_tool_start(
468 "test_tool",
469 OperationType::FileRead,
470 &ctx,
471 serde_json::json!({"test": i}),
472 Some(RiskLevel::Low),
473 ).await.unwrap();
474
475 let result = ToolResult {
476 title: "Test".to_string(),
477 metadata: serde_json::json!({"result": "success"}),
478 output: "Test output".to_string(),
479 };
480
481 logger.log_tool_completion(&entry_id, &result, 100 + i * 50).await.unwrap();
482 }
483
484 let stats = logger.get_statistics().await;
485 assert_eq!(stats.total_entries, 3);
486 assert_eq!(stats.by_tool.get("test_tool"), Some(&3));
487 assert!(stats.average_execution_time_ms > 0.0);
488 }
489
490 #[tokio::test]
491 async fn test_file_logging() {
492 let temp_file = NamedTempFile::new().unwrap();
493 let log_path = temp_file.path().to_path_buf();
494
495 let logger = AuditLogger::with_file(log_path.clone());
496
497 let ctx = ToolContext {
498 session_id: "test_session".to_string(),
499 message_id: "test_message".to_string(),
500 abort_signal: tokio::sync::watch::channel(false).1,
501 working_directory: std::env::current_dir().unwrap(),
502 };
503
504 logger.log_tool_start(
505 "test_tool",
506 OperationType::FileRead,
507 &ctx,
508 serde_json::json!({"test": "value"}),
509 Some(RiskLevel::Low),
510 ).await.unwrap();
511
512 let content = tokio::fs::read_to_string(&log_path).await.unwrap();
514 assert!(content.contains("test_tool"));
515 assert!(content.contains("test_session"));
516 }
517}