#![allow(unused_doc_comments)]
use crate::types::ToolDefinition;
use async_trait::async_trait;
use serde_json::Value;
use std::fmt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::oneshot;
#[derive(Debug, Clone)]
pub struct ToolContext {
pub workspace_dir: PathBuf,
pub root_dir: Option<PathBuf>,
pub session_id: Option<String>,
}
impl ToolContext {
pub fn new(workspace_dir: impl Into<PathBuf>) -> Self {
Self {
workspace_dir: workspace_dir.into(),
root_dir: None,
session_id: None,
}
}
pub fn root(&self) -> &Path {
self.root_dir.as_deref().unwrap_or(&self.workspace_dir)
}
pub fn with_session(mut self, session_id: impl Into<String>) -> Self {
self.session_id = Some(session_id.into());
self
}
pub fn with_root(mut self, root_dir: impl Into<PathBuf>) -> Self {
self.root_dir = Some(root_dir.into());
self
}
}
impl Default for ToolContext {
fn default() -> Self {
Self {
workspace_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
root_dir: None,
session_id: None,
}
}
}
pub type ToolError = String;
#[derive(Debug)]
pub struct AgentToolResult {
pub success: bool,
pub output: String,
pub metadata: Option<serde_json::Value>,
pub content_blocks: Option<Vec<oxi_ai::ContentBlock>>,
pub terminate: bool,
}
impl AgentToolResult {
pub fn success(output: impl Into<String>) -> Self {
Self {
success: true,
output: output.into(),
metadata: None,
content_blocks: None,
terminate: false,
}
}
pub fn error(output: impl Into<String>) -> Self {
Self {
success: false,
output: output.into(),
metadata: None,
content_blocks: None,
terminate: false,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = Some(metadata);
self
}
pub fn with_content_blocks(mut self, blocks: Vec<oxi_ai::ContentBlock>) -> Self {
self.content_blocks = Some(blocks);
self
}
pub fn with_terminate(mut self) -> Self {
self.terminate = true;
self
}
}
impl fmt::Display for AgentToolResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.output)
}
}
pub type ProgressCallback = Arc<dyn Fn(String) + Send + Sync>;
#[async_trait]
pub trait AgentTool: Send + Sync {
fn name(&self) -> &str;
fn label(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> Value;
fn essential(&self) -> bool {
false
}
async fn execute(
&self,
tool_call_id: &str,
params: Value,
signal: Option<oneshot::Receiver<()>>,
ctx: &ToolContext,
) -> Result<AgentToolResult, ToolError>;
fn on_progress(&self, _callback: ProgressCallback) {
}
fn to_definition(&self) -> ToolDefinition {
ToolDefinition {
name: self.name().to_string(),
description: self.description().to_string(),
input_schema: serde_json::from_value(self.parameters_schema()).unwrap_or_default(),
}
}
}
pub mod bash;
pub mod context7;
pub mod edit;
pub mod edit_diff;
pub mod file_mutation_queue;
pub mod find;
pub mod github;
pub mod github_search;
pub mod grep;
pub mod http_client;
pub mod ls;
pub mod path_security;
pub mod path_utils;
pub mod questionnaire;
pub mod read;
pub mod render_utils;
pub mod search_cache;
pub mod subagent;
pub mod tool_definition_wrapper;
pub mod truncate;
pub mod web_search;
pub mod write;
pub use bash::BashTool;
pub use edit::EditTool;
pub use find::FindTool;
pub use grep::GrepTool;
pub use ls::LsTool;
pub use read::ReadTool;
pub use crate::mcp::McpTool;
pub use context7::{Context7QueryDocsTool, Context7ResolveLibraryIdTool};
pub use questionnaire::{QuestionnaireBridge, QuestionnaireTool};
pub use subagent::SubagentTool;
pub use write::WriteTool;
#[derive(Clone)]
pub struct ToolRegistry {
tools: Arc<parking_lot::RwLock<std::collections::HashMap<String, Arc<dyn AgentTool>>>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
}
}
pub fn register(&self, tool: impl AgentTool + 'static) {
let name = tool.name().to_string();
self.tools.write().insert(name, Arc::new(tool));
}
pub fn register_arc(&self, tool: Arc<dyn AgentTool>) {
let name = tool.name().to_string();
self.tools.write().insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn AgentTool>> {
self.tools.read().get(name).cloned()
}
pub fn unregister(&self, name: &str) -> bool {
self.tools.write().remove(name).is_some()
}
pub fn names(&self) -> Vec<String> {
self.tools.read().keys().cloned().collect()
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools
.read()
.values()
.map(|t| t.to_definition())
.collect()
}
pub fn get_tools(&self) -> Vec<Arc<dyn AgentTool>> {
self.tools.read().values().cloned().collect()
}
pub fn has_all(&self, required: &[&str]) -> bool {
let tools = self.tools.read();
required.iter().all(|name| tools.contains_key(*name))
}
pub fn missing<'a>(&self, required: &[&'a str]) -> Vec<&'a str> {
let tools = self.tools.read();
required
.iter()
.filter(|name| !tools.contains_key(**name))
.copied()
.collect()
}
pub fn with_builtins() -> Self {
Self::with_builtins_cwd(PathBuf::from("."), &[])
}
pub fn with_builtins_cwd(cwd: PathBuf, disabled_tools: &[String]) -> Self {
let registry = Self::new();
let disabled: std::collections::HashSet<&str> =
disabled_tools.iter().map(|s| s.as_str()).collect();
let cache_once: std::cell::OnceCell<Arc<search_cache::SearchCache>> =
std::cell::OnceCell::new();
let mcp_once: std::cell::OnceCell<Arc<crate::mcp::McpManager>> = std::cell::OnceCell::new();
let mcp_manager = mcp_once
.get_or_init(|| Arc::new(crate::mcp::McpManager::new()))
.clone();
let mut all_tools: Vec<Box<dyn AgentTool>> = vec![
Box::new(ReadTool::with_cwd(cwd.clone())),
Box::new(WriteTool::with_cwd(cwd.clone())),
Box::new(EditTool::with_cwd(cwd.clone())),
Box::new(BashTool::with_cwd(cwd.clone())),
Box::new(GrepTool::with_cwd(cwd.clone())),
Box::new(FindTool::with_cwd(cwd.clone())),
Box::new(LsTool::with_cwd(cwd.clone())),
Box::new(web_search::WebSearchTool::new(
cache_once
.get_or_init(|| Arc::new(search_cache::SearchCache::new()))
.clone(),
)),
Box::new(search_cache::GetSearchResultsTool::new(
cache_once
.get_or_init(|| Arc::new(search_cache::SearchCache::new()))
.clone(),
)),
Box::new(github::GitHubTool::new(
cache_once
.get_or_init(|| Arc::new(search_cache::SearchCache::new()))
.clone(),
)),
Box::new(SubagentTool::with_cwd(cwd)),
];
all_tools.push(Box::new(crate::mcp::McpTool::new(mcp_manager)));
all_tools.push(Box::new(context7::Context7ResolveLibraryIdTool::new()));
all_tools.push(Box::new(context7::Context7QueryDocsTool::new()));
for tool in all_tools {
if tool.essential() || !disabled.contains(tool.name()) {
if tool.name() == "get_search_results" && disabled.contains("web_search") {
continue;
}
registry.register_arc(Arc::from(tool));
}
}
registry
}
pub fn with_selected_tools(cwd: PathBuf, names: &[&str]) -> Self {
let full = Self::with_builtins_cwd(cwd, &[]);
let registry = Self::new();
let set: std::collections::HashSet<&str> = names.iter().copied().collect();
for name in full.names() {
if set.contains(name.as_str()) {
if let Some(tool) = full.get(&name) {
registry.register_arc(tool);
}
}
}
registry
}
}