1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use autoagents_protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{Receiver, Sender, channel};
12
13#[cfg(not(target_arch = "wasm32"))]
14use crate::event_fanout::EventFanout;
15use crate::utils::{BoxEventStream, receiver_into_stream};
16#[cfg(not(target_arch = "wasm32"))]
17use futures_util::stream;
18
19pub struct DirectAgent {}
25
26impl AgentType for DirectAgent {
27 fn type_name() -> &'static str {
28 "direct_agent"
29 }
30}
31
32pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
36 pub agent: BaseAgent<T, DirectAgent>,
37 pub rx: BoxEventStream<Event>,
38 #[cfg(not(target_arch = "wasm32"))]
39 fanout: Option<EventFanout>,
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
43 pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
44 Self {
45 agent,
46 rx,
47 #[cfg(not(target_arch = "wasm32"))]
48 fanout: None,
49 }
50 }
51
52 #[cfg(not(target_arch = "wasm32"))]
53 pub fn subscribe_events(&mut self) -> BoxEventStream<Event> {
54 if let Some(fanout) = &self.fanout {
55 return fanout.subscribe();
56 }
57
58 let stream = std::mem::replace(&mut self.rx, Box::pin(stream::empty::<Event>()));
59 let fanout = EventFanout::new(stream, DEFAULT_CHANNEL_BUFFER);
60 self.rx = fanout.subscribe();
61 let stream = fanout.subscribe();
62 self.fanout = Some(fanout);
63 stream
64 }
65}
66
67impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
68 #[allow(clippy::result_large_err)]
70 pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
71 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
72 "LLM provider is required".to_string(),
73 ))?;
74 let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
75 let agent: BaseAgent<T, DirectAgent> =
76 BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
77 let stream = receiver_into_stream(rx);
78 Ok(DirectAgentHandle::new(agent, stream))
79 }
80}
81
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
83 pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
85 where
86 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87 {
88 let context = self.create_context();
89
90 let hook_outcome = self.inner.on_run_start(&task, &context).await;
92 match hook_outcome {
93 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
94 HookOutcome::Continue => {}
95 }
96
97 match self.inner().execute(&task, context.clone()).await {
99 Ok(output) => {
100 let output: <T as AgentExecutor>::Output = output;
101
102 let agent_out: <T as AgentDeriveT>::Output = output.into();
104
105 self.inner
107 .on_run_complete(&task, &agent_out, &context)
108 .await;
109 Ok(agent_out)
110 }
111 Err(e) => {
112 Err(RunnableAgentError::ExecutorError(e.to_string()))
114 }
115 }
116 }
117
118 pub async fn run_stream(
121 &self,
122 task: Task,
123 ) -> Result<
124 std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
125 RunnableAgentError,
126 >
127 where
128 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
129 {
130 let context = self.create_context();
131
132 let hook_outcome = self.inner.on_run_start(&task, &context).await;
134 match hook_outcome {
135 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
136 HookOutcome::Continue => {}
137 }
138
139 match self.inner().execute_stream(&task, context.clone()).await {
141 Ok(stream) => {
142 use futures::StreamExt;
143 let transformed_stream = stream.map(move |result| match result {
145 Ok(output) => Ok(output.into()),
146 Err(e) => {
147 let error_msg = e.to_string();
148 Err(RunnableAgentError::ExecutorError(error_msg).into())
149 }
150 });
151
152 Ok(Box::pin(transformed_stream))
153 }
154 Err(e) => {
155 Err(RunnableAgentError::ExecutorError(e.to_string()))
157 }
158 }
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::agent::hooks::HookOutcome;
166 use crate::agent::output::AgentOutputT;
167 use crate::agent::task::Task;
168 use crate::agent::{Context, ExecutorConfig};
169 use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, TestAgentOutput, TestError};
170 use crate::tool::ToolT;
171 use async_trait::async_trait;
172 use futures::StreamExt;
173 use serde_json::Value;
174 use std::sync::{
175 Arc,
176 atomic::{AtomicBool, Ordering},
177 };
178
179 #[tokio::test]
180 async fn test_direct_agent_build_requires_llm() {
181 let mock_agent = MockAgentImpl::new("direct", "direct agent");
182 let err = match AgentBuilder::<_, DirectAgent>::new(mock_agent)
183 .build()
184 .await
185 {
186 Ok(_) => panic!("expected missing llm error"),
187 Err(err) => err,
188 };
189
190 assert!(matches!(err, crate::error::Error::AgentBuildError(_)));
191 }
192
193 #[tokio::test]
194 async fn test_direct_agent_run_success() {
195 let mock_agent = MockAgentImpl::new("direct", "direct agent");
196 let llm = Arc::new(ConfigurableLLMProvider::default());
197 let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
198 .llm(llm)
199 .build()
200 .await
201 .expect("build should succeed");
202
203 let task = Task::new("hello");
204 let result = handle.agent.run(task).await.expect("run should succeed");
205 assert_eq!(result.result, "Processed: hello");
206 }
207
208 #[tokio::test]
209 async fn test_direct_agent_run_executor_error() {
210 let mock_agent = MockAgentImpl::new("direct", "direct agent").with_failure(true);
211 let llm = Arc::new(ConfigurableLLMProvider::default());
212 let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
213 .llm(llm)
214 .build()
215 .await
216 .expect("build should succeed");
217
218 let task = Task::new("fail");
219 let err = handle.agent.run(task).await.expect_err("expected error");
220 assert!(matches!(err, RunnableAgentError::ExecutorError(_)));
221 }
222
223 #[derive(Clone, Debug)]
224 struct StreamAgent;
225
226 #[async_trait]
227 impl AgentDeriveT for StreamAgent {
228 type Output = TestAgentOutput;
229
230 fn description(&self) -> &'static str {
231 "stream agent"
232 }
233
234 fn output_schema(&self) -> Option<Value> {
235 Some(TestAgentOutput::structured_output_format())
236 }
237
238 fn name(&self) -> &'static str {
239 "stream_agent"
240 }
241
242 fn tools(&self) -> Vec<Box<dyn ToolT>> {
243 vec![]
244 }
245 }
246
247 #[async_trait]
248 impl AgentExecutor for StreamAgent {
249 type Output = TestAgentOutput;
250 type Error = TestError;
251
252 fn config(&self) -> ExecutorConfig {
253 ExecutorConfig::default()
254 }
255
256 async fn execute(
257 &self,
258 task: &Task,
259 _context: Arc<Context>,
260 ) -> Result<Self::Output, Self::Error> {
261 Ok(TestAgentOutput {
262 result: format!("Streamed: {}", task.prompt),
263 })
264 }
265 }
266
267 impl AgentHooks for StreamAgent {}
268
269 #[tokio::test]
270 async fn test_direct_agent_run_stream_default_executes_once() {
271 let llm = Arc::new(ConfigurableLLMProvider::default());
272 let handle = AgentBuilder::<_, DirectAgent>::new(StreamAgent)
273 .llm(llm)
274 .build()
275 .await
276 .expect("build should succeed");
277
278 let task = Task::new("stream");
279 let stream = handle
280 .agent
281 .run_stream(task)
282 .await
283 .expect("stream should succeed");
284 let outputs: Vec<_> = stream.collect().await;
285 assert_eq!(outputs.len(), 1);
286 let output = outputs[0].as_ref().expect("expected Ok output");
287 assert_eq!(output.result, "Streamed: stream");
288 }
289
290 #[derive(Debug)]
291 struct AbortAgent {
292 executed: Arc<AtomicBool>,
293 }
294
295 #[async_trait]
296 impl AgentDeriveT for AbortAgent {
297 type Output = TestAgentOutput;
298
299 fn description(&self) -> &'static str {
300 "abort agent"
301 }
302
303 fn output_schema(&self) -> Option<Value> {
304 Some(TestAgentOutput::structured_output_format())
305 }
306
307 fn name(&self) -> &'static str {
308 "abort_agent"
309 }
310
311 fn tools(&self) -> Vec<Box<dyn ToolT>> {
312 vec![]
313 }
314 }
315
316 #[async_trait]
317 impl AgentExecutor for AbortAgent {
318 type Output = TestAgentOutput;
319 type Error = TestError;
320
321 fn config(&self) -> ExecutorConfig {
322 ExecutorConfig::default()
323 }
324
325 async fn execute(
326 &self,
327 _task: &Task,
328 _context: Arc<Context>,
329 ) -> Result<Self::Output, Self::Error> {
330 self.executed.store(true, Ordering::SeqCst);
331 Ok(TestAgentOutput {
332 result: "should-not-run".to_string(),
333 })
334 }
335 }
336
337 #[async_trait]
338 impl AgentHooks for AbortAgent {
339 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
340 HookOutcome::Abort
341 }
342 }
343
344 #[tokio::test]
345 async fn test_direct_agent_run_aborts_before_execute() {
346 let executed = Arc::new(AtomicBool::new(false));
347 let agent = AbortAgent {
348 executed: Arc::clone(&executed),
349 };
350 let llm = Arc::new(ConfigurableLLMProvider::default());
351 let handle = AgentBuilder::<_, DirectAgent>::new(agent)
352 .llm(llm)
353 .build()
354 .await
355 .expect("build should succeed");
356
357 let task = Task::new("abort");
358 let err = handle.agent.run(task).await.expect_err("expected abort");
359 assert!(matches!(err, RunnableAgentError::Abort));
360 assert!(!executed.load(Ordering::SeqCst));
361 }
362}