autoagents_core/agent/prebuilt/
react.rs1use crate::agent::base::AgentConfig;
2use crate::agent::executor::{AgentExecutor, ExecutorConfig, TurnResult};
3use crate::agent::runnable::AgentState;
4use crate::memory::MemoryProvider;
5use crate::protocol::Event;
6use crate::runtime::Task;
7use crate::tool::ToolCallResult;
8use async_trait::async_trait;
9use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, Tool};
10use autoagents_llm::{LLMProvider, ToolCall, ToolT};
11use log::{debug, error};
12use serde::{Deserialize, Serialize};
13use serde_json::Value;
14use std::sync::Arc;
15use thiserror::Error;
16use tokio::sync::mpsc::error::SendError;
17use tokio::sync::{mpsc, RwLock};
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct ReActAgentOutput {
22 pub response: String,
23 pub tool_calls: Vec<ToolCallResult>,
24}
25
26impl From<ReActAgentOutput> for Value {
27 fn from(output: ReActAgentOutput) -> Self {
28 serde_json::to_value(output).unwrap_or(Value::Null)
29 }
30}
31
32impl ReActAgentOutput {
33 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
36 where
37 T: for<'de> serde::Deserialize<'de>,
38 {
39 let react_output: Self = serde_json::from_value(val)
40 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
41 serde_json::from_str(&react_output.response)
42 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
43 }
44}
45
46#[derive(Error, Debug)]
47pub enum ReActExecutorError {
48 #[error("LLM error: {0}")]
49 LLMError(String),
50
51 #[error("Tool execution error: {0}")]
52 ToolError(String),
53
54 #[error("Maximum turns exceeded: {max_turns}")]
55 MaxTurnsExceeded { max_turns: usize },
56
57 #[error("JSON parsing error: {0}")]
58 JsonError(#[from] serde_json::Error),
59
60 #[error("Other error: {0}")]
61 Other(String),
62
63 #[error("Event error: {0}")]
64 EventError(#[from] SendError<Event>),
65
66 #[error("Extracting Agent Output Error: {0}")]
67 AgentOutputError(String),
68}
69
70#[async_trait]
71pub trait ReActExecutor: Send + Sync + 'static {
72 async fn process_tool_calls(
73 &self,
74 tools: &[Box<dyn ToolT>],
75 tool_calls: Vec<autoagents_llm::ToolCall>,
76 tx_event: mpsc::Sender<Event>,
77 _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
78 ) -> Vec<ToolCallResult> {
79 let mut results = Vec::new();
80
81 for call in &tool_calls {
82 let tool_name = call.function.name.clone();
83 let tool_args = call.function.arguments.clone();
84
85 let result = match tools.iter().find(|t| t.name() == tool_name) {
86 Some(tool) => {
87 let _ = tx_event
88 .send(Event::ToolCallRequested {
89 id: call.id.clone(),
90 tool_name: tool_name.clone(),
91 arguments: tool_args.clone(),
92 })
93 .await;
94
95 match serde_json::from_str::<Value>(&tool_args) {
96 Ok(parsed_args) => match tool.run(parsed_args) {
97 Ok(output) => ToolCallResult {
98 tool_name: tool_name.clone(),
99 success: true,
100 arguments: serde_json::from_str(&tool_args).unwrap_or(Value::Null),
101 result: output,
102 },
103 Err(e) => ToolCallResult {
104 tool_name: tool_name.clone(),
105 success: false,
106 arguments: serde_json::from_str(&tool_args).unwrap_or(Value::Null),
107 result: serde_json::json!({"error": e.to_string()}),
108 },
109 },
110 Err(e) => ToolCallResult {
111 tool_name: tool_name.clone(),
112 success: false,
113 arguments: Value::Null,
114 result: serde_json::json!({"error": format!("Failed to parse arguments: {}", e)}),
115 },
116 }
117 }
118 None => ToolCallResult {
119 tool_name: tool_name.clone(),
120 success: false,
121 arguments: serde_json::from_str(&tool_args).unwrap_or(Value::Null),
122 result: serde_json::json!({"error": format!("Tool '{}' not found", tool_name)}),
123 },
124 };
125
126 if result.success {
127 let _ = tx_event
128 .send(Event::ToolCallCompleted {
129 id: call.id.clone(),
130 tool_name: tool_name.clone(),
131 result: result.result.clone(),
132 })
133 .await;
134 } else {
135 let _ = tx_event
136 .send(Event::ToolCallFailed {
137 id: call.id.clone(),
138 tool_name: tool_name.clone(),
139 error: result.result.to_string(),
140 })
141 .await;
142 }
143
144 results.push(result);
145 }
146
147 results
148 }
149
150 #[allow(clippy::too_many_arguments)]
151 async fn process_turn(
152 &self,
153 llm: Arc<dyn LLMProvider>,
154 messages: &[ChatMessage],
155 memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
156 tools: &[Box<dyn ToolT>],
157 agent_config: &AgentConfig,
158 state: Arc<RwLock<AgentState>>,
159 tx_event: mpsc::Sender<Event>,
160 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
161 let response = if !tools.is_empty() {
162 let tools_serialized: Vec<Tool> = tools.iter().map(Tool::from).collect();
163 llm.chat_with_tools(
164 messages,
165 Some(&tools_serialized),
166 agent_config.output_schema.clone(),
167 )
168 .await
169 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))?
170 } else {
171 llm.chat(messages, agent_config.output_schema.clone())
172 .await
173 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))?
174 };
175
176 let response_text = response.text().unwrap_or_default();
177 if let Some(tool_calls) = response.tool_calls() {
178 let tool_results = self
179 .process_tool_calls(tools, tool_calls.clone(), tx_event.clone(), memory.clone())
180 .await;
181
182 if let Some(mem) = &memory {
184 let mut mem = mem.write().await;
185
186 let _ = mem
188 .remember(&ChatMessage {
189 role: ChatRole::Assistant,
190 message_type: MessageType::ToolUse(tool_calls.clone()),
191 content: response_text.clone(),
192 })
193 .await;
194
195 let mut result_tool_calls = Vec::new();
197 for (tool_call, result) in tool_calls.iter().zip(&tool_results) {
198 let result_content = if result.success {
199 match &result.result {
200 serde_json::Value::String(s) => s.clone(),
201 other => serde_json::to_string(other).unwrap_or_default(),
202 }
203 } else {
204 serde_json::json!({"error": format!("{:?}", result.result)}).to_string()
205 };
206
207 result_tool_calls.push(ToolCall {
209 id: tool_call.id.clone(),
210 call_type: tool_call.call_type.clone(),
211 function: autoagents_llm::FunctionCall {
212 name: tool_call.function.name.clone(),
213 arguments: result_content,
214 },
215 });
216 }
217
218 let _ = mem
220 .remember(&ChatMessage {
221 role: ChatRole::Tool,
222 message_type: MessageType::ToolResult(result_tool_calls),
223 content: String::new(),
224 })
225 .await;
226 }
227
228 {
229 let mut guard = state.write().await;
230 for result in &tool_results {
231 guard.record_tool_call(result.clone());
232 }
233 }
234
235 Ok(TurnResult::Continue(Some(ReActAgentOutput {
237 response: response_text,
238 tool_calls: tool_results,
239 })))
240 } else {
241 if !response_text.is_empty() {
243 if let Some(mem) = &memory {
244 let mut mem = mem.write().await;
245 let _ = mem
246 .remember(&ChatMessage {
247 role: ChatRole::Assistant,
248 message_type: MessageType::Text,
249 content: response_text.clone(),
250 })
251 .await;
252 }
253 }
254
255 Ok(TurnResult::Complete(ReActAgentOutput {
256 response: response_text,
257 tool_calls: vec![],
258 }))
259 }
260 }
261}
262
263#[async_trait]
264impl<T: ReActExecutor> AgentExecutor for T {
265 type Output = ReActAgentOutput;
266 type Error = ReActExecutorError;
267
268 fn config(&self) -> ExecutorConfig {
269 ExecutorConfig { max_turns: 10 }
270 }
271
272 async fn execute(
273 &self,
274 llm: Arc<dyn LLMProvider>,
275 mut memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
276 tools: Vec<Box<dyn ToolT>>,
277 agent_config: &AgentConfig,
278 task: Task,
279 state: Arc<RwLock<AgentState>>,
280 tx_event: mpsc::Sender<Event>,
281 ) -> Result<Self::Output, Self::Error> {
282 debug!("Starting ReAct Executor");
283 let max_turns = self.config().max_turns;
284 let mut accumulated_tool_calls = Vec::new();
285 let mut final_response = String::new();
286
287 if let Some(memory) = &mut memory {
288 let mut mem = memory.write().await;
289 let chat_msg = ChatMessage {
290 role: ChatRole::User,
291 message_type: MessageType::Text,
292 content: task.prompt.clone(),
293 };
294 let _ = mem.remember(&chat_msg).await;
295 }
296
297 {
299 let mut state = state.write().await;
300 state.record_task(task.clone());
301 }
302
303 tx_event
304 .send(Event::TaskStarted {
305 sub_id: task.submission_id,
306 agent_id: agent_config.id,
307 task_description: task.prompt,
308 })
309 .await?;
310
311 for turn in 0..max_turns {
312 let mut messages = vec![ChatMessage {
314 role: ChatRole::System,
315 message_type: MessageType::Text,
316 content: agent_config.description.clone(),
317 }];
318 if let Some(memory) = &memory {
319 messages.extend(
321 memory
322 .read()
323 .await
324 .recall("", None)
325 .await
326 .unwrap_or_default(),
327 );
328 }
329
330 tx_event
331 .send(Event::TurnStarted {
332 turn_number: turn,
333 max_turns,
334 })
335 .await?;
336 match self
337 .process_turn(
338 llm.clone(),
339 &messages,
340 memory.clone(),
341 &tools,
342 agent_config,
343 state.clone(),
344 tx_event.clone(),
345 )
346 .await?
347 {
348 TurnResult::Complete(result) => {
349 if !accumulated_tool_calls.is_empty() {
351 tx_event
352 .send(Event::TurnCompleted {
353 turn_number: turn,
354 final_turn: true,
355 })
356 .await?;
357 return Ok(ReActAgentOutput {
358 response: result.response,
359 tool_calls: accumulated_tool_calls,
360 });
361 }
362 tx_event
363 .send(Event::TurnCompleted {
364 turn_number: turn,
365 final_turn: true,
366 })
367 .await?;
368 return Ok(result);
369 }
370 TurnResult::Continue(Some(partial_result)) => {
371 accumulated_tool_calls.extend(partial_result.tool_calls);
373 if !partial_result.response.is_empty() {
374 final_response = partial_result.response;
375 }
376 tx_event
377 .send(Event::TurnCompleted {
378 turn_number: turn,
379 final_turn: false,
380 })
381 .await?;
382 continue;
383 }
384 TurnResult::Continue(None) => {
385 tx_event
386 .send(Event::TurnCompleted {
387 turn_number: turn,
388 final_turn: false,
389 })
390 .await?;
391 continue;
392 }
393 }
394 }
395
396 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
398 Ok(ReActAgentOutput {
399 response: final_response,
400 tool_calls: accumulated_tool_calls,
401 })
402 } else {
403 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
404 }
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411 use serde::{Deserialize, Serialize};
412
413 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
414 struct TestAgentOutput {
415 value: i32,
416 message: String,
417 }
418
419 #[test]
420 fn test_extract_agent_output_success() {
421 let agent_output = TestAgentOutput {
422 value: 42,
423 message: "Hello, world!".to_string(),
424 };
425
426 let react_output = ReActAgentOutput {
427 response: serde_json::to_string(&agent_output).unwrap(),
428 tool_calls: vec![],
429 };
430
431 let react_value = serde_json::to_value(react_output).unwrap();
432 let extracted: TestAgentOutput =
433 ReActAgentOutput::extract_agent_output(react_value).unwrap();
434 assert_eq!(extracted, agent_output);
435 }
436}