use crate::error::AgentError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
fn execute(&self, args: &serde_json::Value) -> Result<ToolResult, AgentError>;
fn requires_permission(&self) -> bool {
true
}
fn is_dangerous(&self, _args: &serde_json::Value) -> bool {
false
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub name: String,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: String,
}
impl ToolResult {
pub fn ok(output: impl Into<String>) -> Self {
Self {
success: true,
output: output.into(),
}
}
pub fn err(output: impl Into<String>) -> Self {
Self {
success: false,
output: output.into(),
}
}
}
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Box<dyn Tool>) {
let name = tool.name().to_string();
self.tools.insert(name, tool);
}
pub fn get(&self, name: &str) -> Option<&dyn Tool> {
self.tools.get(name).map(|t| t.as_ref())
}
pub fn execute(&self, call: &ToolCall) -> Result<ToolResult, AgentError> {
let tool = self
.get(&call.name)
.ok_or_else(|| AgentError::ToolNotFound(call.name.clone()))?;
tool.execute(&call.arguments)
}
pub fn iter(&self) -> impl Iterator<Item = &dyn Tool> {
self.tools.values().map(|t| t.as_ref())
}
pub fn tools_prompt(&self) -> String {
if self.is_empty() {
return String::new();
}
let mut lines = Vec::new();
lines.push("# Tools\n".to_string());
lines.push("You have access to the following tools. To use a tool, output a tool call in this exact format:\n".to_string());
lines.push("<tool_call>".to_string());
lines.push(r#"{"name": "<tool_name>", "arguments": {<json_args>}}"#.to_string());
lines.push("</tool_call>\n".to_string());
lines.push("Available tools:\n".to_string());
for tool in self.iter() {
let schema = serde_json::json!({
"name": tool.name(),
"description": tool.description(),
"parameters": tool.parameters_schema(),
});
lines.push(format!(
"- {}\n```json\n{}\n```\n",
tool.name(),
serde_json::to_string_pretty(&schema).unwrap_or_default()
));
}
lines.push("When you want to use a tool, output ONLY the <tool_call> block. You may use multiple tool calls in a single response. After each tool call, wait for the tool result before continuing.".to_string());
lines.join("\n")
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn parse_tool_calls(text: &str) -> (Vec<ToolCall>, Vec<String>) {
let mut calls = Vec::new();
let mut text_parts = Vec::new();
let mut remaining = text;
loop {
if let Some(start) = remaining.find("<tool_call>") {
let before = &remaining[..start];
if !before.trim().is_empty() {
text_parts.push(before.trim().to_string());
}
let after_tag = &remaining[start + "<tool_call>".len()..];
if let Some(end) = after_tag.find("</tool_call>") {
let json_str = after_tag[..end].trim();
match serde_json::from_str::<ToolCall>(json_str) {
Ok(call) => calls.push(call),
Err(e) => {
text_parts.push(format!("[Failed to parse tool call: {}]", e));
}
}
remaining = &after_tag[end + "</tool_call>".len()..];
} else {
text_parts.push(remaining.to_string());
break;
}
} else {
if !remaining.trim().is_empty() {
text_parts.push(remaining.trim().to_string());
}
break;
}
}
(calls, text_parts)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_single_tool_call() {
let text = r#"Let me check that for you.
<tool_call>
{"name": "bash", "arguments": {"command": "ls -la"}}
</tool_call>
"#;
let (calls, text_parts) = parse_tool_calls(text);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "bash");
assert_eq!(text_parts.len(), 1);
assert!(text_parts[0].contains("Let me check"));
}
#[test]
fn test_parse_multiple_tool_calls() {
let text = r#"I'll read both files.
<tool_call>
{"name": "read", "arguments": {"path": "a.txt"}}
</tool_call>
And now the second one:
<tool_call>
{"name": "read", "arguments": {"path": "b.txt"}}
</tool_call>
Done."#;
let (calls, text_parts) = parse_tool_calls(text);
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "read");
assert_eq!(calls[1].name, "read");
assert_eq!(text_parts.len(), 3);
}
#[test]
fn test_parse_no_tool_calls() {
let text = "Just a normal response with no tools.";
let (calls, text_parts) = parse_tool_calls(text);
assert_eq!(calls.len(), 0);
assert_eq!(text_parts.len(), 1);
}
}