autoagents_core/agent/prebuilt/executor/
react.rs

1use crate::agent::executor::AgentExecutor;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, Context, ExecutorConfig, TurnResult};
4use crate::protocol::{Event, StreamingTurnResult, SubmissionId};
5use crate::tool::{to_llm_tool, ToolCallResult, ToolT};
6use async_trait::async_trait;
7use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChoice, Tool};
8use autoagents_llm::error::LLMError;
9use autoagents_llm::{FunctionCall, ToolCall};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17use thiserror::Error;
18
19#[cfg(not(target_arch = "wasm32"))]
20pub use tokio::sync::mpsc::error::SendError;
21
22#[cfg(target_arch = "wasm32")]
23pub use futures::lock::Mutex;
24#[cfg(target_arch = "wasm32")]
25use futures::SinkExt;
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::executor::event_helper::EventHelper;
30use crate::agent::executor::memory_helper::MemoryHelper;
31use crate::agent::executor::tool_processor::ToolProcessor;
32use crate::agent::hooks::{AgentHooks, HookOutcome};
33use crate::channel::{channel, Sender};
34use crate::utils::{receiver_into_stream, spawn_future};
35
36/// Output of the ReAct-style agent
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReActAgentOutput {
39    pub response: String,
40    pub tool_calls: Vec<ToolCallResult>,
41    pub done: bool,
42}
43
44impl From<ReActAgentOutput> for Value {
45    fn from(output: ReActAgentOutput) -> Self {
46        serde_json::to_value(output).unwrap_or(Value::Null)
47    }
48}
49impl From<ReActAgentOutput> for String {
50    fn from(output: ReActAgentOutput) -> Self {
51        output.response
52    }
53}
54
55impl ReActAgentOutput {
56    /// Extract the agent output from the ReAct response
57    #[allow(clippy::result_large_err)]
58    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
59    where
60        T: for<'de> serde::Deserialize<'de>,
61    {
62        let react_output: Self = serde_json::from_value(val)
63            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
64        serde_json::from_str(&react_output.response)
65            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
66    }
67}
68
69#[derive(Error, Debug)]
70pub enum ReActExecutorError {
71    #[error("LLM error: {0}")]
72    LLMError(String),
73
74    #[error("Maximum turns exceeded: {max_turns}")]
75    MaxTurnsExceeded { max_turns: usize },
76
77    #[error("Other error: {0}")]
78    Other(String),
79
80    #[cfg(not(target_arch = "wasm32"))]
81    #[error("Event error: {0}")]
82    EventError(#[from] SendError<Event>),
83
84    #[cfg(target_arch = "wasm32")]
85    #[error("Event error: {0}")]
86    EventError(#[from] SendError),
87
88    #[error("Extracting Agent Output Error: {0}")]
89    AgentOutputError(String),
90}
91
92/// Wrapper type for ReAct executor
93#[derive(Debug)]
94pub struct ReActAgent<T: AgentDeriveT> {
95    inner: Arc<T>,
96}
97
98impl<T: AgentDeriveT> Clone for ReActAgent<T> {
99    fn clone(&self) -> Self {
100        Self {
101            inner: Arc::clone(&self.inner),
102        }
103    }
104}
105
106impl<T: AgentDeriveT> ReActAgent<T> {
107    pub fn new(inner: T) -> Self {
108        Self {
109            inner: Arc::new(inner),
110        }
111    }
112}
113
114impl<T: AgentDeriveT> Deref for ReActAgent<T> {
115    type Target = T;
116
117    fn deref(&self) -> &Self::Target {
118        &self.inner
119    }
120}
121
122/// Implement AgentDeriveT for the wrapper by delegating to the inner type
123#[async_trait]
124impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
125    type Output = <T as AgentDeriveT>::Output;
126
127    fn description(&self) -> &'static str {
128        self.inner.description()
129    }
130
131    fn output_schema(&self) -> Option<Value> {
132        self.inner.output_schema()
133    }
134
135    fn name(&self) -> &'static str {
136        self.inner.name()
137    }
138
139    fn tools(&self) -> Vec<Box<dyn ToolT>> {
140        self.inner.tools()
141    }
142}
143
144#[async_trait]
145impl<T> AgentHooks for ReActAgent<T>
146where
147    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
148{
149    async fn on_agent_create(&self) {
150        self.inner.on_agent_create().await
151    }
152
153    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
154        self.inner.on_run_start(task, ctx).await
155    }
156
157    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
158        self.inner.on_run_complete(task, result, ctx).await
159    }
160
161    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
162        self.inner.on_turn_start(turn_index, ctx).await
163    }
164
165    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
166        self.inner.on_turn_complete(turn_index, ctx).await
167    }
168
169    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
170        self.inner.on_tool_call(tool_call, ctx).await
171    }
172
173    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
174        self.inner.on_tool_start(tool_call, ctx).await
175    }
176
177    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
178        self.inner.on_tool_result(tool_call, result, ctx).await
179    }
180
181    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
182        self.inner.on_tool_error(tool_call, err, ctx).await
183    }
184    async fn on_agent_shutdown(&self) {
185        self.inner.on_agent_shutdown().await
186    }
187}
188
189impl<T: AgentDeriveT + AgentHooks> ReActAgent<T> {
190    /// Process a single turn with the LLM
191    async fn process_turn(
192        &self,
193        context: &Context,
194        tools: &[Box<dyn ToolT>],
195    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
196        let messages = self.prepare_messages(context).await;
197        let response = self.get_llm_response(context, &messages, tools).await?;
198        let response_text = response.text().unwrap_or_default();
199
200        if let Some(tool_calls) = response.tool_calls() {
201            self.handle_tool_calls(context, tools, tool_calls.clone(), response_text)
202                .await
203        } else {
204            self.handle_text_response(context, response_text).await
205        }
206    }
207
208    /// Get LLM response for the given messages and tools
209    async fn get_llm_response(
210        &self,
211        context: &Context,
212        messages: &[ChatMessage],
213        tools: &[Box<dyn ToolT>],
214    ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, ReActExecutorError> {
215        let llm = context.llm();
216        let agent_config = context.config();
217        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
218
219        llm.chat(
220            messages,
221            if tools.is_empty() {
222                None
223            } else {
224                Some(&tools_serialized)
225            },
226            agent_config.output_schema.clone(),
227        )
228        .await
229        .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
230    }
231
232    /// Handle tool calls and return the result
233    async fn handle_tool_calls(
234        &self,
235        context: &Context,
236        tools: &[Box<dyn ToolT>],
237        tool_calls: Vec<ToolCall>,
238        response_text: String,
239    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
240        let tx_event = context.tx().ok();
241
242        // Process tool calls
243        let mut tool_results = Vec::new();
244        for call in &tool_calls {
245            if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
246                self, context, tools, call, &tx_event,
247            )
248            .await
249            {
250                tool_results.push(result);
251            }
252        }
253
254        // Store in memory
255        MemoryHelper::store_tool_interaction(
256            &context.memory(),
257            &tool_calls,
258            &tool_results,
259            &response_text,
260        )
261        .await;
262
263        // Update state - use try_lock to avoid deadlock
264        {
265            let state = context.state();
266            #[cfg(not(target_arch = "wasm32"))]
267            if let Ok(mut guard) = state.try_lock() {
268                for result in &tool_results {
269                    guard.record_tool_call(result.clone());
270                }
271            };
272            #[cfg(target_arch = "wasm32")]
273            if let Some(mut guard) = state.try_lock() {
274                for result in &tool_results {
275                    guard.record_tool_call(result.clone());
276                }
277            };
278        }
279
280        Ok(TurnResult::Continue(Some(ReActAgentOutput {
281            response: response_text,
282            done: true,
283            tool_calls: tool_results,
284        })))
285    }
286
287    /// Handle text-only response
288    async fn handle_text_response(
289        &self,
290        context: &Context,
291        response_text: String,
292    ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
293        if !response_text.is_empty() {
294            MemoryHelper::store_assistant_response(&context.memory(), response_text.clone()).await;
295        }
296
297        Ok(TurnResult::Complete(ReActAgentOutput {
298            response: response_text,
299            done: true,
300            tool_calls: vec![],
301        }))
302    }
303
304    /// Prepare messages for the current turn
305    async fn prepare_messages(&self, context: &Context) -> Vec<ChatMessage> {
306        let mut messages = vec![ChatMessage {
307            role: ChatRole::System,
308            message_type: MessageType::Text,
309            content: context.config().description.clone(),
310        }];
311
312        let recalled = MemoryHelper::recall_messages(&context.memory()).await;
313        messages.extend(recalled);
314
315        messages
316    }
317
318    /// Process a streaming turn with tool support
319    async fn process_streaming_turn(
320        &self,
321        context: &Context,
322        tools: &[Box<dyn ToolT>],
323        tx: &mut Sender<Result<ReActAgentOutput, ReActExecutorError>>,
324        submission_id: SubmissionId,
325    ) -> Result<StreamingTurnResult, ReActExecutorError> {
326        let messages = self.prepare_messages(context).await;
327        let mut stream = self.get_llm_stream(context, &messages, tools).await?;
328
329        let mut response_text = String::new();
330        let mut tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)> =
331            HashMap::new();
332
333        // Process stream chunks
334        while let Some(chunk_result) = stream.next().await {
335            let chunk = chunk_result.map_err(|e| ReActExecutorError::LLMError(e.to_string()))?;
336
337            if let Some(choice) = chunk.choices.first() {
338                // Handle content
339                if let Some(content) = &choice.delta.content {
340                    response_text.push_str(content);
341                    let _ = tx
342                        .send(Ok(ReActAgentOutput {
343                            response: content.to_string(),
344                            tool_calls: vec![],
345                            done: false,
346                        }))
347                        .await;
348                }
349
350                // Handle tool calls
351                self.process_stream_tool_calls(&mut tool_calls_map, choice);
352
353                // Send stream chunk event
354                let tx_event = context.tx().ok();
355                EventHelper::send_stream_chunk(&tx_event, submission_id, choice.clone()).await;
356            }
357        }
358
359        // Process collected tool calls if any
360        self.finalize_stream_tool_calls(
361            context,
362            tools,
363            tool_calls_map,
364            submission_id,
365            response_text,
366        )
367        .await
368    }
369
370    /// Get streaming LLM response
371    async fn get_llm_stream(
372        &self,
373        context: &Context,
374        messages: &[ChatMessage],
375        tools: &[Box<dyn ToolT>],
376    ) -> Result<
377        Pin<Box<dyn Stream<Item = Result<autoagents_llm::chat::StreamResponse, LLMError>> + Send>>,
378        ReActExecutorError,
379    > {
380        let llm = context.llm();
381        let agent_config = context.config();
382        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
383
384        llm.chat_stream_struct(
385            messages,
386            if tools.is_empty() {
387                None
388            } else {
389                Some(&tools_serialized)
390            },
391            agent_config.output_schema.clone(),
392        )
393        .await
394        .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
395    }
396
397    /// Process tool calls from stream chunks
398    fn process_stream_tool_calls(
399        &self,
400        tool_calls_map: &mut HashMap<usize, (Option<String>, Option<String>, String)>,
401        choice: &StreamChoice,
402    ) {
403        if let Some(tool_call_deltas) = &choice.delta.tool_calls {
404            for delta in tool_call_deltas {
405                let entry =
406                    tool_calls_map
407                        .entry(delta.index)
408                        .or_insert((None, None, String::new()));
409
410                if let Some(function) = &delta.function {
411                    if !function.name.is_empty() {
412                        entry.0 = Some(function.name.clone());
413                    }
414                    entry.2.push_str(&function.arguments);
415                }
416            }
417        }
418    }
419
420    /// Finalize and process collected tool calls from streaming
421    async fn finalize_stream_tool_calls(
422        &self,
423        context: &Context,
424        tools: &[Box<dyn ToolT>],
425        tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)>,
426        submission_id: SubmissionId,
427        response_text: String,
428    ) -> Result<StreamingTurnResult, ReActExecutorError> {
429        if tool_calls_map.is_empty() {
430            if !response_text.is_empty() {
431                MemoryHelper::store_assistant_response(&context.memory(), response_text.clone())
432                    .await;
433            }
434            return Ok(StreamingTurnResult::Complete(response_text));
435        }
436
437        // Convert map to tool calls
438        let mut sorted_calls: Vec<_> = tool_calls_map.into_iter().collect();
439        sorted_calls.sort_by_key(|(index, _)| *index);
440
441        let collected_tool_calls: Vec<ToolCall> = sorted_calls
442            .into_iter()
443            .filter_map(|(_, (name, id, args))| {
444                name.map(|name| ToolCall {
445                    id: id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
446                    call_type: "function".to_string(),
447                    function: FunctionCall {
448                        name,
449                        arguments: args,
450                    },
451                })
452            })
453            .collect();
454
455        // Send tool call events
456        let tx_event = context.tx().ok();
457        for tool_call in &collected_tool_calls {
458            EventHelper::send_stream_tool_call(
459                &tx_event,
460                submission_id,
461                serde_json::to_value(tool_call).unwrap_or(Value::Null),
462            )
463            .await;
464        }
465
466        // Process tool calls
467        let tool_results =
468            ToolProcessor::process_tool_calls(tools, collected_tool_calls.clone(), tx_event).await;
469
470        // Update memory
471        MemoryHelper::store_tool_interaction(
472            &context.memory(),
473            &collected_tool_calls,
474            &tool_results,
475            &response_text,
476        )
477        .await;
478
479        // Update state
480        let state = context.state();
481        let mut guard = state.lock().await;
482        for result in &tool_results {
483            guard.record_tool_call(result.clone());
484        }
485
486        Ok(StreamingTurnResult::ToolCallsProcessed(tool_results))
487    }
488}
489
490/// Implementation of AgentExecutor for the ReActExecutorWrapper
491#[async_trait]
492impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
493    type Output = ReActAgentOutput;
494    type Error = ReActExecutorError;
495
496    fn config(&self) -> ExecutorConfig {
497        ExecutorConfig { max_turns: 10 }
498    }
499
500    async fn execute(
501        &self,
502        task: &Task,
503        context: Arc<Context>,
504    ) -> Result<Self::Output, Self::Error> {
505        // Initialize task
506        MemoryHelper::store_user_message(
507            &context.memory(),
508            task.prompt.clone(),
509            task.image.clone(),
510        )
511        .await;
512
513        // Record task in state - use try_lock to avoid deadlock
514        {
515            let state = context.state();
516            #[cfg(not(target_arch = "wasm32"))]
517            if let Ok(mut guard) = state.try_lock() {
518                guard.record_task(task.clone());
519            };
520            #[cfg(target_arch = "wasm32")]
521            if let Some(mut guard) = state.try_lock() {
522                guard.record_task(task.clone());
523            };
524        }
525
526        // Send task started event
527        let tx_event = context.tx().ok();
528        EventHelper::send_task_started(
529            &tx_event,
530            task.submission_id,
531            context.config().id,
532            task.prompt.clone(),
533            context.config().name.clone(),
534        )
535        .await;
536
537        // Execute turns
538        let max_turns = self.config().max_turns;
539        let mut accumulated_tool_calls = Vec::new();
540        let mut final_response = String::new();
541
542        for turn_num in 0..max_turns {
543            let tools = context.tools();
544            EventHelper::send_turn_started(&tx_event, turn_num, max_turns).await;
545
546            //Run Hook
547            self.on_turn_start(turn_num, &context).await;
548
549            match self.process_turn(&context, tools).await? {
550                TurnResult::Complete(result) => {
551                    if !accumulated_tool_calls.is_empty() {
552                        return Ok(ReActAgentOutput {
553                            response: result.response,
554                            done: true,
555                            tool_calls: accumulated_tool_calls,
556                        });
557                    }
558                    EventHelper::send_turn_completed(&tx_event, turn_num, false).await;
559                    //Run Hook
560                    self.on_turn_complete(turn_num, &context).await;
561                    return Ok(result);
562                }
563                TurnResult::Continue(Some(partial_result)) => {
564                    accumulated_tool_calls.extend(partial_result.tool_calls);
565                    if !partial_result.response.is_empty() {
566                        final_response = partial_result.response;
567                    }
568                }
569                TurnResult::Continue(None) => continue,
570            }
571        }
572
573        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
574            EventHelper::send_task_completed(
575                &tx_event,
576                task.submission_id,
577                context.config().id,
578                final_response.clone(),
579                context.config().name.clone(),
580            )
581            .await;
582            Ok(ReActAgentOutput {
583                response: final_response,
584                done: true,
585                tool_calls: accumulated_tool_calls,
586            })
587        } else {
588            Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
589        }
590    }
591
592    async fn execute_stream(
593        &self,
594        task: &Task,
595        context: Arc<Context>,
596    ) -> Result<
597        Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
598        Self::Error,
599    > {
600        // Initialize task
601        MemoryHelper::store_user_message(
602            &context.memory(),
603            task.prompt.clone(),
604            task.image.clone(),
605        )
606        .await;
607
608        // Record task in state - use try_lock to avoid deadlock
609        {
610            let state = context.state();
611            #[cfg(not(target_arch = "wasm32"))]
612            if let Ok(mut guard) = state.try_lock() {
613                guard.record_task(task.clone());
614            };
615            #[cfg(target_arch = "wasm32")]
616            if let Some(mut guard) = state.try_lock() {
617                guard.record_task(task.clone());
618            };
619        }
620
621        // Send task started event
622        let tx_event = context.tx().ok();
623        EventHelper::send_task_started(
624            &tx_event,
625            task.submission_id,
626            context.config().id,
627            task.prompt.clone(),
628            context.config().name.clone(),
629        )
630        .await;
631
632        // Create channel for streaming
633        let (mut tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
634
635        // Clone necessary components
636        let executor = self.clone();
637        let context_clone = context.clone();
638        let submission_id = task.submission_id;
639        let max_turns = executor.config().max_turns;
640
641        // Spawn streaming task
642        spawn_future(async move {
643            let mut accumulated_tool_calls = Vec::new();
644            let mut final_response = String::new();
645            let tools = context_clone.tools();
646
647            for turn in 0..max_turns {
648                // Send turn events
649                let tx_event = context_clone.tx().ok();
650                EventHelper::send_turn_started(&tx_event, turn, max_turns).await;
651
652                // Process streaming turn
653                match executor
654                    .process_streaming_turn(&context_clone, tools, &mut tx, submission_id)
655                    .await
656                {
657                    Ok(StreamingTurnResult::Complete(response)) => {
658                        final_response = response;
659                        EventHelper::send_turn_completed(&tx_event, turn, true).await;
660                        break;
661                    }
662                    Ok(StreamingTurnResult::ToolCallsProcessed(tool_results)) => {
663                        accumulated_tool_calls.extend(tool_results);
664
665                        let _ = tx
666                            .send(Ok(ReActAgentOutput {
667                                response: String::new(),
668                                done: false,
669                                tool_calls: accumulated_tool_calls.clone(),
670                            }))
671                            .await;
672
673                        EventHelper::send_turn_completed(&tx_event, turn, false).await;
674                    }
675                    Err(e) => {
676                        let _ = tx.send(Err(e)).await;
677                        return;
678                    }
679                }
680            }
681
682            // Send final result
683            let tx_event = context_clone.tx().ok();
684            EventHelper::send_stream_complete(&tx_event, submission_id).await;
685
686            let _ = tx
687                .send(Ok(ReActAgentOutput {
688                    response: final_response,
689                    done: true,
690                    tool_calls: accumulated_tool_calls,
691                }))
692                .await;
693        });
694
695        Ok(receiver_into_stream(rx))
696    }
697}
698
699#[cfg(test)]
700mod tests {
701    use super::*;
702
703    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
704    struct TestAgentOutput {
705        value: i32,
706        message: String,
707    }
708
709    #[test]
710    fn test_extract_agent_output_success() {
711        let agent_output = TestAgentOutput {
712            value: 42,
713            message: "Hello, world!".to_string(),
714        };
715
716        let react_output = ReActAgentOutput {
717            response: serde_json::to_string(&agent_output).unwrap(),
718            done: true,
719            tool_calls: vec![],
720        };
721
722        let react_value = serde_json::to_value(react_output).unwrap();
723        let extracted: TestAgentOutput =
724            ReActAgentOutput::extract_agent_output(react_value).unwrap();
725        assert_eq!(extracted, agent_output);
726    }
727}