oxi_agent/tools.rs
1#![allow(unused_doc_comments)]
2/// Agent tools system
3/// This module provides the tool abstraction layer and built-in tools.
4use crate::types::ToolDefinition;
5use async_trait::async_trait;
6use serde_json::Value;
7use std::fmt;
8use std::path::{Path, PathBuf};
9use std::sync::Arc;
10use tokio::sync::oneshot;
11
12/// Context passed to tools at execution time.
13///
14/// This allows tools to operate on a specific workspace without being
15/// rebuilt. When `root_dir` is `Some`, tools use it as their base directory.
16/// When `None`, tools should fall back to `workspace_dir`.
17#[derive(Debug, Clone)]
18pub struct ToolContext {
19 /// Primary workspace directory (used when root_dir is None).
20 pub workspace_dir: PathBuf,
21 /// Optional explicit root directory for file tools.
22 /// Takes priority over workspace_dir if present.
23 pub root_dir: Option<PathBuf>,
24 /// Session identifier for logging/tracing.
25 pub session_id: Option<String>,
26}
27
28impl ToolContext {
29 /// Create a new context with the given workspace.
30 pub fn new(workspace_dir: impl Into<PathBuf>) -> Self {
31 Self {
32 workspace_dir: workspace_dir.into(),
33 root_dir: None,
34 session_id: None,
35 }
36 }
37
38 /// Get the effective root directory.
39 /// Returns root_dir if set, otherwise workspace_dir.
40 pub fn root(&self) -> &Path {
41 self.root_dir.as_deref().unwrap_or(&self.workspace_dir)
42 }
43
44 /// Set a session ID.
45 pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
46 self.session_id = Some(session_id.into());
47 self
48 }
49
50 /// Set an explicit root directory.
51 pub fn with_root(mut self, root_dir: impl Into<PathBuf>) -> Self {
52 self.root_dir = Some(root_dir.into());
53 self
54 }
55}
56
57impl Default for ToolContext {
58 fn default() -> Self {
59 Self {
60 workspace_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
61 root_dir: None,
62 session_id: None,
63 }
64 }
65}
66
67/// Result type for tool execution
68pub type ToolError = String;
69
70/// Result of tool execution
71#[derive(Debug)]
72pub struct AgentToolResult {
73 /// pub.
74 pub success: bool,
75 /// pub.
76 pub output: String,
77 /// pub.
78 pub metadata: Option<serde_json::Value>,
79 /// Optional content blocks (e.g., image blocks) to include in the tool result message.
80 /// When present, these are used as the content of the ToolResultMessage instead of
81 /// wrapping `output` in a Text block.
82 pub content_blocks: Option<Vec<oxi_ai::ContentBlock>>,
83 /// When `true`, signals that the agent loop should terminate after this batch
84 /// of tool calls completes. Defaults to `false` so that the loop continues
85 /// unless a tool explicitly opts-in to termination.
86 pub terminate: bool,
87}
88
89impl AgentToolResult {
90 /// Creates a successful tool result with the given output text.
91 pub fn success(output: impl Into<String>) -> Self {
92 Self {
93 success: true,
94 output: output.into(),
95 metadata: None,
96 content_blocks: None,
97 terminate: false,
98 }
99 }
100
101 /// Creates an error tool result with the given error message.
102 pub fn error(output: impl Into<String>) -> Self {
103 Self {
104 success: false,
105 output: output.into(),
106 metadata: None,
107 content_blocks: None,
108 terminate: false,
109 }
110 }
111
112 /// Attaches structured metadata (JSON) to this result.
113 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
114 self.metadata = Some(metadata);
115 self
116 }
117
118 /// Attaches rich content blocks (images, code, etc.) to this result.
119 pub fn with_content_blocks(mut self, blocks: Vec<oxi_ai::ContentBlock>) -> Self {
120 self.content_blocks = Some(blocks);
121 self
122 }
123
124 /// Mark this result as requesting agent-loop termination.
125 pub fn with_terminate(mut self) -> Self {
126 self.terminate = true;
127 self
128 }
129}
130
131impl fmt::Display for AgentToolResult {
132 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133 write!(f, "{}", self.output)
134 }
135}
136
137/// Callback type for progress updates
138pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
139
140/// Structured progress event for tool execution streaming.
141#[derive(Debug, Clone)]
142pub enum ToolProgress {
143 /// Status message (progress in progress)
144 Status {
145 /// The status text.
146 message: String,
147 },
148 /// Partial output (e.g., bash stdout streaming)
149 PartialOutput {
150 /// The partial output text.
151 output: String,
152 /// Whether this came from stderr.
153 is_error: bool,
154 },
155 /// Progress percentage (0.0 - 1.0)
156 Percentage {
157 /// Current progress value.
158 current: f64,
159 /// Optional total value.
160 total: Option<f64>,
161 /// Optional human-readable message.
162 message: Option<String>,
163 },
164 /// File operation progress
165 FileOperation {
166 /// Type of file operation.
167 operation: FileOp,
168 /// File path being operated on.
169 path: std::path::PathBuf,
170 /// Bytes processed so far.
171 bytes_processed: Option<u64>,
172 /// Total bytes to process.
173 total_bytes: Option<u64>,
174 },
175}
176
177/// File operation types for progress reporting.
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub enum FileOp {
180 /// Reading a file.
181 Reading,
182 /// Writing a file.
183 Writing,
184 /// Searching file contents.
185 Searching,
186 /// Editing a file.
187 Editing,
188}
189
190/// Tool execution mode for parallel safety.
191#[derive(Debug, Clone)]
192pub enum ToolExecutionMode {
193 /// Safe to run in parallel with any other tool
194 ParallelSafe,
195 /// Must run sequentially — no parallel execution
196 SequentialOnly,
197 /// Mutates a specific file — file_mutation_queue serializes same-file access
198 MutatesFile(std::path::PathBuf),
199 /// Read-only — always parallel safe
200 ReadOnly,
201}
202
203/// Render output for TUI visualization.
204#[derive(Debug, Clone)]
205pub struct RenderOutput {
206 /// Rendered text content (markdown or plain)
207 pub content: String,
208 /// Whether to show collapsed by default
209 pub collapsed: bool,
210 /// Optional summary text for TUI footer
211 pub summary: Option<String>,
212}
213
214/// Structured progress callback (alongside existing String callback)
215pub type StructuredProgressCallback = Arc<dyn Fn(ToolProgress) + Send + Sync>;
216
217/// Core trait for all agent tools
218#[async_trait]
219pub trait AgentTool: Send + Sync {
220 /// Tool name (used in function calls)
221 fn name(&self) -> &str;
222
223 /// Human-readable label
224 fn label(&self) -> &str;
225
226 /// Description for the model
227 fn description(&self) -> &str;
228
229 /// JSON Schema for parameters
230 fn parameters_schema(&self) -> Value;
231
232 /// Whether this tool is essential (cannot be disabled).
233 /// Essential tools: read, write, edit, bash, grep, find, ls
234 /// Optional tools: web_search, github, subagent, etc.
235 fn essential(&self) -> bool {
236 false
237 }
238
239 /// Execute the tool with the given tool call ID and parameters.
240 ///
241 /// The `ctx` parameter provides workspace information. File tools should
242 /// use `ctx.root()` to get the effective directory. Custom tools can use
243 /// `ctx.workspace_dir` for workspace-relative operations.
244 ///
245 /// # Examples
246 ///
247 /// ```ignore
248 /// use oxi_agent::{AgentTool, AgentToolResult, ToolContext};
249 /// use serde_json::json;
250 /// use async_trait::async_trait;
251 ///
252 /// struct MyTool;
253 ///
254 /// #[async_trait]
255 /// impl AgentTool for MyTool {
256 /// fn name(&self) -> &str { "my_tool" }
257 /// fn label(&self) -> &str { "My Tool" }
258 /// fn description(&self) -> &str { "A custom tool" }
259 /// fn parameters_schema(&self) -> Value { json!({
260 /// "type": "object",
261 /// "properties": {}
262 /// }) }
263 ///
264 /// async fn execute(&self, tool_call_id: &str, params: Value, _signal: Option<oneshot::Receiver<()>>, ctx: &ToolContext) -> Result<AgentToolResult, String> {
265 /// println!("Tool '{}' called with params: {:?}, workspace: {:?}", tool_call_id, params, ctx.workspace_dir);
266 /// Ok(AgentToolResult::success("Done!"))
267 /// }
268 /// }
269 /// ```
270 async fn execute(
271 &self,
272 tool_call_id: &str,
273 params: Value,
274 signal: Option<oneshot::Receiver<()>>,
275 ctx: &ToolContext,
276 ) -> Result<AgentToolResult, ToolError>;
277
278 /// Called with progress updates during execution.
279 /// Tools can override this to emit streaming updates.
280 fn on_progress(&self, _callback: ProgressCallback) {
281 // Default no-op
282 }
283
284 /// Structured progress callback for streaming tool execution updates.
285 /// Default implementation is no-op. Override in tools that support
286 /// structured progress (e.g., BashTool for partial output streaming).
287 fn on_structured_progress(&self, _callback: StructuredProgressCallback) {}
288
289 /// Custom rendering for tool call (TUI visualization).
290 /// Return None to use the default tool_renderer.rs formatter.
291 fn render_call(&self, _params: &serde_json::Value) -> Option<RenderOutput> {
292 None
293 }
294
295 /// Custom rendering for tool result (TUI visualization).
296 /// Return None to use the default tool_renderer.rs formatter.
297 fn render_result(&self, _result: &AgentToolResult) -> Option<RenderOutput> {
298 None
299 }
300
301 /// Execution mode for parallel safety.
302 /// Defaults to ParallelSafe. Override for file-mutating or sequential tools.
303 fn execution_mode(&self) -> ToolExecutionMode {
304 ToolExecutionMode::ParallelSafe
305 }
306
307 /// Convert to ToolDefinition
308 fn to_definition(&self) -> ToolDefinition {
309 ToolDefinition {
310 name: self.name().to_string(),
311 description: self.description().to_string(),
312 input_schema: serde_json::from_value(self.parameters_schema()).unwrap_or_default(),
313 }
314 }
315}
316
317// Built-in tools
318/// Bash shell execution tool.
319pub mod bash;
320/// Browser tools (engine abstraction always compiled).
321pub mod browse;
322/// Context7 documentation tools.
323pub mod context7;
324/// In-place file edit tool.
325pub mod edit;
326/// Diff-based edit helpers.
327pub mod edit_diff;
328/// Serialised file-mutation queue.
329pub mod file_mutation_queue;
330/// File-fsystem find tool.
331pub mod find;
332/// Image generation tool (OpenRouter API).
333pub mod generate_image;
334/// GitHub integration tool (gh CLI-based).
335pub mod github;
336/// GitHub repository search tool (legacy REST API).
337pub mod github_search;
338/// Content search (grep) tool.
339pub mod grep;
340/// Shared HTTP client singleton.
341pub mod http_client;
342/// Directory listing tool.
343pub mod ls;
344/// Path security (traversal protection).
345pub mod path_security;
346/// Path manipulation utilities.
347pub mod path_utils;
348/// Questionnaire tool — interactive multi-question TUI overlay.
349pub mod questionnaire;
350/// File reading tool.
351pub mod read;
352/// Rendering utilities for tool output.
353pub mod render_utils;
354/// Search result cache and get_search_results tool.
355pub mod search_cache;
356/// Sub-agent delegation tool.
357pub mod subagent;
358/// Tool definition wrapper helpers.
359pub mod tool_definition_wrapper;
360/// Output truncation helpers.
361pub mod truncate;
362/// Multi-engine web search tool (a3s-search library + DuckDuckGo fallback).
363pub mod web_search;
364/// File writing tool.
365pub mod write;
366
367// Re-export for convenience
368pub use bash::BashTool;
369pub use edit::EditTool;
370pub use find::FindTool;
371pub use grep::GrepTool;
372pub use ls::LsTool;
373pub use read::ReadTool;
374// pub use search_cache;
375
376pub use crate::mcp::McpTool;
377pub use context7::{Context7QueryDocsTool, Context7ResolveLibraryIdTool};
378pub use questionnaire::{QuestionnaireBridge, QuestionnaireTool};
379pub use subagent::SubagentTool;
380pub use write::WriteTool;
381
382/// Tool registry for managing available tools
383#[derive(Clone)]
384pub struct ToolRegistry {
385 tools: Arc<parking_lot::RwLock<std::collections::HashMap<String, Arc<dyn AgentTool>>>>,
386}
387
388impl Default for ToolRegistry {
389 fn default() -> Self {
390 Self::new()
391 }
392}
393
394impl ToolRegistry {
395 /// Creates an empty tool registry.
396 pub fn new() -> Self {
397 Self {
398 tools: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
399 }
400 }
401
402 /// Register a tool
403 pub fn register(&self, tool: impl AgentTool + 'static) {
404 let name = tool.name().to_string();
405 self.tools.write().insert(name, Arc::new(tool));
406 }
407
408 /// Register a tool that is already wrapped in an `Arc`.
409 /// This is the primary path for extensions that produce `Arc<dyn AgentTool>`.
410 pub fn register_arc(&self, tool: Arc<dyn AgentTool>) {
411 let name = tool.name().to_string();
412 self.tools.write().insert(name, tool);
413 }
414
415 /// Get a tool by name
416 pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
417 self.tools.read().get(name).cloned()
418 }
419
420 /// Unregister a tool by name.
421 /// Returns `true` if the tool was present and removed.
422 pub fn unregister(&self, name: &str) -> bool {
423 self.tools.write().remove(name).is_some()
424 }
425
426 /// List all registered tool names
427 pub fn names(&self) -> Vec<String> {
428 self.tools.read().keys().cloned().collect()
429 }
430
431 /// Get all tool definitions
432 pub fn definitions(&self) -> Vec<ToolDefinition> {
433 self.tools
434 .read()
435 .values()
436 .map(|t| t.to_definition())
437 .collect()
438 }
439
440 /// Get all tools as a slice
441 pub fn get_tools(&self) -> Vec<Arc<dyn AgentTool>> {
442 self.tools.read().values().cloned().collect()
443 }
444
445 /// Check whether all tools in `required` are registered.
446 ///
447 /// Useful for validating program/module dependencies before execution.
448 ///
449 /// # Example
450 ///
451 /// ```
452 /// use oxi_agent::ToolRegistry;
453 /// let registry = ToolRegistry::new();
454 /// assert!(!registry.has_all(&["read", "write"]));
455 /// ```
456 pub fn has_all(&self, required: &[&str]) -> bool {
457 let tools = self.tools.read();
458 required.iter().all(|name| tools.contains_key(*name))
459 }
460
461 /// Return the subset of `required` tool names that are **not** registered.
462 ///
463 /// # Example
464 ///
465 /// ```
466 /// use oxi_agent::ToolRegistry;
467 /// let registry = ToolRegistry::new();
468 /// let missing = registry.missing(&["read", "exec", "nonexistent"]);
469 /// assert_eq!(missing, vec!["read", "exec", "nonexistent"]);
470 /// ```
471 pub fn missing<'a>(&self, required: &[&'a str]) -> Vec<&'a str> {
472 let tools = self.tools.read();
473 required
474 .iter()
475 .filter(|name| !tools.contains_key(**name))
476 .copied()
477 .collect()
478 }
479
480 /// Create a registry with all built-in tools
481 ///
482 /// # Examples
483 ///
484 /// ```
485 /// use oxi_agent::ToolRegistry;
486 /// let registry = ToolRegistry::with_builtins();
487 /// let tools = registry.names();
488 /// assert!(tools.contains(&"read".to_string()));
489 /// assert!(tools.contains(&"write".to_string()));
490 /// assert!(tools.contains(&"bash".to_string()));
491 /// ```
492 pub fn with_builtins() -> Self {
493 Self::with_builtins_cwd(PathBuf::from("."), &[])
494 }
495
496 /// Create a registry with all built-in tools, using the given cwd.
497 ///
498 /// Pass `disabled_tools` to selectively disable built-in tools
499 /// (e.g. `["web_search", "github_search"]` for a minimal setup).
500 pub fn with_builtins_cwd(cwd: PathBuf, disabled_tools: &[String]) -> Self {
501 let registry = Self::new();
502 let disabled: std::collections::HashSet<&str> =
503 disabled_tools.iter().map(|s| s.as_str()).collect();
504
505 // Helper to create shared cache on demand
506 let cache_once: std::cell::OnceCell<Arc<search_cache::SearchCache>> =
507 std::cell::OnceCell::new();
508
509 // MCP: use OnceCell to avoid re-creating McpManager on repeated calls
510 let mcp_once: std::cell::OnceCell<Arc<crate::mcp::McpManager>> = std::cell::OnceCell::new();
511 let mcp_manager = mcp_once
512 .get_or_init(|| Arc::new(crate::mcp::McpManager::new()))
513 .clone();
514
515 // Register all builtin tools — essential ones ignore disabled list
516 let mut all_tools: Vec<Box<dyn AgentTool>> = vec![
517 Box::new(ReadTool::with_cwd(cwd.clone())),
518 Box::new(WriteTool::with_cwd(cwd.clone())),
519 Box::new(EditTool::with_cwd(cwd.clone())),
520 Box::new(BashTool::with_cwd(cwd.clone())),
521 Box::new(GrepTool::with_cwd(cwd.clone())),
522 Box::new(FindTool::with_cwd(cwd.clone())),
523 Box::new(LsTool::with_cwd(cwd.clone())),
524 Box::new(web_search::WebSearchTool::new(
525 cache_once
526 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
527 .clone(),
528 )),
529 Box::new(search_cache::GetSearchResultsTool::new(
530 cache_once
531 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
532 .clone(),
533 )),
534 Box::new(github::GitHubTool::new(
535 cache_once
536 .get_or_init(|| Arc::new(search_cache::SearchCache::new()))
537 .clone(),
538 )),
539 Box::new(SubagentTool::with_cwd(cwd)),
540 ];
541
542 all_tools.push(Box::new(crate::mcp::McpTool::new(mcp_manager)));
543 all_tools.push(Box::new(context7::Context7ResolveLibraryIdTool::new()));
544 all_tools.push(Box::new(context7::Context7QueryDocsTool::new()));
545 all_tools.push(Box::new(generate_image::GenerateImageTool::new()));
546
547 for tool in all_tools {
548 if tool.essential() || !disabled.contains(tool.name()) {
549 // web_search ↔ get_search_results coupling
550 if tool.name() == "get_search_results" && disabled.contains("web_search") {
551 continue;
552 }
553 registry.register_arc(Arc::from(tool));
554 }
555 }
556
557 registry
558 }
559
560 /// Extend this registry with all tools from another registry.
561 ///
562 /// Useful for composing tool sets from multiple sources
563 /// (e.g., coding tools + kernel tools + browser tools).
564 ///
565 /// # Example
566 ///
567 /// ```ignore
568 /// let base = ToolRegistry::new();
569 /// base.extend_from(&other_registry);
570 /// ```
571 pub fn extend_from(&self, other: &ToolRegistry) {
572 for name in other.names() {
573 if let Some(tool) = other.get(&name) {
574 self.register_arc(tool);
575 }
576 }
577 }
578
579 /// Create registry with selected builtins only.
580 pub fn with_selected_tools(cwd: PathBuf, names: &[&str]) -> Self {
581 let full = Self::with_builtins_cwd(cwd, &[]);
582 let registry = Self::new();
583 let set: std::collections::HashSet<&str> = names.iter().copied().collect();
584 for name in full.names() {
585 if set.contains(name.as_str()) {
586 if let Some(tool) = full.get(&name) {
587 registry.register_arc(tool);
588 }
589 }
590 }
591 registry
592 }
593}