autoagents_core/agent/
direct.rs1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use autoagents_protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{Receiver, Sender, channel};
12
13#[cfg(not(target_arch = "wasm32"))]
14use crate::event_fanout::EventFanout;
15use crate::utils::{BoxEventStream, receiver_into_stream};
16#[cfg(not(target_arch = "wasm32"))]
17use futures_util::stream;
18
19pub struct DirectAgent {}
25
26impl AgentType for DirectAgent {
27 fn type_name() -> &'static str {
28 "direct_agent"
29 }
30}
31
32pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
36 pub agent: BaseAgent<T, DirectAgent>,
37 pub rx: BoxEventStream<Event>,
38 #[cfg(not(target_arch = "wasm32"))]
39 fanout: Option<EventFanout>,
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
43 pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
44 Self {
45 agent,
46 rx,
47 #[cfg(not(target_arch = "wasm32"))]
48 fanout: None,
49 }
50 }
51
52 #[cfg(not(target_arch = "wasm32"))]
53 pub fn subscribe_events(&mut self) -> BoxEventStream<Event> {
54 if let Some(fanout) = &self.fanout {
55 return fanout.subscribe();
56 }
57
58 let stream = std::mem::replace(&mut self.rx, Box::pin(stream::empty::<Event>()));
59 let fanout = EventFanout::new(stream, DEFAULT_CHANNEL_BUFFER);
60 self.rx = fanout.subscribe();
61 let stream = fanout.subscribe();
62 self.fanout = Some(fanout);
63 stream
64 }
65}
66
67impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
68 #[allow(clippy::result_large_err)]
70 pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
71 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
72 "LLM provider is required".to_string(),
73 ))?;
74 let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
75 let agent: BaseAgent<T, DirectAgent> =
76 BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
77 let stream = receiver_into_stream(rx);
78 Ok(DirectAgentHandle::new(agent, stream))
79 }
80}
81
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
83 pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
85 where
86 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87 {
88 let context = self.create_context();
89
90 let hook_outcome = self.inner.on_run_start(&task, &context).await;
92 match hook_outcome {
93 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
94 HookOutcome::Continue => {}
95 }
96
97 match self.inner().execute(&task, context.clone()).await {
99 Ok(output) => {
100 let output: <T as AgentExecutor>::Output = output;
101
102 let agent_out: <T as AgentDeriveT>::Output = output.into();
104
105 self.inner
107 .on_run_complete(&task, &agent_out, &context)
108 .await;
109 Ok(agent_out)
110 }
111 Err(e) => {
112 Err(RunnableAgentError::ExecutorError(e.to_string()))
114 }
115 }
116 }
117
118 pub async fn run_stream(
121 &self,
122 task: Task,
123 ) -> Result<
124 std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
125 RunnableAgentError,
126 >
127 where
128 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
129 {
130 let context = self.create_context();
131
132 let hook_outcome = self.inner.on_run_start(&task, &context).await;
134 match hook_outcome {
135 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
136 HookOutcome::Continue => {}
137 }
138
139 match self.inner().execute_stream(&task, context.clone()).await {
141 Ok(stream) => {
142 use futures::StreamExt;
143 let transformed_stream = stream.map(move |result| match result {
145 Ok(output) => Ok(output.into()),
146 Err(e) => {
147 let error_msg = e.to_string();
148 Err(RunnableAgentError::ExecutorError(error_msg).into())
149 }
150 });
151
152 Ok(Box::pin(transformed_stream))
153 }
154 Err(e) => {
155 Err(RunnableAgentError::ExecutorError(e.to_string()))
157 }
158 }
159 }
160}