autoagents_core/agent/
direct.rs1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use crate::protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{channel, Receiver, Sender};
12
13use crate::utils::{receiver_into_stream, BoxEventStream};
14
15pub struct DirectAgent {}
21
22impl AgentType for DirectAgent {
23 fn type_name() -> &'static str {
24 "direct_agent"
25 }
26}
27
28pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
32 pub agent: BaseAgent<T, DirectAgent>,
33 pub rx: BoxEventStream<Event>,
34}
35
36impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
37 pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
38 Self { agent, rx }
39 }
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
43 #[allow(clippy::result_large_err)]
45 pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
46 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
47 "LLM provider is required".to_string(),
48 ))?;
49 let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
50 let agent: BaseAgent<T, DirectAgent> =
51 BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
52 let stream = receiver_into_stream(rx);
53 Ok(DirectAgentHandle::new(agent, stream))
54 }
55}
56
57impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
58 pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
60 where
61 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
62 {
63 let context = self.create_context();
64
65 let hook_outcome = self.inner.on_run_start(&task, &context).await;
67 match hook_outcome {
68 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
69 HookOutcome::Continue => {}
70 }
71
72 match self.inner().execute(&task, context.clone()).await {
74 Ok(output) => {
75 let output: <T as AgentExecutor>::Output = output;
76
77 let agent_out: <T as AgentDeriveT>::Output = output.into();
79
80 self.inner
82 .on_run_complete(&task, &agent_out, &context)
83 .await;
84 Ok(agent_out)
85 }
86 Err(e) => {
87 Err(RunnableAgentError::ExecutorError(e.to_string()))
89 }
90 }
91 }
92
93 pub async fn run_stream(
96 &self,
97 task: Task,
98 ) -> Result<
99 std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
100 RunnableAgentError,
101 >
102 where
103 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
104 {
105 let context = self.create_context();
106
107 let hook_outcome = self.inner.on_run_start(&task, &context).await;
109 match hook_outcome {
110 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
111 HookOutcome::Continue => {}
112 }
113
114 match self.inner().execute_stream(&task, context.clone()).await {
116 Ok(stream) => {
117 use futures::StreamExt;
118 let transformed_stream = stream.map(move |result| match result {
120 Ok(output) => Ok(output.into()),
121 Err(e) => {
122 let error_msg = e.to_string();
123 Err(RunnableAgentError::ExecutorError(error_msg).into())
124 }
125 });
126
127 Ok(Box::pin(transformed_stream))
128 }
129 Err(e) => {
130 Err(RunnableAgentError::ExecutorError(e.to_string()))
132 }
133 }
134 }
135}