code_mesh_core/tool/
audit.rs

1//! Audit logging system for tool execution
2//! Provides comprehensive logging and monitoring of all tool operations
3
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::fs::OpenOptions;
11use tokio::io::AsyncWriteExt;
12use tokio::sync::RwLock;
13use uuid::Uuid;
14
15use super::{ToolContext, ToolResult, ToolError};
16use super::permission::RiskLevel;
17
18/// Audit log entry for tool execution
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct AuditLogEntry {
21    pub entry_id: String,
22    pub timestamp: DateTime<Utc>,
23    pub session_id: String,
24    pub message_id: String,
25    pub tool_id: String,
26    pub operation_type: OperationType,
27    pub status: ExecutionStatus,
28    pub risk_level: Option<RiskLevel>,
29    pub parameters: Value,
30    pub result_metadata: Option<Value>,
31    pub error_details: Option<String>,
32    pub execution_time_ms: Option<u64>,
33    pub user_context: HashMap<String, Value>,
34    pub system_context: SystemContext,
35}
36
37/// Type of operation performed
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum OperationType {
40    FileRead,
41    FileWrite,
42    FileEdit,
43    FileDelete,
44    CommandExecution,
45    NetworkRequest,
46    SystemQuery,
47    ProcessSpawn,
48    Other(String),
49}
50
51/// Execution status
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum ExecutionStatus {
54    Started,
55    Completed,
56    Failed,
57    Aborted,
58    PermissionDenied,
59}
60
61/// System context information
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct SystemContext {
64    pub working_directory: PathBuf,
65    pub platform: String,
66    pub hostname: Option<String>,
67    pub process_id: u32,
68    pub environment_hash: Option<String>,
69}
70
71/// Audit logger implementation
72pub struct AuditLogger {
73    log_file_path: Option<PathBuf>,
74    in_memory_logs: Arc<RwLock<Vec<AuditLogEntry>>>,
75    max_memory_entries: usize,
76    enabled: bool,
77}
78
79impl Default for AuditLogger {
80    fn default() -> Self {
81        Self::new()
82    }
83}
84
85impl AuditLogger {
86    /// Create a new audit logger
87    pub fn new() -> Self {
88        Self {
89            log_file_path: None,
90            in_memory_logs: Arc::new(RwLock::new(Vec::new())),
91            max_memory_entries: 10000,
92            enabled: true,
93        }
94    }
95    
96    /// Create audit logger with file output
97    pub fn with_file(log_file_path: PathBuf) -> Self {
98        Self {
99            log_file_path: Some(log_file_path),
100            in_memory_logs: Arc::new(RwLock::new(Vec::new())),
101            max_memory_entries: 10000,
102            enabled: true,
103        }
104    }
105    
106    /// Enable or disable audit logging
107    pub fn set_enabled(&mut self, enabled: bool) {
108        self.enabled = enabled;
109    }
110    
111    /// Log the start of a tool execution
112    pub async fn log_tool_start(
113        &self,
114        tool_id: &str,
115        operation_type: OperationType,
116        ctx: &ToolContext,
117        parameters: Value,
118        risk_level: Option<RiskLevel>,
119    ) -> Result<String, ToolError> {
120        if !self.enabled {
121            return Ok(String::new());
122        }
123        
124        let entry_id = Uuid::new_v4().to_string();
125        let entry = AuditLogEntry {
126            entry_id: entry_id.clone(),
127            timestamp: Utc::now(),
128            session_id: ctx.session_id.clone(),
129            message_id: ctx.message_id.clone(),
130            tool_id: tool_id.to_string(),
131            operation_type,
132            status: ExecutionStatus::Started,
133            risk_level,
134            parameters,
135            result_metadata: None,
136            error_details: None,
137            execution_time_ms: None,
138            user_context: HashMap::new(),
139            system_context: self.create_system_context(ctx).await,
140        };
141        
142        self.write_log_entry(&entry).await?;
143        Ok(entry_id)
144    }
145    
146    /// Log the completion of a tool execution
147    pub async fn log_tool_completion(
148        &self,
149        entry_id: &str,
150        result: &ToolResult,
151        execution_time_ms: u64,
152    ) -> Result<(), ToolError> {
153        if !self.enabled {
154            return Ok(());
155        }
156        
157        self.update_log_entry(
158            entry_id,
159            ExecutionStatus::Completed,
160            Some(result.metadata.clone()),
161            None,
162            Some(execution_time_ms),
163        ).await
164    }
165    
166    /// Log a tool execution failure
167    pub async fn log_tool_failure(
168        &self,
169        entry_id: &str,
170        error: &ToolError,
171        execution_time_ms: u64,
172    ) -> Result<(), ToolError> {
173        if !self.enabled {
174            return Ok(());
175        }
176        
177        let status = match error {
178            ToolError::Aborted => ExecutionStatus::Aborted,
179            ToolError::PermissionDenied(_) => ExecutionStatus::PermissionDenied,
180            _ => ExecutionStatus::Failed,
181        };
182        
183        self.update_log_entry(
184            entry_id,
185            status,
186            None,
187            Some(error.to_string()),
188            Some(execution_time_ms),
189        ).await
190    }
191    
192    /// Get audit logs matching criteria
193    pub async fn get_logs(
194        &self,
195        session_id: Option<&str>,
196        tool_id: Option<&str>,
197        start_time: Option<DateTime<Utc>>,
198        end_time: Option<DateTime<Utc>>,
199        limit: Option<usize>,
200    ) -> Vec<AuditLogEntry> {
201        let logs = self.in_memory_logs.read().await;
202        
203        logs.iter()
204            .filter(|entry| {
205                if let Some(sid) = session_id {
206                    if entry.session_id != sid {
207                        return false;
208                    }
209                }
210                
211                if let Some(tid) = tool_id {
212                    if entry.tool_id != tid {
213                        return false;
214                    }
215                }
216                
217                if let Some(start) = start_time {
218                    if entry.timestamp < start {
219                        return false;
220                    }
221                }
222                
223                if let Some(end) = end_time {
224                    if entry.timestamp > end {
225                        return false;
226                    }
227                }
228                
229                true
230            })
231            .take(limit.unwrap_or(usize::MAX))
232            .cloned()
233            .collect()
234    }
235    
236    /// Get audit statistics
237    pub async fn get_statistics(&self) -> AuditStatistics {
238        let logs = self.in_memory_logs.read().await;
239        
240        let mut stats = AuditStatistics {
241            total_entries: logs.len(),
242            by_tool: HashMap::new(),
243            by_status: HashMap::new(),
244            by_risk_level: HashMap::new(),
245            average_execution_time_ms: 0.0,
246            total_execution_time_ms: 0,
247        };
248        
249        let mut total_time = 0u64;
250        let mut completed_count = 0;
251        
252        for entry in logs.iter() {
253            // Count by tool
254            *stats.by_tool.entry(entry.tool_id.clone()).or_insert(0) += 1;
255            
256            // Count by status
257            let status_key = format!("{:?}", entry.status);
258            *stats.by_status.entry(status_key).or_insert(0) += 1;
259            
260            // Count by risk level
261            if let Some(risk) = &entry.risk_level {
262                let risk_key = format!("{:?}", risk);
263                *stats.by_risk_level.entry(risk_key).or_insert(0) += 1;
264            }
265            
266            // Calculate execution time
267            if let Some(time) = entry.execution_time_ms {
268                total_time += time;
269                completed_count += 1;
270            }
271        }
272        
273        stats.total_execution_time_ms = total_time;
274        if completed_count > 0 {
275            stats.average_execution_time_ms = total_time as f64 / completed_count as f64;
276        }
277        
278        stats
279    }
280    
281    /// Clear old audit logs
282    pub async fn cleanup_old_logs(&self, older_than: DateTime<Utc>) -> usize {
283        let mut logs = self.in_memory_logs.write().await;
284        let original_count = logs.len();
285        
286        logs.retain(|entry| entry.timestamp >= older_than);
287        
288        original_count - logs.len()
289    }
290    
291    /// Create system context information
292    async fn create_system_context(&self, ctx: &ToolContext) -> SystemContext {
293        SystemContext {
294            working_directory: ctx.working_directory.clone(),
295            platform: std::env::consts::OS.to_string(),
296            hostname: hostname::get().ok().and_then(|h| h.into_string().ok()),
297            process_id: std::process::id(),
298            environment_hash: self.hash_environment(),
299        }
300    }
301    
302    /// Create a hash of relevant environment variables
303    fn hash_environment(&self) -> Option<String> {
304        use std::collections::BTreeMap;
305        use sha2::{Sha256, Digest};
306        
307        let relevant_vars = ["PATH", "HOME", "USER", "USERNAME", "SHELL"];
308        let mut env_map = BTreeMap::new();
309        
310        for var in &relevant_vars {
311            if let Ok(value) = std::env::var(var) {
312                env_map.insert(*var, value);
313            }
314        }
315        
316        if env_map.is_empty() {
317            return None;
318        }
319        
320        let serialized = serde_json::to_string(&env_map).ok()?;
321        let mut hasher = Sha256::new();
322        hasher.update(serialized.as_bytes());
323        Some(format!("{:x}", hasher.finalize()))
324    }
325    
326    /// Write a log entry to storage
327    async fn write_log_entry(&self, entry: &AuditLogEntry) -> Result<(), ToolError> {
328        // Add to in-memory storage
329        {
330            let mut logs = self.in_memory_logs.write().await;
331            logs.push(entry.clone());
332            
333            // Trim if over limit
334            if logs.len() > self.max_memory_entries {
335                logs.remove(0);
336            }
337        }
338        
339        // Write to file if configured
340        if let Some(log_path) = &self.log_file_path {
341            let log_line = serde_json::to_string(entry)
342                .map_err(|e| ToolError::ExecutionFailed(format!("Failed to serialize log entry: {}", e)))?;
343            
344            let mut file = OpenOptions::new()
345                .create(true)
346                .append(true)
347                .open(log_path)
348                .await
349                .map_err(|e| ToolError::ExecutionFailed(format!("Failed to open audit log file: {}", e)))?;
350            
351            file.write_all(format!("{}\n", log_line).as_bytes())
352                .await
353                .map_err(|e| ToolError::ExecutionFailed(format!("Failed to write to audit log: {}", e)))?;
354            
355            file.flush().await
356                .map_err(|e| ToolError::ExecutionFailed(format!("Failed to flush audit log: {}", e)))?;
357        }
358        
359        Ok(())
360    }
361    
362    /// Update an existing log entry
363    async fn update_log_entry(
364        &self,
365        entry_id: &str,
366        status: ExecutionStatus,
367        result_metadata: Option<Value>,
368        error_details: Option<String>,
369        execution_time_ms: Option<u64>,
370    ) -> Result<(), ToolError> {
371        let mut logs = self.in_memory_logs.write().await;
372        
373        if let Some(entry) = logs.iter_mut().find(|e| e.entry_id == entry_id) {
374            entry.status = status;
375            entry.result_metadata = result_metadata;
376            entry.error_details = error_details;
377            entry.execution_time_ms = execution_time_ms;
378            
379            // Write updated entry to file if configured
380            if self.log_file_path.is_some() {
381                self.write_log_entry(entry).await?;
382            }
383        }
384        
385        Ok(())
386    }
387}
388
389/// Audit statistics summary
390#[derive(Debug, Clone, Serialize)]
391pub struct AuditStatistics {
392    pub total_entries: usize,
393    pub by_tool: HashMap<String, usize>,
394    pub by_status: HashMap<String, usize>,
395    pub by_risk_level: HashMap<String, usize>,
396    pub average_execution_time_ms: f64,
397    pub total_execution_time_ms: u64,
398}
399
400/// Helper function to determine operation type from tool ID
401pub fn operation_type_from_tool(tool_id: &str) -> OperationType {
402    match tool_id {
403        "read" => OperationType::FileRead,
404        "write" => OperationType::FileWrite,
405        "edit" | "multiedit" => OperationType::FileEdit,
406        "bash" => OperationType::CommandExecution,
407        "web_fetch" | "web_search" => OperationType::NetworkRequest,
408        "grep" | "glob" => OperationType::SystemQuery,
409        _ => OperationType::Other(tool_id.to_string()),
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use tempfile::NamedTempFile;
417    
418    #[tokio::test]
419    async fn test_audit_logger() {
420        let logger = AuditLogger::new();
421        
422        let ctx = ToolContext {
423            session_id: "test_session".to_string(),
424            message_id: "test_message".to_string(),
425            abort_signal: tokio::sync::watch::channel(false).1,
426            working_directory: std::env::current_dir().unwrap(),
427        };
428        
429        // Log tool start
430        let entry_id = logger.log_tool_start(
431            "test_tool",
432            OperationType::FileRead,
433            &ctx,
434            serde_json::json!({"test": "value"}),
435            Some(RiskLevel::Low),
436        ).await.unwrap();
437        
438        // Log completion
439        let result = ToolResult {
440            title: "Test".to_string(),
441            metadata: serde_json::json!({"result": "success"}),
442            output: "Test output".to_string(),
443        };
444        
445        logger.log_tool_completion(&entry_id, &result, 100).await.unwrap();
446        
447        // Check logs
448        let logs = logger.get_logs(Some("test_session"), None, None, None, None).await;
449        assert_eq!(logs.len(), 1);
450        assert_eq!(logs[0].tool_id, "test_tool");
451        assert!(matches!(logs[0].status, ExecutionStatus::Completed));
452    }
453    
454    #[tokio::test]
455    async fn test_audit_statistics() {
456        let logger = AuditLogger::new();
457        
458        let ctx = ToolContext {
459            session_id: "test_session".to_string(),
460            message_id: "test_message".to_string(),
461            abort_signal: tokio::sync::watch::channel(false).1,
462            working_directory: std::env::current_dir().unwrap(),
463        };
464        
465        // Create multiple log entries
466        for i in 0..3 {
467            let entry_id = logger.log_tool_start(
468                "test_tool",
469                OperationType::FileRead,
470                &ctx,
471                serde_json::json!({"test": i}),
472                Some(RiskLevel::Low),
473            ).await.unwrap();
474            
475            let result = ToolResult {
476                title: "Test".to_string(),
477                metadata: serde_json::json!({"result": "success"}),
478                output: "Test output".to_string(),
479            };
480            
481            logger.log_tool_completion(&entry_id, &result, 100 + i * 50).await.unwrap();
482        }
483        
484        let stats = logger.get_statistics().await;
485        assert_eq!(stats.total_entries, 3);
486        assert_eq!(stats.by_tool.get("test_tool"), Some(&3));
487        assert!(stats.average_execution_time_ms > 0.0);
488    }
489    
490    #[tokio::test]
491    async fn test_file_logging() {
492        let temp_file = NamedTempFile::new().unwrap();
493        let log_path = temp_file.path().to_path_buf();
494        
495        let logger = AuditLogger::with_file(log_path.clone());
496        
497        let ctx = ToolContext {
498            session_id: "test_session".to_string(),
499            message_id: "test_message".to_string(),
500            abort_signal: tokio::sync::watch::channel(false).1,
501            working_directory: std::env::current_dir().unwrap(),
502        };
503        
504        logger.log_tool_start(
505            "test_tool",
506            OperationType::FileRead,
507            &ctx,
508            serde_json::json!({"test": "value"}),
509            Some(RiskLevel::Low),
510        ).await.unwrap();
511        
512        // Check that file was written
513        let content = tokio::fs::read_to_string(&log_path).await.unwrap();
514        assert!(content.contains("test_tool"));
515        assert!(content.contains("test_session"));
516    }
517}