1use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::collections::HashMap;
8
9pub mod audit;
11pub mod file_watcher;
12pub mod permission;
13
14pub mod read;
16pub mod write;
17pub mod edit;
18pub mod multiedit;
19
20pub mod bash;
22
23pub mod grep;
25pub mod glob;
26
27pub mod task;
29pub mod todo;
30
31pub mod http;
33pub mod web;
34
35pub use audit::{AuditLogger, AuditLogEntry, OperationType, ExecutionStatus, AuditStatistics, operation_type_from_tool};
37pub use file_watcher::{FileWatcherTool, FileChangeEvent};
38pub use permission::{PermissionManager, PermissionProvider, PermissionRequest, RiskLevel,
39 InteractivePermissionProvider, AutoApprovePermissionProvider,
40 create_permission_request, PermissionResult};
41
42pub use read::ReadTool;
44pub use write::WriteTool;
45pub use edit::EditTool;
46pub use multiedit::MultiEditTool;
47pub use bash::BashTool;
48pub use grep::GrepTool;
49pub use glob::{GlobTool, GlobAdvancedTool};
50pub use task::TaskTool;
51pub use web::{WebFetchTool, WebSearchTool};
52pub use todo::TodoTool;
53
54#[async_trait]
56pub trait Tool: Send + Sync {
57 fn id(&self) -> &str;
59
60 fn description(&self) -> &str;
62
63 fn parameters_schema(&self) -> Value;
65
66 async fn execute(
68 &self,
69 args: Value,
70 ctx: ToolContext,
71 ) -> Result<ToolResult, ToolError>;
72}
73
74#[derive(Debug, Clone)]
76pub struct ToolContext {
77 pub session_id: String,
78 pub message_id: String,
79 pub abort_signal: tokio::sync::watch::Receiver<bool>,
80 pub working_directory: std::path::PathBuf,
81}
82
83#[derive(Debug, Serialize, Deserialize)]
85pub struct ToolResult {
86 pub title: String,
87 pub metadata: Value,
88 pub output: String,
89}
90
91#[derive(Debug, thiserror::Error)]
93pub enum ToolError {
94 #[error("Invalid parameters: {0}")]
95 InvalidParameters(String),
96
97 #[error("Execution failed: {0}")]
98 ExecutionFailed(String),
99
100 #[error("Permission denied: {0}")]
101 PermissionDenied(String),
102
103 #[error("Operation aborted")]
104 Aborted,
105
106 #[error("IO error: {0}")]
107 Io(#[from] std::io::Error),
108
109 #[error("Other error: {0}")]
110 Other(#[from] anyhow::Error),
111}
112
113pub struct ToolRegistry {
115 tools: HashMap<String, Box<dyn Tool>>,
116 audit_logger: Option<AuditLogger>,
117 permission_manager: Option<PermissionManager>,
118}
119
120impl Default for ToolRegistry {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126impl ToolRegistry {
127 pub fn new() -> Self {
128 Self {
129 tools: HashMap::new(),
130 audit_logger: None,
131 permission_manager: None,
132 }
133 }
134
135 pub fn with_defaults() -> Result<Self, ToolError> {
137 let mut registry = Self::new();
138
139 registry.register(Box::new(ReadTool));
141 registry.register(Box::new(WriteTool));
142 registry.register(Box::new(EditTool));
143 registry.register(Box::new(MultiEditTool));
144
145 registry.register(Box::new(GrepTool));
147 registry.register(Box::new(GlobTool));
148 registry.register(Box::new(GlobAdvancedTool));
149
150 registry.register(Box::new(BashTool));
152
153 registry.register(Box::new(FileWatcherTool::new()));
155
156 #[cfg(not(feature = "wasm"))]
158 {
159 if let Ok(web_fetch) = WebFetchTool::new() {
160 registry.register(Box::new(web_fetch));
161 }
162 if let Ok(web_search) = WebSearchTool::new() {
163 registry.register(Box::new(web_search));
164 }
165 }
166
167 registry.register(Box::new(TaskTool::new()));
169 registry.register(Box::new(TodoTool::new()));
170
171 Ok(registry)
172 }
173
174 #[cfg(feature = "wasm")]
176 pub fn with_wasm_tools() -> Result<Self, ToolError> {
177 let mut registry = Self::new();
178
179 registry.register(Box::new(ReadTool));
181 registry.register(Box::new(WriteTool));
182 registry.register(Box::new(EditTool));
183 registry.register(Box::new(MultiEditTool));
184 registry.register(Box::new(GrepTool));
185 registry.register(Box::new(GlobTool));
186
187 if let Ok(web_fetch) = WebFetchTool::new() {
189 registry.register(Box::new(web_fetch));
190 }
191 if let Ok(web_search) = WebSearchTool::new() {
192 registry.register(Box::new(web_search));
193 }
194
195 Ok(registry)
196 }
197
198 pub fn with_audit_logger(mut self, logger: AuditLogger) -> Self {
200 self.audit_logger = Some(logger);
201 self
202 }
203
204 pub fn with_permission_manager(mut self, manager: PermissionManager) -> Self {
206 self.permission_manager = Some(manager);
207 self
208 }
209
210 pub fn register(&mut self, tool: Box<dyn Tool>) {
212 self.tools.insert(tool.id().to_string(), tool);
213 }
214
215 pub fn get(&self, id: &str) -> Option<&Box<dyn Tool>> {
217 self.tools.get(id)
218 }
219
220 pub fn list(&self) -> Vec<&str> {
222 self.tools.keys().map(|s| s.as_str()).collect()
223 }
224
225 pub fn get_definitions(&self) -> Vec<ToolDefinition> {
227 self.tools.values().map(|tool| {
228 ToolDefinition {
229 name: tool.id().to_string(),
230 description: tool.description().to_string(),
231 parameters: tool.parameters_schema(),
232 }
233 }).collect()
234 }
235
236 pub async fn execute_tool(
238 &self,
239 tool_id: &str,
240 args: Value,
241 ctx: ToolContext,
242 ) -> Result<ToolResult, ToolError> {
243 let tool = self.get(tool_id)
244 .ok_or_else(|| ToolError::ExecutionFailed(format!("Tool '{}' not found", tool_id)))?;
245
246 let start_time = std::time::Instant::now();
247
248 let audit_entry_id = if let Some(logger) = &self.audit_logger {
250 let operation_type = operation_type_from_tool(tool_id);
251 let risk_level = self.assess_tool_risk(tool_id, &args);
252
253 Some(logger.log_tool_start(
254 tool_id,
255 operation_type,
256 &ctx,
257 args.clone(),
258 risk_level,
259 ).await?)
260 } else {
261 None
262 };
263
264 let result = tool.execute(args, ctx).await;
266 let execution_time = start_time.elapsed().as_millis() as u64;
267
268 if let Some(logger) = &self.audit_logger {
270 if let Some(entry_id) = audit_entry_id {
271 match &result {
272 Ok(tool_result) => {
273 logger.log_tool_completion(&entry_id, tool_result, execution_time).await?;
274 }
275 Err(error) => {
276 logger.log_tool_failure(&entry_id, error, execution_time).await?;
277 }
278 }
279 }
280 }
281
282 result
283 }
284
285 fn assess_tool_risk(&self, tool_id: &str, _args: &Value) -> Option<RiskLevel> {
287 match tool_id {
288 "read" | "grep" | "glob" => Some(RiskLevel::Low),
289 "write" | "edit" | "multiedit" => Some(RiskLevel::Medium),
290 "bash" => Some(RiskLevel::High),
291 "web_fetch" | "web_search" => Some(RiskLevel::Medium),
292 _ => Some(RiskLevel::Low),
293 }
294 }
295
296 pub async fn get_audit_statistics(&self) -> Option<AuditStatistics> {
298 if let Some(logger) = &self.audit_logger {
299 Some(logger.get_statistics().await)
300 } else {
301 None
302 }
303 }
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
308pub struct ToolDefinition {
309 pub name: String,
310 pub description: String,
311 pub parameters: Value,
312}
313
314#[derive(Debug, Clone)]
316pub struct ToolConfig {
317 pub enable_audit_logging: bool,
318 pub audit_log_path: Option<std::path::PathBuf>,
319 pub permission_provider: PermissionProviderConfig,
320 pub security_mode: SecurityMode,
321}
322
323#[derive(Debug, Clone)]
324pub enum PermissionProviderConfig {
325 AutoApprove,
326 Interactive { auto_approve_low_risk: bool },
327 Disabled,
328}
329
330#[derive(Debug, Clone)]
331pub enum SecurityMode {
332 Strict, Balanced, Permissive, }
336
337impl Default for ToolConfig {
338 fn default() -> Self {
339 Self {
340 enable_audit_logging: true,
341 audit_log_path: None,
342 permission_provider: PermissionProviderConfig::Interactive {
343 auto_approve_low_risk: true
344 },
345 security_mode: SecurityMode::Balanced,
346 }
347 }
348}
349
350pub struct ToolRegistryFactory;
352
353impl ToolRegistryFactory {
354 pub fn create_with_config(config: ToolConfig) -> Result<ToolRegistry, ToolError> {
356 let mut registry = ToolRegistry::with_defaults()?;
357
358 if config.enable_audit_logging {
360 let logger = if let Some(log_path) = config.audit_log_path {
361 AuditLogger::with_file(log_path)
362 } else {
363 AuditLogger::new()
364 };
365 registry = registry.with_audit_logger(logger);
366 }
367
368 let permission_manager = match config.permission_provider {
370 PermissionProviderConfig::AutoApprove => {
371 PermissionManager::new(Box::new(AutoApprovePermissionProvider))
372 }
373 PermissionProviderConfig::Interactive { auto_approve_low_risk } => {
374 PermissionManager::new(Box::new(
375 InteractivePermissionProvider::new(auto_approve_low_risk)
376 ))
377 }
378 PermissionProviderConfig::Disabled => {
379 PermissionManager::new(Box::new(AutoApprovePermissionProvider))
380 }
381 };
382
383 registry = registry.with_permission_manager(permission_manager);
384
385 Ok(registry)
386 }
387
388 pub fn create_for_development() -> Result<ToolRegistry, ToolError> {
390 let config = ToolConfig {
391 enable_audit_logging: true,
392 audit_log_path: None,
393 permission_provider: PermissionProviderConfig::AutoApprove,
394 security_mode: SecurityMode::Permissive,
395 };
396
397 Self::create_with_config(config)
398 }
399
400 pub fn create_for_production(audit_log_path: std::path::PathBuf) -> Result<ToolRegistry, ToolError> {
402 let config = ToolConfig {
403 enable_audit_logging: true,
404 audit_log_path: Some(audit_log_path),
405 permission_provider: PermissionProviderConfig::Interactive {
406 auto_approve_low_risk: false
407 },
408 security_mode: SecurityMode::Strict,
409 };
410
411 Self::create_with_config(config)
412 }
413
414 #[cfg(feature = "wasm")]
416 pub fn create_for_wasm() -> Result<ToolRegistry, ToolError> {
417 let mut registry = ToolRegistry::with_wasm_tools()?;
418
419 let logger = AuditLogger::new();
421 let permission_manager = PermissionManager::new(Box::new(AutoApprovePermissionProvider));
422
423 registry = registry
424 .with_audit_logger(logger)
425 .with_permission_manager(permission_manager);
426
427 Ok(registry)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use std::path::PathBuf;
435
436 #[tokio::test]
437 async fn test_tool_registry_creation() {
438 let registry = ToolRegistry::with_defaults().unwrap();
439 let tools = registry.list();
440
441 assert!(tools.contains(&"read"));
443 assert!(tools.contains(&"write"));
444 assert!(tools.contains(&"edit"));
445 assert!(tools.contains(&"bash"));
446 assert!(tools.contains(&"grep"));
447 assert!(tools.contains(&"glob"));
448 }
449
450 #[tokio::test]
451 async fn test_tool_execution() {
452 let registry = ToolRegistry::with_defaults().unwrap();
453
454 let ctx = ToolContext {
455 session_id: "test".to_string(),
456 message_id: "test".to_string(),
457 abort_signal: tokio::sync::watch::channel(false).1,
458 working_directory: std::env::current_dir().unwrap(),
459 };
460
461 let args = serde_json::json!({
463 "filePath": std::env::current_dir().unwrap().join("Cargo.toml").to_string_lossy()
464 });
465
466 let result = registry.execute_tool("read", args, ctx).await;
467 assert!(result.is_ok());
468 }
469
470 #[test]
471 fn test_factory_configurations() {
472 let dev_registry = ToolRegistryFactory::create_for_development();
474 assert!(dev_registry.is_ok());
475
476 let temp_path = std::env::temp_dir().join("test_audit.log");
478 let prod_registry = ToolRegistryFactory::create_for_production(temp_path);
479 assert!(prod_registry.is_ok());
480 }
481
482 #[test]
483 fn test_tool_definitions() {
484 let registry = ToolRegistry::with_defaults().unwrap();
485 let definitions = registry.get_definitions();
486
487 assert!(!definitions.is_empty());
488
489 for def in definitions {
491 assert!(!def.name.is_empty());
492 assert!(!def.description.is_empty());
493 assert!(def.parameters.is_object());
494 }
495 }
496}