1use crate::agent::executor::AgentExecutor;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::turn_engine::{
4 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
5};
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22pub use tokio::sync::mpsc::error::SendError;
23
24#[cfg(target_arch = "wasm32")]
25type SendError = futures::channel::mpsc::SendError;
26
27use crate::agent::hooks::{AgentHooks, HookOutcome};
28use autoagents_protocol::Event;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ReActAgentOutput {
33 pub response: String,
34 pub tool_calls: Vec<ToolCallResult>,
35 pub done: bool,
36}
37
38impl From<ReActAgentOutput> for Value {
39 fn from(output: ReActAgentOutput) -> Self {
40 serde_json::to_value(output).unwrap_or(Value::Null)
41 }
42}
43impl From<ReActAgentOutput> for String {
44 fn from(output: ReActAgentOutput) -> Self {
45 output.response
46 }
47}
48
49impl ReActAgentOutput {
50 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
53 serde_json::from_str::<T>(&self.response)
54 }
55
56 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
60 where
61 T: for<'de> serde::Deserialize<'de>,
62 F: FnOnce(&str) -> T,
63 {
64 self.try_parse::<T>()
65 .unwrap_or_else(|_| fallback(&self.response))
66 }
67}
68
69impl ReActAgentOutput {
70 #[allow(clippy::result_large_err)]
72 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
73 where
74 T: for<'de> serde::Deserialize<'de>,
75 {
76 let react_output: Self = serde_json::from_value(val)
77 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
78 serde_json::from_str(&react_output.response)
79 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
80 }
81}
82
83#[derive(Error, Debug)]
84pub enum ReActExecutorError {
85 #[error("LLM error: {0}")]
86 LLMError(String),
87
88 #[error("Maximum turns exceeded: {max_turns}")]
89 MaxTurnsExceeded { max_turns: usize },
90
91 #[error("Other error: {0}")]
92 Other(String),
93
94 #[cfg(not(target_arch = "wasm32"))]
95 #[error("Event error: {0}")]
96 EventError(#[from] SendError<Event>),
97
98 #[cfg(target_arch = "wasm32")]
99 #[error("Event error: {0}")]
100 EventError(#[from] SendError),
101
102 #[error("Extracting Agent Output Error: {0}")]
103 AgentOutputError(String),
104}
105
106impl From<TurnEngineError> for ReActExecutorError {
107 fn from(error: TurnEngineError) -> Self {
108 match error {
109 TurnEngineError::LLMError(err) => ReActExecutorError::LLMError(err),
110 TurnEngineError::Aborted => {
111 ReActExecutorError::Other("Run aborted by hook".to_string())
112 }
113 TurnEngineError::Other(err) => ReActExecutorError::Other(err),
114 }
115 }
116}
117
118#[derive(Debug)]
123pub struct ReActAgent<T: AgentDeriveT> {
124 inner: Arc<T>,
125}
126
127impl<T: AgentDeriveT> Clone for ReActAgent<T> {
128 fn clone(&self) -> Self {
129 Self {
130 inner: Arc::clone(&self.inner),
131 }
132 }
133}
134
135impl<T: AgentDeriveT> ReActAgent<T> {
136 pub fn new(inner: T) -> Self {
137 Self {
138 inner: Arc::new(inner),
139 }
140 }
141}
142
143impl<T: AgentDeriveT> Deref for ReActAgent<T> {
144 type Target = T;
145
146 fn deref(&self) -> &Self::Target {
147 &self.inner
148 }
149}
150
151#[async_trait]
153impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
154 type Output = <T as AgentDeriveT>::Output;
155
156 fn description(&self) -> &str {
157 self.inner.description()
158 }
159
160 fn output_schema(&self) -> Option<Value> {
161 self.inner.output_schema()
162 }
163
164 fn name(&self) -> &str {
165 self.inner.name()
166 }
167
168 fn tools(&self) -> Vec<Box<dyn ToolT>> {
169 self.inner.tools()
170 }
171}
172
173#[async_trait]
174impl<T> AgentHooks for ReActAgent<T>
175where
176 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
177{
178 async fn on_agent_create(&self) {
179 self.inner.on_agent_create().await
180 }
181
182 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
183 self.inner.on_run_start(task, ctx).await
184 }
185
186 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
187 self.inner.on_run_complete(task, result, ctx).await
188 }
189
190 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
191 self.inner.on_turn_start(turn_index, ctx).await
192 }
193
194 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
195 self.inner.on_turn_complete(turn_index, ctx).await
196 }
197
198 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
199 self.inner.on_tool_call(tool_call, ctx).await
200 }
201
202 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
203 self.inner.on_tool_start(tool_call, ctx).await
204 }
205
206 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
207 self.inner.on_tool_result(tool_call, result, ctx).await
208 }
209
210 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
211 self.inner.on_tool_error(tool_call, err, ctx).await
212 }
213 async fn on_agent_shutdown(&self) {
214 self.inner.on_agent_shutdown().await
215 }
216}
217
218#[async_trait]
220impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
221 type Output = ReActAgentOutput;
222 type Error = ReActExecutorError;
223
224 fn config(&self) -> ExecutorConfig {
225 ExecutorConfig { max_turns: 10 }
226 }
227
228 async fn execute(
229 &self,
230 task: &Task,
231 context: Arc<Context>,
232 ) -> Result<Self::Output, Self::Error> {
233 if self.on_run_start(task, &context).await == HookOutcome::Abort {
234 return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
235 }
236
237 record_task_state(&context, task);
238
239 let tx_event = context.tx().ok();
240 EventHelper::send_task_started(
241 &tx_event,
242 task.submission_id,
243 context.config().id,
244 context.config().name.clone(),
245 task.prompt.clone(),
246 )
247 .await;
248
249 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
250 let mut turn_state = engine.turn_state(&context);
251 let max_turns = self.config().max_turns;
252 let mut accumulated_tool_calls = Vec::new();
253 let mut final_response = String::new();
254
255 for turn_index in 0..max_turns {
256 let result = engine
257 .run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
258 .await?;
259
260 match result {
261 crate::agent::executor::TurnResult::Complete(output) => {
262 final_response = output.response.clone();
263 accumulated_tool_calls.extend(output.tool_calls);
264
265 return Ok(ReActAgentOutput {
266 response: final_response,
267 done: true,
268 tool_calls: accumulated_tool_calls,
269 });
270 }
271 crate::agent::executor::TurnResult::Continue(Some(output)) => {
272 if !output.response.is_empty() {
273 final_response = output.response;
274 }
275 accumulated_tool_calls.extend(output.tool_calls);
276 }
277 crate::agent::executor::TurnResult::Continue(None) => {}
278 }
279 }
280
281 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
282 return Ok(ReActAgentOutput {
283 response: final_response,
284 done: true,
285 tool_calls: accumulated_tool_calls,
286 });
287 }
288
289 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
290 }
291
292 async fn execute_stream(
293 &self,
294 task: &Task,
295 context: Arc<Context>,
296 ) -> Result<
297 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
298 Self::Error,
299 > {
300 if self.on_run_start(task, &context).await == HookOutcome::Abort {
301 return Err(ReActExecutorError::Other("Run aborted by hook".to_string()));
302 }
303
304 record_task_state(&context, task);
305
306 let tx_event = context.tx().ok();
307 EventHelper::send_task_started(
308 &tx_event,
309 task.submission_id,
310 context.config().id,
311 context.config().name.clone(),
312 task.prompt.clone(),
313 )
314 .await;
315
316 let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
317 let mut turn_state = engine.turn_state(&context);
318 let max_turns = self.config().max_turns;
319 let context_clone = context.clone();
320 let task = task.clone();
321 let executor = self.clone();
322
323 let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
324
325 spawn_future(async move {
326 let mut accumulated_tool_calls = Vec::new();
327 let mut final_response = String::new();
328
329 for turn_index in 0..max_turns {
330 let turn_stream = engine
331 .run_turn_stream(
332 executor.clone(),
333 &task,
334 context_clone.clone(),
335 &mut turn_state,
336 turn_index,
337 max_turns,
338 )
339 .await;
340
341 let mut turn_result = None;
342
343 match turn_stream {
344 Ok(mut stream) => {
345 use futures::StreamExt;
346 while let Some(delta_result) = stream.next().await {
347 match delta_result {
348 Ok(TurnDelta::Text(content)) => {
349 let _ = tx
350 .send(Ok(ReActAgentOutput {
351 response: content,
352 tool_calls: Vec::new(),
353 done: false,
354 }))
355 .await;
356 }
357 Ok(TurnDelta::ToolResults(tool_results)) => {
358 accumulated_tool_calls.extend(tool_results);
359 let _ = tx
360 .send(Ok(ReActAgentOutput {
361 response: String::new(),
362 tool_calls: accumulated_tool_calls.clone(),
363 done: false,
364 }))
365 .await;
366 }
367 Ok(TurnDelta::Done(result)) => {
368 turn_result = Some(result);
369 break;
370 }
371 Err(err) => {
372 let _ = tx.send(Err(err.into())).await;
373 return;
374 }
375 }
376 }
377 }
378 Err(err) => {
379 let _ = tx.send(Err(err.into())).await;
380 return;
381 }
382 }
383
384 let Some(result) = turn_result else {
385 let _ = tx
386 .send(Err(ReActExecutorError::Other(
387 "Stream ended without final result".to_string(),
388 )))
389 .await;
390 return;
391 };
392
393 match result {
394 crate::agent::executor::TurnResult::Complete(output) => {
395 final_response = output.response.clone();
396 accumulated_tool_calls.extend(output.tool_calls);
397 break;
398 }
399 crate::agent::executor::TurnResult::Continue(Some(output)) => {
400 if !output.response.is_empty() {
401 final_response = output.response;
402 }
403 accumulated_tool_calls.extend(output.tool_calls);
404 }
405 crate::agent::executor::TurnResult::Continue(None) => {}
406 }
407 }
408
409 let tx_event = context_clone.tx().ok();
410 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
411 let output = ReActAgentOutput {
412 response: final_response.clone(),
413 done: true,
414 tool_calls: accumulated_tool_calls.clone(),
415 };
416 let _ = tx.send(Ok(output.clone())).await;
417
418 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
419 let result = serde_json::to_string_pretty(&output)
420 .unwrap_or_else(|_| output.response.clone());
421 EventHelper::send_task_completed(
422 &tx_event,
423 task.submission_id,
424 context_clone.config().id,
425 context_clone.config().name.clone(),
426 result,
427 )
428 .await;
429 }
430 });
431
432 Ok(receiver_into_stream(rx))
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::tests::{
440 ConfigurableLLMProvider, MockAgentImpl, StaticChatResponse,
441 TestAgentOutput as TestUtilsOutput,
442 };
443 use async_trait::async_trait;
444 use autoagents_llm::chat::StreamChunk;
445 use autoagents_llm::{FunctionCall, ToolCall};
446
447 #[derive(Debug)]
448 struct LocalTool {
449 name: String,
450 output: serde_json::Value,
451 }
452
453 impl LocalTool {
454 fn new(name: &str, output: serde_json::Value) -> Self {
455 Self {
456 name: name.to_string(),
457 output,
458 }
459 }
460 }
461
462 impl crate::tool::ToolT for LocalTool {
463 fn name(&self) -> &str {
464 &self.name
465 }
466
467 fn description(&self) -> &str {
468 "local tool"
469 }
470
471 fn args_schema(&self) -> serde_json::Value {
472 serde_json::json!({"type": "object"})
473 }
474 }
475
476 #[async_trait]
477 impl crate::tool::ToolRuntime for LocalTool {
478 async fn execute(
479 &self,
480 _args: serde_json::Value,
481 ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
482 Ok(self.output.clone())
483 }
484 }
485
486 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
487 struct ReActTestOutput {
488 value: i32,
489 message: String,
490 }
491
492 #[derive(Debug, Clone)]
493 struct AbortAgent;
494
495 #[async_trait]
496 impl AgentDeriveT for AbortAgent {
497 type Output = TestUtilsOutput;
498
499 fn description(&self) -> &str {
500 "abort"
501 }
502
503 fn output_schema(&self) -> Option<Value> {
504 None
505 }
506
507 fn name(&self) -> &str {
508 "abort_agent"
509 }
510
511 fn tools(&self) -> Vec<Box<dyn ToolT>> {
512 vec![]
513 }
514 }
515
516 #[async_trait]
517 impl AgentHooks for AbortAgent {
518 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
519 HookOutcome::Abort
520 }
521 }
522
523 #[test]
524 fn test_extract_agent_output_success() {
525 let agent_output = ReActTestOutput {
526 value: 42,
527 message: "Hello, world!".to_string(),
528 };
529
530 let react_output = ReActAgentOutput {
531 response: serde_json::to_string(&agent_output).unwrap(),
532 done: true,
533 tool_calls: vec![],
534 };
535
536 let react_value = serde_json::to_value(react_output).unwrap();
537 let extracted: ReActTestOutput =
538 ReActAgentOutput::extract_agent_output(react_value).unwrap();
539 assert_eq!(extracted, agent_output);
540 }
541
542 #[test]
543 fn test_extract_agent_output_invalid_react() {
544 let result = ReActAgentOutput::extract_agent_output::<ReActTestOutput>(
545 serde_json::json!({"not": "react"}),
546 );
547 assert!(result.is_err());
548 }
549
550 #[test]
551 fn test_react_agent_output_try_parse_success() {
552 let output = ReActAgentOutput {
553 response: r#"{"value":1,"message":"hi"}"#.to_string(),
554 tool_calls: vec![],
555 done: true,
556 };
557 let parsed: ReActTestOutput = output.try_parse().unwrap();
558 assert_eq!(parsed.value, 1);
559 }
560
561 #[test]
562 fn test_react_agent_output_try_parse_failure() {
563 let output = ReActAgentOutput {
564 response: "not json".to_string(),
565 tool_calls: vec![],
566 done: true,
567 };
568 assert!(output.try_parse::<ReActTestOutput>().is_err());
569 }
570
571 #[test]
572 fn test_react_agent_output_parse_or_map() {
573 let output = ReActAgentOutput {
574 response: "plain text".to_string(),
575 tool_calls: vec![],
576 done: true,
577 };
578 let result: String = output.parse_or_map(|s| s.to_uppercase());
579 assert_eq!(result, "PLAIN TEXT");
580 }
581
582 #[test]
583 fn test_error_from_turn_engine_llm() {
584 let err: ReActExecutorError = TurnEngineError::LLMError("llm err".to_string()).into();
585 assert!(matches!(err, ReActExecutorError::LLMError(_)));
586 }
587
588 #[test]
589 fn test_error_from_turn_engine_aborted() {
590 let err: ReActExecutorError = TurnEngineError::Aborted.into();
591 assert!(matches!(err, ReActExecutorError::Other(_)));
592 }
593
594 #[test]
595 fn test_error_from_turn_engine_other() {
596 let err: ReActExecutorError = TurnEngineError::Other("other".to_string()).into();
597 assert!(matches!(err, ReActExecutorError::Other(_)));
598 }
599
600 #[test]
601 fn test_react_agent_config() {
602 let mock = MockAgentImpl::new("cfg_test", "desc");
603 let agent = ReActAgent::new(mock);
604 assert_eq!(agent.config().max_turns, 10);
605 }
606
607 #[test]
608 fn test_react_agent_metadata_and_output_conversion() {
609 let mock = MockAgentImpl::new("react_meta", "desc");
610 let agent = ReActAgent::new(mock);
611 let cloned = agent.clone();
612 assert_eq!(cloned.name(), "react_meta");
613 assert_eq!(cloned.description(), "desc");
614
615 let output = ReActAgentOutput {
616 response: "resp".to_string(),
617 tool_calls: vec![],
618 done: true,
619 };
620 let value: Value = output.clone().into();
621 assert_eq!(value["response"], "resp");
622 let string: String = output.into();
623 assert_eq!(string, "resp");
624 }
625
626 #[tokio::test]
627 async fn test_react_agent_execute() {
628 use crate::agent::{AgentConfig, Context};
629 use crate::tests::MockLLMProvider;
630 use autoagents_protocol::ActorID;
631
632 let mock = MockAgentImpl::new("exec_test", "desc");
633 let agent = ReActAgent::new(mock);
634 let llm = std::sync::Arc::new(MockLLMProvider {});
635 let config = AgentConfig {
636 id: ActorID::new_v4(),
637 name: "exec_test".to_string(),
638 description: "desc".to_string(),
639 output_schema: None,
640 };
641 let context = Arc::new(Context::new(llm, None).with_config(config));
642 let task = crate::agent::task::Task::new("test");
643
644 let result = agent.execute(&task, context).await;
645 assert!(result.is_ok());
646 let output = result.unwrap();
647 assert!(output.done);
648 assert_eq!(output.response, "Mock response");
649 }
650
651 #[tokio::test]
652 async fn test_react_agent_execute_with_tool_calls() {
653 use crate::agent::{AgentConfig, Context};
654 use autoagents_protocol::ActorID;
655
656 let tool_call = ToolCall {
657 id: "call_1".to_string(),
658 call_type: "function".to_string(),
659 function: autoagents_llm::FunctionCall {
660 name: "tool_a".to_string(),
661 arguments: r#"{"value":1}"#.to_string(),
662 },
663 };
664
665 let llm = Arc::new(ConfigurableLLMProvider {
666 chat_response: StaticChatResponse {
667 text: Some("Use tool".to_string()),
668 tool_calls: Some(vec![tool_call.clone()]),
669 usage: None,
670 thinking: None,
671 },
672 ..ConfigurableLLMProvider::default()
673 });
674
675 let mock = MockAgentImpl::new("exec_tool", "desc");
676 let agent = ReActAgent::new(mock);
677 let config = AgentConfig {
678 id: ActorID::new_v4(),
679 name: "exec_tool".to_string(),
680 description: "desc".to_string(),
681 output_schema: None,
682 };
683
684 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
685 let context = Arc::new(
686 Context::new(llm, None)
687 .with_config(config)
688 .with_tools(vec![Box::new(tool)]),
689 );
690 let task = crate::agent::task::Task::new("test");
691
692 let result = agent.execute(&task, context).await.unwrap();
693 assert!(result.done);
694 assert!(!result.tool_calls.is_empty());
695 assert!(result.tool_calls[0].success);
696 }
697
698 #[tokio::test]
699 async fn test_react_agent_execute_stream_text() {
700 use crate::agent::{AgentConfig, Context};
701 use autoagents_protocol::ActorID;
702 use futures::StreamExt;
703
704 let llm = Arc::new(ConfigurableLLMProvider {
705 stream_chunks: vec![
706 StreamChunk::Text("Hello ".to_string()),
707 StreamChunk::Text("world".to_string()),
708 StreamChunk::Done {
709 stop_reason: "end_turn".to_string(),
710 },
711 ],
712 ..ConfigurableLLMProvider::default()
713 });
714
715 let mock = MockAgentImpl::new("stream_test", "desc");
716 let agent = ReActAgent::new(mock);
717 let config = AgentConfig {
718 id: ActorID::new_v4(),
719 name: "stream_test".to_string(),
720 description: "desc".to_string(),
721 output_schema: None,
722 };
723 let context = Arc::new(Context::new(llm, None).with_config(config));
724 let task = crate::agent::task::Task::new("test");
725
726 let mut stream = agent.execute_stream(&task, context).await.unwrap();
727 let mut final_output = None;
728 while let Some(item) = stream.next().await {
729 let output = item.unwrap();
730 if output.done {
731 final_output = Some(output);
732 break;
733 }
734 }
735
736 let output = final_output.expect("final output");
737 assert_eq!(output.response, "Hello world");
738 assert!(output.done);
739 }
740
741 #[tokio::test]
742 async fn test_react_agent_execute_stream_tool_results() {
743 use crate::agent::{AgentConfig, Context};
744 use autoagents_protocol::ActorID;
745 use futures::StreamExt;
746
747 let tool_call = ToolCall {
748 id: "call_1".to_string(),
749 call_type: "function".to_string(),
750 function: FunctionCall {
751 name: "tool_a".to_string(),
752 arguments: r#"{"value":1}"#.to_string(),
753 },
754 };
755
756 let llm = Arc::new(ConfigurableLLMProvider {
757 stream_chunks: vec![StreamChunk::ToolUseComplete {
758 index: 0,
759 tool_call: tool_call.clone(),
760 }],
761 ..ConfigurableLLMProvider::default()
762 });
763
764 let mock = MockAgentImpl::new("stream_tool", "desc");
765 let agent = ReActAgent::new(mock);
766 let config = AgentConfig {
767 id: ActorID::new_v4(),
768 name: "stream_tool".to_string(),
769 description: "desc".to_string(),
770 output_schema: None,
771 };
772 let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
773 let context = Arc::new(
774 Context::new(llm, None)
775 .with_config(config)
776 .with_tools(vec![Box::new(tool)]),
777 );
778 let task = crate::agent::task::Task::new("test");
779
780 let mut stream = agent.execute_stream(&task, context).await.unwrap();
781 let mut saw_tool_results = false;
782 let mut final_output = None;
783
784 while let Some(item) = stream.next().await {
785 let output = item.unwrap();
786 if !output.tool_calls.is_empty() {
787 saw_tool_results = true;
788 assert!(output.tool_calls[0].success);
789 }
790 if output.done {
791 final_output = Some(output);
792 break;
793 }
794 }
795
796 assert!(saw_tool_results);
797 let output = final_output.expect("final output");
798 assert!(output.done);
799 assert!(!output.tool_calls.is_empty());
800 }
801
802 #[tokio::test]
803 async fn test_react_agent_execute_aborts_on_hook() {
804 use crate::agent::{AgentConfig, Context};
805 use crate::tests::MockLLMProvider;
806 use autoagents_protocol::ActorID;
807
808 let agent = ReActAgent::new(AbortAgent);
809 let llm = Arc::new(MockLLMProvider {});
810 let config = AgentConfig {
811 id: ActorID::new_v4(),
812 name: "abort_agent".to_string(),
813 description: "abort".to_string(),
814 output_schema: None,
815 };
816 let context = Arc::new(Context::new(llm, None).with_config(config));
817 let task = crate::agent::task::Task::new("abort");
818
819 let err = agent.execute(&task, context).await.unwrap_err();
820 assert!(err.to_string().contains("aborted"));
821 }
822
823 #[tokio::test]
824 async fn test_react_agent_execute_stream_aborts_on_hook() {
825 use crate::agent::{AgentConfig, Context};
826 use crate::tests::MockLLMProvider;
827 use autoagents_protocol::ActorID;
828
829 let agent = ReActAgent::new(AbortAgent);
830 let llm = Arc::new(MockLLMProvider {});
831 let config = AgentConfig {
832 id: ActorID::new_v4(),
833 name: "abort_agent".to_string(),
834 description: "abort".to_string(),
835 output_schema: None,
836 };
837 let context = Arc::new(Context::new(llm, None).with_config(config));
838 let task = crate::agent::task::Task::new("abort");
839
840 let err = agent.execute_stream(&task, context).await.err().unwrap();
841 assert!(err.to_string().contains("aborted"));
842 }
843}