use crate::tools::{ToolRegistry, parse_tool_calls, format_tool_output};
use anyhow::Result;
use std::collections::HashMap;
pub struct AgentConfig {
pub max_iterations: usize,
pub system_prompt: String,
}
impl Default for AgentConfig {
fn default() -> Self {
Self {
max_iterations: 40,
system_prompt: crate::config::SYSTEM_PROMPT.to_string(),
}
}
}
pub struct Agent {
config: AgentConfig,
tools: ToolRegistry,
iteration_count: usize,
}
impl Agent {
pub fn new(config: AgentConfig) -> Self {
Self {
config,
tools: ToolRegistry::new(),
iteration_count: 0,
}
}
pub fn get_tool_descriptions(&self) -> String {
let mut descriptions = String::from("\n\nYou have access to the following tools:\n\n");
for tool in self.tools.all_tools() {
descriptions.push_str(&format!("**{}**\n", tool.name));
descriptions.push_str(&format!("{}\n", tool.description));
descriptions.push_str(&format!("Schema: {}\n\n", serde_json::to_string_pretty(&tool.input_schema).unwrap()));
}
descriptions.push_str(r#"To use a tool, wrap your tool call in <tool_use> tags:
<tool_use>
{"tool": "bash", "params": {"command": "ls -la"}}
</tool_use>
Always think step by step and use tools to explore, understand, and solve problems."#);
descriptions
}
pub async fn process_tool_calls(&mut self, response: &str) -> Result<Vec<(String, String)>> {
let tool_calls = parse_tool_calls(response);
let mut results = Vec::new();
if tool_calls.is_empty() {
return Ok(results);
}
self.iteration_count += 1;
if self.iteration_count >= self.config.max_iterations {
println!("⚠️ Maximum iterations ({}) reached. Stopping tool execution.", self.config.max_iterations);
return Ok(results);
}
for tool_call in tool_calls {
let colored_tool = match tool_call.tool.as_str() {
"read_file" => format!("\x1b[34m{}\x1b[0m", tool_call.tool),
"bash" | "exec" => format!("\x1b[38;5;208m{}\x1b[0m", tool_call.tool),
_ => tool_call.tool.clone(),
};
println!("Executing {}: {}",
colored_tool,
serde_json::to_string(&tool_call.params).unwrap_or_default());
match self.tools.execute(&tool_call.tool, tool_call.params).await {
Ok(output) => {
let formatted = format_tool_output(&tool_call.tool, &output);
println!("{formatted}");
results.push((tool_call.tool, output));
}
Err(e) => {
let error_msg = format!("Error executing {}: {}", tool_call.tool, e);
println!("❌ {error_msg}");
results.push((tool_call.tool, error_msg));
}
}
}
Ok(results)
}
pub fn reset(&mut self) {
self.iteration_count = 0;
}
pub fn get_system_prompt(&self) -> String {
format!("{}{}", self.config.system_prompt, self.get_tool_descriptions())
}
}
pub struct AgentSession {
agent: Agent,
context: HashMap<String, String>,
}
impl AgentSession {
pub fn new(config: AgentConfig) -> Self {
Self {
agent: Agent::new(config),
context: HashMap::new(),
}
}
pub fn add_context(&mut self, key: String, value: String) {
self.context.insert(key, value);
}
pub fn agent(&mut self) -> &mut Agent {
&mut self.agent
}
pub fn context(&self) -> &HashMap<String, String> {
&self.context
}
pub fn reset(&mut self) {
self.agent.reset();
self.context.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_parsing() {
let response = r#"Let me check the files:
<tool_use>
{"tool": "bash", "params": {"command": "ls -la"}}
</tool_use>
And also:
# exec pwd
# read_file test.txt"#;
let tool_calls = parse_tool_calls(response);
assert_eq!(tool_calls.len(), 3);
let bash_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "bash").collect();
let editor_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "editor").collect();
assert_eq!(bash_calls.len(), 2);
assert_eq!(editor_calls.len(), 1);
assert!(bash_calls.iter().any(|tc|
tc.params.get("command").and_then(|v| v.as_str()) == Some("ls -la")));
assert!(bash_calls.iter().any(|tc|
tc.params.get("command").and_then(|v| v.as_str()) == Some("pwd")));
}
}