autoagents_core/agent/prebuilt/executor/
react.rs

1use crate::agent::executor::AgentExecutor;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, Context, ExecutorConfig, TurnResult};
4use crate::protocol::{Event, StreamingTurnResult, SubmissionId};
5use crate::tool::{to_llm_tool, ToolCallResult, ToolT};
6use async_trait::async_trait;
7use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChoice, Tool};
8use autoagents_llm::error::LLMError;
9use autoagents_llm::{FunctionCall, ToolCall};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17use thiserror::Error;
18
19#[cfg(not(target_arch = "wasm32"))]
20pub use tokio::sync::mpsc::error::SendError;
21
22#[cfg(target_arch = "wasm32")]
23pub use futures::lock::Mutex;
24#[cfg(target_arch = "wasm32")]
25use futures::SinkExt;
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::executor::event_helper::EventHelper;
30use crate::agent::executor::memory_helper::MemoryHelper;
31use crate::agent::executor::tool_processor::ToolProcessor;
32use crate::agent::hooks::{AgentHooks, HookOutcome};
33use crate::channel::{channel, Sender};
34use crate::utils::{receiver_into_stream, spawn_future};
35
36/// Output of the ReAct-style agent
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReActAgentOutput {
39    pub response: String,
40    pub tool_calls: Vec<ToolCallResult>,
41    pub done: bool,
42}
43
44impl From<ReActAgentOutput> for Value {
45    fn from(output: ReActAgentOutput) -> Self {
46        serde_json::to_value(output).unwrap_or(Value::Null)
47    }
48}
49impl From<ReActAgentOutput> for String {
50    fn from(output: ReActAgentOutput) -> Self {
51        output.response
52    }
53}
54
55impl ReActAgentOutput {
56    /// Try to parse the response string as structured JSON of type `T`.
57    /// Returns `serde_json::Error` if parsing fails.
58    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
59        serde_json::from_str::<T>(&self.response)
60    }
61
62    /// Parse the response string as structured JSON of type `T`, or map the raw
63    /// text into `T` using the provided fallback function if parsing fails.
64    /// This is useful in examples to avoid repeating parsing boilerplate.
65    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
66    where
67        T: for<'de> serde::Deserialize<'de>,
68        F: FnOnce(&str) -> T,
69    {
70        self.try_parse::<T>()
71            .unwrap_or_else(|_| fallback(&self.response))
72    }
73}
74
75impl ReActAgentOutput {
76    /// Extract the agent output from the ReAct response
77    #[allow(clippy::result_large_err)]
78    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
79    where
80        T: for<'de> serde::Deserialize<'de>,
81    {
82        let react_output: Self = serde_json::from_value(val)
83            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
84        serde_json::from_str(&react_output.response)
85            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
86    }
87}
88
89#[derive(Error, Debug)]
90pub enum ReActExecutorError {
91    #[error("LLM error: {0}")]
92    LLMError(String),
93
94    #[error("Maximum turns exceeded: {max_turns}")]
95    MaxTurnsExceeded { max_turns: usize },
96
97    #[error("Other error: {0}")]
98    Other(String),
99
100    #[cfg(not(target_arch = "wasm32"))]
101    #[error("Event error: {0}")]
102    EventError(#[from] SendError<Event>),
103
104    #[cfg(target_arch = "wasm32")]
105    #[error("Event error: {0}")]
106    EventError(#[from] SendError),
107
108    #[error("Extracting Agent Output Error: {0}")]
109    AgentOutputError(String),
110}
111
112/// Wrapper type for the multi-turn ReAct executor with tool calling support.
113///
114/// Use `ReActAgent<T>` when your agent needs to perform tool calls, manage
115/// multiple turns, and optionally stream content and tool-call deltas.
116#[derive(Debug)]
117pub struct ReActAgent<T: AgentDeriveT> {
118    inner: Arc<T>,
119}
120
121impl<T: AgentDeriveT> Clone for ReActAgent<T> {
122    fn clone(&self) -> Self {
123        Self {
124            inner: Arc::clone(&self.inner),
125        }
126    }
127}
128
129impl<T: AgentDeriveT> ReActAgent<T> {
130    pub fn new(inner: T) -> Self {
131        Self {
132            inner: Arc::new(inner),
133        }
134    }
135}
136
137impl<T: AgentDeriveT> Deref for ReActAgent<T> {
138    type Target = T;
139
140    fn deref(&self) -> &Self::Target {
141        &self.inner
142    }
143}
144
145/// Implement AgentDeriveT for the wrapper by delegating to the inner type
146#[async_trait]
147impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
148    type Output = <T as AgentDeriveT>::Output;
149
150    fn description(&self) -> &'static str {
151        self.inner.description()
152    }
153
154    fn output_schema(&self) -> Option<Value> {
155        self.inner.output_schema()
156    }
157
158    fn name(&self) -> &'static str {
159        self.inner.name()
160    }
161
162    fn tools(&self) -> Vec<Box<dyn ToolT>> {
163        self.inner.tools()
164    }
165}
166
167#[async_trait]
168impl<T> AgentHooks for ReActAgent<T>
169where
170    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
171{
172    async fn on_agent_create(&self) {
173        self.inner.on_agent_create().await
174    }
175
176    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
177        self.inner.on_run_start(task, ctx).await
178    }
179
180    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
181        self.inner.on_run_complete(task, result, ctx).await
182    }
183
184    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
185        self.inner.on_turn_start(turn_index, ctx).await
186    }
187
188    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
189        self.inner.on_turn_complete(turn_index, ctx).await
190    }
191
192    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
193        self.inner.on_tool_call(tool_call, ctx).await
194    }
195
196    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
197        self.inner.on_tool_start(tool_call, ctx).await
198    }
199
200    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
201        self.inner.on_tool_result(tool_call, result, ctx).await
202    }
203
204    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
205        self.inner.on_tool_error(tool_call, err, ctx).await
206    }
207    async fn on_agent_shutdown(&self) {
208        self.inner.on_agent_shutdown().await
209    }
210}
211
212impl<T: AgentDeriveT + AgentHooks> ReActAgent<T> {
213    /// Process a single turn with the LLM
214    async fn process_turn(
215        &self,
216        context: &Context,
217        tools: &[Box<dyn ToolT>],
218    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
219        let messages = self.prepare_messages(context).await;
220        let response = self.get_llm_response(context, &messages, tools).await?;
221        let response_text = response.text().unwrap_or_default();
222
223        if let Some(tool_calls) = response.tool_calls() {
224            self.handle_tool_calls(context, tools, tool_calls.clone(), response_text)
225                .await
226        } else {
227            self.handle_text_response(context, response_text).await
228        }
229    }
230
231    /// Get LLM response for the given messages and tools
232    async fn get_llm_response(
233        &self,
234        context: &Context,
235        messages: &[ChatMessage],
236        tools: &[Box<dyn ToolT>],
237    ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, ReActExecutorError> {
238        let llm = context.llm();
239        let agent_config = context.config();
240        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
241
242        llm.chat(
243            messages,
244            if tools.is_empty() {
245                None
246            } else {
247                Some(&tools_serialized)
248            },
249            agent_config.output_schema.clone(),
250        )
251        .await
252        .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
253    }
254
255    /// Handle tool calls and return the result
256    async fn handle_tool_calls(
257        &self,
258        context: &Context,
259        tools: &[Box<dyn ToolT>],
260        tool_calls: Vec<ToolCall>,
261        response_text: String,
262    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
263        let tx_event = context.tx().ok();
264
265        // Process tool calls
266        let mut tool_results = Vec::new();
267        for call in &tool_calls {
268            if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
269                self, context, tools, call, &tx_event,
270            )
271            .await
272            {
273                tool_results.push(result);
274            }
275        }
276
277        // Store in memory
278        MemoryHelper::store_tool_interaction(
279            &context.memory(),
280            &tool_calls,
281            &tool_results,
282            &response_text,
283        )
284        .await;
285
286        // Update state - use try_lock to avoid deadlock
287        {
288            let state = context.state();
289            #[cfg(not(target_arch = "wasm32"))]
290            if let Ok(mut guard) = state.try_lock() {
291                for result in &tool_results {
292                    guard.record_tool_call(result.clone());
293                }
294            };
295            #[cfg(target_arch = "wasm32")]
296            if let Some(mut guard) = state.try_lock() {
297                for result in &tool_results {
298                    guard.record_tool_call(result.clone());
299                }
300            };
301        }
302
303        Ok(TurnResult::Continue(Some(ReActAgentOutput {
304            response: response_text,
305            done: true,
306            tool_calls: tool_results,
307        })))
308    }
309
310    /// Handle text-only response
311    async fn handle_text_response(
312        &self,
313        context: &Context,
314        response_text: String,
315    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
316        if !response_text.is_empty() {
317            MemoryHelper::store_assistant_response(&context.memory(), response_text.clone()).await;
318        }
319
320        Ok(TurnResult::Complete(ReActAgentOutput {
321            response: response_text,
322            done: true,
323            tool_calls: vec![],
324        }))
325    }
326
327    /// Prepare messages for the current turn
328    async fn prepare_messages(&self, context: &Context) -> Vec<ChatMessage> {
329        let mut messages = vec![ChatMessage {
330            role: ChatRole::System,
331            message_type: MessageType::Text,
332            content: context.config().description.clone(),
333        }];
334
335        let recalled = MemoryHelper::recall_messages(&context.memory()).await;
336        messages.extend(recalled);
337
338        messages
339    }
340
341    /// Process a streaming turn with tool support
342    async fn process_streaming_turn(
343        &self,
344        context: &Context,
345        tools: &[Box<dyn ToolT>],
346        tx: &mut Sender<Result<ReActAgentOutput, ReActExecutorError>>,
347        submission_id: SubmissionId,
348    ) -> Result<StreamingTurnResult, ReActExecutorError> {
349        let messages = self.prepare_messages(context).await;
350        let mut stream = self.get_llm_stream(context, &messages, tools).await?;
351
352        let mut response_text = String::new();
353        let mut tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)> =
354            HashMap::new();
355
356        // Process stream chunks
357        while let Some(chunk_result) = stream.next().await {
358            let chunk = chunk_result.map_err(|e| ReActExecutorError::LLMError(e.to_string()))?;
359
360            if let Some(choice) = chunk.choices.first() {
361                // Handle content
362                if let Some(content) = &choice.delta.content {
363                    response_text.push_str(content);
364                    let _ = tx
365                        .send(Ok(ReActAgentOutput {
366                            response: content.to_string(),
367                            tool_calls: vec![],
368                            done: false,
369                        }))
370                        .await;
371                }
372
373                // Handle tool calls
374                self.process_stream_tool_calls(&mut tool_calls_map, choice);
375
376                // Send stream chunk event
377                let tx_event = context.tx().ok();
378                EventHelper::send_stream_chunk(&tx_event, submission_id, choice.clone()).await;
379            }
380        }
381
382        // Process collected tool calls if any
383        self.finalize_stream_tool_calls(
384            context,
385            tools,
386            tool_calls_map,
387            submission_id,
388            response_text,
389        )
390        .await
391    }
392
393    /// Get streaming LLM response
394    async fn get_llm_stream(
395        &self,
396        context: &Context,
397        messages: &[ChatMessage],
398        tools: &[Box<dyn ToolT>],
399    ) -> Result<
400        Pin<Box<dyn Stream<Item = Result<autoagents_llm::chat::StreamResponse, LLMError>> + Send>>,
401        ReActExecutorError,
402    > {
403        let llm = context.llm();
404        let agent_config = context.config();
405        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
406
407        llm.chat_stream_struct(
408            messages,
409            if tools.is_empty() {
410                None
411            } else {
412                Some(&tools_serialized)
413            },
414            agent_config.output_schema.clone(),
415        )
416        .await
417        .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
418    }
419
420    /// Process tool calls from stream chunks
421    fn process_stream_tool_calls(
422        &self,
423        tool_calls_map: &mut HashMap<usize, (Option<String>, Option<String>, String)>,
424        choice: &StreamChoice,
425    ) {
426        if let Some(tool_call_deltas) = &choice.delta.tool_calls {
427            for delta in tool_call_deltas {
428                let entry =
429                    tool_calls_map
430                        .entry(delta.index)
431                        .or_insert((None, None, String::new()));
432
433                if let Some(function) = &delta.function {
434                    if !function.name.is_empty() {
435                        entry.0 = Some(function.name.clone());
436                    }
437                    entry.2.push_str(&function.arguments);
438                }
439            }
440        }
441    }
442
443    /// Finalize and process collected tool calls from streaming
444    async fn finalize_stream_tool_calls(
445        &self,
446        context: &Context,
447        tools: &[Box<dyn ToolT>],
448        tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)>,
449        submission_id: SubmissionId,
450        response_text: String,
451    ) -> Result<StreamingTurnResult, ReActExecutorError> {
452        if tool_calls_map.is_empty() {
453            if !response_text.is_empty() {
454                MemoryHelper::store_assistant_response(&context.memory(), response_text.clone())
455                    .await;
456            }
457            return Ok(StreamingTurnResult::Complete(response_text));
458        }
459
460        // Convert map to tool calls
461        let mut sorted_calls: Vec<_> = tool_calls_map.into_iter().collect();
462        sorted_calls.sort_by_key(|(index, _)| *index);
463
464        let collected_tool_calls: Vec<ToolCall> = sorted_calls
465            .into_iter()
466            .filter_map(|(_, (name, id, args))| {
467                name.map(|name| ToolCall {
468                    id: id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
469                    call_type: "function".to_string(),
470                    function: FunctionCall {
471                        name,
472                        arguments: args,
473                    },
474                })
475            })
476            .collect();
477
478        // Send tool call events
479        let tx_event = context.tx().ok();
480        for tool_call in &collected_tool_calls {
481            EventHelper::send_stream_tool_call(
482                &tx_event,
483                submission_id,
484                serde_json::to_value(tool_call).unwrap_or(Value::Null),
485            )
486            .await;
487        }
488
489        // Process tool calls
490        let tool_results =
491            ToolProcessor::process_tool_calls(tools, collected_tool_calls.clone(), tx_event).await;
492
493        // Update memory
494        MemoryHelper::store_tool_interaction(
495            &context.memory(),
496            &collected_tool_calls,
497            &tool_results,
498            &response_text,
499        )
500        .await;
501
502        // Update state
503        let state = context.state();
504        let mut guard = state.lock().await;
505        for result in &tool_results {
506            guard.record_tool_call(result.clone());
507        }
508
509        Ok(StreamingTurnResult::ToolCallsProcessed(tool_results))
510    }
511}
512
513/// Implementation of AgentExecutor for the ReActExecutorWrapper
514#[async_trait]
515impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
516    type Output = ReActAgentOutput;
517    type Error = ReActExecutorError;
518
519    fn config(&self) -> ExecutorConfig {
520        ExecutorConfig { max_turns: 10 }
521    }
522
523    async fn execute(
524        &self,
525        task: &Task,
526        context: Arc<Context>,
527    ) -> Result<Self::Output, Self::Error> {
528        // Initialize task
529        MemoryHelper::store_user_message(
530            &context.memory(),
531            task.prompt.clone(),
532            task.image.clone(),
533        )
534        .await;
535
536        // Record task in state - use try_lock to avoid deadlock
537        {
538            let state = context.state();
539            #[cfg(not(target_arch = "wasm32"))]
540            if let Ok(mut guard) = state.try_lock() {
541                guard.record_task(task.clone());
542            };
543            #[cfg(target_arch = "wasm32")]
544            if let Some(mut guard) = state.try_lock() {
545                guard.record_task(task.clone());
546            };
547        }
548
549        // Send task started event
550        let tx_event = context.tx().ok();
551        EventHelper::send_task_started(
552            &tx_event,
553            task.submission_id,
554            context.config().id,
555            task.prompt.clone(),
556            context.config().name.clone(),
557        )
558        .await;
559
560        // Execute turns
561        let max_turns = self.config().max_turns;
562        let mut accumulated_tool_calls = Vec::new();
563        let mut final_response = String::new();
564
565        for turn_num in 0..max_turns {
566            let tools = context.tools();
567            EventHelper::send_turn_started(&tx_event, turn_num, max_turns).await;
568
569            //Run Hook
570            self.on_turn_start(turn_num, &context).await;
571
572            match self.process_turn(&context, tools).await? {
573                TurnResult::Complete(result) => {
574                    if !accumulated_tool_calls.is_empty() {
575                        return Ok(ReActAgentOutput {
576                            response: result.response,
577                            done: true,
578                            tool_calls: accumulated_tool_calls,
579                        });
580                    }
581                    EventHelper::send_turn_completed(&tx_event, turn_num, false).await;
582                    //Run Hook
583                    self.on_turn_complete(turn_num, &context).await;
584                    return Ok(result);
585                }
586                TurnResult::Continue(Some(partial_result)) => {
587                    accumulated_tool_calls.extend(partial_result.tool_calls);
588                    if !partial_result.response.is_empty() {
589                        final_response = partial_result.response;
590                    }
591                }
592                TurnResult::Continue(None) => continue,
593            }
594        }
595
596        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
597            EventHelper::send_task_completed(
598                &tx_event,
599                task.submission_id,
600                context.config().id,
601                final_response.clone(),
602                context.config().name.clone(),
603            )
604            .await;
605            Ok(ReActAgentOutput {
606                response: final_response,
607                done: true,
608                tool_calls: accumulated_tool_calls,
609            })
610        } else {
611            Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
612        }
613    }
614
615    async fn execute_stream(
616        &self,
617        task: &Task,
618        context: Arc<Context>,
619    ) -> Result<
620        Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
621        Self::Error,
622    > {
623        // Initialize task
624        MemoryHelper::store_user_message(
625            &context.memory(),
626            task.prompt.clone(),
627            task.image.clone(),
628        )
629        .await;
630
631        // Record task in state - use try_lock to avoid deadlock
632        {
633            let state = context.state();
634            #[cfg(not(target_arch = "wasm32"))]
635            if let Ok(mut guard) = state.try_lock() {
636                guard.record_task(task.clone());
637            };
638            #[cfg(target_arch = "wasm32")]
639            if let Some(mut guard) = state.try_lock() {
640                guard.record_task(task.clone());
641            };
642        }
643
644        // Send task started event
645        let tx_event = context.tx().ok();
646        EventHelper::send_task_started(
647            &tx_event,
648            task.submission_id,
649            context.config().id,
650            task.prompt.clone(),
651            context.config().name.clone(),
652        )
653        .await;
654
655        // Create channel for streaming
656        let (mut tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
657
658        // Clone necessary components
659        let executor = self.clone();
660        let context_clone = context.clone();
661        let submission_id = task.submission_id;
662        let max_turns = executor.config().max_turns;
663
664        // Spawn streaming task
665        spawn_future(async move {
666            let mut accumulated_tool_calls = Vec::new();
667            let mut final_response = String::new();
668            let tools = context_clone.tools();
669
670            for turn in 0..max_turns {
671                // Send turn events
672                let tx_event = context_clone.tx().ok();
673                EventHelper::send_turn_started(&tx_event, turn, max_turns).await;
674
675                // Process streaming turn
676                match executor
677                    .process_streaming_turn(&context_clone, tools, &mut tx, submission_id)
678                    .await
679                {
680                    Ok(StreamingTurnResult::Complete(response)) => {
681                        final_response = response;
682                        EventHelper::send_turn_completed(&tx_event, turn, true).await;
683                        break;
684                    }
685                    Ok(StreamingTurnResult::ToolCallsProcessed(tool_results)) => {
686                        accumulated_tool_calls.extend(tool_results);
687
688                        let _ = tx
689                            .send(Ok(ReActAgentOutput {
690                                response: String::new(),
691                                done: false,
692                                tool_calls: accumulated_tool_calls.clone(),
693                            }))
694                            .await;
695
696                        EventHelper::send_turn_completed(&tx_event, turn, false).await;
697                    }
698                    Err(e) => {
699                        let _ = tx.send(Err(e)).await;
700                        return;
701                    }
702                }
703            }
704
705            // Send final result
706            let tx_event = context_clone.tx().ok();
707            EventHelper::send_stream_complete(&tx_event, submission_id).await;
708
709            let _ = tx
710                .send(Ok(ReActAgentOutput {
711                    response: final_response,
712                    done: true,
713                    tool_calls: accumulated_tool_calls,
714                }))
715                .await;
716        });
717
718        Ok(receiver_into_stream(rx))
719    }
720}
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725
726    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
727    struct TestAgentOutput {
728        value: i32,
729        message: String,
730    }
731
732    #[test]
733    fn test_extract_agent_output_success() {
734        let agent_output = TestAgentOutput {
735            value: 42,
736            message: "Hello, world!".to_string(),
737        };
738
739        let react_output = ReActAgentOutput {
740            response: serde_json::to_string(&agent_output).unwrap(),
741            done: true,
742            tool_calls: vec![],
743        };
744
745        let react_value = serde_json::to_value(react_output).unwrap();
746        let extracted: TestAgentOutput =
747            ReActAgentOutput::extract_agent_output(react_value).unwrap();
748        assert_eq!(extracted, agent_output);
749    }
750}