pub mod builtin;
pub mod plugin_adapter;
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::debug;
use self::builtin::{
AudioTranscriptionTool, BashTool, CodeSearchTool, EchoTool, FileExtractTool, FileReadTool,
FileWriteTool, GenerateCodeTool, GraphTool, GrepTool, MathTool, PromptUserTool, RgTool,
SearchTool, ShellTool,
};
#[cfg(feature = "api")]
use self::builtin::WebSearchTool;
#[cfg(feature = "web-scraping")]
use self::builtin::WebScraperTool;
use crate::spec_ai_core::agent::model::ModelProvider;
use crate::spec_ai_core::embeddings::EmbeddingsClient;
use crate::spec_ai_core::persistence::Persistence;
pub use plugin_adapter::PluginToolAdapter;
#[cfg(feature = "openai")]
use async_openai::types::ChatCompletionTool;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: String,
pub error: Option<String>,
}
impl ToolResult {
pub fn success(output: impl Into<String>) -> Self {
Self {
success: true,
output: output.into(),
error: None,
}
}
pub fn failure(error: impl Into<String>) -> Self {
Self {
success: false,
output: String::new(),
error: Some(error.into()),
}
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Value;
async fn execute(&self, args: Value) -> Result<ToolResult>;
}
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
#[allow(unused_variables)]
pub fn with_builtin_tools(
persistence: Option<Arc<Persistence>>,
embeddings: Option<EmbeddingsClient>,
code_model_provider: Option<Arc<dyn ModelProvider>>,
) -> Self {
let mut registry = Self::new();
registry.register(Arc::new(EchoTool::new()));
registry.register(Arc::new(MathTool::new()));
registry.register(Arc::new(FileReadTool::new()));
registry.register(Arc::new(FileExtractTool::new()));
registry.register(Arc::new(FileWriteTool::new()));
registry.register(Arc::new(PromptUserTool::new()));
registry.register(Arc::new(SearchTool::new()));
registry.register(Arc::new(GrepTool::new()));
registry.register(Arc::new(RgTool::new()));
registry.register(Arc::new(CodeSearchTool::new()));
registry.register(Arc::new(BashTool::new()));
registry.register(Arc::new(ShellTool::new()));
if let Some(provider) = code_model_provider {
registry.register(Arc::new(GenerateCodeTool::new(provider)));
}
#[cfg(feature = "api")]
registry.register(Arc::new(WebSearchTool::new().with_embeddings(embeddings)));
#[cfg(feature = "web-scraping")]
registry.register(Arc::new(WebScraperTool::new()));
if let Some(persistence) = persistence {
registry.register(Arc::new(GraphTool::new(persistence.clone())));
registry.register(Arc::new(AudioTranscriptionTool::with_persistence(
persistence,
)));
} else {
registry.register(Arc::new(AudioTranscriptionTool::new()));
}
tracing::debug!("ToolRegistry created with {} tools", registry.tools.len());
for name in registry.tools.keys() {
tracing::debug!(" - Tool: {}", name);
}
registry
}
pub fn register(&mut self, tool: Arc<dyn Tool>) {
let name = tool.name().to_string();
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
self.tools.get(name).cloned()
}
pub fn list(&self) -> Vec<&str> {
self.tools.keys().map(|s| s.as_str()).collect()
}
pub fn has(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub async fn execute(&self, name: &str, args: Value) -> Result<ToolResult> {
let tool = self
.get(name)
.ok_or_else(|| anyhow::anyhow!("Tool not found: {}", name))?;
debug!("Executing tool '{}'", name);
let result = tool.execute(args).await;
match &result {
Ok(res) => {
debug!(
"Tool '{}' completed: success={}, error={:?}",
name, res.success, res.error
);
}
Err(err) => {
debug!("Tool '{}' failed to execute: {}", name, err);
}
}
result
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
pub fn load_plugins(
&mut self,
dir: &std::path::Path,
allow_override: bool,
) -> anyhow::Result<crate::spec_ai_plugin::LoadStats> {
use crate::spec_ai_plugin::{expand_tilde, PluginLoader};
let expanded_dir = expand_tilde(dir);
let mut loader = PluginLoader::new();
let stats = loader.load_directory(&expanded_dir)?;
for (tool_ref, plugin_name) in loader.all_tools() {
let adapter = match PluginToolAdapter::new(tool_ref, plugin_name) {
Ok(a) => a,
Err(e) => {
tracing::warn!(
"Failed to create adapter for tool from {}: {}",
plugin_name,
e
);
continue;
}
};
let tool_name = adapter.name().to_string();
if self.has(&tool_name) {
if allow_override {
tracing::info!(
"Plugin tool '{}' from '{}' overriding built-in tool",
tool_name,
plugin_name
);
} else {
tracing::warn!(
"Plugin tool '{}' from '{}' would override built-in, skipping (set allow_override_builtin=true to allow)",
tool_name,
plugin_name
);
continue;
}
}
tracing::debug!(
"Registering plugin tool '{}' from '{}'",
tool_name,
plugin_name
);
self.register(Arc::new(adapter));
}
Ok(stats)
}
#[cfg(any(feature = "openai", feature = "mlx", feature = "lmstudio"))]
pub fn to_openai_tools(&self) -> Vec<ChatCompletionTool> {
use crate::spec_ai_core::agent::function_calling::tool_to_openai_function;
self.tools
.values()
.map(|tool| {
tool_to_openai_function(tool.name(), tool.description(), &tool.parameters())
})
.collect()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DummyTool;
#[async_trait]
impl Tool for DummyTool {
fn name(&self) -> &str {
"dummy"
}
fn description(&self) -> &str {
"A dummy tool for testing"
}
fn parameters(&self) -> Value {
serde_json::json!({
"type": "object",
"properties": {}
})
}
async fn execute(&self, _args: Value) -> Result<ToolResult> {
Ok(ToolResult::success("dummy output"))
}
}
#[tokio::test]
async fn test_register_and_get_tool() {
let mut registry = ToolRegistry::new();
let tool = Arc::new(DummyTool);
registry.register(tool.clone());
assert!(registry.has("dummy"));
assert!(registry.get("dummy").is_some());
assert_eq!(registry.len(), 1);
}
#[tokio::test]
async fn test_list_tools() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(DummyTool));
let tools = registry.list();
assert_eq!(tools.len(), 1);
assert!(tools.contains(&"dummy"));
}
#[tokio::test]
async fn test_execute_tool() {
let mut registry = ToolRegistry::new();
registry.register(Arc::new(DummyTool));
let result = registry.execute("dummy", Value::Null).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "dummy output");
}
#[tokio::test]
async fn test_execute_nonexistent_tool() {
let registry = ToolRegistry::new();
let result = registry.execute("nonexistent", Value::Null).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_tool_result_success() {
let result = ToolResult::success("test output");
assert!(result.success);
assert_eq!(result.output, "test output");
assert!(result.error.is_none());
}
#[tokio::test]
async fn test_tool_result_failure() {
let result = ToolResult::failure("test error");
assert!(!result.success);
assert_eq!(result.error, Some("test error".to_string()));
}
}