autoagents_core/agent/
direct.rs1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use crate::protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{channel, Receiver, Sender};
12
13use crate::utils::{receiver_into_stream, BoxEventStream};
14
15pub struct DirectAgent {}
16
17impl AgentType for DirectAgent {
18 fn type_name() -> &'static str {
19 "direct_agent"
20 }
21}
22
23pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks> {
25 pub agent: BaseAgent<T, DirectAgent>,
26 pub rx: BoxEventStream<Event>,
27}
28
29impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
30 pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
31 Self { agent, rx }
32 }
33}
34
35impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
36 #[allow(clippy::result_large_err)]
38 pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
39 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
40 "LLM provider is required".to_string(),
41 ))?;
42 let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
43 let agent: BaseAgent<T, DirectAgent> =
44 BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
45 let stream = receiver_into_stream(rx);
46 Ok(DirectAgentHandle::new(agent, stream))
47 }
48}
49
50impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
51 pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
52 where
53 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
54 {
55 let context = self.create_context();
56
57 let hook_outcome = self.inner.on_run_start(&task, &context).await;
59 match hook_outcome {
60 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
61 HookOutcome::Continue => {}
62 }
63
64 match self.inner().execute(&task, context.clone()).await {
66 Ok(output) => {
67 let output: <T as AgentExecutor>::Output = output;
68
69 let agent_out: <T as AgentDeriveT>::Output = output.into();
71
72 self.inner
74 .on_run_complete(&task, &agent_out, &context)
75 .await;
76 Ok(agent_out)
77 }
78 Err(e) => {
79 Err(RunnableAgentError::ExecutorError(e.to_string()))
81 }
82 }
83 }
84
85 pub async fn run_stream(
86 &self,
87 task: Task,
88 ) -> Result<
89 std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
90 RunnableAgentError,
91 >
92 where
93 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
94 {
95 let context = self.create_context();
96
97 let hook_outcome = self.inner.on_run_start(&task, &context).await;
99 match hook_outcome {
100 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
101 HookOutcome::Continue => {}
102 }
103
104 match self.inner().execute_stream(&task, context.clone()).await {
106 Ok(stream) => {
107 use futures::StreamExt;
108 let transformed_stream = stream.map(move |result| match result {
110 Ok(output) => Ok(output.into()),
111 Err(e) => {
112 let error_msg = e.to_string();
113 Err(RunnableAgentError::ExecutorError(error_msg).into())
114 }
115 });
116
117 Ok(Box::pin(transformed_stream))
118 }
119 Err(e) => {
120 Err(RunnableAgentError::ExecutorError(e.to_string()))
122 }
123 }
124 }
125}