1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22pub use tokio::sync::mpsc::error::SendError;
23
24#[cfg(target_arch = "wasm32")]
25type SendError = futures::channel::mpsc::SendError;
26
27use crate::agent::hooks::{AgentHooks, HookOutcome};
28use autoagents_protocol::Event;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReActAgentOutput {
33 pub response: String,
34 pub tool_calls: Vec<ToolCallResult>,
35 pub done: bool,
36}
37
38impl From<ReActAgentOutput> for Value {
39 fn from(output: ReActAgentOutput) -> Self {
40 serde_json::to_value(output).unwrap_or(Value::Null)
41 }
42}
43impl From<ReActAgentOutput> for String {
44 fn from(output: ReActAgentOutput) -> Self {
45 output.response
46 }
47}
48
49impl ReActAgentOutput {
50 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
53 serde_json::from_str::<T>(&self.response)
54 }
55
56 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
60 where
61 T: for<'de> serde::Deserialize<'de>,
62 F: FnOnce(&str) -> T,
63 {
64 self.try_parse::<T>()
65 .unwrap_or_else(|_| fallback(&self.response))
66 }
67}
68
69impl ReActAgentOutput {
70 #[allow(clippy::result_large_err)]
72 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
73 where
74 T: for<'de> serde::Deserialize<'de>,
75 {
76 let react_output: Self = serde_json::from_value(val)
77 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
78 serde_json::from_str(&react_output.response)
79 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
80 }
81}
82
83#[derive(Error, Debug)]
84pub enum ReActExecutorError {
85 #[error("LLM error: {0}")]
86 LLMError(String),
87
88 #[error("Maximum turns exceeded: {max_turns}")]
89 MaxTurnsExceeded { max_turns: usize },
90
91 #[error("Other error: {0}")]
92 Other(String),
93
94 #[cfg(not(target_arch = "wasm32"))]
95 #[error("Event error: {0}")]
96 EventError(#[from] SendError<Event>),
97
98 #[cfg(target_arch = "wasm32")]
99 #[error("Event error: {0}")]
100 EventError(#[from] SendError),
101
102 #[error("Extracting Agent Output Error: {0}")]
103 AgentOutputError(String),
104}
105
106impl From<TurnEngineError> for ReActExecutorError {
107 fn from(error: TurnEngineError) -> Self {
108 match error {
109 TurnEngineError::LLMError(err) => ReActExecutorError::LLMError(err),
110 TurnEngineError::Aborted => {
111 ReActExecutorError::Other("Run aborted by hook".to_string())
112 }
113 TurnEngineError::Other(err) => ReActExecutorError::Other(err),
114 }
115 }
116}
117
118#[derive(Debug)]
123pub struct ReActAgent<T: AgentDeriveT> {
124 inner: Arc<T>,
125}
126
127impl<T: AgentDeriveT> Clone for ReActAgent<T> {
128 fn clone(&self) -> Self {
129 Self {
130 inner: Arc::clone(&self.inner),
131 }
132 }
133}
134
135impl<T: AgentDeriveT> ReActAgent<T> {
136 pub fn new(inner: T) -> Self {
137 Self {
138 inner: Arc::new(inner),
139 }
140 }
141}
142
143impl<T: AgentDeriveT> Deref for ReActAgent<T> {
144 type Target = T;
145
146 fn deref(&self) -> &Self::Target {
147 &self.inner
148 }
149}
150
151#[async_trait]
153impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
154 type Output = <T as AgentDeriveT>::Output;
155
156 fn description(&self) -> &str {
157 self.inner.description()
158 }
159
160 fn output_schema(&self) -> Option<Value> {
161 self.inner.output_schema()
162 }
163
164 fn name(&self) -> &str {
165 self.inner.name()
166 }
167
168 fn tools(&self) -> Vec<Box<dyn ToolT>> {
169 self.inner.tools()
170 }
171}
172
173#[async_trait]
174impl<T> AgentHooks for ReActAgent<T>
175where
176 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
177{
178 async fn on_agent_create(&self) {
179 self.inner.on_agent_create().await
180 }
181
182 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
183 self.inner.on_run_start(task, ctx).await
184 }
185
186 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
187 self.inner.on_run_complete(task, result, ctx).await
188 }
189
190 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
191 self.inner.on_turn_start(turn_index, ctx).await
192 }
193
194 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
195 self.inner.on_turn_complete(turn_index, ctx).await
196 }
197
198 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
199 self.inner.on_tool_call(tool_call, ctx).await
200 }
201
202 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
203 self.inner.on_tool_start(tool_call, ctx).await
204 }
205
206 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
207 self.inner.on_tool_result(tool_call, result, ctx).await
208 }
209
210 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
211 self.inner.on_tool_error(tool_call, err, ctx).await
212 }
213 async fn on_agent_shutdown(&self) {
214 self.inner.on_agent_shutdown().await
215 }
216}
217
218#[async_trait]
220impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
221 type Output = ReActAgentOutput;
222 type Error = ReActExecutorError;
223
224 fn config(&self) -> ExecutorConfig {
225 ExecutorConfig { max_turns: 10 }
226 }
227
228 async fn execute(
229 &self,
230 task: &Task,
231 context: Arc<Context>,
232 ) -> Result<Self::Output, Self::Error> {
233 if self.on_run_start(task, &context).await == HookOutcome::Abort {
234 return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
235 }
236
237 record_task_state(&context, task);
238
239 let tx_event = context.tx().ok();
240 EventHelper::send_task_started(
241 &tx_event,
242 task.submission_id,
243 context.config().id,
244 context.config().name.clone(),
245 task.prompt.clone(),
246 )
247 .await;
248
249 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
250 let mut turn_state = engine.turn_state(&context);
251 let max_turns = self.config().max_turns;
252 let mut accumulated_tool_calls = Vec::new();
253 let mut final_response = String::new();
254
255 for turn_index in 0..max_turns {
256 let result = engine
257 .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
258 .await?;
259
260 match result {
261 crate::agent::executor::TurnResult::Complete(output) => {
262 final_response = output.response.clone();
263 EventHelper::send_task_completed(
264 &tx_event,
265 task.submission_id,
266 context.config().id,
267 context.config().name.clone(),
268 final_response.clone(),
269 )
270 .await;
271
272 accumulated_tool_calls.extend(output.tool_calls);
273
274 return Ok(ReActAgentOutput {
275 response: final_response,
276 done: true,
277 tool_calls: accumulated_tool_calls,
278 });
279 }
280 crate::agent::executor::TurnResult::Continue(Some(output)) => {
281 if !output.response.is_empty() {
282 final_response = output.response;
283 }
284 accumulated_tool_calls.extend(output.tool_calls);
285 }
286 crate::agent::executor::TurnResult::Continue(None) => {}
287 }
288 }
289
290 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
291 EventHelper::send_task_completed(
292 &tx_event,
293 task.submission_id,
294 context.config().id,
295 context.config().name.clone(),
296 final_response.clone(),
297 )
298 .await;
299
300 return Ok(ReActAgentOutput {
301 response: final_response,
302 done: true,
303 tool_calls: accumulated_tool_calls,
304 });
305 }
306
307 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
308 }
309
310 async fn execute_stream(
311 &self,
312 task: &Task,
313 context: Arc<Context>,
314 ) -> Result<
315 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
316 Self::Error,
317 > {
318 if self.on_run_start(task, &context).await == HookOutcome::Abort {
319 return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
320 }
321
322 record_task_state(&context, task);
323
324 let tx_event = context.tx().ok();
325 EventHelper::send_task_started(
326 &tx_event,
327 task.submission_id,
328 context.config().id,
329 context.config().name.clone(),
330 task.prompt.clone(),
331 )
332 .await;
333
334 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
335 let mut turn_state = engine.turn_state(&context);
336 let max_turns = self.config().max_turns;
337 let context_clone = context.clone();
338 let task = task.clone();
339 let executor = self.clone();
340
341 let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
342
343 spawn_future(async move {
344 let mut accumulated_tool_calls = Vec::new();
345 let mut final_response = String::new();
346
347 for turn_index in 0..max_turns {
348 let turn_stream = engine
349 .run_turn_stream(
350 executor.clone(),
351 &task,
352 context_clone.clone(),
353 &mut turn_state,
354 turn_index,
355 max_turns,
356 )
357 .await;
358
359 let mut turn_result = None;
360
361 match turn_stream {
362 Ok(mut stream) => {
363 use futures::StreamExt;
364 while let Some(delta_result) = stream.next().await {
365 match delta_result {
366 Ok(TurnDelta::Text(content)) => {
367 let _ = tx
368 .send(Ok(ReActAgentOutput {
369 response: content,
370 tool_calls: Vec::new(),
371 done: false,
372 }))
373 .await;
374 }
375 Ok(TurnDelta::ToolResults(tool_results)) => {
376 accumulated_tool_calls.extend(tool_results);
377 let _ = tx
378 .send(Ok(ReActAgentOutput {
379 response: String::new(),
380 tool_calls: accumulated_tool_calls.clone(),
381 done: false,
382 }))
383 .await;
384 }
385 Ok(TurnDelta::Done(result)) => {
386 turn_result = Some(result);
387 break;
388 }
389 Err(err) => {
390 let _ = tx.send(Err(err.into())).await;
391 return;
392 }
393 }
394 }
395 }
396 Err(err) => {
397 let _ = tx.send(Err(err.into())).await;
398 return;
399 }
400 }
401
402 let Some(result) = turn_result else {
403 let _ = tx
404 .send(Err(ReActExecutorError::Other(
405 "Stream ended without final result".to_string(),
406 )))
407 .await;
408 return;
409 };
410
411 match result {
412 crate::agent::executor::TurnResult::Complete(output) => {
413 final_response = output.response.clone();
414 accumulated_tool_calls.extend(output.tool_calls);
415 break;
416 }
417 crate::agent::executor::TurnResult::Continue(Some(output)) => {
418 if !output.response.is_empty() {
419 final_response = output.response;
420 }
421 accumulated_tool_calls.extend(output.tool_calls);
422 }
423 crate::agent::executor::TurnResult::Continue(None) => {}
424 }
425 }
426
427 let tx_event = context_clone.tx().ok();
428 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
429 let _ = tx
430 .send(Ok(ReActAgentOutput {
431 response: final_response.clone(),
432 done: true,
433 tool_calls: accumulated_tool_calls.clone(),
434 }))
435 .await;
436
437 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
438 EventHelper::send_task_completed(
439 &tx_event,
440 task.submission_id,
441 context_clone.config().id,
442 context_clone.config().name.clone(),
443 final_response,
444 )
445 .await;
446 }
447 });
448
449 Ok(receiver_into_stream(rx))
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::tests::{
457 ConfigurableLLMProvider, MockAgentImpl, StaticChatResponse,
458 TestAgentOutput as TestUtilsOutput,
459 };
460 use async_trait::async_trait;
461 use autoagents_llm::chat::StreamChunk;
462 use autoagents_llm::{FunctionCall, ToolCall};
463
464 #[derive(Debug)]
465 struct LocalTool {
466 name: String,
467 output: serde_json::Value,
468 }
469
470 impl LocalTool {
471 fn new(name: &str, output: serde_json::Value) -> Self {
472 Self {
473 name: name.to_string(),
474 output,
475 }
476 }
477 }
478
479 impl crate::tool::ToolT for LocalTool {
480 fn name(&self) -> &str {
481 &self.name
482 }
483
484 fn description(&self) -> &str {
485 "local tool"
486 }
487
488 fn args_schema(&self) -> serde_json::Value {
489 serde_json::json!({"type": "object"})
490 }
491 }
492
493 #[async_trait]
494 impl crate::tool::ToolRuntime for LocalTool {
495 async fn execute(
496 &self,
497 _args: serde_json::Value,
498 ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
499 Ok(self.output.clone())
500 }
501 }
502
503 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
504 struct ReActTestOutput {
505 value: i32,
506 message: String,
507 }
508
509 #[derive(Debug, Clone)]
510 struct AbortAgent;
511
512 #[async_trait]
513 impl AgentDeriveT for AbortAgent {
514 type Output = TestUtilsOutput;
515
516 fn description(&self) -> &str {
517 "abort"
518 }
519
520 fn output_schema(&self) -> Option<Value> {
521 None
522 }
523
524 fn name(&self) -> &str {
525 "abort_agent"
526 }
527
528 fn tools(&self) -> Vec<Box<dyn ToolT>> {
529 vec![]
530 }
531 }
532
533 #[async_trait]
534 impl AgentHooks for AbortAgent {
535 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
536 HookOutcome::Abort
537 }
538 }
539
540 #[test]
541 fn test_extract_agent_output_success() {
542 let agent_output = ReActTestOutput {
543 value: 42,
544 message: "Hello, world!".to_string(),
545 };
546
547 let react_output = ReActAgentOutput {
548 response: serde_json::to_string(&agent_output).unwrap(),
549 done: true,
550 tool_calls: vec![],
551 };
552
553 let react_value = serde_json::to_value(react_output).unwrap();
554 let extracted: ReActTestOutput =
555 ReActAgentOutput::extract_agent_output(react_value).unwrap();
556 assert_eq!(extracted, agent_output);
557 }
558
559 #[test]
560 fn test_extract_agent_output_invalid_react() {
561 let result = ReActAgentOutput::extract_agent_output::<ReActTestOutput>(
562 serde_json::json!({"not": "react"}),
563 );
564 assert!(result.is_err());
565 }
566
567 #[test]
568 fn test_react_agent_output_try_parse_success() {
569 let output = ReActAgentOutput {
570 response: r#"{"value":1,"message":"hi"}"#.to_string(),
571 tool_calls: vec![],
572 done: true,
573 };
574 let parsed: ReActTestOutput = output.try_parse().unwrap();
575 assert_eq!(parsed.value, 1);
576 }
577
578 #[test]
579 fn test_react_agent_output_try_parse_failure() {
580 let output = ReActAgentOutput {
581 response: "not json".to_string(),
582 tool_calls: vec![],
583 done: true,
584 };
585 assert!(output.try_parse::<ReActTestOutput>().is_err());
586 }
587
588 #[test]
589 fn test_react_agent_output_parse_or_map() {
590 let output = ReActAgentOutput {
591 response: "plain text".to_string(),
592 tool_calls: vec![],
593 done: true,
594 };
595 let result: String = output.parse_or_map(|s| s.to_uppercase());
596 assert_eq!(result, "PLAIN TEXT");
597 }
598
599 #[test]
600 fn test_error_from_turn_engine_llm() {
601 let err: ReActExecutorError = TurnEngineError::LLMError("llm err".to_string()).into();
602 assert!(matches!(err, ReActExecutorError::LLMError(_)));
603 }
604
605 #[test]
606 fn test_error_from_turn_engine_aborted() {
607 let err: ReActExecutorError = TurnEngineError::Aborted.into();
608 assert!(matches!(err, ReActExecutorError::Other(_)));
609 }
610
611 #[test]
612 fn test_error_from_turn_engine_other() {
613 let err: ReActExecutorError = TurnEngineError::Other("other".to_string()).into();
614 assert!(matches!(err, ReActExecutorError::Other(_)));
615 }
616
617 #[test]
618 fn test_react_agent_config() {
619 let mock = MockAgentImpl::new("cfg_test", "desc");
620 let agent = ReActAgent::new(mock);
621 assert_eq!(agent.config().max_turns, 10);
622 }
623
624 #[test]
625 fn test_react_agent_metadata_and_output_conversion() {
626 let mock = MockAgentImpl::new("react_meta", "desc");
627 let agent = ReActAgent::new(mock);
628 let cloned = agent.clone();
629 assert_eq!(cloned.name(), "react_meta");
630 assert_eq!(cloned.description(), "desc");
631
632 let output = ReActAgentOutput {
633 response: "resp".to_string(),
634 tool_calls: vec![],
635 done: true,
636 };
637 let value: Value = output.clone().into();
638 assert_eq!(value["response"], "resp");
639 let string: String = output.into();
640 assert_eq!(string, "resp");
641 }
642
643 #[tokio::test]
644 async fn test_react_agent_execute() {
645 use crate::agent::{AgentConfig, Context};
646 use crate::tests::MockLLMProvider;
647 use autoagents_protocol::ActorID;
648
649 let mock = MockAgentImpl::new("exec_test", "desc");
650 let agent = ReActAgent::new(mock);
651 let llm = std::sync::Arc::new(MockLLMProvider {});
652 let config = AgentConfig {
653 id: ActorID::new_v4(),
654 name: "exec_test".to_string(),
655 description: "desc".to_string(),
656 output_schema: None,
657 };
658 let context = Arc::new(Context::new(llm, None).with_config(config));
659 let task = crate::agent::task::Task::new("test");
660
661 let result = agent.execute(&task, context).await;
662 assert!(result.is_ok());
663 let output = result.unwrap();
664 assert!(output.done);
665 assert_eq!(output.response, "Mock response");
666 }
667
668 #[tokio::test]
669 async fn test_react_agent_execute_with_tool_calls() {
670 use crate::agent::{AgentConfig, Context};
671 use autoagents_protocol::ActorID;
672
673 let tool_call = ToolCall {
674 id: "call_1".to_string(),
675 call_type: "function".to_string(),
676 function: autoagents_llm::FunctionCall {
677 name: "tool_a".to_string(),
678 arguments: r#"{"value":1}"#.to_string(),
679 },
680 };
681
682 let llm = Arc::new(ConfigurableLLMProvider {
683 chat_response: StaticChatResponse {
684 text: Some("Use tool".to_string()),
685 tool_calls: Some(vec![tool_call.clone()]),
686 usage: None,
687 thinking: None,
688 },
689 ..ConfigurableLLMProvider::default()
690 });
691
692 let mock = MockAgentImpl::new("exec_tool", "desc");
693 let agent = ReActAgent::new(mock);
694 let config = AgentConfig {
695 id: ActorID::new_v4(),
696 name: "exec_tool".to_string(),
697 description: "desc".to_string(),
698 output_schema: None,
699 };
700
701 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
702 let context = Arc::new(
703 Context::new(llm, None)
704 .with_config(config)
705 .with_tools(vec![Box::new(tool)]),
706 );
707 let task = crate::agent::task::Task::new("test");
708
709 let result = agent.execute(&task, context).await.unwrap();
710 assert!(result.done);
711 assert!(!result.tool_calls.is_empty());
712 assert!(result.tool_calls[0].success);
713 }
714
715 #[tokio::test]
716 async fn test_react_agent_execute_stream_text() {
717 use crate::agent::{AgentConfig, Context};
718 use autoagents_protocol::ActorID;
719 use futures::StreamExt;
720
721 let llm = Arc::new(ConfigurableLLMProvider {
722 stream_chunks: vec![
723 StreamChunk::Text("Hello ".to_string()),
724 StreamChunk::Text("world".to_string()),
725 StreamChunk::Done {
726 stop_reason: "end_turn".to_string(),
727 },
728 ],
729 ..ConfigurableLLMProvider::default()
730 });
731
732 let mock = MockAgentImpl::new("stream_test", "desc");
733 let agent = ReActAgent::new(mock);
734 let config = AgentConfig {
735 id: ActorID::new_v4(),
736 name: "stream_test".to_string(),
737 description: "desc".to_string(),
738 output_schema: None,
739 };
740 let context = Arc::new(Context::new(llm, None).with_config(config));
741 let task = crate::agent::task::Task::new("test");
742
743 let mut stream = agent.execute_stream(&task, context).await.unwrap();
744 let mut final_output = None;
745 while let Some(item) = stream.next().await {
746 let output = item.unwrap();
747 if output.done {
748 final_output = Some(output);
749 break;
750 }
751 }
752
753 let output = final_output.expect("final output");
754 assert_eq!(output.response, "Hello world");
755 assert!(output.done);
756 }
757
758 #[tokio::test]
759 async fn test_react_agent_execute_stream_tool_results() {
760 use crate::agent::{AgentConfig, Context};
761 use autoagents_protocol::ActorID;
762 use futures::StreamExt;
763
764 let tool_call = ToolCall {
765 id: "call_1".to_string(),
766 call_type: "function".to_string(),
767 function: FunctionCall {
768 name: "tool_a".to_string(),
769 arguments: r#"{"value":1}"#.to_string(),
770 },
771 };
772
773 let llm = Arc::new(ConfigurableLLMProvider {
774 stream_chunks: vec![StreamChunk::ToolUseComplete {
775 index: 0,
776 tool_call: tool_call.clone(),
777 }],
778 ..ConfigurableLLMProvider::default()
779 });
780
781 let mock = MockAgentImpl::new("stream_tool", "desc");
782 let agent = ReActAgent::new(mock);
783 let config = AgentConfig {
784 id: ActorID::new_v4(),
785 name: "stream_tool".to_string(),
786 description: "desc".to_string(),
787 output_schema: None,
788 };
789 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
790 let context = Arc::new(
791 Context::new(llm, None)
792 .with_config(config)
793 .with_tools(vec![Box::new(tool)]),
794 );
795 let task = crate::agent::task::Task::new("test");
796
797 let mut stream = agent.execute_stream(&task, context).await.unwrap();
798 let mut saw_tool_results = false;
799 let mut final_output = None;
800
801 while let Some(item) = stream.next().await {
802 let output = item.unwrap();
803 if !output.tool_calls.is_empty() {
804 saw_tool_results = true;
805 assert!(output.tool_calls[0].success);
806 }
807 if output.done {
808 final_output = Some(output);
809 break;
810 }
811 }
812
813 assert!(saw_tool_results);
814 let output = final_output.expect("final output");
815 assert!(output.done);
816 assert!(!output.tool_calls.is_empty());
817 }
818
819 #[tokio::test]
820 async fn test_react_agent_execute_aborts_on_hook() {
821 use crate::agent::{AgentConfig, Context};
822 use crate::tests::MockLLMProvider;
823 use autoagents_protocol::ActorID;
824
825 let agent = ReActAgent::new(AbortAgent);
826 let llm = Arc::new(MockLLMProvider {});
827 let config = AgentConfig {
828 id: ActorID::new_v4(),
829 name: "abort_agent".to_string(),
830 description: "abort".to_string(),
831 output_schema: None,
832 };
833 let context = Arc::new(Context::new(llm, None).with_config(config));
834 let task = crate::agent::task::Task::new("abort");
835
836 let err = agent.execute(&task, context).await.unwrap_err();
837 assert!(err.to_string().contains("aborted"));
838 }
839
840 #[tokio::test]
841 async fn test_react_agent_execute_stream_aborts_on_hook() {
842 use crate::agent::{AgentConfig, Context};
843 use crate::tests::MockLLMProvider;
844 use autoagents_protocol::ActorID;
845
846 let agent = ReActAgent::new(AbortAgent);
847 let llm = Arc::new(MockLLMProvider {});
848 let config = AgentConfig {
849 id: ActorID::new_v4(),
850 name: "abort_agent".to_string(),
851 description: "abort".to_string(),
852 output_schema: None,
853 };
854 let context = Arc::new(Context::new(llm, None).with_config(config));
855 let task = crate::agent::task::Task::new("abort");
856
857 let err = agent.execute_stream(&task, context).await.err().unwrap();
858 assert!(err.to_string().contains("aborted"));
859 }
860}