use crate::agent::executor::event_helper::EventHelper;
use crate::agent::executor::turn_engine::{
TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, TurnEngineOutput, record_task_state,
};
use crate::agent::hooks::HookOutcome;
use crate::agent::task::Task;
use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
use crate::channel::channel;
use crate::tool::{ToolCallResult, ToolT};
use crate::utils::{receiver_into_stream, spawn_future};
use async_trait::async_trait;
use autoagents_llm::ToolCall;
use autoagents_llm::error::LLMError;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::ops::Deref;
use std::pin::Pin;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BasicAgentOutput {
pub response: String,
pub done: bool,
}
impl From<BasicAgentOutput> for Value {
fn from(output: BasicAgentOutput) -> Self {
serde_json::to_value(output).unwrap_or(Value::Null)
}
}
impl From<BasicAgentOutput> for String {
fn from(output: BasicAgentOutput) -> Self {
output.response
}
}
impl BasicAgentOutput {
pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
serde_json::from_str::<T>(&self.response)
}
pub fn parse_or_map<T, F>(&self, fallback: F) -> T
where
T: for<'de> serde::Deserialize<'de>,
F: FnOnce(&str) -> T,
{
self.try_parse::<T>()
.unwrap_or_else(|_| fallback(&self.response))
}
}
#[derive(Debug, thiserror::Error)]
pub enum BasicExecutorError {
#[error("LLM error: {0}")]
LLMError(
#[from]
#[source]
LLMError,
),
#[error("Other error: {0}")]
Other(String),
}
impl From<TurnEngineError> for BasicExecutorError {
fn from(error: TurnEngineError) -> Self {
match error {
TurnEngineError::LLMError(err) => err.into(),
TurnEngineError::Aborted => {
BasicExecutorError::Other("Run aborted by hook".to_string())
}
TurnEngineError::Other(err) => BasicExecutorError::Other(err),
}
}
}
#[derive(Debug)]
pub struct BasicAgent<T: AgentDeriveT> {
inner: Arc<T>,
}
impl<T: AgentDeriveT> Clone for BasicAgent<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<T: AgentDeriveT> BasicAgent<T> {
pub fn new(inner: T) -> Self {
Self {
inner: Arc::new(inner),
}
}
}
impl<T: AgentDeriveT> Deref for BasicAgent<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[async_trait]
impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
type Output = <T as AgentDeriveT>::Output;
fn description(&self) -> &str {
self.inner.description()
}
fn output_schema(&self) -> Option<Value> {
self.inner.output_schema()
}
fn name(&self) -> &str {
self.inner.name()
}
fn tools(&self) -> Vec<Box<dyn ToolT>> {
self.inner.tools()
}
}
#[async_trait]
impl<T> AgentHooks for BasicAgent<T>
where
T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
{
async fn on_agent_create(&self) {
self.inner.on_agent_create().await
}
async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
self.inner.on_run_start(task, ctx).await
}
async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
self.inner.on_run_complete(task, result, ctx).await
}
async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
self.inner.on_turn_start(turn_index, ctx).await
}
async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
self.inner.on_turn_complete(turn_index, ctx).await
}
async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
self.inner.on_tool_call(tool_call, ctx).await
}
async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
self.inner.on_tool_start(tool_call, ctx).await
}
async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
self.inner.on_tool_result(tool_call, result, ctx).await
}
async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
self.inner.on_tool_error(tool_call, err, ctx).await
}
async fn on_agent_shutdown(&self) {
self.inner.on_agent_shutdown().await
}
}
#[async_trait]
impl<T: AgentDeriveT + AgentHooks> AgentExecutor for BasicAgent<T> {
type Output = BasicAgentOutput;
type Error = BasicExecutorError;
fn config(&self) -> ExecutorConfig {
ExecutorConfig { max_turns: 1 }
}
async fn execute(
&self,
task: &Task,
context: Arc<Context>,
) -> Result<Self::Output, Self::Error> {
record_task_state(&context, task);
let tx_event = context.tx().ok();
EventHelper::send_task_started(
&tx_event,
task.submission_id,
context.config().id,
context.config().name.clone(),
task.prompt.clone(),
)
.await;
let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
let mut turn_state = engine.turn_state(&context);
let turn_result = engine
.run_turn(
self,
task,
&context,
&mut turn_state,
0,
self.config().max_turns,
)
.await?;
let output = extract_turn_output(turn_result);
Ok(BasicAgentOutput {
response: output.response,
done: true,
})
}
async fn execute_stream(
&self,
task: &Task,
context: Arc<Context>,
) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
{
record_task_state(&context, task);
let tx_event = context.tx().ok();
EventHelper::send_task_started(
&tx_event,
task.submission_id,
context.config().id,
context.config().name.clone(),
task.prompt.clone(),
)
.await;
let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
let mut turn_state = engine.turn_state(&context);
let context_clone = context.clone();
let task = task.clone();
let executor = self.clone();
let (tx, rx) = channel::<Result<BasicAgentOutput, BasicExecutorError>>(100);
spawn_future(async move {
let turn_stream = engine
.run_turn_stream(
executor,
&task,
context_clone.clone(),
&mut turn_state,
0,
1,
)
.await;
let mut final_response = String::default();
match turn_stream {
Ok(mut stream) => {
use futures::StreamExt;
while let Some(delta_result) = stream.next().await {
match delta_result {
Ok(TurnDelta::Text(content)) => {
let _ = tx
.send(Ok(BasicAgentOutput {
response: content,
done: false,
}))
.await;
}
Ok(TurnDelta::ReasoningContent(_)) => {}
Ok(TurnDelta::ToolResults(_)) => {}
Ok(TurnDelta::Done(result)) => {
let output = extract_turn_output(result);
final_response = output.response.clone();
let _ = tx
.send(Ok(BasicAgentOutput {
response: output.response,
done: true,
}))
.await;
break;
}
Err(err) => {
let _ = tx.send(Err(err.into())).await;
return;
}
}
}
}
Err(err) => {
let _ = tx.send(Err(err.into())).await;
return;
}
}
let tx_event = context_clone.tx().ok();
EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
let output = BasicAgentOutput {
response: final_response,
done: true,
};
let result =
serde_json::to_string_pretty(&output).unwrap_or_else(|_| output.response.clone());
EventHelper::send_task_completed(
&tx_event,
task.submission_id,
context_clone.config().id,
context_clone.config().name.clone(),
result,
)
.await;
});
Ok(receiver_into_stream(rx))
}
}
fn extract_turn_output(
result: crate::agent::executor::TurnResult<TurnEngineOutput>,
) -> TurnEngineOutput {
match result {
crate::agent::executor::TurnResult::Complete(output) => output,
crate::agent::executor::TurnResult::Continue(Some(output)) => output,
crate::agent::executor::TurnResult::Continue(None) => TurnEngineOutput {
response: String::default(),
reasoning_content: String::default(),
tool_calls: Vec::default(),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::AgentDeriveT;
use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, MockLLMProvider};
use async_trait::async_trait;
use autoagents_llm::chat::{StreamChoice, StreamDelta, StreamResponse};
use std::sync::Arc;
#[derive(Debug, Clone)]
struct AbortAgent;
#[async_trait]
impl AgentDeriveT for AbortAgent {
type Output = String;
fn description(&self) -> &str {
"abort"
}
fn output_schema(&self) -> Option<Value> {
None
}
fn name(&self) -> &str {
"abort_agent"
}
fn tools(&self) -> Vec<Box<dyn ToolT>> {
vec![]
}
}
#[async_trait]
impl AgentHooks for AbortAgent {
async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
HookOutcome::Abort
}
}
#[tokio::test]
async fn test_basic_agent_execute() {
use crate::agent::task::Task;
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
let basic_agent = BasicAgent::new(mock_agent);
let llm = Arc::new(MockLLMProvider {});
let config = AgentConfig {
id: ActorID::new_v4(),
name: "test_agent".to_string(),
description: "Test agent description".to_string(),
output_schema: None,
};
let context = Context::new(llm, None).with_config(config);
let context_arc = Arc::new(context);
let task = Task::new("Test task");
let result = basic_agent.execute(&task, context_arc).await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.response, "Mock response");
assert!(output.done);
}
#[test]
fn test_basic_agent_metadata_and_output_conversion() {
let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
let basic_agent = BasicAgent::new(mock_agent);
let config = basic_agent.config();
assert_eq!(config.max_turns, 1);
let cloned = basic_agent.clone();
assert_eq!(cloned.name(), "test_agent");
assert_eq!(cloned.description(), "Test agent description");
let output = BasicAgentOutput {
response: "Test response".to_string(),
done: true,
};
let value: Value = output.clone().into();
assert_eq!(value["response"], "Test response");
let string: String = output.into();
assert_eq!(string, "Test response");
}
#[test]
fn test_basic_agent_output_try_parse_success() {
let output = BasicAgentOutput {
response: r#"{"name":"test","value":42}"#.to_string(),
done: true,
};
#[derive(serde::Deserialize, PartialEq, Debug)]
struct Data {
name: String,
value: i32,
}
let parsed: Data = output.try_parse().unwrap();
assert_eq!(
parsed,
Data {
name: "test".to_string(),
value: 42
}
);
}
#[test]
fn test_basic_agent_output_try_parse_failure() {
let output = BasicAgentOutput {
response: "not json".to_string(),
done: true,
};
let result = output.try_parse::<serde_json::Value>();
assert!(result.is_err());
}
#[test]
fn test_basic_agent_output_parse_or_map_fallback() {
let output = BasicAgentOutput {
response: "plain text".to_string(),
done: true,
};
let result: String = output.parse_or_map(|s| s.to_uppercase());
assert_eq!(result, "PLAIN TEXT");
}
#[test]
fn test_basic_agent_output_parse_or_map_success() {
let output = BasicAgentOutput {
response: r#""hello""#.to_string(),
done: true,
};
let result: String = output.parse_or_map(|s| s.to_uppercase());
assert_eq!(result, "hello");
}
#[test]
fn test_error_from_turn_engine_llm() {
let err: BasicExecutorError =
TurnEngineError::LLMError(LLMError::Generic("bad".to_string())).into();
assert!(matches!(err, BasicExecutorError::LLMError(_)));
assert!(err.to_string().contains("bad"));
}
#[test]
fn test_error_from_turn_engine_aborted() {
let err: BasicExecutorError = TurnEngineError::Aborted.into();
assert!(matches!(err, BasicExecutorError::Other(_)));
assert!(err.to_string().contains("aborted"));
}
#[test]
fn test_error_from_turn_engine_other() {
let err: BasicExecutorError = TurnEngineError::Other("misc".to_string()).into();
assert!(matches!(err, BasicExecutorError::Other(_)));
assert!(err.to_string().contains("misc"));
}
#[test]
fn test_extract_turn_output_complete() {
let result = crate::agent::executor::TurnResult::Complete(
crate::agent::executor::turn_engine::TurnEngineOutput {
response: "done".to_string(),
reasoning_content: String::default(),
tool_calls: Vec::new(),
},
);
let output = extract_turn_output(result);
assert_eq!(output.response, "done");
}
#[test]
fn test_extract_turn_output_continue_some() {
let result = crate::agent::executor::TurnResult::Continue(Some(
crate::agent::executor::turn_engine::TurnEngineOutput {
response: "partial".to_string(),
reasoning_content: String::default(),
tool_calls: Vec::new(),
},
));
let output = extract_turn_output(result);
assert_eq!(output.response, "partial");
}
#[test]
fn test_extract_turn_output_continue_none() {
let result = crate::agent::executor::TurnResult::Continue(None);
let output = extract_turn_output(result);
assert!(output.response.is_empty());
assert!(output.tool_calls.is_empty());
}
#[tokio::test]
async fn test_basic_agent_execute_stream_returns_output() {
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
use futures::StreamExt;
let llm = Arc::new(ConfigurableLLMProvider {
structured_stream: vec![
StreamResponse {
choices: vec![StreamChoice {
delta: StreamDelta {
content: Some("Hello ".to_string()),
reasoning_content: None,
tool_calls: None,
},
}],
usage: None,
},
StreamResponse {
choices: vec![StreamChoice {
delta: StreamDelta {
content: Some("world".to_string()),
reasoning_content: None,
tool_calls: None,
},
}],
usage: None,
},
],
..ConfigurableLLMProvider::default()
});
let mock_agent = MockAgentImpl::new("stream_agent", "desc");
let basic_agent = BasicAgent::new(mock_agent);
let config = AgentConfig {
id: ActorID::new_v4(),
name: "stream_agent".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let context = Arc::new(Context::new(llm, None).with_config(config));
let task = Task::new("Test task");
let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
let mut final_output = None;
while let Some(item) = stream.next().await {
let output = item.unwrap();
if output.done {
final_output = Some(output);
break;
}
}
let output = final_output.expect("final output");
assert_eq!(output.response, "Hello world");
assert!(output.done);
}
#[tokio::test]
async fn test_basic_agent_execute_stream_ignores_reasoning_output() {
use crate::agent::{AgentConfig, Context};
use autoagents_protocol::ActorID;
use futures::StreamExt;
let llm = Arc::new(ConfigurableLLMProvider {
structured_stream: vec![
StreamResponse {
choices: vec![StreamChoice {
delta: StreamDelta {
content: None,
reasoning_content: Some("plan".to_string()),
tool_calls: None,
},
}],
usage: None,
},
StreamResponse {
choices: vec![StreamChoice {
delta: StreamDelta {
content: Some("done".to_string()),
reasoning_content: None,
tool_calls: None,
},
}],
usage: None,
},
],
..ConfigurableLLMProvider::default()
});
let mock_agent = MockAgentImpl::new("stream_agent_reasoning", "desc");
let basic_agent = BasicAgent::new(mock_agent);
let config = AgentConfig {
id: ActorID::new_v4(),
name: "stream_agent_reasoning".to_string(),
description: "desc".to_string(),
output_schema: None,
};
let context = Arc::new(Context::new(llm, None).with_config(config));
let task = Task::new("Test task");
let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
let mut outputs = Vec::new();
while let Some(item) = stream.next().await {
outputs.push(item.unwrap());
}
assert_eq!(outputs.len(), 2);
assert_eq!(outputs[0].response, "done");
assert!(!outputs[0].done);
assert_eq!(outputs[1].response, "done");
assert!(outputs[1].done);
}
#[tokio::test]
async fn test_basic_agent_run_aborts_on_hook() {
use crate::agent::AgentBuilder;
use crate::agent::direct::DirectAgent;
use crate::agent::error::RunnableAgentError;
let agent = BasicAgent::new(AbortAgent);
let llm = Arc::new(MockLLMProvider {});
let handle = AgentBuilder::<_, DirectAgent>::new(agent)
.llm(llm)
.build()
.await
.expect("build should succeed");
let task = Task::new("abort");
let err = handle.agent.run(task).await.expect_err("expected abort");
assert!(matches!(err, RunnableAgentError::Abort));
}
#[tokio::test]
async fn test_basic_agent_run_stream_aborts_on_hook() {
use crate::agent::AgentBuilder;
use crate::agent::direct::DirectAgent;
use crate::agent::error::RunnableAgentError;
let agent = BasicAgent::new(AbortAgent);
let llm = Arc::new(MockLLMProvider {});
let handle = AgentBuilder::<_, DirectAgent>::new(agent)
.llm(llm)
.build()
.await
.expect("build should succeed");
let task = Task::new("abort");
let err = match handle.agent.run_stream(task).await {
Ok(_) => panic!("expected abort"),
Err(err) => err,
};
assert!(matches!(err, RunnableAgentError::Abort));
}
}