oxi_agent/tools.rs
1#![allow(unused_doc_comments)]
2/// Agent tools system
3/// This module provides the tool abstraction layer and built-in tools.
4use crate::types::ToolDefinition;
5use async_trait::async_trait;
6use serde_json::Value;
7use std::fmt;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// Context passed to tools at execution time.
13///
14/// This allows tools to operate on a specific workspace without being
15/// rebuilt. When `root_dir` is `Some`, tools use it as their base directory.
16/// When `None`, tools should fall back to `workspace_dir`.
17#[derive(Debug, Clone)]
18pub struct ToolContext {
19 /// Primary workspace directory (used when root_dir is None).
20 pub workspace_dir: PathBuf,
21 /// Optional explicit root directory for file tools.
22 /// Takes priority over workspace_dir if present.
23 pub root_dir: Option<PathBuf>,
24 /// Session identifier for logging/tracing.
25 pub session_id: Option<String>,
26}
27
28impl ToolContext {
29 /// Create a new context with the given workspace.
30 pub fn new(workspace_dir: impl Into<PathBuf>) -> Self {
31 Self {
32 workspace_dir: workspace_dir.into(),
33 root_dir: None,
34 session_id: None,
35 }
36 }
37
38 /// Get the effective root directory.
39 /// Returns root_dir if set, otherwise workspace_dir.
40 pub fn root(&self) -> &Path {
41 self.root_dir.as_deref().unwrap_or(&self.workspace_dir)
42 }
43
44 /// Set a session ID.
45 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
46 self.session_id = Some(session_id.into());
47 self
48 }
49
50 /// Set an explicit root directory.
51 pub fn with_root(mut self, root_dir: impl Into<PathBuf>) -> Self {
52 self.root_dir = Some(root_dir.into());
53 self
54 }
55}
56
57impl Default for ToolContext {
58 fn default() -> Self {
59 Self {
60 workspace_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
61 root_dir: None,
62 session_id: None,
63 }
64 }
65}
66
67/// Result type for tool execution
68pub type ToolError = String;
69
70/// Result of tool execution
71#[derive(Debug)]
72pub struct AgentToolResult {
73 /// pub.
74 pub success: bool,
75 /// pub.
76 pub output: String,
77 /// pub.
78 pub metadata: Option<serde_json::Value>,
79 /// Optional content blocks (e.g., image blocks) to include in the tool result message.
80 /// When present, these are used as the content of the ToolResultMessage instead of
81 /// wrapping `output` in a Text block.
82 pub content_blocks: Option<Vec<oxi_ai::ContentBlock>>,
83 /// When `true`, signals that the agent loop should terminate after this batch
84 /// of tool calls completes. Defaults to `false` so that the loop continues
85 /// unless a tool explicitly opts-in to termination.
86 pub terminate: bool,
87}
88
89impl AgentToolResult {
90 /// TODO.
91 pub fn success(output: impl Into<String>) -> Self {
92 Self {
93 success: true,
94 output: output.into(),
95 metadata: None,
96 content_blocks: None,
97 terminate: false,
98 }
99 }
100
101 /// TODO.
102 pub fn error(output: impl Into<String>) -> Self {
103 Self {
104 success: false,
105 output: output.into(),
106 metadata: None,
107 content_blocks: None,
108 terminate: false,
109 }
110 }
111
112 /// TODO: document this function.
113 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
114 self.metadata = Some(metadata);
115 self
116 }
117
118 /// TODO: document this function.
119 pub fn with_content_blocks(mut self, blocks: Vec<oxi_ai::ContentBlock>) -> Self {
120 self.content_blocks = Some(blocks);
121 self
122 }
123
124 /// Mark this result as requesting agent-loop termination.
125 pub fn with_terminate(mut self) -> Self {
126 self.terminate = true;
127 self
128 }
129}
130
131impl fmt::Display for AgentToolResult {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 write!(f, "{}", self.output)
134 }
135}
136
137/// Callback type for progress updates
138pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
139
140/// Core trait for all agent tools
141#[async_trait]
142pub trait AgentTool: Send + Sync {
143 /// Tool name (used in function calls)
144 fn name(&self) -> &str;
145
146 /// Human-readable label
147 fn label(&self) -> &str;
148
149 /// Description for the model
150 fn description(&self) -> &str;
151
152 /// JSON Schema for parameters
153 fn parameters_schema(&self) -> Value;
154
155 /// Whether this tool is essential (cannot be disabled).
156 /// Essential tools: read, write, edit, bash, grep, find, ls
157 /// Optional tools: web_search, github, subagent, etc.
158 fn essential(&self) -> bool {
159 false
160 }
161
162 /// Execute the tool with the given tool call ID and parameters.
163 ///
164 /// The `ctx` parameter provides workspace information. File tools should
165 /// use `ctx.root()` to get the effective directory. Custom tools can use
166 /// `ctx.workspace_dir` for workspace-relative operations.
167 ///
168 /// # Examples
169 ///
170 /// ```ignore
171 /// use oxi_agent::{AgentTool, AgentToolResult, ToolContext};
172 /// use serde_json::json;
173 /// use async_trait::async_trait;
174 ///
175 /// struct MyTool;
176 ///
177 /// #[async_trait]
178 /// impl AgentTool for MyTool {
179 /// fn name(&self) -> &str { "my_tool" }
180 /// fn label(&self) -> &str { "My Tool" }
181 /// fn description(&self) -> &str { "A custom tool" }
182 /// fn parameters_schema(&self) -> Value { json!({
183 /// "type": "object",
184 /// "properties": {}
185 /// }) }
186 ///
187 /// async fn execute(&self, tool_call_id: &str, params: Value, _signal: Option<oneshot::Receiver<()>>, ctx: &ToolContext) -> Result<AgentToolResult, String> {
188 /// println!("Tool '{}' called with params: {:?}, workspace: {:?}", tool_call_id, params, ctx.workspace_dir);
189 /// Ok(AgentToolResult::success("Done!"))
190 /// }
191 /// }
192 /// ```
193 async fn execute(
194 &self,
195 tool_call_id: &str,
196 params: Value,
197 signal: Option<oneshot::Receiver<()>>,
198 ctx: &ToolContext,
199 ) -> Result<AgentToolResult, ToolError>;
200
201 /// Called with progress updates during execution.
202 /// Tools can override this to emit streaming updates.
203 fn on_progress(&self, _callback: ProgressCallback) {
204 // Default no-op
205 }
206
207 /// Convert to ToolDefinition
208 fn to_definition(&self) -> ToolDefinition {
209 ToolDefinition {
210 name: self.name().to_string(),
211 description: self.description().to_string(),
212 input_schema: serde_json::from_value(self.parameters_schema()).unwrap_or_default(),
213 }
214 }
215}
216
217// Built-in tools
218/// Bash shell execution tool.
219pub mod bash;
220/// Context7 documentation tools.
221pub mod context7;
222/// In-place file edit tool.
223pub mod edit;
224/// Diff-based edit helpers.
225pub mod edit_diff;
226/// Serialised file-mutation queue.
227pub mod file_mutation_queue;
228/// File-system find tool.
229pub mod find;
230/// GitHub integration tool (gh CLI-based).
231pub mod github;
232/// GitHub repository search tool (legacy REST API).
233pub mod github_search;
234/// Content search (grep) tool.
235pub mod grep;
236/// Shared HTTP client singleton.
237pub mod http_client;
238/// Directory listing tool.
239pub mod ls;
240/// Path security (traversal protection).
241pub mod path_security;
242/// Path manipulation utilities.
243pub mod path_utils;
244/// Questionnaire tool — interactive multi-question TUI overlay.
245pub mod questionnaire;
246/// File reading tool.
247pub mod read;
248/// Rendering utilities for tool output.
249pub mod render_utils;
250/// Search result cache and get_search_results tool.
251pub mod search_cache;
252/// Sub-agent delegation tool.
253pub mod subagent;
254/// Tool definition wrapper helpers.
255pub mod tool_definition_wrapper;
256/// Output truncation helpers.
257pub mod truncate;
258/// Multi-engine web search tool (a3s-search library + DuckDuckGo fallback).
259pub mod web_search;
260/// File writing tool.
261pub mod write;
262
263// Re-export for convenience
264pub use bash::BashTool;
265pub use edit::EditTool;
266pub use find::FindTool;
267pub use grep::GrepTool;
268pub use ls::LsTool;
269pub use read::ReadTool;
270// pub use search_cache;
271
272pub use crate::mcp::McpTool;
273pub use context7::{Context7QueryDocsTool, Context7ResolveLibraryIdTool};
274pub use questionnaire::{QuestionnaireBridge, QuestionnaireTool};
275pub use subagent::SubagentTool;
276pub use write::WriteTool;
277
278/// Tool registry for managing available tools
279#[derive(Clone)]
280pub struct ToolRegistry {
281 tools: Arc<parking_lot::RwLock<std::collections::HashMap<String, Arc<dyn AgentTool>>>>,
282}
283
284impl Default for ToolRegistry {
285 fn default() -> Self {
286 Self::new()
287 }
288}
289
290impl ToolRegistry {
291 /// TODO.
292 pub fn new() -> Self {
293 Self {
294 tools: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
295 }
296 }
297
298 /// Register a tool
299 pub fn register(&self, tool: impl AgentTool + 'static) {
300 let name = tool.name().to_string();
301 self.tools.write().insert(name, Arc::new(tool));
302 }
303
304 /// Register a tool that is already wrapped in an `Arc`.
305 /// This is the primary path for extensions that produce `Arc<dyn AgentTool>`.
306 pub fn register_arc(&self, tool: Arc<dyn AgentTool>) {
307 let name = tool.name().to_string();
308 self.tools.write().insert(name, tool);
309 }
310
311 /// Get a tool by name
312 pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
313 self.tools.read().get(name).cloned()
314 }
315
316 /// Unregister a tool by name.
317 /// Returns `true` if the tool was present and removed.
318 pub fn unregister(&self, name: &str) -> bool {
319 self.tools.write().remove(name).is_some()
320 }
321
322 /// List all registered tool names
323 pub fn names(&self) -> Vec<String> {
324 self.tools.read().keys().cloned().collect()
325 }
326
327 /// Get all tool definitions
328 pub fn definitions(&self) -> Vec<ToolDefinition> {
329 self.tools
330 .read()
331 .values()
332 .map(|t| t.to_definition())
333 .collect()
334 }
335
336 /// Get all tools as a slice
337 pub fn get_tools(&self) -> Vec<Arc<dyn AgentTool>> {
338 self.tools.read().values().cloned().collect()
339 }
340
341 /// Check whether all tools in `required` are registered.
342 ///
343 /// Useful for validating program/module dependencies before execution.
344 ///
345 /// # Example
346 ///
347 /// ```
348 /// use oxi_agent::ToolRegistry;
349 /// let registry = ToolRegistry::new();
350 /// assert!(!registry.has_all(&["read", "write"]));
351 /// ```
352 pub fn has_all(&self, required: &[&str]) -> bool {
353 let tools = self.tools.read();
354 required.iter().all(|name| tools.contains_key(*name))
355 }
356
357 /// Return the subset of `required` tool names that are **not** registered.
358 ///
359 /// # Example
360 ///
361 /// ```
362 /// use oxi_agent::ToolRegistry;
363 /// let registry = ToolRegistry::new();
364 /// let missing = registry.missing(&["read", "exec", "nonexistent"]);
365 /// assert_eq!(missing, vec!["read", "exec", "nonexistent"]);
366 /// ```
367 pub fn missing<'a>(&self, required: &[&'a str]) -> Vec<&'a str> {
368 let tools = self.tools.read();
369 required
370 .iter()
371 .filter(|name| !tools.contains_key(**name))
372 .copied()
373 .collect()
374 }
375
376 /// Create a registry with all built-in tools
377 ///
378 /// # Examples
379 ///
380 /// ```
381 /// use oxi_agent::ToolRegistry;
382 /// let registry = ToolRegistry::with_builtins();
383 /// let tools = registry.names();
384 /// assert!(tools.contains(&"read".to_string()));
385 /// assert!(tools.contains(&"write".to_string()));
386 /// assert!(tools.contains(&"bash".to_string()));
387 /// ```
388 pub fn with_builtins() -> Self {
389 Self::with_builtins_cwd(PathBuf::from("."), &[])
390 }
391
392 /// Create a registry with all built-in tools, using the given cwd.
393 ///
394 /// Pass `disabled_tools` to selectively disable built-in tools
395 /// (e.g. `["web_search", "github_search"]` for a minimal setup).
396 pub fn with_builtins_cwd(cwd: PathBuf, disabled_tools: &[String]) -> Self {
397 let registry = Self::new();
398 let disabled: std::collections::HashSet<&str> =
399 disabled_tools.iter().map(|s| s.as_str()).collect();
400
401 // Helper to create shared cache on demand
402 let cache_once: std::cell::OnceCell<Arc<search_cache::SearchCache>> =
403 std::cell::OnceCell::new();
404
405 // MCP: use OnceCell to avoid re-creating McpManager on repeated calls
406 let mcp_once: std::cell::OnceCell<Arc<crate::mcp::McpManager>> = std::cell::OnceCell::new();
407 let mcp_manager = mcp_once
408 .get_or_init(|| Arc::new(crate::mcp::McpManager::new()))
409 .clone();
410
411 // Register all builtin tools — essential ones ignore disabled list
412 let mut all_tools: Vec<Box<dyn AgentTool>> = vec![
413 Box::new(ReadTool::with_cwd(cwd.clone())),
414 Box::new(WriteTool::with_cwd(cwd.clone())),
415 Box::new(EditTool::with_cwd(cwd.clone())),
416 Box::new(BashTool::with_cwd(cwd.clone())),
417 Box::new(GrepTool::with_cwd(cwd.clone())),
418 Box::new(FindTool::with_cwd(cwd.clone())),
419 Box::new(LsTool::with_cwd(cwd.clone())),
420 Box::new(web_search::WebSearchTool::new(
421 cache_once
422 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
423 .clone(),
424 )),
425 Box::new(search_cache::GetSearchResultsTool::new(
426 cache_once
427 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
428 .clone(),
429 )),
430 Box::new(github::GitHubTool::new(
431 cache_once
432 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
433 .clone(),
434 )),
435 Box::new(SubagentTool::with_cwd(cwd)),
436 ];
437
438 all_tools.push(Box::new(crate::mcp::McpTool::new(mcp_manager)));
439 all_tools.push(Box::new(context7::Context7ResolveLibraryIdTool::new()));
440 all_tools.push(Box::new(context7::Context7QueryDocsTool::new()));
441
442 for tool in all_tools {
443 if tool.essential() || !disabled.contains(tool.name()) {
444 // web_search ↔ get_search_results coupling
445 if tool.name() == "get_search_results" && disabled.contains("web_search") {
446 continue;
447 }
448 registry.register_arc(Arc::from(tool));
449 }
450 }
451
452 registry
453 }
454
455 /// Create registry with selected builtins only.
456 pub fn with_selected_tools(cwd: PathBuf, names: &[&str]) -> Self {
457 let full = Self::with_builtins_cwd(cwd, &[]);
458 let registry = Self::new();
459 let set: std::collections::HashSet<&str> = names.iter().copied().collect();
460 for name in full.names() {
461 if set.contains(name.as_str()) {
462 if let Some(tool) = full.get(&name) {
463 registry.register_arc(tool);
464 }
465 }
466 }
467 registry
468 }
469}