toast_api/
agent.rs

1//! Agent implementation for toast
2//! 
3//! This module provides an agent that can use tools to accomplish tasks,
4//! similar to the DGM paper's approach but adapted for toast's architecture.
5
6use crate::tools::{ToolRegistry, parse_tool_calls, format_tool_output};
7use anyhow::Result;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11
12/// Agent configuration
13pub struct AgentConfig {
14    pub max_iterations: usize,
15    pub system_prompt: String,
16}
17
18impl Default for AgentConfig {
19    fn default() -> Self {
20        Self {
21            max_iterations: 40,
22            system_prompt: crate::config::SYSTEM_PROMPT.to_string(),
23        }
24    }
25}
26
27/// Agent that can use tools to accomplish tasks
28pub struct Agent {
29    config: AgentConfig,
30    tools: ToolRegistry,
31    iteration_count: usize,
32    interrupt_flag: Arc<AtomicBool>,
33}
34
35impl Agent {
36    pub fn new(config: AgentConfig) -> Self {
37        Self {
38            config,
39            tools: ToolRegistry::new(),
40            iteration_count: 0,
41            interrupt_flag: Arc::new(AtomicBool::new(false)),
42        }
43    }
44
45    /// Get tool descriptions for the system prompt
46    pub fn get_tool_descriptions(&self) -> String {
47        let mut descriptions = String::from("\n\nYou have access to the following tools:\n\n");
48        
49        for tool in self.tools.all_tools() {
50            descriptions.push_str(&format!("**{}**\n", tool.name));
51            descriptions.push_str(&format!("{}\n", tool.description));
52            descriptions.push_str(&format!("Schema: {}\n\n", serde_json::to_string_pretty(&tool.input_schema).unwrap()));
53        }
54
55        descriptions.push_str(r#"To use a tool, wrap your tool call in <tool_use> tags:
56<tool_use>
57{"tool": "bash", "params": {"command": "ls -la"}}
58</tool_use>
59
60Always think step by step and use tools to explore, understand, and solve problems."#);
61
62
63        descriptions
64    }
65
66    /// Process a response and execute any tool calls
67    pub async fn process_tool_calls(&mut self, response: &str) -> Result<Vec<(String, String)>> {
68        let tool_calls = parse_tool_calls(response);
69        let mut results = Vec::new();
70
71        // Check if interrupted before starting
72        if self.interrupt_flag.load(Ordering::Relaxed) {
73            println!("\n⏹️  Agent interrupted by user");
74            self.interrupt_flag.store(false, Ordering::Relaxed);
75            return Ok(results);
76        }
77
78        if tool_calls.is_empty() {
79            return Ok(results);
80        }
81
82        self.iteration_count += 1;
83        if self.iteration_count >= self.config.max_iterations {
84            println!("⚠️  Maximum iterations ({}) reached. Stopping tool execution.", self.config.max_iterations);
85            return Ok(results);
86        }
87
88        for tool_call in tool_calls {
89            // Check for interrupt before each tool execution
90            if self.interrupt_flag.load(Ordering::Relaxed) {
91                println!("\n⏹️  Agent interrupted by user");
92                self.interrupt_flag.store(false, Ordering::Relaxed);
93                break;
94            }
95
96            let colored_tool = match tool_call.tool.as_str() {
97                "read_file" => format!("\x1b[34m{}\x1b[0m", tool_call.tool),
98                "bash" | "exec" => format!("\x1b[38;5;208m{}\x1b[0m", tool_call.tool),
99                _ => tool_call.tool.clone(),
100            };
101            println!("Executing {}: {}", 
102                     colored_tool, 
103                     serde_json::to_string(&tool_call.params).unwrap_or_default());
104
105            match self.tools.execute(&tool_call.tool, tool_call.params).await {
106                Ok(output) => {
107                    let formatted = format_tool_output(&tool_call.tool, &output);
108                    println!("{formatted}");
109                    results.push((tool_call.tool, output));
110                }
111                Err(e) => {
112                    let error_msg = format!("Error executing {}: {}", tool_call.tool, e);
113                    println!("❌ {error_msg}");
114                    results.push((tool_call.tool, error_msg));
115                }
116            }
117        }
118
119        Ok(results)
120    }
121
122    /// Reset iteration count for a new task
123    pub fn reset(&mut self) {
124        self.iteration_count = 0;
125    }
126
127    /// Get the interrupt flag for external signaling
128    pub fn interrupt_flag(&self) -> Arc<AtomicBool> {
129        Arc::clone(&self.interrupt_flag)
130    }
131
132    /// Get enhanced system prompt with tool descriptions
133    pub fn get_system_prompt(&self) -> String {
134        format!("{}{}", self.config.system_prompt, self.get_tool_descriptions())
135    }
136}
137
138/// Agent session that maintains context across multiple interactions
139pub struct AgentSession {
140    agent: Agent,
141    context: HashMap<String, String>,
142}
143
144impl AgentSession {
145    pub fn new(config: AgentConfig) -> Self {
146        Self {
147            agent: Agent::new(config),
148            context: HashMap::new(),
149        }
150    }
151
152    /// Add context information
153    pub fn add_context(&mut self, key: String, value: String) {
154        self.context.insert(key, value);
155    }
156
157    /// Get the agent
158    pub fn agent(&mut self) -> &mut Agent {
159        &mut self.agent
160    }
161
162    /// Get context
163    pub fn context(&self) -> &HashMap<String, String> {
164        &self.context
165    }
166
167    /// Get interrupt flag
168    pub fn interrupt_flag(&self) -> Arc<AtomicBool> {
169        self.agent.interrupt_flag()
170    }
171
172    /// Reset for a new task
173    pub fn reset(&mut self) {
174        self.agent.reset();
175        self.context.clear();
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_tool_parsing() {
185        // Test with realistic formatting (no leading whitespace on command lines)
186        let response = r#"Let me check the files:
187<tool_use>
188{"tool": "bash", "params": {"command": "ls -la"}}
189</tool_use>
190
191And also:
192# exec pwd
193# read_file test.txt"#;
194
195        let tool_calls = parse_tool_calls(response);
196        
197        // Should find 3 tool calls
198        assert_eq!(tool_calls.len(), 3);
199        
200        // Check that we have the right tools
201        let bash_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "bash").collect();
202        let editor_calls: Vec<_> = tool_calls.iter().filter(|tc| tc.tool == "editor").collect();
203        
204        assert_eq!(bash_calls.len(), 2);
205        assert_eq!(editor_calls.len(), 1);
206        
207        // Check the specific commands
208        assert!(bash_calls.iter().any(|tc| 
209            tc.params.get("command").and_then(|v| v.as_str()) == Some("ls -la")));
210        assert!(bash_calls.iter().any(|tc| 
211            tc.params.get("command").and_then(|v| v.as_str()) == Some("pwd")));
212    }
213}