autoagents_core/agent/prebuilt/
react.rs

1use crate::agent::base::AgentConfig;
2use crate::agent::executor::{AgentExecutor, ExecutorConfig, TurnResult};
3use crate::agent::runnable::AgentState;
4use crate::memory::MemoryProvider;
5use crate::protocol::Event;
6use crate::runtime::Task;
7use crate::tool::ToolCallResult;
8use async_trait::async_trait;
9use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, Tool};
10use autoagents_llm::{LLMProvider, ToolCall, ToolT};
11use log::{debug, error};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::sync::Arc;
15use thiserror::Error;
16use tokio::sync::mpsc::error::SendError;
17use tokio::sync::{mpsc, RwLock};
18
19/// Output of the ReAct-style agent
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ReActAgentOutput {
22    pub response: String,
23    pub tool_calls: Vec<ToolCallResult>,
24}
25
26impl From<ReActAgentOutput> for Value {
27    fn from(output: ReActAgentOutput) -> Self {
28        serde_json::to_value(output).unwrap_or(Value::Null)
29    }
30}
31
32impl ReActAgentOutput {
33    /// Extract the agent output from the ReAct response
34    /// This parses the response string as JSON and deserializes it to the target type
35    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
36    where
37        T: for<'de> serde::Deserialize<'de>,
38    {
39        let react_output: Self = serde_json::from_value(val)
40            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
41        serde_json::from_str(&react_output.response)
42            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
43    }
44}
45
46#[derive(Error, Debug)]
47pub enum ReActExecutorError {
48    #[error("LLM error: {0}")]
49    LLMError(String),
50
51    #[error("Tool execution error: {0}")]
52    ToolError(String),
53
54    #[error("Maximum turns exceeded: {max_turns}")]
55    MaxTurnsExceeded { max_turns: usize },
56
57    #[error("JSON parsing error: {0}")]
58    JsonError(#[from] serde_json::Error),
59
60    #[error("Other error: {0}")]
61    Other(String),
62
63    #[error("Event error: {0}")]
64    EventError(#[from] SendError<Event>),
65
66    #[error("Extracting Agent Output Error: {0}")]
67    AgentOutputError(String),
68}
69
70#[async_trait]
71pub trait ReActExecutor: Send + Sync + 'static {
72    async fn process_tool_calls(
73        &self,
74        tools: &[Box<dyn ToolT>],
75        tool_calls: Vec<autoagents_llm::ToolCall>,
76        tx_event: mpsc::Sender<Event>,
77        _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
78    ) -> Vec<ToolCallResult> {
79        let mut results = Vec::new();
80
81        for call in &tool_calls {
82            let tool_name = call.function.name.clone();
83            let tool_args = call.function.arguments.clone();
84
85            let result = match tools.iter().find(|t| t.name() == tool_name) {
86                Some(tool) => {
87                    let _ = tx_event
88                        .send(Event::ToolCallRequested {
89                            id: call.id.clone(),
90                            tool_name: tool_name.clone(),
91                            arguments: tool_args.clone(),
92                        })
93                        .await;
94
95                    match serde_json::from_str::<Value>(&tool_args) {
96                        Ok(parsed_args) => match tool.run(parsed_args) {
97                            Ok(output) => ToolCallResult {
98                                tool_name: tool_name.clone(),
99                                success: true,
100                                arguments: serde_json::from_str(&tool_args).unwrap_or(Value::Null),
101                                result: output,
102                            },
103                            Err(e) => ToolCallResult {
104                                tool_name: tool_name.clone(),
105                                success: false,
106                                arguments: serde_json::from_str(&tool_args).unwrap_or(Value::Null),
107                                result: serde_json::json!({"error": e.to_string()}),
108                            },
109                        },
110                        Err(e) => ToolCallResult {
111                            tool_name: tool_name.clone(),
112                            success: false,
113                            arguments: Value::Null,
114                            result: serde_json::json!({"error": format!("Failed to parse arguments: {}", e)}),
115                        },
116                    }
117                }
118                None => ToolCallResult {
119                    tool_name: tool_name.clone(),
120                    success: false,
121                    arguments: serde_json::from_str(&tool_args).unwrap_or(Value::Null),
122                    result: serde_json::json!({"error": format!("Tool '{}' not found", tool_name)}),
123                },
124            };
125
126            if result.success {
127                let _ = tx_event
128                    .send(Event::ToolCallCompleted {
129                        id: call.id.clone(),
130                        tool_name: tool_name.clone(),
131                        result: result.result.clone(),
132                    })
133                    .await;
134            } else {
135                let _ = tx_event
136                    .send(Event::ToolCallFailed {
137                        id: call.id.clone(),
138                        tool_name: tool_name.clone(),
139                        error: result.result.to_string(),
140                    })
141                    .await;
142            }
143
144            results.push(result);
145        }
146
147        results
148    }
149
150    #[allow(clippy::too_many_arguments)]
151    async fn process_turn(
152        &self,
153        llm: Arc<dyn LLMProvider>,
154        messages: &[ChatMessage],
155        memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
156        tools: &[Box<dyn ToolT>],
157        agent_config: &AgentConfig,
158        state: Arc<RwLock<AgentState>>,
159        tx_event: mpsc::Sender<Event>,
160    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
161        let response = if !tools.is_empty() {
162            let tools_serialized: Vec<Tool> = tools.iter().map(Tool::from).collect();
163            llm.chat_with_tools(
164                messages,
165                Some(&tools_serialized),
166                agent_config.output_schema.clone(),
167            )
168            .await
169            .map_err(|e| ReActExecutorError::LLMError(e.to_string()))?
170        } else {
171            llm.chat(messages, agent_config.output_schema.clone())
172                .await
173                .map_err(|e| ReActExecutorError::LLMError(e.to_string()))?
174        };
175
176        let response_text = response.text().unwrap_or_default();
177        if let Some(tool_calls) = response.tool_calls() {
178            let tool_results = self
179                .process_tool_calls(tools, tool_calls.clone(), tx_event.clone(), memory.clone())
180                .await;
181
182            // Store tool calls and results in memory
183            if let Some(mem) = &memory {
184                let mut mem = mem.write().await;
185
186                // Record that assistant is calling tools
187                let _ = mem
188                    .remember(&ChatMessage {
189                        role: ChatRole::Assistant,
190                        message_type: MessageType::ToolUse(tool_calls.clone()),
191                        content: response_text.clone(),
192                    })
193                    .await;
194
195                // Create ToolCall objects with the results for ToolResult message type
196                let mut result_tool_calls = Vec::new();
197                for (tool_call, result) in tool_calls.iter().zip(&tool_results) {
198                    let result_content = if result.success {
199                        match &result.result {
200                            serde_json::Value::String(s) => s.clone(),
201                            other => serde_json::to_string(other).unwrap_or_default(),
202                        }
203                    } else {
204                        serde_json::json!({"error": format!("{:?}", result.result)}).to_string()
205                    };
206
207                    // Create a new ToolCall with the result in the arguments field
208                    result_tool_calls.push(ToolCall {
209                        id: tool_call.id.clone(),
210                        call_type: tool_call.call_type.clone(),
211                        function: autoagents_llm::FunctionCall {
212                            name: tool_call.function.name.clone(),
213                            arguments: result_content,
214                        },
215                    });
216                }
217
218                // Store tool results using ToolResult message type with Tool role
219                let _ = mem
220                    .remember(&ChatMessage {
221                        role: ChatRole::Tool,
222                        message_type: MessageType::ToolResult(result_tool_calls),
223                        content: String::new(),
224                    })
225                    .await;
226            }
227
228            {
229                let mut guard = state.write().await;
230                for result in &tool_results {
231                    guard.record_tool_call(result.clone());
232                }
233            }
234
235            // Continue to let the LLM generate a response based on tool results
236            Ok(TurnResult::Continue(Some(ReActAgentOutput {
237                response: response_text,
238                tool_calls: tool_results,
239            })))
240        } else {
241            // Record the final response in memory
242            if !response_text.is_empty() {
243                if let Some(mem) = &memory {
244                    let mut mem = mem.write().await;
245                    let _ = mem
246                        .remember(&ChatMessage {
247                            role: ChatRole::Assistant,
248                            message_type: MessageType::Text,
249                            content: response_text.clone(),
250                        })
251                        .await;
252                }
253            }
254
255            Ok(TurnResult::Complete(ReActAgentOutput {
256                response: response_text,
257                tool_calls: vec![],
258            }))
259        }
260    }
261}
262
263#[async_trait]
264impl<T: ReActExecutor> AgentExecutor for T {
265    type Output = ReActAgentOutput;
266    type Error = ReActExecutorError;
267
268    fn config(&self) -> ExecutorConfig {
269        ExecutorConfig { max_turns: 10 }
270    }
271
272    async fn execute(
273        &self,
274        llm: Arc<dyn LLMProvider>,
275        mut memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
276        tools: Vec<Box<dyn ToolT>>,
277        agent_config: &AgentConfig,
278        task: Task,
279        state: Arc<RwLock<AgentState>>,
280        tx_event: mpsc::Sender<Event>,
281    ) -> Result<Self::Output, Self::Error> {
282        debug!("Starting ReAct Executor");
283        let max_turns = self.config().max_turns;
284        let mut accumulated_tool_calls = Vec::new();
285        let mut final_response = String::new();
286
287        if let Some(memory) = &mut memory {
288            let mut mem = memory.write().await;
289            let chat_msg = ChatMessage {
290                role: ChatRole::User,
291                message_type: MessageType::Text,
292                content: task.prompt.clone(),
293            };
294            let _ = mem.remember(&chat_msg).await;
295        }
296
297        // Record the task in state
298        {
299            let mut state = state.write().await;
300            state.record_task(task.clone());
301        }
302
303        tx_event
304            .send(Event::TaskStarted {
305                sub_id: task.submission_id,
306                agent_id: agent_config.id,
307                task_description: task.prompt,
308            })
309            .await?;
310
311        for turn in 0..max_turns {
312            //Prepare messages with memory
313            let mut messages = vec![ChatMessage {
314                role: ChatRole::System,
315                message_type: MessageType::Text,
316                content: agent_config.description.clone(),
317            }];
318            if let Some(memory) = &memory {
319                // Fetch All previous messsages and extend
320                messages.extend(
321                    memory
322                        .read()
323                        .await
324                        .recall("", None)
325                        .await
326                        .unwrap_or_default(),
327                );
328            }
329
330            tx_event
331                .send(Event::TurnStarted {
332                    turn_number: turn,
333                    max_turns,
334                })
335                .await?;
336            match self
337                .process_turn(
338                    llm.clone(),
339                    &messages,
340                    memory.clone(),
341                    &tools,
342                    agent_config,
343                    state.clone(),
344                    tx_event.clone(),
345                )
346                .await?
347            {
348                TurnResult::Complete(result) => {
349                    // If we have accumulated tool calls, merge them with the final result
350                    if !accumulated_tool_calls.is_empty() {
351                        tx_event
352                            .send(Event::TurnCompleted {
353                                turn_number: turn,
354                                final_turn: true,
355                            })
356                            .await?;
357                        return Ok(ReActAgentOutput {
358                            response: result.response,
359                            tool_calls: accumulated_tool_calls,
360                        });
361                    }
362                    tx_event
363                        .send(Event::TurnCompleted {
364                            turn_number: turn,
365                            final_turn: true,
366                        })
367                        .await?;
368                    return Ok(result);
369                }
370                TurnResult::Continue(Some(partial_result)) => {
371                    // Accumulate tool calls and continue for final response
372                    accumulated_tool_calls.extend(partial_result.tool_calls);
373                    if !partial_result.response.is_empty() {
374                        final_response = partial_result.response;
375                    }
376                    tx_event
377                        .send(Event::TurnCompleted {
378                            turn_number: turn,
379                            final_turn: false,
380                        })
381                        .await?;
382                    continue;
383                }
384                TurnResult::Continue(None) => {
385                    tx_event
386                        .send(Event::TurnCompleted {
387                            turn_number: turn,
388                            final_turn: false,
389                        })
390                        .await?;
391                    continue;
392                }
393            }
394        }
395
396        // If we've exhausted turns but have results, return what we have
397        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
398            Ok(ReActAgentOutput {
399                response: final_response,
400                tool_calls: accumulated_tool_calls,
401            })
402        } else {
403            Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411    use serde::{Deserialize, Serialize};
412
413    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
414    struct TestAgentOutput {
415        value: i32,
416        message: String,
417    }
418
419    #[test]
420    fn test_extract_agent_output_success() {
421        let agent_output = TestAgentOutput {
422            value: 42,
423            message: "Hello, world!".to_string(),
424        };
425
426        let react_output = ReActAgentOutput {
427            response: serde_json::to_string(&agent_output).unwrap(),
428            tool_calls: vec![],
429        };
430
431        let react_value = serde_json::to_value(react_output).unwrap();
432        let extracted: TestAgentOutput =
433            ReActAgentOutput::extract_agent_output(react_value).unwrap();
434        assert_eq!(extracted, agent_output);
435    }
436}