Skip to main content

autoagents_core/agent/prebuilt/executor/
react.rs

1use crate::agent::executor::AgentExecutor;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, Context, ExecutorConfig, TurnResult};
4use crate::protocol::{Event, StreamingTurnResult, SubmissionId};
5use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
6use async_trait::async_trait;
7use autoagents_llm::ToolCall;
8use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, Tool};
9use autoagents_llm::error::LLMError;
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashSet;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17use thiserror::Error;
18
19#[cfg(not(target_arch = "wasm32"))]
20pub use tokio::sync::mpsc::error::SendError;
21
22#[cfg(target_arch = "wasm32")]
23use futures::SinkExt;
24#[cfg(target_arch = "wasm32")]
25pub use futures::lock::Mutex;
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::executor::event_helper::EventHelper;
30use crate::agent::executor::memory_helper::MemoryHelper;
31use crate::agent::executor::tool_processor::ToolProcessor;
32use crate::agent::hooks::{AgentHooks, HookOutcome};
33use crate::channel::{Sender, channel};
34use crate::utils::{receiver_into_stream, spawn_future};
35
36/// Output of the ReAct-style agent
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReActAgentOutput {
39    pub response: String,
40    pub tool_calls: Vec<ToolCallResult>,
41    pub done: bool,
42}
43
44impl From<ReActAgentOutput> for Value {
45    fn from(output: ReActAgentOutput) -> Self {
46        serde_json::to_value(output).unwrap_or(Value::Null)
47    }
48}
49impl From<ReActAgentOutput> for String {
50    fn from(output: ReActAgentOutput) -> Self {
51        output.response
52    }
53}
54
55impl ReActAgentOutput {
56    /// Try to parse the response string as structured JSON of type `T`.
57    /// Returns `serde_json::Error` if parsing fails.
58    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
59        serde_json::from_str::<T>(&self.response)
60    }
61
62    /// Parse the response string as structured JSON of type `T`, or map the raw
63    /// text into `T` using the provided fallback function if parsing fails.
64    /// This is useful in examples to avoid repeating parsing boilerplate.
65    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
66    where
67        T: for<'de> serde::Deserialize<'de>,
68        F: FnOnce(&str) -> T,
69    {
70        self.try_parse::<T>()
71            .unwrap_or_else(|_| fallback(&self.response))
72    }
73}
74
75impl ReActAgentOutput {
76    /// Extract the agent output from the ReAct response
77    #[allow(clippy::result_large_err)]
78    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
79    where
80        T: for<'de> serde::Deserialize<'de>,
81    {
82        let react_output: Self = serde_json::from_value(val)
83            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
84        serde_json::from_str(&react_output.response)
85            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
86    }
87}
88
89#[derive(Error, Debug)]
90pub enum ReActExecutorError {
91    #[error("LLM error: {0}")]
92    LLMError(String),
93
94    #[error("Maximum turns exceeded: {max_turns}")]
95    MaxTurnsExceeded { max_turns: usize },
96
97    #[error("Other error: {0}")]
98    Other(String),
99
100    #[cfg(not(target_arch = "wasm32"))]
101    #[error("Event error: {0}")]
102    EventError(#[from] SendError<Event>),
103
104    #[cfg(target_arch = "wasm32")]
105    #[error("Event error: {0}")]
106    EventError(#[from] SendError),
107
108    #[error("Extracting Agent Output Error: {0}")]
109    AgentOutputError(String),
110}
111
112/// Wrapper type for the multi-turn ReAct executor with tool calling support.
113///
114/// Use `ReActAgent<T>` when your agent needs to perform tool calls, manage
115/// multiple turns, and optionally stream content and tool-call deltas.
116#[derive(Debug)]
117pub struct ReActAgent<T: AgentDeriveT> {
118    inner: Arc<T>,
119}
120
121impl<T: AgentDeriveT> Clone for ReActAgent<T> {
122    fn clone(&self) -> Self {
123        Self {
124            inner: Arc::clone(&self.inner),
125        }
126    }
127}
128
129impl<T: AgentDeriveT> ReActAgent<T> {
130    pub fn new(inner: T) -> Self {
131        Self {
132            inner: Arc::new(inner),
133        }
134    }
135}
136
137impl<T: AgentDeriveT> Deref for ReActAgent<T> {
138    type Target = T;
139
140    fn deref(&self) -> &Self::Target {
141        &self.inner
142    }
143}
144
145/// Implement AgentDeriveT for the wrapper by delegating to the inner type
146#[async_trait]
147impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
148    type Output = <T as AgentDeriveT>::Output;
149
150    fn description(&self) -> &'static str {
151        self.inner.description()
152    }
153
154    fn output_schema(&self) -> Option<Value> {
155        self.inner.output_schema()
156    }
157
158    fn name(&self) -> &'static str {
159        self.inner.name()
160    }
161
162    fn tools(&self) -> Vec<Box<dyn ToolT>> {
163        self.inner.tools()
164    }
165}
166
167#[async_trait]
168impl<T> AgentHooks for ReActAgent<T>
169where
170    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
171{
172    async fn on_agent_create(&self) {
173        self.inner.on_agent_create().await
174    }
175
176    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
177        self.inner.on_run_start(task, ctx).await
178    }
179
180    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
181        self.inner.on_run_complete(task, result, ctx).await
182    }
183
184    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
185        self.inner.on_turn_start(turn_index, ctx).await
186    }
187
188    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
189        self.inner.on_turn_complete(turn_index, ctx).await
190    }
191
192    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
193        self.inner.on_tool_call(tool_call, ctx).await
194    }
195
196    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
197        self.inner.on_tool_start(tool_call, ctx).await
198    }
199
200    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
201        self.inner.on_tool_result(tool_call, result, ctx).await
202    }
203
204    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
205        self.inner.on_tool_error(tool_call, err, ctx).await
206    }
207    async fn on_agent_shutdown(&self) {
208        self.inner.on_agent_shutdown().await
209    }
210}
211
212impl<T: AgentDeriveT + AgentHooks> ReActAgent<T> {
213    /// Process a single turn with the LLM
214    async fn process_turn(
215        &self,
216        context: &Context,
217        tools: &[Box<dyn ToolT>],
218    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
219        let messages = self.prepare_messages(context).await;
220        let response = self.get_llm_response(context, &messages, tools).await?;
221        let response_text = response.text().unwrap_or_default();
222
223        if let Some(tool_calls) = response.tool_calls() {
224            self.handle_tool_calls(context, tools, tool_calls.clone(), response_text)
225                .await
226        } else {
227            self.handle_text_response(context, response_text).await
228        }
229    }
230
231    /// Get LLM response for the given messages and tools
232    async fn get_llm_response(
233        &self,
234        context: &Context,
235        messages: &[ChatMessage],
236        tools: &[Box<dyn ToolT>],
237    ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, ReActExecutorError> {
238        let llm = context.llm();
239        let agent_config = context.config();
240        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
241
242        if tools.is_empty() {
243            llm.chat(messages, agent_config.output_schema.clone())
244                .await
245                .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
246        } else {
247            llm.chat_with_tools(
248                messages,
249                Some(&tools_serialized),
250                agent_config.output_schema.clone(),
251            )
252            .await
253            .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
254        }
255    }
256
257    /// Handle tool calls and return the result
258    async fn handle_tool_calls(
259        &self,
260        context: &Context,
261        tools: &[Box<dyn ToolT>],
262        tool_calls: Vec<ToolCall>,
263        response_text: String,
264    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
265        let tx_event = context.tx().ok();
266
267        // Process tool calls
268        let mut tool_results = Vec::new();
269        for call in &tool_calls {
270            if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
271                self, context, tools, call, &tx_event,
272            )
273            .await
274            {
275                tool_results.push(result);
276            }
277        }
278
279        // Store in memory
280        MemoryHelper::store_tool_interaction(
281            &context.memory(),
282            &tool_calls,
283            &tool_results,
284            &response_text,
285        )
286        .await;
287
288        // Update state - use try_lock to avoid deadlock
289        {
290            let state = context.state();
291            #[cfg(not(target_arch = "wasm32"))]
292            if let Ok(mut guard) = state.try_lock() {
293                for result in &tool_results {
294                    guard.record_tool_call(result.clone());
295                }
296            };
297            #[cfg(target_arch = "wasm32")]
298            if let Some(mut guard) = state.try_lock() {
299                for result in &tool_results {
300                    guard.record_tool_call(result.clone());
301                }
302            };
303        }
304
305        Ok(TurnResult::Continue(Some(ReActAgentOutput {
306            response: response_text,
307            done: true,
308            tool_calls: tool_results,
309        })))
310    }
311
312    /// Handle text-only response
313    async fn handle_text_response(
314        &self,
315        context: &Context,
316        response_text: String,
317    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
318        if !response_text.is_empty() {
319            MemoryHelper::store_assistant_response(&context.memory(), response_text.clone()).await;
320        }
321
322        Ok(TurnResult::Complete(ReActAgentOutput {
323            response: response_text,
324            done: true,
325            tool_calls: vec![],
326        }))
327    }
328
329    /// Prepare messages for the current turn
330    async fn prepare_messages(&self, context: &Context) -> Vec<ChatMessage> {
331        let mut messages = vec![ChatMessage {
332            role: ChatRole::System,
333            message_type: MessageType::Text,
334            content: context.config().description.clone(),
335        }];
336
337        let recalled = MemoryHelper::recall_messages(&context.memory()).await;
338        messages.extend(recalled);
339
340        messages
341    }
342
343    /// Process a streaming turn with tool support
344    async fn process_streaming_turn(
345        &self,
346        context: &Context,
347        tools: &[Box<dyn ToolT>],
348        tx: &mut Sender<Result<ReActAgentOutput, ReActExecutorError>>,
349        submission_id: SubmissionId,
350    ) -> Result<StreamingTurnResult, ReActExecutorError> {
351        let messages = self.prepare_messages(context).await;
352        let mut stream = self.get_llm_stream(context, &messages, tools).await?;
353
354        let mut response_text = String::new();
355        let mut tool_calls: Vec<ToolCall> = Vec::new();
356        let mut tool_call_ids: HashSet<String> = HashSet::new();
357
358        // Process stream chunks
359        while let Some(chunk_result) = stream.next().await {
360            let chunk: StreamChunk =
361                chunk_result.map_err(|e| ReActExecutorError::LLMError(e.to_string()))?;
362            let chunk_clone = chunk.clone();
363
364            match chunk {
365                StreamChunk::Text(content) => {
366                    response_text.push_str(&content);
367                    let _ = tx
368                        .send(Ok(ReActAgentOutput {
369                            response: content.to_string(),
370                            tool_calls: vec![],
371                            done: false,
372                        }))
373                        .await;
374                }
375                StreamChunk::ToolUseComplete {
376                    index: _,
377                    tool_call,
378                } => {
379                    if tool_call_ids.insert(tool_call.id.clone()) {
380                        tool_calls.push(tool_call.clone());
381
382                        let tx_event = context.tx().ok();
383                        EventHelper::send_stream_tool_call(
384                            &tx_event,
385                            submission_id,
386                            serde_json::to_value(tool_call).unwrap_or(Value::Null),
387                        )
388                        .await;
389                    }
390                }
391                StreamChunk::Usage(_) => {
392                    //TODO: Add Usage Analytics
393                }
394                _ => {
395                    //Do nothing
396                }
397            }
398            // Send stream chunk event
399            let tx_event = context.tx().ok();
400            EventHelper::send_stream_chunk(&tx_event, submission_id, chunk_clone).await;
401        }
402
403        // Process collected tool calls if any
404        if tool_calls.is_empty() {
405            if !response_text.is_empty() {
406                MemoryHelper::store_assistant_response(&context.memory(), response_text.clone())
407                    .await;
408            }
409            return Ok(StreamingTurnResult::Complete(response_text));
410        }
411
412        let tx_event = context.tx().ok();
413        let tool_results =
414            ToolProcessor::process_tool_calls(tools, tool_calls.clone(), tx_event).await;
415
416        MemoryHelper::store_tool_interaction(
417            &context.memory(),
418            &tool_calls,
419            &tool_results,
420            &response_text,
421        )
422        .await;
423
424        let state = context.state();
425        let mut guard = state.lock().await;
426        for result in &tool_results {
427            guard.record_tool_call(result.clone());
428        }
429
430        Ok(StreamingTurnResult::ToolCallsProcessed(tool_results))
431    }
432
433    /// Get streaming LLM response
434    async fn get_llm_stream(
435        &self,
436        context: &Context,
437        messages: &[ChatMessage],
438        tools: &[Box<dyn ToolT>],
439    ) -> Result<
440        Pin<Box<dyn Stream<Item = Result<autoagents_llm::chat::StreamChunk, LLMError>> + Send>>,
441        ReActExecutorError,
442    > {
443        let llm = context.llm();
444        let agent_config = context.config();
445        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
446
447        llm.chat_stream_with_tools(
448            messages,
449            if tools.is_empty() {
450                None
451            } else {
452                Some(&tools_serialized)
453            },
454            agent_config.output_schema.clone(),
455        )
456        .await
457        .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
458    }
459}
460
461/// Implementation of AgentExecutor for the ReActExecutorWrapper
462#[async_trait]
463impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
464    type Output = ReActAgentOutput;
465    type Error = ReActExecutorError;
466
467    fn config(&self) -> ExecutorConfig {
468        ExecutorConfig { max_turns: 10 }
469    }
470
471    async fn execute(
472        &self,
473        task: &Task,
474        context: Arc<Context>,
475    ) -> Result<Self::Output, Self::Error> {
476        // Initialize task
477        MemoryHelper::store_user_message(
478            &context.memory(),
479            task.prompt.clone(),
480            task.image.clone(),
481        )
482        .await;
483
484        // Record task in state - use try_lock to avoid deadlock
485        {
486            let state = context.state();
487            #[cfg(not(target_arch = "wasm32"))]
488            if let Ok(mut guard) = state.try_lock() {
489                guard.record_task(task.clone());
490            };
491            #[cfg(target_arch = "wasm32")]
492            if let Some(mut guard) = state.try_lock() {
493                guard.record_task(task.clone());
494            };
495        }
496
497        // Send task started event
498        let tx_event = context.tx().ok();
499        EventHelper::send_task_started(
500            &tx_event,
501            task.submission_id,
502            context.config().id,
503            task.prompt.clone(),
504            context.config().name.clone(),
505        )
506        .await;
507
508        // Execute turns
509        let max_turns = self.config().max_turns;
510        let mut accumulated_tool_calls = Vec::new();
511        let mut final_response = String::new();
512
513        for turn_num in 0..max_turns {
514            let tools = context.tools();
515            EventHelper::send_turn_started(&tx_event, turn_num, max_turns).await;
516
517            //Run Hook
518            self.on_turn_start(turn_num, &context).await;
519
520            match self.process_turn(&context, tools).await? {
521                TurnResult::Complete(result) => {
522                    if !accumulated_tool_calls.is_empty() {
523                        return Ok(ReActAgentOutput {
524                            response: result.response,
525                            done: true,
526                            tool_calls: accumulated_tool_calls,
527                        });
528                    }
529                    EventHelper::send_turn_completed(&tx_event, turn_num, false).await;
530                    //Run Hook
531                    self.on_turn_complete(turn_num, &context).await;
532                    return Ok(result);
533                }
534                TurnResult::Continue(Some(partial_result)) => {
535                    accumulated_tool_calls.extend(partial_result.tool_calls);
536                    if !partial_result.response.is_empty() {
537                        final_response = partial_result.response;
538                    }
539                }
540                TurnResult::Continue(None) => continue,
541            }
542        }
543
544        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
545            EventHelper::send_task_completed(
546                &tx_event,
547                task.submission_id,
548                context.config().id,
549                final_response.clone(),
550                context.config().name.clone(),
551            )
552            .await;
553            Ok(ReActAgentOutput {
554                response: final_response,
555                done: true,
556                tool_calls: accumulated_tool_calls,
557            })
558        } else {
559            Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
560        }
561    }
562
563    async fn execute_stream(
564        &self,
565        task: &Task,
566        context: Arc<Context>,
567    ) -> Result<
568        Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
569        Self::Error,
570    > {
571        // Initialize task
572        MemoryHelper::store_user_message(
573            &context.memory(),
574            task.prompt.clone(),
575            task.image.clone(),
576        )
577        .await;
578
579        // Record task in state - use try_lock to avoid deadlock
580        {
581            let state = context.state();
582            #[cfg(not(target_arch = "wasm32"))]
583            if let Ok(mut guard) = state.try_lock() {
584                guard.record_task(task.clone());
585            };
586            #[cfg(target_arch = "wasm32")]
587            if let Some(mut guard) = state.try_lock() {
588                guard.record_task(task.clone());
589            };
590        }
591
592        // Send task started event
593        let tx_event = context.tx().ok();
594        EventHelper::send_task_started(
595            &tx_event,
596            task.submission_id,
597            context.config().id,
598            task.prompt.clone(),
599            context.config().name.clone(),
600        )
601        .await;
602
603        // Create channel for streaming
604        let (mut tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
605
606        // Clone necessary components
607        let executor = self.clone();
608        let context_clone = context.clone();
609        let submission_id = task.submission_id;
610        let max_turns = executor.config().max_turns;
611
612        // Spawn streaming task
613        spawn_future(async move {
614            let mut accumulated_tool_calls = Vec::new();
615            let mut final_response = String::new();
616            let tools = context_clone.tools();
617
618            for turn in 0..max_turns {
619                // Send turn events
620                let tx_event = context_clone.tx().ok();
621                EventHelper::send_turn_started(&tx_event, turn, max_turns).await;
622
623                // Process streaming turn
624                match executor
625                    .process_streaming_turn(&context_clone, tools, &mut tx, submission_id)
626                    .await
627                {
628                    Ok(StreamingTurnResult::Complete(response)) => {
629                        final_response = response;
630                        EventHelper::send_turn_completed(&tx_event, turn, true).await;
631                        break;
632                    }
633                    Ok(StreamingTurnResult::ToolCallsProcessed(tool_results)) => {
634                        accumulated_tool_calls.extend(tool_results);
635
636                        let _ = tx
637                            .send(Ok(ReActAgentOutput {
638                                response: String::new(),
639                                done: false,
640                                tool_calls: accumulated_tool_calls.clone(),
641                            }))
642                            .await;
643
644                        EventHelper::send_turn_completed(&tx_event, turn, false).await;
645                    }
646                    Err(e) => {
647                        let _ = tx.send(Err(e)).await;
648                        return;
649                    }
650                }
651            }
652
653            // Send final result
654            let tx_event = context_clone.tx().ok();
655            EventHelper::send_stream_complete(&tx_event, submission_id).await;
656
657            let _ = tx
658                .send(Ok(ReActAgentOutput {
659                    response: final_response,
660                    done: true,
661                    tool_calls: accumulated_tool_calls,
662                }))
663                .await;
664        });
665
666        Ok(receiver_into_stream(rx))
667    }
668}
669
670#[cfg(test)]
671mod tests {
672    use super::*;
673
674    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
675    struct TestAgentOutput {
676        value: i32,
677        message: String,
678    }
679
680    #[test]
681    fn test_extract_agent_output_success() {
682        let agent_output = TestAgentOutput {
683            value: 42,
684            message: "Hello, world!".to_string(),
685        };
686
687        let react_output = ReActAgentOutput {
688            response: serde_json::to_string(&agent_output).unwrap(),
689            done: true,
690            tool_calls: vec![],
691        };
692
693        let react_value = serde_json::to_value(react_output).unwrap();
694        let extracted: TestAgentOutput =
695            ReActAgentOutput::extract_agent_output(react_value).unwrap();
696        assert_eq!(extracted, agent_output);
697    }
698}