Skip to main content

wasm_agent/
loop_runner.rs

1// SPDX-License-Identifier: MIT
2//! ReAct loop runner: Thought -> Action -> Observation -> FinalAnswer.
3
4use crate::error::AgentError;
5use crate::history::ConversationHistory;
6use crate::tools::ToolRegistry;
7use crate::types::{AgentConfig, Message, ReActStep, Role};
8
9/// Parses a raw LLM response string into a [`ReActStep`].
10///
11/// Supported formats:
12/// - `Thought: <text>`
13/// - `Action: <tool_name>(<input>)`
14/// - `Action: <tool_name>: <input>`
15/// - `Final Answer: <text>`
16///
17/// # Errors
18/// Returns [`AgentError::ParseError`] if the response does not match any known prefix.
19pub fn parse_react_step(response: &str) -> Result<ReActStep, AgentError> {
20    let trimmed = response.trim();
21    if let Some(rest) = trimmed.strip_prefix("Final Answer:") {
22        return Ok(ReActStep::FinalAnswer(rest.trim().to_string()));
23    }
24    if let Some(rest) = trimmed.strip_prefix("Thought:") {
25        return Ok(ReActStep::Thought(rest.trim().to_string()));
26    }
27    if let Some(rest) = trimmed.strip_prefix("Action:") {
28        let rest = rest.trim();
29        if let Some(paren) = rest.find('(') {
30            let tool = rest[..paren].trim().to_string();
31            let input = rest[paren + 1..].trim_end_matches(')').trim().to_string();
32            return Ok(ReActStep::Action { tool, input });
33        }
34        if let Some(colon) = rest.find(':') {
35            let tool = rest[..colon].trim().to_string();
36            let input = rest[colon + 1..].trim().to_string();
37            return Ok(ReActStep::Action { tool, input });
38        }
39        return Ok(ReActStep::Action { tool: rest.to_string(), input: String::new() });
40    }
41    Err(AgentError::ParseError(format!("Could not parse ReAct step from: {trimmed}")))
42}
43
44/// The ReAct loop execution context.
45///
46/// Holds references to the agent configuration and tool registry.
47/// Drives the Thought -> Action -> Observation -> FinalAnswer cycle.
48pub struct LoopRunner<'a> {
49    config: &'a AgentConfig,
50    registry: &'a ToolRegistry,
51    history: ConversationHistory,
52}
53
54impl<'a> LoopRunner<'a> {
55    /// Creates a new runner with its own conversation history.
56    pub fn new(config: &'a AgentConfig, registry: &'a ToolRegistry) -> Self {
57        let history = ConversationHistory::new(config.context_token_limit);
58        Self { config, registry, history }
59    }
60
61    /// Processes one LLM response, dispatching tool calls when the response is an Action.
62    ///
63    /// In production `llm_response` comes from an LLM API call.
64    /// Accepting it as a parameter keeps the runner pure and fully testable without I/O.
65    ///
66    /// # Errors
67    /// - [`AgentError::ParseError`] — response does not match any ReAct format.
68    /// - [`AgentError::ToolNotFound`] — the requested tool is not in the registry.
69    pub fn step(&mut self, llm_response: &str) -> Result<ReActStep, AgentError> {
70        let step = parse_react_step(llm_response)?;
71        let obs_msg = match &step {
72            ReActStep::Action { tool, input } => {
73                let result = self.registry.dispatch(tool, input)?;
74                Some(Message::new(Role::Tool, result.output))
75            }
76            _ => None,
77        };
78        self.history.push_with_eviction(Message::assistant(llm_response));
79        if let Some(obs) = obs_msg {
80            self.history.push_with_eviction(obs);
81        }
82        Ok(step)
83    }
84
85    /// Runs a pre-scripted sequence of LLM responses, returning the final answer.
86    ///
87    /// Useful for offline simulation and deterministic testing.
88    ///
89    /// # Errors
90    /// - [`AgentError::MaxIterationsExceeded`] — sequence exhausted before a `FinalAnswer`.
91    /// - Any error from [`step`](Self::step).
92    pub fn run_scripted(&mut self, responses: &[&str]) -> Result<String, AgentError> {
93        for (i, response) in responses.iter().enumerate() {
94            if i as u32 >= self.config.max_iterations {
95                return Err(AgentError::MaxIterationsExceeded(self.config.max_iterations));
96            }
97            let step = self.step(response)?;
98            if let ReActStep::FinalAnswer(answer) = step {
99                return Ok(answer);
100            }
101        }
102        Err(AgentError::MaxIterationsExceeded(self.config.max_iterations))
103    }
104
105    /// Returns a reference to the accumulated conversation history.
106    pub fn history(&self) -> &ConversationHistory { &self.history }
107
108    /// Returns the number of messages currently in the history.
109    pub fn iteration_count(&self) -> usize { self.history.len() }
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115    use crate::tools::{ToolRegistry, ToolSpec};
116    use crate::types::ToolResult;
117
118    fn registry_with_echo() -> ToolRegistry {
119        let mut reg = ToolRegistry::new();
120        reg.register(
121            ToolSpec::new("echo", "Echoes input", "{}"),
122            Box::new(|input: &str| ToolResult {
123                tool_name: "echo".into(),
124                output: format!("echoed: {input}"),
125                success: true,
126            }),
127        ).unwrap();
128        reg
129    }
130
131    #[test]
132    fn test_parse_thought() {
133        let step = parse_react_step("Thought: I need to search for information").unwrap();
134        assert!(matches!(step, ReActStep::Thought(_)));
135    }
136
137    #[test]
138    fn test_parse_action_paren_syntax() {
139        let step = parse_react_step("Action: echo(hello world)").unwrap();
140        assert!(matches!(&step, ReActStep::Action { tool, input }
141            if tool == "echo" && input == "hello world"));
142    }
143
144    #[test]
145    fn test_parse_action_colon_syntax() {
146        let step = parse_react_step("Action: echo: hello").unwrap();
147        assert!(matches!(&step, ReActStep::Action { tool, .. } if tool == "echo"));
148    }
149
150    #[test]
151    fn test_parse_final_answer() {
152        let step = parse_react_step("Final Answer: 42").unwrap();
153        assert!(matches!(&step, ReActStep::FinalAnswer(s) if s == "42"));
154    }
155
156    #[test]
157    fn test_parse_unknown_format_returns_parse_error() {
158        let err = parse_react_step("This is just random text without a prefix").unwrap_err();
159        assert!(matches!(err, AgentError::ParseError(_)));
160    }
161
162    #[test]
163    fn test_parse_action_no_args() {
164        let step = parse_react_step("Action: mytool").unwrap();
165        assert!(matches!(&step, ReActStep::Action { tool, .. } if tool == "mytool"));
166    }
167
168    #[test]
169    fn test_loop_runner_scripted_final_answer() {
170        let config = AgentConfig::default();
171        let reg = registry_with_echo();
172        let mut runner = LoopRunner::new(&config, &reg);
173        let answer = runner.run_scripted(&[
174            "Thought: Let me think",
175            "Action: echo(test)",
176            "Final Answer: done",
177        ]).unwrap();
178        assert_eq!(answer, "done");
179    }
180
181    #[test]
182    fn test_loop_runner_max_iterations_returns_error() {
183        let config = AgentConfig { max_iterations: 2, ..Default::default() };
184        let reg = ToolRegistry::new();
185        let mut runner = LoopRunner::new(&config, &reg);
186        let err = runner.run_scripted(&[
187            "Thought: step 1",
188            "Thought: step 2",
189            "Thought: step 3",
190        ]).unwrap_err();
191        assert!(matches!(err, AgentError::MaxIterationsExceeded(2)));
192    }
193
194    #[test]
195    fn test_loop_runner_tool_not_found_returns_error() {
196        let config = AgentConfig::default();
197        let reg = ToolRegistry::new();
198        let mut runner = LoopRunner::new(&config, &reg);
199        let err = runner.step("Action: missing_tool(input)").unwrap_err();
200        assert!(matches!(err, AgentError::ToolNotFound { .. }));
201    }
202
203    #[test]
204    fn test_loop_runner_history_grows_with_steps() {
205        let config = AgentConfig::default();
206        let reg = registry_with_echo();
207        let mut runner = LoopRunner::new(&config, &reg);
208        runner.step("Thought: thinking").unwrap();
209        runner.step("Thought: still thinking").unwrap();
210        assert!(runner.history().len() >= 2);
211    }
212
213    #[test]
214    fn test_loop_runner_action_adds_observation_to_history() {
215        let config = AgentConfig::default();
216        let reg = registry_with_echo();
217        let mut runner = LoopRunner::new(&config, &reg);
218        runner.step("Action: echo(hi)").unwrap();
219        // Two messages: the Action response + the Observation
220        assert!(runner.history().len() >= 2);
221    }
222
223    #[test]
224    fn test_loop_runner_iteration_count_matches_history_len() {
225        let config = AgentConfig::default();
226        let reg = ToolRegistry::new();
227        let mut runner = LoopRunner::new(&config, &reg);
228        runner.step("Thought: one").unwrap();
229        assert_eq!(runner.iteration_count(), runner.history().len());
230    }
231}