Skip to main content

autoagents_core/agent/prebuilt/executor/
react.rs

1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4    TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22pub use tokio::sync::mpsc::error::SendError;
23
24#[cfg(target_arch = "wasm32")]
25type SendError = futures::channel::mpsc::SendError;
26
27use crate::agent::hooks::{AgentHooks, HookOutcome};
28use autoagents_protocol::Event;
29
30/// Output of the ReAct-style agent
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReActAgentOutput {
33    pub response: String,
34    pub tool_calls: Vec<ToolCallResult>,
35    pub done: bool,
36}
37
38impl From<ReActAgentOutput> for Value {
39    fn from(output: ReActAgentOutput) -> Self {
40        serde_json::to_value(output).unwrap_or(Value::Null)
41    }
42}
43impl From<ReActAgentOutput> for String {
44    fn from(output: ReActAgentOutput) -> Self {
45        output.response
46    }
47}
48
49impl ReActAgentOutput {
50    /// Try to parse the response string as structured JSON of type `T`.
51    /// Returns `serde_json::Error` if parsing fails.
52    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
53        serde_json::from_str::<T>(&self.response)
54    }
55
56    /// Parse the response string as structured JSON of type `T`, or map the raw
57    /// text into `T` using the provided fallback function if parsing fails.
58    /// This is useful in examples to avoid repeating parsing boilerplate.
59    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
60    where
61        T: for<'de> serde::Deserialize<'de>,
62        F: FnOnce(&str) -> T,
63    {
64        self.try_parse::<T>()
65            .unwrap_or_else(|_| fallback(&self.response))
66    }
67}
68
69impl ReActAgentOutput {
70    /// Extract the agent output from the ReAct response
71    #[allow(clippy::result_large_err)]
72    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
73    where
74        T: for<'de> serde::Deserialize<'de>,
75    {
76        let react_output: Self = serde_json::from_value(val)
77            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
78        serde_json::from_str(&react_output.response)
79            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
80    }
81}
82
83#[derive(Error, Debug)]
84pub enum ReActExecutorError {
85    #[error("LLM error: {0}")]
86    LLMError(String),
87
88    #[error("Maximum turns exceeded: {max_turns}")]
89    MaxTurnsExceeded { max_turns: usize },
90
91    #[error("Other error: {0}")]
92    Other(String),
93
94    #[cfg(not(target_arch = "wasm32"))]
95    #[error("Event error: {0}")]
96    EventError(#[from] SendError<Event>),
97
98    #[cfg(target_arch = "wasm32")]
99    #[error("Event error: {0}")]
100    EventError(#[from] SendError),
101
102    #[error("Extracting Agent Output Error: {0}")]
103    AgentOutputError(String),
104}
105
106impl From<TurnEngineError> for ReActExecutorError {
107    fn from(error: TurnEngineError) -> Self {
108        match error {
109            TurnEngineError::LLMError(err) => ReActExecutorError::LLMError(err),
110            TurnEngineError::Aborted => {
111                ReActExecutorError::Other("Run aborted by hook".to_string())
112            }
113            TurnEngineError::Other(err) => ReActExecutorError::Other(err),
114        }
115    }
116}
117
118/// Wrapper type for the multi-turn ReAct executor with tool calling support.
119///
120/// Use `ReActAgent<T>` when your agent needs to perform tool calls, manage
121/// multiple turns, and optionally stream content and tool-call deltas.
122#[derive(Debug)]
123pub struct ReActAgent<T: AgentDeriveT> {
124    inner: Arc<T>,
125}
126
127impl<T: AgentDeriveT> Clone for ReActAgent<T> {
128    fn clone(&self) -> Self {
129        Self {
130            inner: Arc::clone(&self.inner),
131        }
132    }
133}
134
135impl<T: AgentDeriveT> ReActAgent<T> {
136    pub fn new(inner: T) -> Self {
137        Self {
138            inner: Arc::new(inner),
139        }
140    }
141}
142
143impl<T: AgentDeriveT> Deref for ReActAgent<T> {
144    type Target = T;
145
146    fn deref(&self) -> &Self::Target {
147        &self.inner
148    }
149}
150
151/// Implement AgentDeriveT for the wrapper by delegating to the inner type
152#[async_trait]
153impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
154    type Output = <T as AgentDeriveT>::Output;
155
156    fn description(&self) -> &str {
157        self.inner.description()
158    }
159
160    fn output_schema(&self) -> Option<Value> {
161        self.inner.output_schema()
162    }
163
164    fn name(&self) -> &str {
165        self.inner.name()
166    }
167
168    fn tools(&self) -> Vec<Box<dyn ToolT>> {
169        self.inner.tools()
170    }
171}
172
173#[async_trait]
174impl<T> AgentHooks for ReActAgent<T>
175where
176    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
177{
178    async fn on_agent_create(&self) {
179        self.inner.on_agent_create().await
180    }
181
182    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
183        self.inner.on_run_start(task, ctx).await
184    }
185
186    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
187        self.inner.on_run_complete(task, result, ctx).await
188    }
189
190    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
191        self.inner.on_turn_start(turn_index, ctx).await
192    }
193
194    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
195        self.inner.on_turn_complete(turn_index, ctx).await
196    }
197
198    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
199        self.inner.on_tool_call(tool_call, ctx).await
200    }
201
202    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
203        self.inner.on_tool_start(tool_call, ctx).await
204    }
205
206    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
207        self.inner.on_tool_result(tool_call, result, ctx).await
208    }
209
210    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
211        self.inner.on_tool_error(tool_call, err, ctx).await
212    }
213    async fn on_agent_shutdown(&self) {
214        self.inner.on_agent_shutdown().await
215    }
216}
217
218/// Implementation of AgentExecutor for the ReActExecutorWrapper
219#[async_trait]
220impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
221    type Output = ReActAgentOutput;
222    type Error = ReActExecutorError;
223
224    fn config(&self) -> ExecutorConfig {
225        ExecutorConfig { max_turns: 10 }
226    }
227
228    async fn execute(
229        &self,
230        task: &Task,
231        context: Arc<Context>,
232    ) -> Result<Self::Output, Self::Error> {
233        if self.on_run_start(task, &context).await == HookOutcome::Abort {
234            return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
235        }
236
237        record_task_state(&context, task);
238
239        let tx_event = context.tx().ok();
240        EventHelper::send_task_started(
241            &tx_event,
242            task.submission_id,
243            context.config().id,
244            context.config().name.clone(),
245            task.prompt.clone(),
246        )
247        .await;
248
249        let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
250        let mut turn_state = engine.turn_state(&context);
251        let max_turns = self.config().max_turns;
252        let mut accumulated_tool_calls = Vec::new();
253        let mut final_response = String::new();
254
255        for turn_index in 0..max_turns {
256            let result = engine
257                .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
258                .await?;
259
260            match result {
261                crate::agent::executor::TurnResult::Complete(output) => {
262                    final_response = output.response.clone();
263                    EventHelper::send_task_completed(
264                        &tx_event,
265                        task.submission_id,
266                        context.config().id,
267                        context.config().name.clone(),
268                        final_response.clone(),
269                    )
270                    .await;
271
272                    accumulated_tool_calls.extend(output.tool_calls);
273
274                    return Ok(ReActAgentOutput {
275                        response: final_response,
276                        done: true,
277                        tool_calls: accumulated_tool_calls,
278                    });
279                }
280                crate::agent::executor::TurnResult::Continue(Some(output)) => {
281                    if !output.response.is_empty() {
282                        final_response = output.response;
283                    }
284                    accumulated_tool_calls.extend(output.tool_calls);
285                }
286                crate::agent::executor::TurnResult::Continue(None) => {}
287            }
288        }
289
290        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
291            EventHelper::send_task_completed(
292                &tx_event,
293                task.submission_id,
294                context.config().id,
295                context.config().name.clone(),
296                final_response.clone(),
297            )
298            .await;
299
300            return Ok(ReActAgentOutput {
301                response: final_response,
302                done: true,
303                tool_calls: accumulated_tool_calls,
304            });
305        }
306
307        Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
308    }
309
310    async fn execute_stream(
311        &self,
312        task: &Task,
313        context: Arc<Context>,
314    ) -> Result<
315        Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
316        Self::Error,
317    > {
318        if self.on_run_start(task, &context).await == HookOutcome::Abort {
319            return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
320        }
321
322        record_task_state(&context, task);
323
324        let tx_event = context.tx().ok();
325        EventHelper::send_task_started(
326            &tx_event,
327            task.submission_id,
328            context.config().id,
329            context.config().name.clone(),
330            task.prompt.clone(),
331        )
332        .await;
333
334        let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
335        let mut turn_state = engine.turn_state(&context);
336        let max_turns = self.config().max_turns;
337        let context_clone = context.clone();
338        let task = task.clone();
339        let executor = self.clone();
340
341        let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
342
343        spawn_future(async move {
344            let mut accumulated_tool_calls = Vec::new();
345            let mut final_response = String::new();
346
347            for turn_index in 0..max_turns {
348                let turn_stream = engine
349                    .run_turn_stream(
350                        executor.clone(),
351                        &task,
352                        context_clone.clone(),
353                        &mut turn_state,
354                        turn_index,
355                        max_turns,
356                    )
357                    .await;
358
359                let mut turn_result = None;
360
361                match turn_stream {
362                    Ok(mut stream) => {
363                        use futures::StreamExt;
364                        while let Some(delta_result) = stream.next().await {
365                            match delta_result {
366                                Ok(TurnDelta::Text(content)) => {
367                                    let _ = tx
368                                        .send(Ok(ReActAgentOutput {
369                                            response: content,
370                                            tool_calls: Vec::new(),
371                                            done: false,
372                                        }))
373                                        .await;
374                                }
375                                Ok(TurnDelta::ToolResults(tool_results)) => {
376                                    accumulated_tool_calls.extend(tool_results);
377                                    let _ = tx
378                                        .send(Ok(ReActAgentOutput {
379                                            response: String::new(),
380                                            tool_calls: accumulated_tool_calls.clone(),
381                                            done: false,
382                                        }))
383                                        .await;
384                                }
385                                Ok(TurnDelta::Done(result)) => {
386                                    turn_result = Some(result);
387                                    break;
388                                }
389                                Err(err) => {
390                                    let _ = tx.send(Err(err.into())).await;
391                                    return;
392                                }
393                            }
394                        }
395                    }
396                    Err(err) => {
397                        let _ = tx.send(Err(err.into())).await;
398                        return;
399                    }
400                }
401
402                let Some(result) = turn_result else {
403                    let _ = tx
404                        .send(Err(ReActExecutorError::Other(
405                            "Stream ended without final result".to_string(),
406                        )))
407                        .await;
408                    return;
409                };
410
411                match result {
412                    crate::agent::executor::TurnResult::Complete(output) => {
413                        final_response = output.response.clone();
414                        accumulated_tool_calls.extend(output.tool_calls);
415                        break;
416                    }
417                    crate::agent::executor::TurnResult::Continue(Some(output)) => {
418                        if !output.response.is_empty() {
419                            final_response = output.response;
420                        }
421                        accumulated_tool_calls.extend(output.tool_calls);
422                    }
423                    crate::agent::executor::TurnResult::Continue(None) => {}
424                }
425            }
426
427            let tx_event = context_clone.tx().ok();
428            EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
429            let _ = tx
430                .send(Ok(ReActAgentOutput {
431                    response: final_response.clone(),
432                    done: true,
433                    tool_calls: accumulated_tool_calls.clone(),
434                }))
435                .await;
436
437            if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
438                EventHelper::send_task_completed(
439                    &tx_event,
440                    task.submission_id,
441                    context_clone.config().id,
442                    context_clone.config().name.clone(),
443                    final_response,
444                )
445                .await;
446            }
447        });
448
449        Ok(receiver_into_stream(rx))
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456    use crate::tests::{
457        ConfigurableLLMProvider, MockAgentImpl, StaticChatResponse,
458        TestAgentOutput as TestUtilsOutput,
459    };
460    use async_trait::async_trait;
461    use autoagents_llm::chat::StreamChunk;
462    use autoagents_llm::{FunctionCall, ToolCall};
463
464    #[derive(Debug)]
465    struct LocalTool {
466        name: String,
467        output: serde_json::Value,
468    }
469
470    impl LocalTool {
471        fn new(name: &str, output: serde_json::Value) -> Self {
472            Self {
473                name: name.to_string(),
474                output,
475            }
476        }
477    }
478
479    impl crate::tool::ToolT for LocalTool {
480        fn name(&self) -> &str {
481            &self.name
482        }
483
484        fn description(&self) -> &str {
485            "local tool"
486        }
487
488        fn args_schema(&self) -> serde_json::Value {
489            serde_json::json!({"type": "object"})
490        }
491    }
492
493    #[async_trait]
494    impl crate::tool::ToolRuntime for LocalTool {
495        async fn execute(
496            &self,
497            _args: serde_json::Value,
498        ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
499            Ok(self.output.clone())
500        }
501    }
502
503    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
504    struct ReActTestOutput {
505        value: i32,
506        message: String,
507    }
508
509    #[derive(Debug, Clone)]
510    struct AbortAgent;
511
512    #[async_trait]
513    impl AgentDeriveT for AbortAgent {
514        type Output = TestUtilsOutput;
515
516        fn description(&self) -> &str {
517            "abort"
518        }
519
520        fn output_schema(&self) -> Option<Value> {
521            None
522        }
523
524        fn name(&self) -> &str {
525            "abort_agent"
526        }
527
528        fn tools(&self) -> Vec<Box<dyn ToolT>> {
529            vec![]
530        }
531    }
532
533    #[async_trait]
534    impl AgentHooks for AbortAgent {
535        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
536            HookOutcome::Abort
537        }
538    }
539
540    #[test]
541    fn test_extract_agent_output_success() {
542        let agent_output = ReActTestOutput {
543            value: 42,
544            message: "Hello, world!".to_string(),
545        };
546
547        let react_output = ReActAgentOutput {
548            response: serde_json::to_string(&agent_output).unwrap(),
549            done: true,
550            tool_calls: vec![],
551        };
552
553        let react_value = serde_json::to_value(react_output).unwrap();
554        let extracted: ReActTestOutput =
555            ReActAgentOutput::extract_agent_output(react_value).unwrap();
556        assert_eq!(extracted, agent_output);
557    }
558
559    #[test]
560    fn test_extract_agent_output_invalid_react() {
561        let result = ReActAgentOutput::extract_agent_output::<ReActTestOutput>(
562            serde_json::json!({"not": "react"}),
563        );
564        assert!(result.is_err());
565    }
566
567    #[test]
568    fn test_react_agent_output_try_parse_success() {
569        let output = ReActAgentOutput {
570            response: r#"{"value":1,"message":"hi"}"#.to_string(),
571            tool_calls: vec![],
572            done: true,
573        };
574        let parsed: ReActTestOutput = output.try_parse().unwrap();
575        assert_eq!(parsed.value, 1);
576    }
577
578    #[test]
579    fn test_react_agent_output_try_parse_failure() {
580        let output = ReActAgentOutput {
581            response: "not json".to_string(),
582            tool_calls: vec![],
583            done: true,
584        };
585        assert!(output.try_parse::<ReActTestOutput>().is_err());
586    }
587
588    #[test]
589    fn test_react_agent_output_parse_or_map() {
590        let output = ReActAgentOutput {
591            response: "plain text".to_string(),
592            tool_calls: vec![],
593            done: true,
594        };
595        let result: String = output.parse_or_map(|s| s.to_uppercase());
596        assert_eq!(result, "PLAIN TEXT");
597    }
598
599    #[test]
600    fn test_error_from_turn_engine_llm() {
601        let err: ReActExecutorError = TurnEngineError::LLMError("llm err".to_string()).into();
602        assert!(matches!(err, ReActExecutorError::LLMError(_)));
603    }
604
605    #[test]
606    fn test_error_from_turn_engine_aborted() {
607        let err: ReActExecutorError = TurnEngineError::Aborted.into();
608        assert!(matches!(err, ReActExecutorError::Other(_)));
609    }
610
611    #[test]
612    fn test_error_from_turn_engine_other() {
613        let err: ReActExecutorError = TurnEngineError::Other("other".to_string()).into();
614        assert!(matches!(err, ReActExecutorError::Other(_)));
615    }
616
617    #[test]
618    fn test_react_agent_config() {
619        let mock = MockAgentImpl::new("cfg_test", "desc");
620        let agent = ReActAgent::new(mock);
621        assert_eq!(agent.config().max_turns, 10);
622    }
623
624    #[test]
625    fn test_react_agent_metadata_and_output_conversion() {
626        let mock = MockAgentImpl::new("react_meta", "desc");
627        let agent = ReActAgent::new(mock);
628        let cloned = agent.clone();
629        assert_eq!(cloned.name(), "react_meta");
630        assert_eq!(cloned.description(), "desc");
631
632        let output = ReActAgentOutput {
633            response: "resp".to_string(),
634            tool_calls: vec![],
635            done: true,
636        };
637        let value: Value = output.clone().into();
638        assert_eq!(value["response"], "resp");
639        let string: String = output.into();
640        assert_eq!(string, "resp");
641    }
642
643    #[tokio::test]
644    async fn test_react_agent_execute() {
645        use crate::agent::{AgentConfig, Context};
646        use crate::tests::MockLLMProvider;
647        use autoagents_protocol::ActorID;
648
649        let mock = MockAgentImpl::new("exec_test", "desc");
650        let agent = ReActAgent::new(mock);
651        let llm = std::sync::Arc::new(MockLLMProvider {});
652        let config = AgentConfig {
653            id: ActorID::new_v4(),
654            name: "exec_test".to_string(),
655            description: "desc".to_string(),
656            output_schema: None,
657        };
658        let context = Arc::new(Context::new(llm, None).with_config(config));
659        let task = crate::agent::task::Task::new("test");
660
661        let result = agent.execute(&task, context).await;
662        assert!(result.is_ok());
663        let output = result.unwrap();
664        assert!(output.done);
665        assert_eq!(output.response, "Mock response");
666    }
667
668    #[tokio::test]
669    async fn test_react_agent_execute_with_tool_calls() {
670        use crate::agent::{AgentConfig, Context};
671        use autoagents_protocol::ActorID;
672
673        let tool_call = ToolCall {
674            id: "call_1".to_string(),
675            call_type: "function".to_string(),
676            function: autoagents_llm::FunctionCall {
677                name: "tool_a".to_string(),
678                arguments: r#"{"value":1}"#.to_string(),
679            },
680        };
681
682        let llm = Arc::new(ConfigurableLLMProvider {
683            chat_response: StaticChatResponse {
684                text: Some("Use tool".to_string()),
685                tool_calls: Some(vec![tool_call.clone()]),
686                usage: None,
687                thinking: None,
688            },
689            ..ConfigurableLLMProvider::default()
690        });
691
692        let mock = MockAgentImpl::new("exec_tool", "desc");
693        let agent = ReActAgent::new(mock);
694        let config = AgentConfig {
695            id: ActorID::new_v4(),
696            name: "exec_tool".to_string(),
697            description: "desc".to_string(),
698            output_schema: None,
699        };
700
701        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
702        let context = Arc::new(
703            Context::new(llm, None)
704                .with_config(config)
705                .with_tools(vec![Box::new(tool)]),
706        );
707        let task = crate::agent::task::Task::new("test");
708
709        let result = agent.execute(&task, context).await.unwrap();
710        assert!(result.done);
711        assert!(!result.tool_calls.is_empty());
712        assert!(result.tool_calls[0].success);
713    }
714
715    #[tokio::test]
716    async fn test_react_agent_execute_stream_text() {
717        use crate::agent::{AgentConfig, Context};
718        use autoagents_protocol::ActorID;
719        use futures::StreamExt;
720
721        let llm = Arc::new(ConfigurableLLMProvider {
722            stream_chunks: vec![
723                StreamChunk::Text("Hello ".to_string()),
724                StreamChunk::Text("world".to_string()),
725                StreamChunk::Done {
726                    stop_reason: "end_turn".to_string(),
727                },
728            ],
729            ..ConfigurableLLMProvider::default()
730        });
731
732        let mock = MockAgentImpl::new("stream_test", "desc");
733        let agent = ReActAgent::new(mock);
734        let config = AgentConfig {
735            id: ActorID::new_v4(),
736            name: "stream_test".to_string(),
737            description: "desc".to_string(),
738            output_schema: None,
739        };
740        let context = Arc::new(Context::new(llm, None).with_config(config));
741        let task = crate::agent::task::Task::new("test");
742
743        let mut stream = agent.execute_stream(&task, context).await.unwrap();
744        let mut final_output = None;
745        while let Some(item) = stream.next().await {
746            let output = item.unwrap();
747            if output.done {
748                final_output = Some(output);
749                break;
750            }
751        }
752
753        let output = final_output.expect("final output");
754        assert_eq!(output.response, "Hello world");
755        assert!(output.done);
756    }
757
758    #[tokio::test]
759    async fn test_react_agent_execute_stream_tool_results() {
760        use crate::agent::{AgentConfig, Context};
761        use autoagents_protocol::ActorID;
762        use futures::StreamExt;
763
764        let tool_call = ToolCall {
765            id: "call_1".to_string(),
766            call_type: "function".to_string(),
767            function: FunctionCall {
768                name: "tool_a".to_string(),
769                arguments: r#"{"value":1}"#.to_string(),
770            },
771        };
772
773        let llm = Arc::new(ConfigurableLLMProvider {
774            stream_chunks: vec![StreamChunk::ToolUseComplete {
775                index: 0,
776                tool_call: tool_call.clone(),
777            }],
778            ..ConfigurableLLMProvider::default()
779        });
780
781        let mock = MockAgentImpl::new("stream_tool", "desc");
782        let agent = ReActAgent::new(mock);
783        let config = AgentConfig {
784            id: ActorID::new_v4(),
785            name: "stream_tool".to_string(),
786            description: "desc".to_string(),
787            output_schema: None,
788        };
789        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
790        let context = Arc::new(
791            Context::new(llm, None)
792                .with_config(config)
793                .with_tools(vec![Box::new(tool)]),
794        );
795        let task = crate::agent::task::Task::new("test");
796
797        let mut stream = agent.execute_stream(&task, context).await.unwrap();
798        let mut saw_tool_results = false;
799        let mut final_output = None;
800
801        while let Some(item) = stream.next().await {
802            let output = item.unwrap();
803            if !output.tool_calls.is_empty() {
804                saw_tool_results = true;
805                assert!(output.tool_calls[0].success);
806            }
807            if output.done {
808                final_output = Some(output);
809                break;
810            }
811        }
812
813        assert!(saw_tool_results);
814        let output = final_output.expect("final output");
815        assert!(output.done);
816        assert!(!output.tool_calls.is_empty());
817    }
818
819    #[tokio::test]
820    async fn test_react_agent_execute_aborts_on_hook() {
821        use crate::agent::{AgentConfig, Context};
822        use crate::tests::MockLLMProvider;
823        use autoagents_protocol::ActorID;
824
825        let agent = ReActAgent::new(AbortAgent);
826        let llm = Arc::new(MockLLMProvider {});
827        let config = AgentConfig {
828            id: ActorID::new_v4(),
829            name: "abort_agent".to_string(),
830            description: "abort".to_string(),
831            output_schema: None,
832        };
833        let context = Arc::new(Context::new(llm, None).with_config(config));
834        let task = crate::agent::task::Task::new("abort");
835
836        let err = agent.execute(&task, context).await.unwrap_err();
837        assert!(err.to_string().contains("aborted"));
838    }
839
840    #[tokio::test]
841    async fn test_react_agent_execute_stream_aborts_on_hook() {
842        use crate::agent::{AgentConfig, Context};
843        use crate::tests::MockLLMProvider;
844        use autoagents_protocol::ActorID;
845
846        let agent = ReActAgent::new(AbortAgent);
847        let llm = Arc::new(MockLLMProvider {});
848        let config = AgentConfig {
849            id: ActorID::new_v4(),
850            name: "abort_agent".to_string(),
851            description: "abort".to_string(),
852            output_schema: None,
853        };
854        let context = Arc::new(Context::new(llm, None).with_config(config));
855        let task = crate::agent::task::Task::new("abort");
856
857        let err = agent.execute_stream(&task, context).await.err().unwrap();
858        assert!(err.to_string().contains("aborted"));
859    }
860}