1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use autoagents_llm::error::LLMError;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::collections::HashSet;
18use std::ops::Deref;
19use std::pin::Pin;
20use std::sync::Arc;
21use thiserror::Error;
22
23#[cfg(not(target_arch = "wasm32"))]
24pub use tokio::sync::mpsc::error::SendError;
25
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::hooks::{AgentHooks, HookOutcome};
30use autoagents_protocol::Event;
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ReActAgentOutput {
35 pub response: String,
36 pub tool_calls: Vec<ToolCallResult>,
37 pub done: bool,
38}
39
40impl From<ReActAgentOutput> for Value {
41 fn from(output: ReActAgentOutput) -> Self {
42 serde_json::to_value(output).unwrap_or(Value::Null)
43 }
44}
45impl From<ReActAgentOutput> for String {
46 fn from(output: ReActAgentOutput) -> Self {
47 output.response
48 }
49}
50
51impl ReActAgentOutput {
52 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
55 serde_json::from_str::<T>(&self.response)
56 }
57
58 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
62 where
63 T: for<'de> serde::Deserialize<'de>,
64 F: FnOnce(&str) -> T,
65 {
66 self.try_parse::<T>()
67 .unwrap_or_else(|_| fallback(&self.response))
68 }
69}
70
71fn dedupe_tool_calls(tool_calls: Vec<ToolCallResult>) -> Vec<ToolCallResult> {
72 let mut seen = HashSet::new();
73 let mut unique = Vec::with_capacity(tool_calls.len());
74
75 for call in tool_calls {
76 let key = format!(
77 "{}|{}|{}|{}",
78 call.tool_name, call.success, call.arguments, call.result
79 );
80 if seen.insert(key) {
81 unique.push(call);
82 }
83 }
84
85 unique
86}
87
88impl ReActAgentOutput {
89 #[allow(clippy::result_large_err)]
91 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
92 where
93 T: for<'de> serde::Deserialize<'de>,
94 {
95 let react_output: Self = serde_json::from_value(val)
96 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
97 serde_json::from_str(&react_output.response)
98 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
99 }
100}
101
102#[derive(Error, Debug)]
103pub enum ReActExecutorError {
104 #[error("LLM error: {0}")]
105 LLMError(
106 #[from]
107 #[source]
108 LLMError,
109 ),
110
111 #[error("Maximum turns exceeded: {max_turns}")]
112 MaxTurnsExceeded { max_turns: usize },
113
114 #[error("Other error: {0}")]
115 Other(String),
116
117 #[cfg(not(target_arch = "wasm32"))]
118 #[error("Event error: {0}")]
119 EventError(#[from] SendError<Event>),
120
121 #[cfg(target_arch = "wasm32")]
122 #[error("Event error: {0}")]
123 EventError(#[from] SendError),
124
125 #[error("Extracting Agent Output Error: {0}")]
126 AgentOutputError(String),
127}
128
129impl From<TurnEngineError> for ReActExecutorError {
130 fn from(error: TurnEngineError) -> Self {
131 match error {
132 TurnEngineError::LLMError(err) => err.into(),
133 TurnEngineError::Aborted => {
134 ReActExecutorError::Other("Run aborted by hook".to_string())
135 }
136 TurnEngineError::Other(err) => ReActExecutorError::Other(err),
137 }
138 }
139}
140
141#[derive(Debug)]
146pub struct ReActAgent<T: AgentDeriveT> {
147 inner: Arc<T>,
148 max_turns: usize,
149}
150
151impl<T: AgentDeriveT> Clone for ReActAgent<T> {
152 fn clone(&self) -> Self {
153 Self {
154 inner: Arc::clone(&self.inner),
155 max_turns: self.max_turns,
156 }
157 }
158}
159
160impl<T: AgentDeriveT> ReActAgent<T> {
161 pub fn new(inner: T) -> Self {
162 Self {
163 inner: Arc::new(inner),
164 max_turns: 10,
165 }
166 }
167
168 pub fn with_max_turns(inner: T, max_turns: usize) -> Self {
169 Self {
170 inner: Arc::new(inner),
171 max_turns: max_turns.max(1),
172 }
173 }
174}
175
176impl<T: AgentDeriveT> Deref for ReActAgent<T> {
177 type Target = T;
178
179 fn deref(&self) -> &Self::Target {
180 &self.inner
181 }
182}
183
184#[async_trait]
186impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
187 type Output = <T as AgentDeriveT>::Output;
188
189 fn description(&self) -> &str {
190 self.inner.description()
191 }
192
193 fn output_schema(&self) -> Option<Value> {
194 self.inner.output_schema()
195 }
196
197 fn name(&self) -> &str {
198 self.inner.name()
199 }
200
201 fn tools(&self) -> Vec<Box<dyn ToolT>> {
202 self.inner.tools()
203 }
204}
205
206#[async_trait]
207impl<T> AgentHooks for ReActAgent<T>
208where
209 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
210{
211 async fn on_agent_create(&self) {
212 self.inner.on_agent_create().await
213 }
214
215 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
216 self.inner.on_run_start(task, ctx).await
217 }
218
219 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
220 self.inner.on_run_complete(task, result, ctx).await
221 }
222
223 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
224 self.inner.on_turn_start(turn_index, ctx).await
225 }
226
227 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
228 self.inner.on_turn_complete(turn_index, ctx).await
229 }
230
231 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
232 self.inner.on_tool_call(tool_call, ctx).await
233 }
234
235 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
236 self.inner.on_tool_start(tool_call, ctx).await
237 }
238
239 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
240 self.inner.on_tool_result(tool_call, result, ctx).await
241 }
242
243 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
244 self.inner.on_tool_error(tool_call, err, ctx).await
245 }
246 async fn on_agent_shutdown(&self) {
247 self.inner.on_agent_shutdown().await
248 }
249}
250
251#[async_trait]
253impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
254 type Output = ReActAgentOutput;
255 type Error = ReActExecutorError;
256
257 fn config(&self) -> ExecutorConfig {
258 ExecutorConfig {
259 max_turns: self.max_turns,
260 }
261 }
262
263 async fn execute(
264 &self,
265 task: &Task,
266 context: Arc<Context>,
267 ) -> Result<Self::Output, Self::Error> {
268 record_task_state(&context, task);
269
270 let tx_event = context.tx().ok();
271 EventHelper::send_task_started(
272 &tx_event,
273 task.submission_id,
274 context.config().id,
275 context.config().name.clone(),
276 task.prompt.clone(),
277 )
278 .await;
279
280 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
281 let mut turn_state = engine.turn_state(&context);
282 let max_turns = self.config().max_turns;
283 let mut accumulated_tool_calls = Vec::new();
284 let mut final_response = String::new();
285 for turn_index in 0..max_turns {
286 let result = engine
287 .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
288 .await?;
289
290 match result {
291 crate::agent::executor::TurnResult::Complete(output) => {
292 final_response = output.response.clone();
293 accumulated_tool_calls.extend(output.tool_calls);
294 accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
295
296 return Ok(ReActAgentOutput {
297 response: final_response,
298 done: true,
299 tool_calls: accumulated_tool_calls,
300 });
301 }
302 crate::agent::executor::TurnResult::Continue(Some(output)) => {
303 if !output.response.is_empty() {
304 final_response = output.response;
305 }
306 accumulated_tool_calls.extend(output.tool_calls);
307 accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
308 }
309 crate::agent::executor::TurnResult::Continue(None) => {}
310 }
311 }
312
313 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
314 return Ok(ReActAgentOutput {
315 response: final_response,
316 done: true,
317 tool_calls: accumulated_tool_calls,
318 });
319 }
320
321 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
322 }
323
324 async fn execute_stream(
325 &self,
326 task: &Task,
327 context: Arc<Context>,
328 ) -> Result<
329 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
330 Self::Error,
331 > {
332 record_task_state(&context, task);
333
334 let tx_event = context.tx().ok();
335 EventHelper::send_task_started(
336 &tx_event,
337 task.submission_id,
338 context.config().id,
339 context.config().name.clone(),
340 task.prompt.clone(),
341 )
342 .await;
343
344 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
345 let mut turn_state = engine.turn_state(&context);
346 let max_turns = self.config().max_turns;
347 let context_clone = context.clone();
348 let task = task.clone();
349 let executor = self.clone();
350
351 let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
352
353 spawn_future(async move {
354 let mut accumulated_tool_calls = Vec::new();
355 let mut final_response = String::new();
356
357 for turn_index in 0..max_turns {
358 let turn_stream = engine
359 .run_turn_stream(
360 executor.clone(),
361 &task,
362 context_clone.clone(),
363 &mut turn_state,
364 turn_index,
365 max_turns,
366 )
367 .await;
368
369 let mut turn_result = None;
370
371 match turn_stream {
372 Ok(mut stream) => {
373 use futures::StreamExt;
374 while let Some(delta_result) = stream.next().await {
375 match delta_result {
376 Ok(TurnDelta::Text(content)) => {
377 let _ = tx
378 .send(Ok(ReActAgentOutput {
379 response: content,
380 tool_calls: Vec::new(),
381 done: false,
382 }))
383 .await;
384 }
385 Ok(TurnDelta::ReasoningContent(_)) => {}
386 Ok(TurnDelta::ToolResults(tool_results)) => {
387 accumulated_tool_calls.extend(tool_results);
388 accumulated_tool_calls =
389 dedupe_tool_calls(accumulated_tool_calls);
390 let _ = tx
391 .send(Ok(ReActAgentOutput {
392 response: String::new(),
393 tool_calls: accumulated_tool_calls.clone(),
394 done: false,
395 }))
396 .await;
397 }
398 Ok(TurnDelta::Done(result)) => {
399 turn_result = Some(result);
400 break;
401 }
402 Err(err) => {
403 let _ = tx.send(Err(err.into())).await;
404 return;
405 }
406 }
407 }
408 }
409 Err(err) => {
410 let _ = tx.send(Err(err.into())).await;
411 return;
412 }
413 }
414
415 let Some(result) = turn_result else {
416 let _ = tx
417 .send(Err(ReActExecutorError::Other(
418 "Stream ended without final result".to_string(),
419 )))
420 .await;
421 return;
422 };
423
424 match result {
425 crate::agent::executor::TurnResult::Complete(output) => {
426 final_response = output.response.clone();
427 accumulated_tool_calls.extend(output.tool_calls);
428 accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
429 break;
430 }
431 crate::agent::executor::TurnResult::Continue(Some(output)) => {
432 if !output.response.is_empty() {
433 final_response = output.response;
434 }
435 accumulated_tool_calls.extend(output.tool_calls);
436 accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
437 }
438 crate::agent::executor::TurnResult::Continue(None) => {}
439 }
440 }
441
442 let tx_event = context_clone.tx().ok();
443 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
444 let output = ReActAgentOutput {
445 response: final_response.clone(),
446 done: true,
447 tool_calls: accumulated_tool_calls.clone(),
448 };
449 let _ = tx.send(Ok(output.clone())).await;
450
451 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
452 let result = serde_json::to_string_pretty(&output)
453 .unwrap_or_else(|_| output.response.clone());
454 EventHelper::send_task_completed(
455 &tx_event,
456 task.submission_id,
457 context_clone.config().id,
458 context_clone.config().name.clone(),
459 result,
460 )
461 .await;
462 }
463 });
464
465 Ok(receiver_into_stream(rx))
466 }
467}
468
469#[cfg(test)]
470mod tests {
471 use super::*;
472 use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, StaticChatResponse};
473 use async_trait::async_trait;
474 use autoagents_llm::chat::StreamChunk;
475 use autoagents_llm::{FunctionCall, ToolCall};
476
477 #[derive(Debug)]
478 struct LocalTool {
479 name: String,
480 output: serde_json::Value,
481 }
482
483 impl LocalTool {
484 fn new(name: &str, output: serde_json::Value) -> Self {
485 Self {
486 name: name.to_string(),
487 output,
488 }
489 }
490 }
491
492 impl crate::tool::ToolT for LocalTool {
493 fn name(&self) -> &str {
494 &self.name
495 }
496
497 fn description(&self) -> &str {
498 "local tool"
499 }
500
501 fn args_schema(&self) -> serde_json::Value {
502 serde_json::json!({"type": "object"})
503 }
504 }
505
506 #[async_trait]
507 impl crate::tool::ToolRuntime for LocalTool {
508 async fn execute(
509 &self,
510 _args: serde_json::Value,
511 ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
512 Ok(self.output.clone())
513 }
514 }
515
516 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
517 struct ReActTestOutput {
518 value: i32,
519 message: String,
520 }
521
522 #[derive(Debug, Clone)]
523 struct AbortAgent;
524
525 #[async_trait]
526 impl AgentDeriveT for AbortAgent {
527 type Output = String;
528
529 fn description(&self) -> &str {
530 "abort"
531 }
532
533 fn output_schema(&self) -> Option<Value> {
534 None
535 }
536
537 fn name(&self) -> &str {
538 "abort_agent"
539 }
540
541 fn tools(&self) -> Vec<Box<dyn ToolT>> {
542 vec![]
543 }
544 }
545
546 #[async_trait]
547 impl AgentHooks for AbortAgent {
548 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
549 HookOutcome::Abort
550 }
551 }
552
553 #[test]
554 fn test_extract_agent_output_success() {
555 let agent_output = ReActTestOutput {
556 value: 42,
557 message: "Hello, world!".to_string(),
558 };
559
560 let react_output = ReActAgentOutput {
561 response: serde_json::to_string(&agent_output).unwrap(),
562 done: true,
563 tool_calls: vec![],
564 };
565
566 let react_value = serde_json::to_value(react_output).unwrap();
567 let extracted: ReActTestOutput =
568 ReActAgentOutput::extract_agent_output(react_value).unwrap();
569 assert_eq!(extracted, agent_output);
570 }
571
572 #[test]
573 fn test_extract_agent_output_invalid_react() {
574 let result = ReActAgentOutput::extract_agent_output::<ReActTestOutput>(
575 serde_json::json!({"not": "react"}),
576 );
577 assert!(result.is_err());
578 }
579
580 #[test]
581 fn test_react_agent_output_try_parse_success() {
582 let output = ReActAgentOutput {
583 response: r#"{"value":1,"message":"hi"}"#.to_string(),
584 tool_calls: vec![],
585 done: true,
586 };
587 let parsed: ReActTestOutput = output.try_parse().unwrap();
588 assert_eq!(parsed.value, 1);
589 }
590
591 #[test]
592 fn test_react_agent_output_try_parse_failure() {
593 let output = ReActAgentOutput {
594 response: "not json".to_string(),
595 tool_calls: vec![],
596 done: true,
597 };
598 assert!(output.try_parse::<ReActTestOutput>().is_err());
599 }
600
601 #[test]
602 fn test_react_agent_output_parse_or_map() {
603 let output = ReActAgentOutput {
604 response: "plain text".to_string(),
605 tool_calls: vec![],
606 done: true,
607 };
608 let result: String = output.parse_or_map(|s| s.to_uppercase());
609 assert_eq!(result, "PLAIN TEXT");
610 }
611
612 #[test]
613 fn test_error_from_turn_engine_llm() {
614 let err: ReActExecutorError =
615 TurnEngineError::LLMError(LLMError::Generic("llm err".to_string())).into();
616 assert!(matches!(err, ReActExecutorError::LLMError(_)));
617 }
618
619 #[test]
620 fn test_error_from_turn_engine_aborted() {
621 let err: ReActExecutorError = TurnEngineError::Aborted.into();
622 assert!(matches!(err, ReActExecutorError::Other(_)));
623 }
624
625 #[test]
626 fn test_error_from_turn_engine_other() {
627 let err: ReActExecutorError = TurnEngineError::Other("other".to_string()).into();
628 assert!(matches!(err, ReActExecutorError::Other(_)));
629 }
630
631 #[test]
632 fn test_react_agent_config() {
633 let mock = MockAgentImpl::new("cfg_test", "desc");
634 let agent = ReActAgent::new(mock);
635 assert_eq!(agent.config().max_turns, 10);
636 }
637
638 #[test]
639 fn test_react_agent_metadata_and_output_conversion() {
640 let mock = MockAgentImpl::new("react_meta", "desc");
641 let agent = ReActAgent::new(mock);
642 let cloned = agent.clone();
643 assert_eq!(cloned.name(), "react_meta");
644 assert_eq!(cloned.description(), "desc");
645
646 let output = ReActAgentOutput {
647 response: "resp".to_string(),
648 tool_calls: vec![],
649 done: true,
650 };
651 let value: Value = output.clone().into();
652 assert_eq!(value["response"], "resp");
653 let string: String = output.into();
654 assert_eq!(string, "resp");
655 }
656
657 #[tokio::test]
658 async fn test_react_agent_execute() {
659 use crate::agent::{AgentConfig, Context};
660 use crate::tests::MockLLMProvider;
661 use autoagents_protocol::ActorID;
662
663 let mock = MockAgentImpl::new("exec_test", "desc");
664 let agent = ReActAgent::new(mock);
665 let llm = std::sync::Arc::new(MockLLMProvider {});
666 let config = AgentConfig {
667 id: ActorID::new_v4(),
668 name: "exec_test".to_string(),
669 description: "desc".to_string(),
670 output_schema: None,
671 };
672 let context = Arc::new(Context::new(llm, None).with_config(config));
673 let task = crate::agent::task::Task::new("test");
674
675 let result = agent.execute(&task, context).await;
676 assert!(result.is_ok());
677 let output = result.unwrap();
678 assert!(output.done);
679 assert_eq!(output.response, "Mock response");
680 }
681
682 #[tokio::test]
683 async fn test_react_agent_execute_with_tool_calls() {
684 use crate::agent::{AgentConfig, Context};
685 use autoagents_protocol::ActorID;
686
687 let tool_call = ToolCall {
688 id: "call_1".to_string(),
689 call_type: "function".to_string(),
690 function: autoagents_llm::FunctionCall {
691 name: "tool_a".to_string(),
692 arguments: r#"{"value":1}"#.to_string(),
693 },
694 };
695
696 let llm = Arc::new(ConfigurableLLMProvider {
697 chat_response: StaticChatResponse {
698 text: Some("Use tool".to_string()),
699 tool_calls: Some(vec![tool_call.clone()]),
700 usage: None,
701 thinking: None,
702 },
703 ..ConfigurableLLMProvider::default()
704 });
705
706 let mock = MockAgentImpl::new("exec_tool", "desc");
707 let agent = ReActAgent::new(mock);
708 let config = AgentConfig {
709 id: ActorID::new_v4(),
710 name: "exec_tool".to_string(),
711 description: "desc".to_string(),
712 output_schema: None,
713 };
714
715 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
716 let context = Arc::new(
717 Context::new(llm, None)
718 .with_config(config)
719 .with_tools(vec![Box::new(tool)]),
720 );
721 let task = crate::agent::task::Task::new("test");
722
723 let result = agent.execute(&task, context).await.unwrap();
724 assert!(result.done);
725 assert!(!result.tool_calls.is_empty());
726 assert!(result.tool_calls[0].success);
727 }
728
729 #[tokio::test]
730 async fn test_react_agent_execute_stream_text() {
731 use crate::agent::{AgentConfig, Context};
732 use autoagents_protocol::ActorID;
733 use futures::StreamExt;
734
735 let llm = Arc::new(ConfigurableLLMProvider {
736 stream_chunks: vec![
737 StreamChunk::Text("Hello ".to_string()),
738 StreamChunk::Text("world".to_string()),
739 StreamChunk::Done {
740 stop_reason: "end_turn".to_string(),
741 },
742 ],
743 ..ConfigurableLLMProvider::default()
744 });
745
746 let mock = MockAgentImpl::new("stream_test", "desc");
747 let agent = ReActAgent::new(mock);
748 let config = AgentConfig {
749 id: ActorID::new_v4(),
750 name: "stream_test".to_string(),
751 description: "desc".to_string(),
752 output_schema: None,
753 };
754 let context = Arc::new(Context::new(llm, None).with_config(config));
755 let task = crate::agent::task::Task::new("test");
756
757 let mut stream = agent.execute_stream(&task, context).await.unwrap();
758 let mut final_output = None;
759 while let Some(item) = stream.next().await {
760 let output = item.unwrap();
761 if output.done {
762 final_output = Some(output);
763 break;
764 }
765 }
766
767 let output = final_output.expect("final output");
768 assert_eq!(output.response, "Hello world");
769 assert!(output.done);
770 }
771
772 #[tokio::test]
773 async fn test_react_agent_execute_stream_ignores_reasoning_output() {
774 use crate::agent::{AgentConfig, Context};
775 use autoagents_protocol::ActorID;
776 use futures::StreamExt;
777
778 let llm = Arc::new(ConfigurableLLMProvider {
779 stream_chunks: vec![
780 StreamChunk::ReasoningContent("plan".to_string()),
781 StreamChunk::Text("done".to_string()),
782 StreamChunk::Done {
783 stop_reason: "end_turn".to_string(),
784 },
785 ],
786 ..ConfigurableLLMProvider::default()
787 });
788
789 let mock = MockAgentImpl::new("stream_reasoning_test", "desc");
790 let agent = ReActAgent::new(mock);
791 let config = AgentConfig {
792 id: ActorID::new_v4(),
793 name: "stream_reasoning_test".to_string(),
794 description: "desc".to_string(),
795 output_schema: None,
796 };
797 let context = Arc::new(Context::new(llm, None).with_config(config));
798 let task = crate::agent::task::Task::new("test");
799
800 let mut stream = agent.execute_stream(&task, context).await.unwrap();
801 let mut outputs = Vec::new();
802 while let Some(item) = stream.next().await {
803 outputs.push(item.unwrap());
804 }
805
806 assert_eq!(outputs.len(), 2);
807 assert_eq!(outputs[0].response, "done");
808 assert!(!outputs[0].done);
809 assert_eq!(outputs[1].response, "done");
810 assert!(outputs[1].done);
811 }
812
813 #[tokio::test]
814 async fn test_react_agent_execute_stream_tool_results() {
815 use crate::agent::{AgentConfig, Context};
816 use autoagents_protocol::ActorID;
817 use futures::StreamExt;
818
819 let tool_call = ToolCall {
820 id: "call_1".to_string(),
821 call_type: "function".to_string(),
822 function: FunctionCall {
823 name: "tool_a".to_string(),
824 arguments: r#"{"value":1}"#.to_string(),
825 },
826 };
827
828 let llm = Arc::new(ConfigurableLLMProvider {
829 stream_chunks: vec![StreamChunk::ToolUseComplete {
830 index: 0,
831 tool_call: tool_call.clone(),
832 }],
833 ..ConfigurableLLMProvider::default()
834 });
835
836 let mock = MockAgentImpl::new("stream_tool", "desc");
837 let agent = ReActAgent::new(mock);
838 let config = AgentConfig {
839 id: ActorID::new_v4(),
840 name: "stream_tool".to_string(),
841 description: "desc".to_string(),
842 output_schema: None,
843 };
844 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
845 let context = Arc::new(
846 Context::new(llm, None)
847 .with_config(config)
848 .with_tools(vec![Box::new(tool)]),
849 );
850 let task = crate::agent::task::Task::new("test");
851
852 let mut stream = agent.execute_stream(&task, context).await.unwrap();
853 let mut saw_tool_results = false;
854 let mut final_output = None;
855
856 while let Some(item) = stream.next().await {
857 let output = item.unwrap();
858 if !output.tool_calls.is_empty() {
859 saw_tool_results = true;
860 assert!(output.tool_calls[0].success);
861 }
862 if output.done {
863 final_output = Some(output);
864 break;
865 }
866 }
867
868 assert!(saw_tool_results);
869 let output = final_output.expect("final output");
870 assert!(output.done);
871 assert!(!output.tool_calls.is_empty());
872 }
873
874 #[tokio::test]
875 async fn test_react_agent_run_aborts_on_hook() {
876 use crate::agent::AgentBuilder;
877 use crate::agent::direct::DirectAgent;
878 use crate::agent::error::RunnableAgentError;
879 use crate::tests::MockLLMProvider;
880
881 let agent = ReActAgent::new(AbortAgent);
882 let llm = Arc::new(MockLLMProvider {});
883 let handle = AgentBuilder::<_, DirectAgent>::new(agent)
884 .llm(llm)
885 .build()
886 .await
887 .expect("build should succeed");
888 let task = crate::agent::task::Task::new("abort");
889
890 let err = handle.agent.run(task).await.expect_err("expected abort");
891 assert!(matches!(err, RunnableAgentError::Abort));
892 }
893
894 #[tokio::test]
895 async fn test_react_agent_run_stream_aborts_on_hook() {
896 use crate::agent::AgentBuilder;
897 use crate::agent::direct::DirectAgent;
898 use crate::agent::error::RunnableAgentError;
899 use crate::tests::MockLLMProvider;
900
901 let agent = ReActAgent::new(AbortAgent);
902 let llm = Arc::new(MockLLMProvider {});
903 let handle = AgentBuilder::<_, DirectAgent>::new(agent)
904 .llm(llm)
905 .build()
906 .await
907 .expect("build should succeed");
908 let task = crate::agent::task::Task::new("abort");
909
910 let err = match handle.agent.run_stream(task).await {
911 Ok(_) => panic!("expected abort"),
912 Err(err) => err,
913 };
914 assert!(matches!(err, RunnableAgentError::Abort));
915 }
916}