Skip to main content

autoagents_core/agent/prebuilt/executor/
react.rs

1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4    TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22pub use tokio::sync::mpsc::error::SendError;
23
24#[cfg(target_arch = "wasm32")]
25type SendError = futures::channel::mpsc::SendError;
26
27use crate::agent::hooks::{AgentHooks, HookOutcome};
28use autoagents_protocol::Event;
29
30/// Output of the ReAct-style agent
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReActAgentOutput {
33    pub response: String,
34    pub tool_calls: Vec<ToolCallResult>,
35    pub done: bool,
36}
37
38impl From<ReActAgentOutput> for Value {
39    fn from(output: ReActAgentOutput) -> Self {
40        serde_json::to_value(output).unwrap_or(Value::Null)
41    }
42}
43impl From<ReActAgentOutput> for String {
44    fn from(output: ReActAgentOutput) -> Self {
45        output.response
46    }
47}
48
49impl ReActAgentOutput {
50    /// Try to parse the response string as structured JSON of type `T`.
51    /// Returns `serde_json::Error` if parsing fails.
52    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
53        serde_json::from_str::<T>(&self.response)
54    }
55
56    /// Parse the response string as structured JSON of type `T`, or map the raw
57    /// text into `T` using the provided fallback function if parsing fails.
58    /// This is useful in examples to avoid repeating parsing boilerplate.
59    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
60    where
61        T: for<'de> serde::Deserialize<'de>,
62        F: FnOnce(&str) -> T,
63    {
64        self.try_parse::<T>()
65            .unwrap_or_else(|_| fallback(&self.response))
66    }
67}
68
69impl ReActAgentOutput {
70    /// Extract the agent output from the ReAct response
71    #[allow(clippy::result_large_err)]
72    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
73    where
74        T: for<'de> serde::Deserialize<'de>,
75    {
76        let react_output: Self = serde_json::from_value(val)
77            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
78        serde_json::from_str(&react_output.response)
79            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
80    }
81}
82
83#[derive(Error, Debug)]
84pub enum ReActExecutorError {
85    #[error("LLM error: {0}")]
86    LLMError(String),
87
88    #[error("Maximum turns exceeded: {max_turns}")]
89    MaxTurnsExceeded { max_turns: usize },
90
91    #[error("Other error: {0}")]
92    Other(String),
93
94    #[cfg(not(target_arch = "wasm32"))]
95    #[error("Event error: {0}")]
96    EventError(#[from] SendError<Event>),
97
98    #[cfg(target_arch = "wasm32")]
99    #[error("Event error: {0}")]
100    EventError(#[from] SendError),
101
102    #[error("Extracting Agent Output Error: {0}")]
103    AgentOutputError(String),
104}
105
106impl From<TurnEngineError> for ReActExecutorError {
107    fn from(error: TurnEngineError) -> Self {
108        match error {
109            TurnEngineError::LLMError(err) => ReActExecutorError::LLMError(err),
110            TurnEngineError::Aborted => {
111                ReActExecutorError::Other("Run aborted by hook".to_string())
112            }
113            TurnEngineError::Other(err) => ReActExecutorError::Other(err),
114        }
115    }
116}
117
118/// Wrapper type for the multi-turn ReAct executor with tool calling support.
119///
120/// Use `ReActAgent<T>` when your agent needs to perform tool calls, manage
121/// multiple turns, and optionally stream content and tool-call deltas.
122#[derive(Debug)]
123pub struct ReActAgent<T: AgentDeriveT> {
124    inner: Arc<T>,
125}
126
127impl<T: AgentDeriveT> Clone for ReActAgent<T> {
128    fn clone(&self) -> Self {
129        Self {
130            inner: Arc::clone(&self.inner),
131        }
132    }
133}
134
135impl<T: AgentDeriveT> ReActAgent<T> {
136    pub fn new(inner: T) -> Self {
137        Self {
138            inner: Arc::new(inner),
139        }
140    }
141}
142
143impl<T: AgentDeriveT> Deref for ReActAgent<T> {
144    type Target = T;
145
146    fn deref(&self) -> &Self::Target {
147        &self.inner
148    }
149}
150
151/// Implement AgentDeriveT for the wrapper by delegating to the inner type
152#[async_trait]
153impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
154    type Output = <T as AgentDeriveT>::Output;
155
156    fn description(&self) -> &str {
157        self.inner.description()
158    }
159
160    fn output_schema(&self) -> Option<Value> {
161        self.inner.output_schema()
162    }
163
164    fn name(&self) -> &str {
165        self.inner.name()
166    }
167
168    fn tools(&self) -> Vec<Box<dyn ToolT>> {
169        self.inner.tools()
170    }
171}
172
173#[async_trait]
174impl<T> AgentHooks for ReActAgent<T>
175where
176    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
177{
178    async fn on_agent_create(&self) {
179        self.inner.on_agent_create().await
180    }
181
182    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
183        self.inner.on_run_start(task, ctx).await
184    }
185
186    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
187        self.inner.on_run_complete(task, result, ctx).await
188    }
189
190    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
191        self.inner.on_turn_start(turn_index, ctx).await
192    }
193
194    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
195        self.inner.on_turn_complete(turn_index, ctx).await
196    }
197
198    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
199        self.inner.on_tool_call(tool_call, ctx).await
200    }
201
202    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
203        self.inner.on_tool_start(tool_call, ctx).await
204    }
205
206    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
207        self.inner.on_tool_result(tool_call, result, ctx).await
208    }
209
210    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
211        self.inner.on_tool_error(tool_call, err, ctx).await
212    }
213    async fn on_agent_shutdown(&self) {
214        self.inner.on_agent_shutdown().await
215    }
216}
217
218/// Implementation of AgentExecutor for the ReActExecutorWrapper
219#[async_trait]
220impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
221    type Output = ReActAgentOutput;
222    type Error = ReActExecutorError;
223
224    fn config(&self) -> ExecutorConfig {
225        ExecutorConfig { max_turns: 10 }
226    }
227
228    async fn execute(
229        &self,
230        task: &Task,
231        context: Arc<Context>,
232    ) -> Result<Self::Output, Self::Error> {
233        if self.on_run_start(task, &context).await == HookOutcome::Abort {
234            return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
235        }
236
237        record_task_state(&context, task);
238
239        let tx_event = context.tx().ok();
240        EventHelper::send_task_started(
241            &tx_event,
242            task.submission_id,
243            context.config().id,
244            context.config().name.clone(),
245            task.prompt.clone(),
246        )
247        .await;
248
249        let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
250        let mut turn_state = engine.turn_state(&context);
251        let max_turns = self.config().max_turns;
252        let mut accumulated_tool_calls = Vec::new();
253        let mut final_response = String::new();
254
255        for turn_index in 0..max_turns {
256            let result = engine
257                .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
258                .await?;
259
260            match result {
261                crate::agent::executor::TurnResult::Complete(output) => {
262                    final_response = output.response.clone();
263                    EventHelper::send_task_completed(
264                        &tx_event,
265                        task.submission_id,
266                        context.config().id,
267                        context.config().name.clone(),
268                        final_response.clone(),
269                    )
270                    .await;
271
272                    accumulated_tool_calls.extend(output.tool_calls);
273
274                    return Ok(ReActAgentOutput {
275                        response: final_response,
276                        done: true,
277                        tool_calls: accumulated_tool_calls,
278                    });
279                }
280                crate::agent::executor::TurnResult::Continue(Some(output)) => {
281                    if !output.response.is_empty() {
282                        final_response = output.response;
283                    }
284                    accumulated_tool_calls.extend(output.tool_calls);
285                }
286                crate::agent::executor::TurnResult::Continue(None) => {}
287            }
288        }
289
290        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
291            EventHelper::send_task_completed(
292                &tx_event,
293                task.submission_id,
294                context.config().id,
295                context.config().name.clone(),
296                final_response.clone(),
297            )
298            .await;
299
300            return Ok(ReActAgentOutput {
301                response: final_response,
302                done: true,
303                tool_calls: accumulated_tool_calls,
304            });
305        }
306
307        Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
308    }
309
310    async fn execute_stream(
311        &self,
312        task: &Task,
313        context: Arc<Context>,
314    ) -> Result<
315        Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
316        Self::Error,
317    > {
318        if self.on_run_start(task, &context).await == HookOutcome::Abort {
319            return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
320        }
321
322        record_task_state(&context, task);
323
324        let tx_event = context.tx().ok();
325        EventHelper::send_task_started(
326            &tx_event,
327            task.submission_id,
328            context.config().id,
329            context.config().name.clone(),
330            task.prompt.clone(),
331        )
332        .await;
333
334        let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
335        let mut turn_state = engine.turn_state(&context);
336        let max_turns = self.config().max_turns;
337        let context_clone = context.clone();
338        let task = task.clone();
339        let executor = self.clone();
340
341        let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
342
343        spawn_future(async move {
344            let mut accumulated_tool_calls = Vec::new();
345            let mut final_response = String::new();
346
347            for turn_index in 0..max_turns {
348                let turn_stream = engine
349                    .run_turn_stream(
350                        executor.clone(),
351                        &task,
352                        context_clone.clone(),
353                        &mut turn_state,
354                        turn_index,
355                        max_turns,
356                    )
357                    .await;
358
359                let mut turn_result = None;
360
361                match turn_stream {
362                    Ok(mut stream) => {
363                        use futures::StreamExt;
364                        while let Some(delta_result) = stream.next().await {
365                            match delta_result {
366                                Ok(TurnDelta::Text(content)) => {
367                                    let _ = tx
368                                        .send(Ok(ReActAgentOutput {
369                                            response: content,
370                                            tool_calls: Vec::new(),
371                                            done: false,
372                                        }))
373                                        .await;
374                                }
375                                Ok(TurnDelta::ToolResults(tool_results)) => {
376                                    accumulated_tool_calls.extend(tool_results);
377                                    let _ = tx
378                                        .send(Ok(ReActAgentOutput {
379                                            response: String::new(),
380                                            tool_calls: accumulated_tool_calls.clone(),
381                                            done: false,
382                                        }))
383                                        .await;
384                                }
385                                Ok(TurnDelta::Done(result)) => {
386                                    turn_result = Some(result);
387                                    break;
388                                }
389                                Err(err) => {
390                                    let _ = tx.send(Err(err.into())).await;
391                                    return;
392                                }
393                            }
394                        }
395                    }
396                    Err(err) => {
397                        let _ = tx.send(Err(err.into())).await;
398                        return;
399                    }
400                }
401
402                let Some(result) = turn_result else {
403                    let _ = tx
404                        .send(Err(ReActExecutorError::Other(
405                            "Stream ended without final result".to_string(),
406                        )))
407                        .await;
408                    return;
409                };
410
411                match result {
412                    crate::agent::executor::TurnResult::Complete(output) => {
413                        final_response = output.response.clone();
414                        accumulated_tool_calls.extend(output.tool_calls);
415                        break;
416                    }
417                    crate::agent::executor::TurnResult::Continue(Some(output)) => {
418                        if !output.response.is_empty() {
419                            final_response = output.response;
420                        }
421                        accumulated_tool_calls.extend(output.tool_calls);
422                    }
423                    crate::agent::executor::TurnResult::Continue(None) => {}
424                }
425            }
426
427            let tx_event = context_clone.tx().ok();
428            EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
429            let _ = tx
430                .send(Ok(ReActAgentOutput {
431                    response: final_response.clone(),
432                    done: true,
433                    tool_calls: accumulated_tool_calls.clone(),
434                }))
435                .await;
436
437            if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
438                EventHelper::send_task_completed(
439                    &tx_event,
440                    task.submission_id,
441                    context_clone.config().id,
442                    context_clone.config().name.clone(),
443                    final_response,
444                )
445                .await;
446            }
447        });
448
449        Ok(receiver_into_stream(rx))
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
458    struct TestAgentOutput {
459        value: i32,
460        message: String,
461    }
462
463    #[test]
464    fn test_extract_agent_output_success() {
465        let agent_output = TestAgentOutput {
466            value: 42,
467            message: "Hello, world!".to_string(),
468        };
469
470        let react_output = ReActAgentOutput {
471            response: serde_json::to_string(&agent_output).unwrap(),
472            done: true,
473            tool_calls: vec![],
474        };
475
476        let react_value = serde_json::to_value(react_output).unwrap();
477        let extracted: TestAgentOutput =
478            ReActAgentOutput::extract_agent_output(react_value).unwrap();
479        assert_eq!(extracted, agent_output);
480    }
481}