use crate::agent::context::AgentContext;
use crate::agent::error::{AgentError, AgentResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, input: ToolInput, ctx: &AgentContext) -> ToolResult;
fn metadata(&self) -> ToolMetadata {
ToolMetadata::default()
}
fn validate_input(&self, input: &ToolInput) -> AgentResult<()> {
let _ = input;
Ok(())
}
fn requires_confirmation(&self) -> bool {
false
}
fn to_llm_tool(&self) -> LLMTool {
LLMTool {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.parameters_schema(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInput {
pub arguments: serde_json::Value,
pub raw_input: Option<String>,
}
impl ToolInput {
pub fn from_json(arguments: serde_json::Value) -> Self {
Self {
arguments,
raw_input: None,
}
}
pub fn from_raw(raw: impl Into<String>) -> Self {
let raw = raw.into();
Self {
arguments: serde_json::Value::String(raw.clone()),
raw_input: Some(raw),
}
}
pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
self.arguments
.get(key)
.and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn get_str(&self, key: &str) -> Option<&str> {
self.arguments.get(key).and_then(|v| v.as_str())
}
pub fn get_number(&self, key: &str) -> Option<f64> {
self.arguments.get(key).and_then(|v| v.as_f64())
}
pub fn get_bool(&self, key: &str) -> Option<bool> {
self.arguments.get(key).and_then(|v| v.as_bool())
}
}
impl From<serde_json::Value> for ToolInput {
fn from(v: serde_json::Value) -> Self {
Self::from_json(v)
}
}
impl From<String> for ToolInput {
fn from(s: String) -> Self {
Self::from_raw(s)
}
}
impl From<&str> for ToolInput {
fn from(s: &str) -> Self {
Self::from_raw(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: serde_json::Value,
pub error: Option<String>,
pub metadata: HashMap<String, String>,
}
impl ToolResult {
pub fn success(output: serde_json::Value) -> Self {
Self {
success: true,
output,
error: None,
metadata: HashMap::new(),
}
}
pub fn success_text(text: impl Into<String>) -> Self {
Self::success(serde_json::Value::String(text.into()))
}
pub fn failure(error: impl Into<String>) -> Self {
Self {
success: false,
output: serde_json::Value::Null,
error: Some(error.into()),
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn as_text(&self) -> Option<&str> {
self.output.as_str()
}
pub fn to_string_output(&self) -> String {
if self.success {
match &self.output {
serde_json::Value::String(s) => s.clone(),
v => v.to_string(),
}
} else {
format!(
"Error: {}",
self.error.as_deref().unwrap_or("Unknown error")
)
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ToolMetadata {
pub category: Option<String>,
pub tags: Vec<String>,
pub is_dangerous: bool,
pub requires_network: bool,
pub requires_filesystem: bool,
pub custom: HashMap<String, serde_json::Value>,
}
impl ToolMetadata {
pub fn new() -> Self {
Self::default()
}
pub fn with_category(mut self, category: impl Into<String>) -> Self {
self.category = Some(category.into());
self
}
pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
pub fn dangerous(mut self) -> Self {
self.is_dangerous = true;
self
}
pub fn needs_network(mut self) -> Self {
self.requires_network = true;
self
}
pub fn needs_filesystem(mut self) -> Self {
self.requires_filesystem = true;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDescriptor {
pub name: String,
pub description: String,
pub parameters_schema: serde_json::Value,
pub metadata: ToolMetadata,
}
impl ToolDescriptor {
pub fn from_tool(tool: &dyn Tool) -> Self {
Self {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters_schema: tool.parameters_schema(),
metadata: tool.metadata(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LLMTool {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[async_trait]
pub trait ToolRegistry: Send + Sync {
fn register(&mut self, tool: Arc<dyn Tool>) -> AgentResult<()>;
fn register_all(&mut self, tools: Vec<Arc<dyn Tool>>) -> AgentResult<()> {
for tool in tools {
self.register(tool)?;
}
Ok(())
}
fn get(&self, name: &str) -> Option<Arc<dyn Tool>>;
fn unregister(&mut self, name: &str) -> AgentResult<bool>;
fn list(&self) -> Vec<ToolDescriptor>;
fn list_names(&self) -> Vec<String>;
fn contains(&self, name: &str) -> bool;
fn count(&self) -> usize;
async fn execute(
&self,
name: &str,
input: ToolInput,
ctx: &AgentContext,
) -> AgentResult<ToolResult> {
let tool = self
.get(name)
.ok_or_else(|| AgentError::ToolNotFound(name.to_string()))?;
tool.validate_input(&input)?;
Ok(tool.execute(input, ctx).await)
}
fn to_llm_tools(&self) -> Vec<LLMTool> {
self.list()
.iter()
.map(|d| LLMTool {
name: d.name.clone(),
description: d.description.clone(),
parameters: d.parameters_schema.clone(),
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::context::AgentContext;
#[test]
fn test_tool_input_from_json() {
let input = ToolInput::from_json(serde_json::json!({
"name": "test",
"count": 42
}));
assert_eq!(input.get_str("name"), Some("test"));
assert_eq!(input.get_number("count"), Some(42.0));
}
#[test]
fn test_tool_result() {
let success = ToolResult::success_text("OK");
assert!(success.success);
assert_eq!(success.as_text(), Some("OK"));
let failure = ToolResult::failure("Something went wrong");
assert!(!failure.success);
assert!(failure.error.is_some());
}
}