use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::{debug, info};
use crate::tools;
use crate::tools::invocation::InvocationError;
use crate::tools::{
AdvancedTool, TerminalTool, TerminationMessage, Tool, ToolDescription, ToolUseError,
};
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct Stats {
pub success_count: HashMap<String, usize>,
pub error_count: HashMap<String, usize>,
pub inexistent_count: HashMap<String, usize>,
}
#[derive(Default, Clone)]
pub struct Toolbox {
terminal_tools: Arc<RwLock<HashMap<String, Box<dyn TerminalTool>>>>,
tools: Arc<RwLock<HashMap<String, Box<dyn Tool>>>>,
advanced_tools: Arc<RwLock<HashMap<String, Box<dyn AdvancedTool>>>>,
stats: Arc<RwLock<Stats>>,
}
impl Debug for Toolbox {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Toolbox").finish()
}
}
impl Toolbox {
pub async fn termination_messages(&self) -> Vec<TerminationMessage> {
let mut messages = Vec::new();
for tool in self.terminal_tools.read().await.values() {
if let Some(message) = tool.take_done().await {
messages.push(message);
}
}
messages
}
pub async fn add_terminal_tool(&mut self, tool: impl TerminalTool + 'static) {
let name = tool.description().name;
self.terminal_tools
.write()
.await
.insert(name, Box::new(tool));
}
pub async fn has_terminal_tools(&self) -> bool {
!self.terminal_tools.read().await.is_empty()
}
pub async fn add_tool(&mut self, tool: impl Tool + 'static) {
let name = tool.description().name;
self.tools.write().await.insert(name, Box::new(tool));
}
pub async fn add_advanced_tool(&mut self, tool: impl AdvancedTool + 'static) {
let name = tool.description().name;
self.advanced_tools
.write()
.await
.insert(name, Box::new(tool));
}
pub async fn describe(&self) -> HashMap<String, ToolDescription> {
let mut descriptions = HashMap::new();
for (name, tool) in self.terminal_tools.read().await.iter() {
descriptions.insert(name.clone(), tool.description());
}
for (name, tool) in self.tools.read().await.iter() {
descriptions.insert(name.clone(), tool.description());
}
for (name, tool) in self.advanced_tools.read().await.iter() {
descriptions.insert(name.clone(), tool.description());
}
descriptions
}
pub async fn reset_stats(&self) {
*self.stats.write().await = Stats::default();
}
pub async fn stats(&self) -> Stats {
self.stats.read().await.clone()
}
pub async fn report_success(&self, tool_name: &str) {
let mut stats = self.stats.write().await;
stats
.success_count
.entry(tool_name.to_string())
.and_modify(|c| *c += 1)
.or_insert(1);
}
pub async fn report_error(&self, tool_name: &str) {
let mut stats = self.stats.write().await;
stats
.error_count
.entry(tool_name.to_string())
.and_modify(|c| *c += 1)
.or_insert(1);
}
pub async fn report_inexistent(&self, tool_name: &str) {
let mut stats = self.stats.write().await;
stats
.inexistent_count
.entry(tool_name.to_string())
.and_modify(|c| *c += 1)
.or_insert(1);
}
}
async fn invoke_from_toolbox(
toolbox: Toolbox,
tool_name: &str,
input: serde_yaml::Value,
) -> Result<serde_yaml::Value, ToolUseError> {
if let Some(tool) = toolbox.clone().advanced_tools.read().await.get(tool_name) {
let result = tool.invoke_with_toolbox(toolbox.clone(), input).await;
if result.is_ok() {
toolbox.report_success(tool_name).await;
} else {
toolbox.report_error(tool_name).await;
}
return result;
}
{
let guard = toolbox.terminal_tools.read().await;
if let Some(tool) = guard.get(tool_name) {
let result = tool.invoke(input).await;
if result.is_ok() {
toolbox.report_success(tool_name).await;
} else {
toolbox.report_error(tool_name).await;
}
return result;
}
}
let guard = toolbox.tools.read().await;
let tool = guard.get(tool_name);
if tool.is_none() {
toolbox.report_inexistent(tool_name).await;
}
let tool = tool.ok_or(ToolUseError::ToolNotFound(tool_name.to_string()))?;
let result = tool.invoke(input).await;
if result.is_ok() {
toolbox.report_success(tool_name).await;
} else {
toolbox.report_error(tool_name).await;
}
result
}
pub async fn invoke_simple_from_toolbox(
toolbox: Toolbox,
tool_name: &str,
input: serde_yaml::Value,
) -> Result<serde_yaml::Value, ToolUseError> {
{
let guard = toolbox.terminal_tools.read().await;
if let Some(tool) = guard.get(tool_name) {
let result = tool.invoke(input).await;
if result.is_ok() {
toolbox.report_success(tool_name).await;
} else {
toolbox.report_error(tool_name).await;
}
return result;
}
}
let guard = toolbox.tools.read().await;
let tool = guard.get(tool_name);
if tool.is_none() {
toolbox.report_inexistent(tool_name).await;
}
let tool = tool.ok_or(ToolUseError::ToolNotFound(tool_name.to_string()))?;
let result = tool.invoke(input).await;
if result.is_ok() {
toolbox.report_success(tool_name).await;
} else {
toolbox.report_error(tool_name).await;
}
result
}
#[derive(Debug, Clone)]
pub enum InvokeResult {
NoInvocationsFound {
e: InvocationError,
},
NoValidInvocationsFound {
e: InvocationError,
invocation_count: usize,
},
Success {
invocation_count: usize,
tool_name: String,
extracted_input: String,
result: String,
},
Error {
invocation_count: usize,
tool_name: String,
extracted_input: String,
e: ToolUseError,
},
}
#[tracing::instrument(skip(toolbox, data))]
pub async fn invoke_tool(toolbox: Toolbox, data: &str) -> InvokeResult {
let tool_invocations = match tools::invocation::find_all(data) {
Ok(invocations) => invocations,
Err(e) => return InvokeResult::NoInvocationsFound { e },
};
let invocation_count = tool_invocations.invocations.len();
info!(
"{} YAML blocks and {} Tool invocations found",
tool_invocations.yaml_block_count, invocation_count
);
let invocation = match tools::choose_invocation(tool_invocations).await {
Ok(invocation) => invocation,
Err(e) => {
return InvokeResult::NoValidInvocationsFound {
e,
invocation_count,
}
}
};
debug!(tool_name = invocation.tool_name, "Invocation found");
let tool_name = invocation.tool_name.clone();
let input = invocation.parameters;
let extracted_input = serde_yaml::to_string(&input).unwrap_or_else(|_| {
format!(
"Failed to serialize input for tool {}",
invocation.tool_name
)
});
let result = invoke_from_toolbox(toolbox, &tool_name, input.clone()).await;
match result {
Ok(output) => {
let result = serde_yaml::to_string(&output).unwrap_or_else(|_| {
format!(
"Failed to serialize output for tool {}",
invocation.tool_name
)
});
InvokeResult::Success {
tool_name,
extracted_input,
invocation_count,
result,
}
}
Err(e) => InvokeResult::Error {
tool_name,
extracted_input,
invocation_count,
e,
},
}
}