1#![allow(unused_doc_comments)]
2use crate::types::ToolDefinition;
5use async_trait::async_trait;
6use serde_json::Value;
7use std::fmt;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12#[derive(Debug, Clone)]
18pub struct ToolContext {
19 pub workspace_dir: PathBuf,
21 pub root_dir: Option<PathBuf>,
24 pub session_id: Option<String>,
26}
27
28impl ToolContext {
29 pub fn new(workspace_dir: impl Into<PathBuf>) -> Self {
31 Self {
32 workspace_dir: workspace_dir.into(),
33 root_dir: None,
34 session_id: None,
35 }
36 }
37
38 pub fn root(&self) -> &Path {
41 self.root_dir.as_deref().unwrap_or(&self.workspace_dir)
42 }
43
44 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
46 self.session_id = Some(session_id.into());
47 self
48 }
49
50 pub fn with_root(mut self, root_dir: impl Into<PathBuf>) -> Self {
52 self.root_dir = Some(root_dir.into());
53 self
54 }
55}
56
57impl Default for ToolContext {
58 fn default() -> Self {
59 Self {
60 workspace_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
61 root_dir: None,
62 session_id: None,
63 }
64 }
65}
66
67pub type ToolError = String;
69
70#[derive(Debug)]
72pub struct AgentToolResult {
73 pub success: bool,
75 pub output: String,
77 pub metadata: Option<serde_json::Value>,
79 pub content_blocks: Option<Vec<oxi_ai::ContentBlock>>,
83 pub terminate: bool,
87}
88
89impl AgentToolResult {
90 pub fn success(output: impl Into<String>) -> Self {
92 Self {
93 success: true,
94 output: output.into(),
95 metadata: None,
96 content_blocks: None,
97 terminate: false,
98 }
99 }
100
101 pub fn error(output: impl Into<String>) -> Self {
103 Self {
104 success: false,
105 output: output.into(),
106 metadata: None,
107 content_blocks: None,
108 terminate: false,
109 }
110 }
111
112 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
114 self.metadata = Some(metadata);
115 self
116 }
117
118 pub fn with_content_blocks(mut self, blocks: Vec<oxi_ai::ContentBlock>) -> Self {
120 self.content_blocks = Some(blocks);
121 self
122 }
123
124 pub fn with_terminate(mut self) -> Self {
126 self.terminate = true;
127 self
128 }
129}
130
131impl fmt::Display for AgentToolResult {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 write!(f, "{}", self.output)
134 }
135}
136
137pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
139
140#[async_trait]
142pub trait AgentTool: Send + Sync {
143 fn name(&self) -> &str;
145
146 fn label(&self) -> &str;
148
149 fn description(&self) -> &str;
151
152 fn parameters_schema(&self) -> Value;
154
155 fn essential(&self) -> bool {
159 false
160 }
161
162 async fn execute(
194 &self,
195 tool_call_id: &str,
196 params: Value,
197 signal: Option<oneshot::Receiver<()>>,
198 ctx: &ToolContext,
199 ) -> Result<AgentToolResult, ToolError>;
200
201 fn on_progress(&self, _callback: ProgressCallback) {
204 }
206
207 fn to_definition(&self) -> ToolDefinition {
209 ToolDefinition {
210 name: self.name().to_string(),
211 description: self.description().to_string(),
212 input_schema: serde_json::from_value(self.parameters_schema()).unwrap_or_default(),
213 }
214 }
215}
216
217pub mod bash;
220pub mod context7;
222pub mod edit;
224pub mod edit_diff;
226pub mod file_mutation_queue;
228pub mod find;
230pub mod github;
232pub mod github_search;
234pub mod grep;
236pub mod http_client;
238pub mod ls;
240pub mod path_security;
242pub mod path_utils;
244pub mod questionnaire;
246pub mod read;
248pub mod render_utils;
250pub mod search_cache;
252pub mod subagent;
254pub mod tool_definition_wrapper;
256pub mod truncate;
258pub mod web_search;
260pub mod write;
262
263pub use bash::BashTool;
265pub use edit::EditTool;
266pub use find::FindTool;
267pub use grep::GrepTool;
268pub use ls::LsTool;
269pub use read::ReadTool;
270pub use crate::mcp::McpTool;
273pub use context7::{Context7QueryDocsTool, Context7ResolveLibraryIdTool};
274pub use questionnaire::{QuestionnaireBridge, QuestionnaireTool};
275pub use subagent::SubagentTool;
276pub use write::WriteTool;
277
278#[derive(Clone)]
280pub struct ToolRegistry {
281 tools: Arc<parking_lot::RwLock<std::collections::HashMap<String, Arc<dyn AgentTool>>>>,
282}
283
284impl Default for ToolRegistry {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290impl ToolRegistry {
291 pub fn new() -> Self {
293 Self {
294 tools: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
295 }
296 }
297
298 pub fn register(&self, tool: impl AgentTool + 'static) {
300 let name = tool.name().to_string();
301 self.tools.write().insert(name, Arc::new(tool));
302 }
303
304 pub fn register_arc(&self, tool: Arc<dyn AgentTool>) {
307 let name = tool.name().to_string();
308 self.tools.write().insert(name, tool);
309 }
310
311 pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
313 self.tools.read().get(name).cloned()
314 }
315
316 pub fn unregister(&self, name: &str) -> bool {
319 self.tools.write().remove(name).is_some()
320 }
321
322 pub fn names(&self) -> Vec<String> {
324 self.tools.read().keys().cloned().collect()
325 }
326
327 pub fn definitions(&self) -> Vec<ToolDefinition> {
329 self.tools
330 .read()
331 .values()
332 .map(|t| t.to_definition())
333 .collect()
334 }
335
336 pub fn get_tools(&self) -> Vec<Arc<dyn AgentTool>> {
338 self.tools.read().values().cloned().collect()
339 }
340
341 pub fn with_builtins() -> Self {
354 Self::with_builtins_cwd(PathBuf::from("."), &[])
355 }
356
357 pub fn with_builtins_cwd(cwd: PathBuf, disabled_tools: &[String]) -> Self {
362 let registry = Self::new();
363 let disabled: std::collections::HashSet<&str> =
364 disabled_tools.iter().map(|s| s.as_str()).collect();
365
366 let cache_once: std::cell::OnceCell<Arc<search_cache::SearchCache>> =
368 std::cell::OnceCell::new();
369
370 let mcp_once: std::cell::OnceCell<Arc<crate::mcp::McpManager>> = std::cell::OnceCell::new();
372 let mcp_manager = mcp_once
373 .get_or_init(|| Arc::new(crate::mcp::McpManager::new()))
374 .clone();
375
376 let mut all_tools: Vec<Box<dyn AgentTool>> = vec![
378 Box::new(ReadTool::with_cwd(cwd.clone())),
379 Box::new(WriteTool::with_cwd(cwd.clone())),
380 Box::new(EditTool::with_cwd(cwd.clone())),
381 Box::new(BashTool::with_cwd(cwd.clone())),
382 Box::new(GrepTool::with_cwd(cwd.clone())),
383 Box::new(FindTool::with_cwd(cwd.clone())),
384 Box::new(LsTool::with_cwd(cwd.clone())),
385 Box::new(web_search::WebSearchTool::new(
386 cache_once
387 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
388 .clone(),
389 )),
390 Box::new(search_cache::GetSearchResultsTool::new(
391 cache_once
392 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
393 .clone(),
394 )),
395 Box::new(github::GitHubTool::new(
396 cache_once
397 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
398 .clone(),
399 )),
400 Box::new(SubagentTool::with_cwd(cwd)),
401 ];
402
403 all_tools.push(Box::new(crate::mcp::McpTool::new(mcp_manager)));
404 all_tools.push(Box::new(context7::Context7ResolveLibraryIdTool::new()));
405 all_tools.push(Box::new(context7::Context7QueryDocsTool::new()));
406
407 for tool in all_tools {
408 if tool.essential() || !disabled.contains(tool.name()) {
409 if tool.name() == "get_search_results" && disabled.contains("web_search") {
411 continue;
412 }
413 registry.register_arc(Arc::from(tool));
414 }
415 }
416
417 registry
418 }
419
420 pub fn with_selected_tools(cwd: PathBuf, names: &[&str]) -> Self {
422 let full = Self::with_builtins_cwd(cwd, &[]);
423 let registry = Self::new();
424 let set: std::collections::HashSet<&str> = names.iter().copied().collect();
425 for name in full.names() {
426 if set.contains(name.as_str()) {
427 if let Some(tool) = full.get(&name) {
428 registry.register_arc(tool);
429 }
430 }
431 }
432 registry
433 }
434}