autoagents_core/agent/prebuilt/executor/
react.rs1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22pub use tokio::sync::mpsc::error::SendError;
23
24#[cfg(target_arch = "wasm32")]
25type SendError = futures::channel::mpsc::SendError;
26
27use crate::agent::hooks::{AgentHooks, HookOutcome};
28use autoagents_protocol::Event;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReActAgentOutput {
33 pub response: String,
34 pub tool_calls: Vec<ToolCallResult>,
35 pub done: bool,
36}
37
38impl From<ReActAgentOutput> for Value {
39 fn from(output: ReActAgentOutput) -> Self {
40 serde_json::to_value(output).unwrap_or(Value::Null)
41 }
42}
43impl From<ReActAgentOutput> for String {
44 fn from(output: ReActAgentOutput) -> Self {
45 output.response
46 }
47}
48
49impl ReActAgentOutput {
50 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
53 serde_json::from_str::<T>(&self.response)
54 }
55
56 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
60 where
61 T: for<'de> serde::Deserialize<'de>,
62 F: FnOnce(&str) -> T,
63 {
64 self.try_parse::<T>()
65 .unwrap_or_else(|_| fallback(&self.response))
66 }
67}
68
69impl ReActAgentOutput {
70 #[allow(clippy::result_large_err)]
72 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
73 where
74 T: for<'de> serde::Deserialize<'de>,
75 {
76 let react_output: Self = serde_json::from_value(val)
77 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
78 serde_json::from_str(&react_output.response)
79 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
80 }
81}
82
83#[derive(Error, Debug)]
84pub enum ReActExecutorError {
85 #[error("LLM error: {0}")]
86 LLMError(String),
87
88 #[error("Maximum turns exceeded: {max_turns}")]
89 MaxTurnsExceeded { max_turns: usize },
90
91 #[error("Other error: {0}")]
92 Other(String),
93
94 #[cfg(not(target_arch = "wasm32"))]
95 #[error("Event error: {0}")]
96 EventError(#[from] SendError<Event>),
97
98 #[cfg(target_arch = "wasm32")]
99 #[error("Event error: {0}")]
100 EventError(#[from] SendError),
101
102 #[error("Extracting Agent Output Error: {0}")]
103 AgentOutputError(String),
104}
105
106impl From<TurnEngineError> for ReActExecutorError {
107 fn from(error: TurnEngineError) -> Self {
108 match error {
109 TurnEngineError::LLMError(err) => ReActExecutorError::LLMError(err),
110 TurnEngineError::Aborted => {
111 ReActExecutorError::Other("Run aborted by hook".to_string())
112 }
113 TurnEngineError::Other(err) => ReActExecutorError::Other(err),
114 }
115 }
116}
117
118#[derive(Debug)]
123pub struct ReActAgent<T: AgentDeriveT> {
124 inner: Arc<T>,
125}
126
127impl<T: AgentDeriveT> Clone for ReActAgent<T> {
128 fn clone(&self) -> Self {
129 Self {
130 inner: Arc::clone(&self.inner),
131 }
132 }
133}
134
135impl<T: AgentDeriveT> ReActAgent<T> {
136 pub fn new(inner: T) -> Self {
137 Self {
138 inner: Arc::new(inner),
139 }
140 }
141}
142
143impl<T: AgentDeriveT> Deref for ReActAgent<T> {
144 type Target = T;
145
146 fn deref(&self) -> &Self::Target {
147 &self.inner
148 }
149}
150
151#[async_trait]
153impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
154 type Output = <T as AgentDeriveT>::Output;
155
156 fn description(&self) -> &str {
157 self.inner.description()
158 }
159
160 fn output_schema(&self) -> Option<Value> {
161 self.inner.output_schema()
162 }
163
164 fn name(&self) -> &str {
165 self.inner.name()
166 }
167
168 fn tools(&self) -> Vec<Box<dyn ToolT>> {
169 self.inner.tools()
170 }
171}
172
173#[async_trait]
174impl<T> AgentHooks for ReActAgent<T>
175where
176 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
177{
178 async fn on_agent_create(&self) {
179 self.inner.on_agent_create().await
180 }
181
182 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
183 self.inner.on_run_start(task, ctx).await
184 }
185
186 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
187 self.inner.on_run_complete(task, result, ctx).await
188 }
189
190 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
191 self.inner.on_turn_start(turn_index, ctx).await
192 }
193
194 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
195 self.inner.on_turn_complete(turn_index, ctx).await
196 }
197
198 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
199 self.inner.on_tool_call(tool_call, ctx).await
200 }
201
202 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
203 self.inner.on_tool_start(tool_call, ctx).await
204 }
205
206 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
207 self.inner.on_tool_result(tool_call, result, ctx).await
208 }
209
210 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
211 self.inner.on_tool_error(tool_call, err, ctx).await
212 }
213 async fn on_agent_shutdown(&self) {
214 self.inner.on_agent_shutdown().await
215 }
216}
217
218#[async_trait]
220impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
221 type Output = ReActAgentOutput;
222 type Error = ReActExecutorError;
223
224 fn config(&self) -> ExecutorConfig {
225 ExecutorConfig { max_turns: 10 }
226 }
227
228 async fn execute(
229 &self,
230 task: &Task,
231 context: Arc<Context>,
232 ) -> Result<Self::Output, Self::Error> {
233 if self.on_run_start(task, &context).await == HookOutcome::Abort {
234 return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
235 }
236
237 record_task_state(&context, task);
238
239 let tx_event = context.tx().ok();
240 EventHelper::send_task_started(
241 &tx_event,
242 task.submission_id,
243 context.config().id,
244 context.config().name.clone(),
245 task.prompt.clone(),
246 )
247 .await;
248
249 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
250 let mut turn_state = engine.turn_state(&context);
251 let max_turns = self.config().max_turns;
252 let mut accumulated_tool_calls = Vec::new();
253 let mut final_response = String::new();
254
255 for turn_index in 0..max_turns {
256 let result = engine
257 .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
258 .await?;
259
260 match result {
261 crate::agent::executor::TurnResult::Complete(output) => {
262 final_response = output.response.clone();
263 EventHelper::send_task_completed(
264 &tx_event,
265 task.submission_id,
266 context.config().id,
267 context.config().name.clone(),
268 final_response.clone(),
269 )
270 .await;
271
272 accumulated_tool_calls.extend(output.tool_calls);
273
274 return Ok(ReActAgentOutput {
275 response: final_response,
276 done: true,
277 tool_calls: accumulated_tool_calls,
278 });
279 }
280 crate::agent::executor::TurnResult::Continue(Some(output)) => {
281 if !output.response.is_empty() {
282 final_response = output.response;
283 }
284 accumulated_tool_calls.extend(output.tool_calls);
285 }
286 crate::agent::executor::TurnResult::Continue(None) => {}
287 }
288 }
289
290 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
291 EventHelper::send_task_completed(
292 &tx_event,
293 task.submission_id,
294 context.config().id,
295 context.config().name.clone(),
296 final_response.clone(),
297 )
298 .await;
299
300 return Ok(ReActAgentOutput {
301 response: final_response,
302 done: true,
303 tool_calls: accumulated_tool_calls,
304 });
305 }
306
307 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
308 }
309
310 async fn execute_stream(
311 &self,
312 task: &Task,
313 context: Arc<Context>,
314 ) -> Result<
315 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
316 Self::Error,
317 > {
318 if self.on_run_start(task, &context).await == HookOutcome::Abort {
319 return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
320 }
321
322 record_task_state(&context, task);
323
324 let tx_event = context.tx().ok();
325 EventHelper::send_task_started(
326 &tx_event,
327 task.submission_id,
328 context.config().id,
329 context.config().name.clone(),
330 task.prompt.clone(),
331 )
332 .await;
333
334 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
335 let mut turn_state = engine.turn_state(&context);
336 let max_turns = self.config().max_turns;
337 let context_clone = context.clone();
338 let task = task.clone();
339 let executor = self.clone();
340
341 let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
342
343 spawn_future(async move {
344 let mut accumulated_tool_calls = Vec::new();
345 let mut final_response = String::new();
346
347 for turn_index in 0..max_turns {
348 let turn_stream = engine
349 .run_turn_stream(
350 executor.clone(),
351 &task,
352 context_clone.clone(),
353 &mut turn_state,
354 turn_index,
355 max_turns,
356 )
357 .await;
358
359 let mut turn_result = None;
360
361 match turn_stream {
362 Ok(mut stream) => {
363 use futures::StreamExt;
364 while let Some(delta_result) = stream.next().await {
365 match delta_result {
366 Ok(TurnDelta::Text(content)) => {
367 let _ = tx
368 .send(Ok(ReActAgentOutput {
369 response: content,
370 tool_calls: Vec::new(),
371 done: false,
372 }))
373 .await;
374 }
375 Ok(TurnDelta::ToolResults(tool_results)) => {
376 accumulated_tool_calls.extend(tool_results);
377 let _ = tx
378 .send(Ok(ReActAgentOutput {
379 response: String::new(),
380 tool_calls: accumulated_tool_calls.clone(),
381 done: false,
382 }))
383 .await;
384 }
385 Ok(TurnDelta::Done(result)) => {
386 turn_result = Some(result);
387 break;
388 }
389 Err(err) => {
390 let _ = tx.send(Err(err.into())).await;
391 return;
392 }
393 }
394 }
395 }
396 Err(err) => {
397 let _ = tx.send(Err(err.into())).await;
398 return;
399 }
400 }
401
402 let Some(result) = turn_result else {
403 let _ = tx
404 .send(Err(ReActExecutorError::Other(
405 "Stream ended without final result".to_string(),
406 )))
407 .await;
408 return;
409 };
410
411 match result {
412 crate::agent::executor::TurnResult::Complete(output) => {
413 final_response = output.response.clone();
414 accumulated_tool_calls.extend(output.tool_calls);
415 break;
416 }
417 crate::agent::executor::TurnResult::Continue(Some(output)) => {
418 if !output.response.is_empty() {
419 final_response = output.response;
420 }
421 accumulated_tool_calls.extend(output.tool_calls);
422 }
423 crate::agent::executor::TurnResult::Continue(None) => {}
424 }
425 }
426
427 let tx_event = context_clone.tx().ok();
428 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
429 let _ = tx
430 .send(Ok(ReActAgentOutput {
431 response: final_response.clone(),
432 done: true,
433 tool_calls: accumulated_tool_calls.clone(),
434 }))
435 .await;
436
437 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
438 EventHelper::send_task_completed(
439 &tx_event,
440 task.submission_id,
441 context_clone.config().id,
442 context_clone.config().name.clone(),
443 final_response,
444 )
445 .await;
446 }
447 });
448
449 Ok(receiver_into_stream(rx))
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
458 struct TestAgentOutput {
459 value: i32,
460 message: String,
461 }
462
463 #[test]
464 fn test_extract_agent_output_success() {
465 let agent_output = TestAgentOutput {
466 value: 42,
467 message: "Hello, world!".to_string(),
468 };
469
470 let react_output = ReActAgentOutput {
471 response: serde_json::to_string(&agent_output).unwrap(),
472 done: true,
473 tool_calls: vec![],
474 };
475
476 let react_value = serde_json::to_value(react_output).unwrap();
477 let extracted: TestAgentOutput =
478 ReActAgentOutput::extract_agent_output(react_value).unwrap();
479 assert_eq!(extracted, agent_output);
480 }
481}