Skip to main content

autoagents_core/agent/prebuilt/executor/
react.rs

1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4    TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use autoagents_llm::error::LLMError;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashSet;
18use std::ops::Deref;
19use std::pin::Pin;
20use std::sync::Arc;
21use thiserror::Error;
22
23#[cfg(not(target_arch = "wasm32"))]
24pub use tokio::sync::mpsc::error::SendError;
25
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::hooks::{AgentHooks, HookOutcome};
30use autoagents_protocol::Event;
31
32/// Output of the ReAct-style agent
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ReActAgentOutput {
35    pub response: String,
36    pub tool_calls: Vec<ToolCallResult>,
37    pub done: bool,
38}
39
40impl From<ReActAgentOutput> for Value {
41    fn from(output: ReActAgentOutput) -> Self {
42        serde_json::to_value(output).unwrap_or(Value::Null)
43    }
44}
45impl From<ReActAgentOutput> for String {
46    fn from(output: ReActAgentOutput) -> Self {
47        output.response
48    }
49}
50
51impl ReActAgentOutput {
52    /// Try to parse the response string as structured JSON of type `T`.
53    /// Returns `serde_json::Error` if parsing fails.
54    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
55        serde_json::from_str::<T>(&self.response)
56    }
57
58    /// Parse the response string as structured JSON of type `T`, or map the raw
59    /// text into `T` using the provided fallback function if parsing fails.
60    /// This is useful in examples to avoid repeating parsing boilerplate.
61    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
62    where
63        T: for<'de> serde::Deserialize<'de>,
64        F: FnOnce(&str) -> T,
65    {
66        self.try_parse::<T>()
67            .unwrap_or_else(|_| fallback(&self.response))
68    }
69}
70
71fn dedupe_tool_calls(tool_calls: Vec<ToolCallResult>) -> Vec<ToolCallResult> {
72    let mut seen = HashSet::new();
73    let mut unique = Vec::with_capacity(tool_calls.len());
74
75    for call in tool_calls {
76        let key = format!(
77            "{}|{}|{}|{}",
78            call.tool_name, call.success, call.arguments, call.result
79        );
80        if seen.insert(key) {
81            unique.push(call);
82        }
83    }
84
85    unique
86}
87
88impl ReActAgentOutput {
89    /// Extract the agent output from the ReAct response
90    #[allow(clippy::result_large_err)]
91    pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
92    where
93        T: for<'de> serde::Deserialize<'de>,
94    {
95        let react_output: Self = serde_json::from_value(val)
96            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
97        serde_json::from_str(&react_output.response)
98            .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
99    }
100}
101
102#[derive(Error, Debug)]
103pub enum ReActExecutorError {
104    #[error("LLM error: {0}")]
105    LLMError(
106        #[from]
107        #[source]
108        LLMError,
109    ),
110
111    #[error("Maximum turns exceeded: {max_turns}")]
112    MaxTurnsExceeded { max_turns: usize },
113
114    #[error("Other error: {0}")]
115    Other(String),
116
117    #[cfg(not(target_arch = "wasm32"))]
118    #[error("Event error: {0}")]
119    EventError(#[from] SendError<Event>),
120
121    #[cfg(target_arch = "wasm32")]
122    #[error("Event error: {0}")]
123    EventError(#[from] SendError),
124
125    #[error("Extracting Agent Output Error: {0}")]
126    AgentOutputError(String),
127}
128
129impl From<TurnEngineError> for ReActExecutorError {
130    fn from(error: TurnEngineError) -> Self {
131        match error {
132            TurnEngineError::LLMError(err) => err.into(),
133            TurnEngineError::Aborted => {
134                ReActExecutorError::Other("Run aborted by hook".to_string())
135            }
136            TurnEngineError::Other(err) => ReActExecutorError::Other(err),
137        }
138    }
139}
140
141/// Wrapper type for the multi-turn ReAct executor with tool calling support.
142///
143/// Use `ReActAgent<T>` when your agent needs to perform tool calls, manage
144/// multiple turns, and optionally stream content and tool-call deltas.
145#[derive(Debug)]
146pub struct ReActAgent<T: AgentDeriveT> {
147    inner: Arc<T>,
148    max_turns: usize,
149}
150
151impl<T: AgentDeriveT> Clone for ReActAgent<T> {
152    fn clone(&self) -> Self {
153        Self {
154            inner: Arc::clone(&self.inner),
155            max_turns: self.max_turns,
156        }
157    }
158}
159
160impl<T: AgentDeriveT> ReActAgent<T> {
161    pub fn new(inner: T) -> Self {
162        Self {
163            inner: Arc::new(inner),
164            max_turns: 10,
165        }
166    }
167
168    pub fn with_max_turns(inner: T, max_turns: usize) -> Self {
169        Self {
170            inner: Arc::new(inner),
171            max_turns: max_turns.max(1),
172        }
173    }
174}
175
176impl<T: AgentDeriveT> Deref for ReActAgent<T> {
177    type Target = T;
178
179    fn deref(&self) -> &Self::Target {
180        &self.inner
181    }
182}
183
184/// Implement AgentDeriveT for the wrapper by delegating to the inner type
185#[async_trait]
186impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
187    type Output = <T as AgentDeriveT>::Output;
188
189    fn description(&self) -> &str {
190        self.inner.description()
191    }
192
193    fn output_schema(&self) -> Option<Value> {
194        self.inner.output_schema()
195    }
196
197    fn name(&self) -> &str {
198        self.inner.name()
199    }
200
201    fn tools(&self) -> Vec<Box<dyn ToolT>> {
202        self.inner.tools()
203    }
204}
205
206#[async_trait]
207impl<T> AgentHooks for ReActAgent<T>
208where
209    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
210{
211    async fn on_agent_create(&self) {
212        self.inner.on_agent_create().await
213    }
214
215    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
216        self.inner.on_run_start(task, ctx).await
217    }
218
219    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
220        self.inner.on_run_complete(task, result, ctx).await
221    }
222
223    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
224        self.inner.on_turn_start(turn_index, ctx).await
225    }
226
227    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
228        self.inner.on_turn_complete(turn_index, ctx).await
229    }
230
231    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
232        self.inner.on_tool_call(tool_call, ctx).await
233    }
234
235    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
236        self.inner.on_tool_start(tool_call, ctx).await
237    }
238
239    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
240        self.inner.on_tool_result(tool_call, result, ctx).await
241    }
242
243    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
244        self.inner.on_tool_error(tool_call, err, ctx).await
245    }
246    async fn on_agent_shutdown(&self) {
247        self.inner.on_agent_shutdown().await
248    }
249}
250
251/// Implementation of AgentExecutor for the ReActExecutorWrapper
252#[async_trait]
253impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
254    type Output = ReActAgentOutput;
255    type Error = ReActExecutorError;
256
257    fn config(&self) -> ExecutorConfig {
258        ExecutorConfig {
259            max_turns: self.max_turns,
260        }
261    }
262
263    async fn execute(
264        &self,
265        task: &Task,
266        context: Arc<Context>,
267    ) -> Result<Self::Output, Self::Error> {
268        record_task_state(&context, task);
269
270        let tx_event = context.tx().ok();
271        EventHelper::send_task_started(
272            &tx_event,
273            task.submission_id,
274            context.config().id,
275            context.config().name.clone(),
276            task.prompt.clone(),
277        )
278        .await;
279
280        let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
281        let mut turn_state = engine.turn_state(&context);
282        let max_turns = self.config().max_turns;
283        let mut accumulated_tool_calls = Vec::new();
284        let mut final_response = String::new();
285        for turn_index in 0..max_turns {
286            let result = engine
287                .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
288                .await?;
289
290            match result {
291                crate::agent::executor::TurnResult::Complete(output) => {
292                    final_response = output.response.clone();
293                    accumulated_tool_calls.extend(output.tool_calls);
294                    accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
295
296                    return Ok(ReActAgentOutput {
297                        response: final_response,
298                        done: true,
299                        tool_calls: accumulated_tool_calls,
300                    });
301                }
302                crate::agent::executor::TurnResult::Continue(Some(output)) => {
303                    if !output.response.is_empty() {
304                        final_response = output.response;
305                    }
306                    accumulated_tool_calls.extend(output.tool_calls);
307                    accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
308                }
309                crate::agent::executor::TurnResult::Continue(None) => {}
310            }
311        }
312
313        if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
314            return Ok(ReActAgentOutput {
315                response: final_response,
316                done: true,
317                tool_calls: accumulated_tool_calls,
318            });
319        }
320
321        Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
322    }
323
324    async fn execute_stream(
325        &self,
326        task: &Task,
327        context: Arc<Context>,
328    ) -> Result<
329        Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
330        Self::Error,
331    > {
332        record_task_state(&context, task);
333
334        let tx_event = context.tx().ok();
335        EventHelper::send_task_started(
336            &tx_event,
337            task.submission_id,
338            context.config().id,
339            context.config().name.clone(),
340            task.prompt.clone(),
341        )
342        .await;
343
344        let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
345        let mut turn_state = engine.turn_state(&context);
346        let max_turns = self.config().max_turns;
347        let context_clone = context.clone();
348        let task = task.clone();
349        let executor = self.clone();
350
351        let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
352
353        spawn_future(async move {
354            let mut accumulated_tool_calls = Vec::new();
355            let mut final_response = String::new();
356
357            for turn_index in 0..max_turns {
358                let turn_stream = engine
359                    .run_turn_stream(
360                        executor.clone(),
361                        &task,
362                        context_clone.clone(),
363                        &mut turn_state,
364                        turn_index,
365                        max_turns,
366                    )
367                    .await;
368
369                let mut turn_result = None;
370
371                match turn_stream {
372                    Ok(mut stream) => {
373                        use futures::StreamExt;
374                        while let Some(delta_result) = stream.next().await {
375                            match delta_result {
376                                Ok(TurnDelta::Text(content)) => {
377                                    let _ = tx
378                                        .send(Ok(ReActAgentOutput {
379                                            response: content,
380                                            tool_calls: Vec::new(),
381                                            done: false,
382                                        }))
383                                        .await;
384                                }
385                                Ok(TurnDelta::ReasoningContent(_)) => {}
386                                Ok(TurnDelta::ToolResults(tool_results)) => {
387                                    accumulated_tool_calls.extend(tool_results);
388                                    accumulated_tool_calls =
389                                        dedupe_tool_calls(accumulated_tool_calls);
390                                    let _ = tx
391                                        .send(Ok(ReActAgentOutput {
392                                            response: String::new(),
393                                            tool_calls: accumulated_tool_calls.clone(),
394                                            done: false,
395                                        }))
396                                        .await;
397                                }
398                                Ok(TurnDelta::Done(result)) => {
399                                    turn_result = Some(result);
400                                    break;
401                                }
402                                Err(err) => {
403                                    let _ = tx.send(Err(err.into())).await;
404                                    return;
405                                }
406                            }
407                        }
408                    }
409                    Err(err) => {
410                        let _ = tx.send(Err(err.into())).await;
411                        return;
412                    }
413                }
414
415                let Some(result) = turn_result else {
416                    let _ = tx
417                        .send(Err(ReActExecutorError::Other(
418                            "Stream ended without final result".to_string(),
419                        )))
420                        .await;
421                    return;
422                };
423
424                match result {
425                    crate::agent::executor::TurnResult::Complete(output) => {
426                        final_response = output.response.clone();
427                        accumulated_tool_calls.extend(output.tool_calls);
428                        accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
429                        break;
430                    }
431                    crate::agent::executor::TurnResult::Continue(Some(output)) => {
432                        if !output.response.is_empty() {
433                            final_response = output.response;
434                        }
435                        accumulated_tool_calls.extend(output.tool_calls);
436                        accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
437                    }
438                    crate::agent::executor::TurnResult::Continue(None) => {}
439                }
440            }
441
442            let tx_event = context_clone.tx().ok();
443            EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
444            let output = ReActAgentOutput {
445                response: final_response.clone(),
446                done: true,
447                tool_calls: accumulated_tool_calls.clone(),
448            };
449            let _ = tx.send(Ok(output.clone())).await;
450
451            if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
452                let result = serde_json::to_string_pretty(&output)
453                    .unwrap_or_else(|_| output.response.clone());
454                EventHelper::send_task_completed(
455                    &tx_event,
456                    task.submission_id,
457                    context_clone.config().id,
458                    context_clone.config().name.clone(),
459                    result,
460                )
461                .await;
462            }
463        });
464
465        Ok(receiver_into_stream(rx))
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, StaticChatResponse};
473    use async_trait::async_trait;
474    use autoagents_llm::chat::StreamChunk;
475    use autoagents_llm::{FunctionCall, ToolCall};
476
477    #[derive(Debug)]
478    struct LocalTool {
479        name: String,
480        output: serde_json::Value,
481    }
482
483    impl LocalTool {
484        fn new(name: &str, output: serde_json::Value) -> Self {
485            Self {
486                name: name.to_string(),
487                output,
488            }
489        }
490    }
491
492    impl crate::tool::ToolT for LocalTool {
493        fn name(&self) -> &str {
494            &self.name
495        }
496
497        fn description(&self) -> &str {
498            "local tool"
499        }
500
501        fn args_schema(&self) -> serde_json::Value {
502            serde_json::json!({"type": "object"})
503        }
504    }
505
506    #[async_trait]
507    impl crate::tool::ToolRuntime for LocalTool {
508        async fn execute(
509            &self,
510            _args: serde_json::Value,
511        ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
512            Ok(self.output.clone())
513        }
514    }
515
516    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
517    struct ReActTestOutput {
518        value: i32,
519        message: String,
520    }
521
522    #[derive(Debug, Clone)]
523    struct AbortAgent;
524
525    #[async_trait]
526    impl AgentDeriveT for AbortAgent {
527        type Output = String;
528
529        fn description(&self) -> &str {
530            "abort"
531        }
532
533        fn output_schema(&self) -> Option<Value> {
534            None
535        }
536
537        fn name(&self) -> &str {
538            "abort_agent"
539        }
540
541        fn tools(&self) -> Vec<Box<dyn ToolT>> {
542            vec![]
543        }
544    }
545
546    #[async_trait]
547    impl AgentHooks for AbortAgent {
548        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
549            HookOutcome::Abort
550        }
551    }
552
553    #[test]
554    fn test_extract_agent_output_success() {
555        let agent_output = ReActTestOutput {
556            value: 42,
557            message: "Hello, world!".to_string(),
558        };
559
560        let react_output = ReActAgentOutput {
561            response: serde_json::to_string(&agent_output).unwrap(),
562            done: true,
563            tool_calls: vec![],
564        };
565
566        let react_value = serde_json::to_value(react_output).unwrap();
567        let extracted: ReActTestOutput =
568            ReActAgentOutput::extract_agent_output(react_value).unwrap();
569        assert_eq!(extracted, agent_output);
570    }
571
572    #[test]
573    fn test_extract_agent_output_invalid_react() {
574        let result = ReActAgentOutput::extract_agent_output::<ReActTestOutput>(
575            serde_json::json!({"not": "react"}),
576        );
577        assert!(result.is_err());
578    }
579
580    #[test]
581    fn test_react_agent_output_try_parse_success() {
582        let output = ReActAgentOutput {
583            response: r#"{"value":1,"message":"hi"}"#.to_string(),
584            tool_calls: vec![],
585            done: true,
586        };
587        let parsed: ReActTestOutput = output.try_parse().unwrap();
588        assert_eq!(parsed.value, 1);
589    }
590
591    #[test]
592    fn test_react_agent_output_try_parse_failure() {
593        let output = ReActAgentOutput {
594            response: "not json".to_string(),
595            tool_calls: vec![],
596            done: true,
597        };
598        assert!(output.try_parse::<ReActTestOutput>().is_err());
599    }
600
601    #[test]
602    fn test_react_agent_output_parse_or_map() {
603        let output = ReActAgentOutput {
604            response: "plain text".to_string(),
605            tool_calls: vec![],
606            done: true,
607        };
608        let result: String = output.parse_or_map(|s| s.to_uppercase());
609        assert_eq!(result, "PLAIN TEXT");
610    }
611
612    #[test]
613    fn test_error_from_turn_engine_llm() {
614        let err: ReActExecutorError =
615            TurnEngineError::LLMError(LLMError::Generic("llm err".to_string())).into();
616        assert!(matches!(err, ReActExecutorError::LLMError(_)));
617    }
618
619    #[test]
620    fn test_error_from_turn_engine_aborted() {
621        let err: ReActExecutorError = TurnEngineError::Aborted.into();
622        assert!(matches!(err, ReActExecutorError::Other(_)));
623    }
624
625    #[test]
626    fn test_error_from_turn_engine_other() {
627        let err: ReActExecutorError = TurnEngineError::Other("other".to_string()).into();
628        assert!(matches!(err, ReActExecutorError::Other(_)));
629    }
630
631    #[test]
632    fn test_react_agent_config() {
633        let mock = MockAgentImpl::new("cfg_test", "desc");
634        let agent = ReActAgent::new(mock);
635        assert_eq!(agent.config().max_turns, 10);
636    }
637
638    #[test]
639    fn test_react_agent_metadata_and_output_conversion() {
640        let mock = MockAgentImpl::new("react_meta", "desc");
641        let agent = ReActAgent::new(mock);
642        let cloned = agent.clone();
643        assert_eq!(cloned.name(), "react_meta");
644        assert_eq!(cloned.description(), "desc");
645
646        let output = ReActAgentOutput {
647            response: "resp".to_string(),
648            tool_calls: vec![],
649            done: true,
650        };
651        let value: Value = output.clone().into();
652        assert_eq!(value["response"], "resp");
653        let string: String = output.into();
654        assert_eq!(string, "resp");
655    }
656
657    #[tokio::test]
658    async fn test_react_agent_execute() {
659        use crate::agent::{AgentConfig, Context};
660        use crate::tests::MockLLMProvider;
661        use autoagents_protocol::ActorID;
662
663        let mock = MockAgentImpl::new("exec_test", "desc");
664        let agent = ReActAgent::new(mock);
665        let llm = std::sync::Arc::new(MockLLMProvider {});
666        let config = AgentConfig {
667            id: ActorID::new_v4(),
668            name: "exec_test".to_string(),
669            description: "desc".to_string(),
670            output_schema: None,
671        };
672        let context = Arc::new(Context::new(llm, None).with_config(config));
673        let task = crate::agent::task::Task::new("test");
674
675        let result = agent.execute(&task, context).await;
676        assert!(result.is_ok());
677        let output = result.unwrap();
678        assert!(output.done);
679        assert_eq!(output.response, "Mock response");
680    }
681
682    #[tokio::test]
683    async fn test_react_agent_execute_with_tool_calls() {
684        use crate::agent::{AgentConfig, Context};
685        use autoagents_protocol::ActorID;
686
687        let tool_call = ToolCall {
688            id: "call_1".to_string(),
689            call_type: "function".to_string(),
690            function: autoagents_llm::FunctionCall {
691                name: "tool_a".to_string(),
692                arguments: r#"{"value":1}"#.to_string(),
693            },
694        };
695
696        let llm = Arc::new(ConfigurableLLMProvider {
697            chat_response: StaticChatResponse {
698                text: Some("Use tool".to_string()),
699                tool_calls: Some(vec![tool_call.clone()]),
700                usage: None,
701                thinking: None,
702            },
703            ..ConfigurableLLMProvider::default()
704        });
705
706        let mock = MockAgentImpl::new("exec_tool", "desc");
707        let agent = ReActAgent::new(mock);
708        let config = AgentConfig {
709            id: ActorID::new_v4(),
710            name: "exec_tool".to_string(),
711            description: "desc".to_string(),
712            output_schema: None,
713        };
714
715        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
716        let context = Arc::new(
717            Context::new(llm, None)
718                .with_config(config)
719                .with_tools(vec![Box::new(tool)]),
720        );
721        let task = crate::agent::task::Task::new("test");
722
723        let result = agent.execute(&task, context).await.unwrap();
724        assert!(result.done);
725        assert!(!result.tool_calls.is_empty());
726        assert!(result.tool_calls[0].success);
727    }
728
729    #[tokio::test]
730    async fn test_react_agent_execute_stream_text() {
731        use crate::agent::{AgentConfig, Context};
732        use autoagents_protocol::ActorID;
733        use futures::StreamExt;
734
735        let llm = Arc::new(ConfigurableLLMProvider {
736            stream_chunks: vec![
737                StreamChunk::Text("Hello ".to_string()),
738                StreamChunk::Text("world".to_string()),
739                StreamChunk::Done {
740                    stop_reason: "end_turn".to_string(),
741                },
742            ],
743            ..ConfigurableLLMProvider::default()
744        });
745
746        let mock = MockAgentImpl::new("stream_test", "desc");
747        let agent = ReActAgent::new(mock);
748        let config = AgentConfig {
749            id: ActorID::new_v4(),
750            name: "stream_test".to_string(),
751            description: "desc".to_string(),
752            output_schema: None,
753        };
754        let context = Arc::new(Context::new(llm, None).with_config(config));
755        let task = crate::agent::task::Task::new("test");
756
757        let mut stream = agent.execute_stream(&task, context).await.unwrap();
758        let mut final_output = None;
759        while let Some(item) = stream.next().await {
760            let output = item.unwrap();
761            if output.done {
762                final_output = Some(output);
763                break;
764            }
765        }
766
767        let output = final_output.expect("final output");
768        assert_eq!(output.response, "Hello world");
769        assert!(output.done);
770    }
771
772    #[tokio::test]
773    async fn test_react_agent_execute_stream_ignores_reasoning_output() {
774        use crate::agent::{AgentConfig, Context};
775        use autoagents_protocol::ActorID;
776        use futures::StreamExt;
777
778        let llm = Arc::new(ConfigurableLLMProvider {
779            stream_chunks: vec![
780                StreamChunk::ReasoningContent("plan".to_string()),
781                StreamChunk::Text("done".to_string()),
782                StreamChunk::Done {
783                    stop_reason: "end_turn".to_string(),
784                },
785            ],
786            ..ConfigurableLLMProvider::default()
787        });
788
789        let mock = MockAgentImpl::new("stream_reasoning_test", "desc");
790        let agent = ReActAgent::new(mock);
791        let config = AgentConfig {
792            id: ActorID::new_v4(),
793            name: "stream_reasoning_test".to_string(),
794            description: "desc".to_string(),
795            output_schema: None,
796        };
797        let context = Arc::new(Context::new(llm, None).with_config(config));
798        let task = crate::agent::task::Task::new("test");
799
800        let mut stream = agent.execute_stream(&task, context).await.unwrap();
801        let mut outputs = Vec::new();
802        while let Some(item) = stream.next().await {
803            outputs.push(item.unwrap());
804        }
805
806        assert_eq!(outputs.len(), 2);
807        assert_eq!(outputs[0].response, "done");
808        assert!(!outputs[0].done);
809        assert_eq!(outputs[1].response, "done");
810        assert!(outputs[1].done);
811    }
812
813    #[tokio::test]
814    async fn test_react_agent_execute_stream_tool_results() {
815        use crate::agent::{AgentConfig, Context};
816        use autoagents_protocol::ActorID;
817        use futures::StreamExt;
818
819        let tool_call = ToolCall {
820            id: "call_1".to_string(),
821            call_type: "function".to_string(),
822            function: FunctionCall {
823                name: "tool_a".to_string(),
824                arguments: r#"{"value":1}"#.to_string(),
825            },
826        };
827
828        let llm = Arc::new(ConfigurableLLMProvider {
829            stream_chunks: vec![StreamChunk::ToolUseComplete {
830                index: 0,
831                tool_call: tool_call.clone(),
832            }],
833            ..ConfigurableLLMProvider::default()
834        });
835
836        let mock = MockAgentImpl::new("stream_tool", "desc");
837        let agent = ReActAgent::new(mock);
838        let config = AgentConfig {
839            id: ActorID::new_v4(),
840            name: "stream_tool".to_string(),
841            description: "desc".to_string(),
842            output_schema: None,
843        };
844        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
845        let context = Arc::new(
846            Context::new(llm, None)
847                .with_config(config)
848                .with_tools(vec![Box::new(tool)]),
849        );
850        let task = crate::agent::task::Task::new("test");
851
852        let mut stream = agent.execute_stream(&task, context).await.unwrap();
853        let mut saw_tool_results = false;
854        let mut final_output = None;
855
856        while let Some(item) = stream.next().await {
857            let output = item.unwrap();
858            if !output.tool_calls.is_empty() {
859                saw_tool_results = true;
860                assert!(output.tool_calls[0].success);
861            }
862            if output.done {
863                final_output = Some(output);
864                break;
865            }
866        }
867
868        assert!(saw_tool_results);
869        let output = final_output.expect("final output");
870        assert!(output.done);
871        assert!(!output.tool_calls.is_empty());
872    }
873
874    #[tokio::test]
875    async fn test_react_agent_run_aborts_on_hook() {
876        use crate::agent::AgentBuilder;
877        use crate::agent::direct::DirectAgent;
878        use crate::agent::error::RunnableAgentError;
879        use crate::tests::MockLLMProvider;
880
881        let agent = ReActAgent::new(AbortAgent);
882        let llm = Arc::new(MockLLMProvider {});
883        let handle = AgentBuilder::<_, DirectAgent>::new(agent)
884            .llm(llm)
885            .build()
886            .await
887            .expect("build should succeed");
888        let task = crate::agent::task::Task::new("abort");
889
890        let err = handle.agent.run(task).await.expect_err("expected abort");
891        assert!(matches!(err, RunnableAgentError::Abort));
892    }
893
894    #[tokio::test]
895    async fn test_react_agent_run_stream_aborts_on_hook() {
896        use crate::agent::AgentBuilder;
897        use crate::agent::direct::DirectAgent;
898        use crate::agent::error::RunnableAgentError;
899        use crate::tests::MockLLMProvider;
900
901        let agent = ReActAgent::new(AbortAgent);
902        let llm = Arc::new(MockLLMProvider {});
903        let handle = AgentBuilder::<_, DirectAgent>::new(agent)
904            .llm(llm)
905            .build()
906            .await
907            .expect("build should succeed");
908        let task = crate::agent::task::Task::new("abort");
909
910        let err = match handle.agent.run_stream(task).await {
911            Ok(_) => panic!("expected abort"),
912            Err(err) => err,
913        };
914        assert!(matches!(err, RunnableAgentError::Abort));
915    }
916}