use crate::config::Config;
use crate::log_debug;
use crate::session::chat::session::ChatSession;
use crate::session::chat::tool_error_tracker::ToolErrorTracker;
use anyhow::Result;
use colored::Colorize;
pub struct ToolProcessor {
pub error_tracker: ToolErrorTracker,
}
impl ToolProcessor {
pub fn new() -> Self {
Self {
error_tracker: ToolErrorTracker::new(3),
}
}
pub async fn execute_tool_calls(
&mut self,
tool_calls: Vec<crate::mcp::McpToolCall>,
chat_session: &mut ChatSession,
config: &Config,
operation_cancelled: tokio::sync::watch::Receiver<bool>,
) -> Result<(Vec<String>, bool)> {
let mut tool_tasks = Vec::new();
let mut tool_results = Vec::new();
for tool_call in tool_calls.clone() {
if *operation_cancelled.borrow() {
return Ok((tool_results, false));
}
let operation_cancelled_clone = operation_cancelled.clone();
let config_clone = config.clone();
let task = tokio::spawn(async move {
crate::mcp::execute_tool_call(
&tool_call,
&config_clone,
Some(operation_cancelled_clone),
)
.await
});
tool_tasks.push(task);
}
for (i, task) in tool_tasks.into_iter().enumerate() {
if *operation_cancelled.borrow() {
return Ok((tool_results, false));
}
let tool_call = &tool_calls[i];
let result = task.await;
match result {
Ok(Ok((tool_result, _duration_ms))) => {
self.error_tracker.record_success(&tool_call.tool_name);
let result_content = match &tool_result.result {
serde_json::Value::String(s) => s.clone(),
other => other.to_string(),
};
let formatted_result =
format!("**{}**: {}", tool_call.tool_name, result_content.trim());
log_debug!("Tool {} executed successfully", tool_call.tool_name);
tool_results.push(formatted_result.clone());
let tool_message = crate::session::Message {
role: "tool".to_string(),
content: result_content,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
cached: false,
tool_call_id: Some(tool_call.tool_id.clone()),
name: Some(tool_call.tool_name.clone()),
..Default::default()
};
chat_session.session.messages.push(tool_message);
}
Ok(Err(e)) => {
let has_hit_threshold = self.error_tracker.record_error(&tool_call.tool_name);
let error_msg = format!("Error executing {}: {}", tool_call.tool_name, e);
log_debug!("{}", error_msg);
if has_hit_threshold {
println!(
"{}",
"Too many consecutive tool errors. Stopping tool execution.".red()
);
return Ok((tool_results, false));
}
tool_results.push(error_msg.clone());
let tool_message = crate::session::Message {
role: "tool".to_string(),
content: error_msg,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
cached: false,
tool_call_id: Some(tool_call.tool_id.clone()),
name: Some(tool_call.tool_name.clone()),
..Default::default()
};
chat_session.session.messages.push(tool_message);
}
Err(e) => {
let has_hit_threshold = self.error_tracker.record_error(&tool_call.tool_name);
let error_msg = format!("Task error for {}: {}", tool_call.tool_name, e);
log_debug!("{}", error_msg);
if has_hit_threshold {
println!(
"{}",
"Too many consecutive tool errors. Stopping tool execution.".red()
);
return Ok((tool_results, false));
}
tool_results.push(error_msg);
}
}
}
let should_continue = !tool_results.is_empty();
Ok((tool_results, should_continue))
}
}
impl Default for ToolProcessor {
fn default() -> Self {
Self::new()
}
}