use crate::agent::executor::AgentExecutor;
use crate::agent::executor::event_helper::EventHelper;
use crate::agent::executor::turn_engine::{
TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, record_task_state,
};
use crate::agent::task::Task;
use crate::agent::{AgentDeriveT, Context, ExecutorConfig};
use crate::channel::channel;
use crate::tool::{ToolCallResult, ToolT};
use crate::utils::{receiver_into_stream, spawn_future};
use async_trait::async_trait;
use autoagents_llm::ToolCall;
use autoagents_llm::error::LLMError;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
use thiserror::Error;
#[cfg(not(target_arch = "wasm32"))]
pub use tokio::sync::mpsc::error::SendError;
#[cfg(target_arch = "wasm32")]
type SendError = futures::channel::mpsc::SendError;
use crate::agent::hooks::{AgentHooks, HookOutcome};
use autoagents_protocol::Event;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReActAgentOutput {
pub response: String,
pub tool_calls: Vec<ToolCallResult>,
pub done: bool,
}
impl From<ReActAgentOutput> for Value {
fn from(output: ReActAgentOutput) -> Self {
serde_json::to_value(output).unwrap_or(Value::Null)
}
}
impl From<ReActAgentOutput> for String {
fn from(output: ReActAgentOutput) -> Self {
output.response
}
}
impl ReActAgentOutput {
pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
serde_json::from_str::<T>(&self.response)
}
pub fn parse_or_map<T, F>(&self, fallback: F) -> T
where
T: for<'de> serde::Deserialize<'de>,
F: FnOnce(&str) -> T,
{
self.try_parse::<T>()
.unwrap_or_else(|_| fallback(&self.response))
}
}
fn dedupe_tool_calls(tool_calls: Vec<ToolCallResult>) -> Vec<ToolCallResult> {
let mut seen = HashSet::new();
let mut unique = Vec::with_capacity(tool_calls.len());
for call in tool_calls {
let key = format!(
"{}|{}|{}|{}",
call.tool_name, call.success, call.arguments, call.result
);
if seen.insert(key) {
unique.push(call);
}
}
unique
}
impl ReActAgentOutput {
#[allow(clippy::result_large_err)]
pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
where
T: for<'de> serde::Deserialize<'de>,
{
let react_output: Self = serde_json::from_value(val)
.map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
serde_json::from_str(&react_output.response)
.map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
}
}
#[derive(Error, Debug)]
pub enum ReActExecutorError {
#[error("LLM error: {0}")]
LLMError(
#[from]
#[source]
LLMError,
),
#[error("Maximum turns exceeded: {max_turns}")]
MaxTurnsExceeded { max_turns: usize },
#[error("Other error: {0}")]
Other(String),
#[cfg(not(target_arch = "wasm32"))]
#[error("Event error: {0}")]
EventError(#[from] SendError<Event>),
#[cfg(target_arch = "wasm32")]
#[error("Event error: {0}")]
EventError(#[from] SendError),
#[error("Extracting Agent Output Error: {0}")]
AgentOutputError(String),
}
impl From<TurnEngineError> for ReActExecutorError {
fn from(error: TurnEngineError) -> Self {
match error {
TurnEngineError::LLMError(err) => err.into(),
TurnEngineError::Aborted => {
ReActExecutorError::Other("Run aborted by hook".to_string())
}
TurnEngineError::Other(err) => ReActExecutorError::Other(err),
}
}
}
#[derive(Debug)]
pub struct ReActAgent<T: AgentDeriveT> {
inner: Arc<T>,
max_turns: usize,
}
impl<T: AgentDeriveT> Clone for ReActAgent<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
max_turns: self.max_turns,
}
}
}
impl<T: AgentDeriveT> ReActAgent<T> {
pub fn new(inner: T) -> Self {
Self {
inner: Arc::new(inner),
max_turns: 10,
}
}
pub fn with_max_turns(inner: T, max_turns: usize) -> Self {
Self {
inner: Arc::new(inner),
max_turns: max_turns.max(1),
}
}
}
impl<T: AgentDeriveT> Deref for ReActAgent<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[async_trait]
impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
type Output = <T as AgentDeriveT>::Output;
fn description(&self) -> &str {
self.inner.description()
}
fn output_schema(&self) -> Option<Value> {
self.inner.output_schema()
}
fn name(&self) -> &str {
self.inner.name()
}
fn tools(&self) -> Vec<Box<dyn ToolT>> {
self.inner.tools()
}
}
#[async_trait]
impl<T> AgentHooks for ReActAgent<T>
where
T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
{
async fn on_agent_create(&self) {
self.inner.on_agent_create().await
}
async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
self.inner.on_run_start(task, ctx).await
}
async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
self.inner.on_run_complete(task, result, ctx).await
}
async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
self.inner.on_turn_start(turn_index, ctx).await
}
async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
self.inner.on_turn_complete(turn_index, ctx).await
}
async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
self.inner.on_tool_call(tool_call, ctx).await
}
async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
self.inner.on_tool_start(tool_call, ctx).await
}
async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
self.inner.on_tool_result(tool_call, result, ctx).await
}
async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
self.inner.on_tool_error(tool_call, err, ctx).await
}
async fn on_agent_shutdown(&self) {
self.inner.on_agent_shutdown().await
}
}
#[async_trait]
impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
type Output = ReActAgentOutput;
type Error = ReActExecutorError;
fn config(&self) -> ExecutorConfig {
ExecutorConfig {
max_turns: self.max_turns,
}
}
async fn execute(
&self,
task: &Task,
context: Arc<Context>,
) -> Result<Self::Output, Self::Error> {
record_task_state(&context, task);
let tx_event = context.tx().ok();
EventHelper::send_task_started(
&tx_event,
task.submission_id,
context.config().id,
context.config().name.clone(),
task.prompt.clone(),
)
.await;
let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
let mut turn_state = engine.turn_state(&context);
let max_turns = self.config().max_turns;
let mut accumulated_tool_calls = Vec::new();
let mut final_response = String::new();
for turn_index in 0..max_turns {
let result = engine
.run_turn(self, task, &context, &mut turn_state, turn_index, max_turns)
.await?;
match result {
crate::agent::executor::TurnResult::Complete(output) => {
final_response = output.response.clone();
accumulated_tool_calls.extend(output.tool_calls);
accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
return Ok(ReActAgentOutput {
response: final_response,
done: true,
tool_calls: accumulated_tool_calls,
});
}
crate::agent::executor::TurnResult::Continue(Some(output)) => {
if !output.response.is_empty() {
final_response = output.response;
}
accumulated_tool_calls.extend(output.tool_calls);
accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
}
crate::agent::executor::TurnResult::Continue(None) => {}
}
}
if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
return Ok(ReActAgentOutput {
response: final_response,
done: true,
tool_calls: accumulated_tool_calls,
});
}
Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
}
async fn execute_stream(
&self,
task: &Task,
context: Arc<Context>,
) -> Result<
Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
Self::Error,
> {
record_task_state(&context, task);
let tx_event = context.tx().ok();
EventHelper::send_task_started(
&tx_event,
task.submission_id,
context.config().id,
context.config().name.clone(),
task.prompt.clone(),
)
.await;
let engine = TurnEngine::new(TurnEngineConfig::react(self.config().max_turns));
let mut turn_state = engine.turn_state(&context);
let max_turns = self.config().max_turns;
let context_clone = context.clone();
let task = task.clone();
let executor = self.clone();
let (tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
spawn_future(async move {
let mut accumulated_tool_calls = Vec::new();
let mut final_response = String::new();
for turn_index in 0..max_turns {
let turn_stream = engine
.run_turn_stream(
executor.clone(),
&task,
context_clone.clone(),
&mut turn_state,
turn_index,
max_turns,
)
.await;
let mut turn_result = None;
match turn_stream {
Ok(mut stream) => {
use futures::StreamExt;
while let Some(delta_result) = stream.next().await {
match delta_result {
Ok(TurnDelta::Text(content)) => {
let _ = tx
.send(Ok(ReActAgentOutput {
response: content,
tool_calls: Vec::new(),
done: false,
}))
.await;
}
Ok(TurnDelta::ReasoningContent(_)) => {}
Ok(TurnDelta::ToolResults(tool_results)) => {
accumulated_tool_calls.extend(tool_results);
accumulated_tool_calls =
dedupe_tool_calls(accumulated_tool_calls);
let _ = tx
.send(Ok(ReActAgentOutput {
response: String::new(),
tool_calls: accumulated_tool_calls.clone(),
done: false,
}))
.await;
}
Ok(TurnDelta::Done(result)) => {
turn_result = Some(result);
break;
}
Err(err) => {
let _ = tx.send(Err(err.into())).await;
return;
}
}
}
}
Err(err) => {
let _ = tx.send(Err(err.into())).await;
return;
}
}
let Some(result) = turn_result else {
let _ = tx
.send(Err(ReActExecutorError::Other(
"Stream ended without final result".to_string(),
)))
.await;
return;
};
match result {
crate::agent::executor::TurnResult::Complete(output) => {
final_response = output.response.clone();
accumulated_tool_calls.extend(output.tool_calls);
accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
break;
}
crate::agent::executor::TurnResult::Continue(Some(output)) => {
if !output.response.is_empty() {
final_response = output.response;
}
accumulated_tool_calls.extend(output.tool_calls);
accumulated_tool_calls = dedupe_tool_calls(accumulated_tool_calls);
}
crate::agent::executor::TurnResult::Continue(None) => {}
}
}
let tx_event = context_clone.tx().ok();
EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
let output = ReActAgentOutput {
response: final_response.clone(),
done: true,
tool_calls: accumulated_tool_calls.clone(),
};
let _ = tx.send(Ok(output.clone())).await;
if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
let result = serde_json::to_string_pretty(&output)
.unwrap_or_else(|_| output.response.clone());
EventHelper::send_task_completed(
&tx_event,
task.submission_id,
context_clone.config().id,
context_clone.config().name.clone(),
result,
)
.await;
}
});
Ok(receiver_into_stream(rx))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, StaticChatResponse};
use async_trait::async_trait;
use autoagents_llm::chat::StreamChunk;
use autoagents_llm::{FunctionCall, ToolCall};
#[derive(Debug)]
struct LocalTool {
name: String,
output: serde_json::Value,
}
impl LocalTool {
fn new(name: &str, output: serde_json::Value) -> Self {
Self {
name: name.to_string(),
output,
}
}
}
impl crate::tool::ToolT for LocalTool {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
"local tool"
}
fn args_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
}
#[async_trait]
impl crate::tool::ToolRuntime for LocalTool {
async fn execute(
&self,
_args: serde_json::Value,
) -> Result<serde_json::Value, crate::tool::ToolCallError> {
Ok(self.output.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct ReActTestOutput {
value: i32,
message: String,
}
#[derive(Debug, Clone)]
struct AbortAgent;
#[async_trait]
impl AgentDeriveT for AbortAgent {
type Output = String;
fn description(&self) -> &str {
"abort"
}
fn output_schema(&self) -> Option<Value> {
None
}
fn name(&self) -> &str {
"abort_agent"
}
fn tools(&self) -> Vec<Box<dyn ToolT>> {
vec![]
}
}
#[async_trait]
impl AgentHooks for AbortAgent {
async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
HookOutcome::Abort
}
}
#[test]
fn test_extract_agent_output_success() {
let agent_output = ReActTestOutput {
value: 42,
message: "Hello, world!".to_string(),
};
let react_output = ReActAgentOutput {
response: serde_json::to_string(&agent_output).unwrap(),
done: true,
tool_calls: vec![],
};
let react_value = serde_json::to_value(react_output).unwrap();
let extracted: ReActTestOutput =
ReActAgentOutput::extract_agent_output(react_value).unwrap();
assert_eq!(extracted, agent_output);
}
#[test]
fn test_extract_agent_output_invalid_react() {
let result = ReActAgentOutput::extract_agent_output::<ReActTestOutput>(
serde_json::json!({"not": "react"}),
);
assert!(result.is_err());
}
#[test]
fn test_react_agent_output_try_parse_success() {
let output = ReActAgentOutput {
response: r#"{"value":1,"message":"hi"}"#.to_string(),
tool_calls: vec![],
done: true,
};
let parsed: ReActTestOutput = output.try_parse().unwrap();
assert_eq!(parsed.value, 1);
}
#[test]
fn test_react_agent_output_try_parse_failure() {
let output = ReActAgentOutput {
response: "not json".to_string(),
tool_calls: vec![],
done: true,
};
assert!(output.try_parse::<ReActTestOutput>().is_err());
}
#[test]
fn test_react_agent_output_parse_or_map() {
let output = ReActAgentOutput {
response: "plain text".to_string(),
tool_calls: vec![],
done: true,
};
let result: String = output.parse_or_map(|s| s.to_uppercase());
assert_eq!(result, "PLAIN TEXT");
}
#[test]
fn test_error_from_turn_engine_llm() {
let err: ReActExecutorError =
TurnEngineError::LLMError(LLMError::Generic("llm err".to_string())).into();
assert!(matches!(err, ReActExecutorError::LLMError(_)));
}
#[test]
fn test_error_from_turn_engine_aborted() {
let err: ReActExecutorError = TurnEngineError::Aborted.into();
assert!(matches!(err, ReActExecutorError::Other(_)));
}
#[test]
fn test_error_from_turn_engine_other() {
let err: ReActExecutorError = TurnEngineError::Other("other".to_string()).into();
assert!(matches!(err, ReActExecutorError::Other(_)));
}
#[test]
fn test_react_agent_config() {
let mock = MockAgentImpl::new("cfg_test", "desc");
let agent = ReActAgent::new(mock);
assert_eq!(agent.config().max_turns, 10);
}
#[test]
fn test_react_agent_metadata_and_output_conversion() {
let mock = MockAgentImpl::new("react_meta", "desc");
let agent = ReActAgent::new(mock);
let cloned = agent.clone();
assert_eq!(cloned.name(), "react_meta");
assert_eq!(cloned.description(), "desc");
let output = ReActAgentOutput {
response: "resp".to_string(),
tool_calls: vec![],
done: true,
};
let value: Value = output.clone().into();
assert_eq!(value["response"], "resp");
let string: String = output.into();
assert_eq!(string, "resp");
}
#[tokio::test]
async fn test_react_agent_execute() {
use crate::agent::{AgentConfig, Context};
use crate::tests::MockLLMProvider;
use autoagents_protocol::ActorID;
let mock = MockAgentImpl::new("exec_test", "desc");
let agent = ReActAgent::new(mock);
let llm = std::sync::Arc::new(MockLLMProvider {});
let config = AgentConfig {
id: ActorID::new_v4(),
name: "exec_test".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let context = Arc::new(Context::new(llm, None).with_config(config));
let task = crate::agent::task::Task::new("test");
let result = agent.execute(&task, context).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.done);
assert_eq!(output.response, "Mock response");
}
#[tokio::test]
async fn test_react_agent_execute_with_tool_calls() {
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
let tool_call = ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: autoagents_llm::FunctionCall {
name: "tool_a".to_string(),
arguments: r#"{"value":1}"#.to_string(),
},
};
let llm = Arc::new(ConfigurableLLMProvider {
chat_response: StaticChatResponse {
text: Some("Use tool".to_string()),
tool_calls: Some(vec![tool_call.clone()]),
usage: None,
thinking: None,
},
..ConfigurableLLMProvider::default()
});
let mock = MockAgentImpl::new("exec_tool", "desc");
let agent = ReActAgent::new(mock);
let config = AgentConfig {
id: ActorID::new_v4(),
name: "exec_tool".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
let context = Arc::new(
Context::new(llm, None)
.with_config(config)
.with_tools(vec![Box::new(tool)]),
);
let task = crate::agent::task::Task::new("test");
let result = agent.execute(&task, context).await.unwrap();
assert!(result.done);
assert!(!result.tool_calls.is_empty());
assert!(result.tool_calls[0].success);
}
#[tokio::test]
async fn test_react_agent_execute_stream_text() {
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
use futures::StreamExt;
let llm = Arc::new(ConfigurableLLMProvider {
stream_chunks: vec![
StreamChunk::Text("Hello ".to_string()),
StreamChunk::Text("world".to_string()),
StreamChunk::Done {
stop_reason: "end_turn".to_string(),
},
],
..ConfigurableLLMProvider::default()
});
let mock = MockAgentImpl::new("stream_test", "desc");
let agent = ReActAgent::new(mock);
let config = AgentConfig {
id: ActorID::new_v4(),
name: "stream_test".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let context = Arc::new(Context::new(llm, None).with_config(config));
let task = crate::agent::task::Task::new("test");
let mut stream = agent.execute_stream(&task, context).await.unwrap();
let mut final_output = None;
while let Some(item) = stream.next().await {
let output = item.unwrap();
if output.done {
final_output = Some(output);
break;
}
}
let output = final_output.expect("final output");
assert_eq!(output.response, "Hello world");
assert!(output.done);
}
#[tokio::test]
async fn test_react_agent_execute_stream_ignores_reasoning_output() {
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
use futures::StreamExt;
let llm = Arc::new(ConfigurableLLMProvider {
stream_chunks: vec![
StreamChunk::ReasoningContent("plan".to_string()),
StreamChunk::Text("done".to_string()),
StreamChunk::Done {
stop_reason: "end_turn".to_string(),
},
],
..ConfigurableLLMProvider::default()
});
let mock = MockAgentImpl::new("stream_reasoning_test", "desc");
let agent = ReActAgent::new(mock);
let config = AgentConfig {
id: ActorID::new_v4(),
name: "stream_reasoning_test".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let context = Arc::new(Context::new(llm, None).with_config(config));
let task = crate::agent::task::Task::new("test");
let mut stream = agent.execute_stream(&task, context).await.unwrap();
let mut outputs = Vec::new();
while let Some(item) = stream.next().await {
outputs.push(item.unwrap());
}
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].response, "done");
assert!(!outputs[0].done);
assert_eq!(outputs[1].response, "done");
assert!(outputs[1].done);
}
#[tokio::test]
async fn test_react_agent_execute_stream_tool_results() {
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
use futures::StreamExt;
let tool_call = ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "tool_a".to_string(),
arguments: r#"{"value":1}"#.to_string(),
},
};
let llm = Arc::new(ConfigurableLLMProvider {
stream_chunks: vec![StreamChunk::ToolUseComplete {
index: 0,
tool_call: tool_call.clone(),
}],
..ConfigurableLLMProvider::default()
});
let mock = MockAgentImpl::new("stream_tool", "desc");
let agent = ReActAgent::new(mock);
let config = AgentConfig {
id: ActorID::new_v4(),
name: "stream_tool".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
let context = Arc::new(
Context::new(llm, None)
.with_config(config)
.with_tools(vec![Box::new(tool)]),
);
let task = crate::agent::task::Task::new("test");
let mut stream = agent.execute_stream(&task, context).await.unwrap();
let mut saw_tool_results = false;
let mut final_output = None;
while let Some(item) = stream.next().await {
let output = item.unwrap();
if !output.tool_calls.is_empty() {
saw_tool_results = true;
assert!(output.tool_calls[0].success);
}
if output.done {
final_output = Some(output);
break;
}
}
assert!(saw_tool_results);
let output = final_output.expect("final output");
assert!(output.done);
assert!(!output.tool_calls.is_empty());
}
#[tokio::test]
async fn test_react_agent_run_aborts_on_hook() {
use crate::agent::AgentBuilder;
use crate::agent::direct::DirectAgent;
use crate::agent::error::RunnableAgentError;
use crate::tests::MockLLMProvider;
let agent = ReActAgent::new(AbortAgent);
let llm = Arc::new(MockLLMProvider {});
let handle = AgentBuilder::<_, DirectAgent>::new(agent)
.llm(llm)
.build()
.await
.expect("build should succeed");
let task = crate::agent::task::Task::new("abort");
let err = handle.agent.run(task).await.expect_err("expected abort");
assert!(matches!(err, RunnableAgentError::Abort));
}
#[tokio::test]
async fn test_react_agent_run_stream_aborts_on_hook() {
use crate::agent::AgentBuilder;
use crate::agent::direct::DirectAgent;
use crate::agent::error::RunnableAgentError;
use crate::tests::MockLLMProvider;
let agent = ReActAgent::new(AbortAgent);
let llm = Arc::new(MockLLMProvider {});
let handle = AgentBuilder::<_, DirectAgent>::new(agent)
.llm(llm)
.build()
.await
.expect("build should succeed");
let task = crate::agent::task::Task::new("abort");
let err = match handle.agent.run_stream(task).await {
Ok(_) => panic!("expected abort"),
Err(err) => err,
};
assert!(matches!(err, RunnableAgentError::Abort));
}
}