1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use autoagents_protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{Receiver, Sender, channel};
12
13#[cfg(not(target_arch = "wasm32"))]
14use crate::event_fanout::EventFanout;
15use crate::utils::{BoxEventStream, receiver_into_stream};
16#[cfg(not(target_arch = "wasm32"))]
17use futures_util::stream;
18
19pub struct DirectAgent {}
25
26impl AgentType for DirectAgent {
27 fn type_name() -> &'static str {
28 "direct_agent"
29 }
30}
31
32pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
36 pub agent: BaseAgent<T, DirectAgent>,
37 pub rx: BoxEventStream<Event>,
38 #[cfg(not(target_arch = "wasm32"))]
39 fanout: Option<EventFanout>,
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
43 pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
44 Self {
45 agent,
46 rx,
47 #[cfg(not(target_arch = "wasm32"))]
48 fanout: None,
49 }
50 }
51
52 #[cfg(not(target_arch = "wasm32"))]
53 pub fn subscribe_events(&mut self) -> BoxEventStream<Event> {
54 if let Some(fanout) = &self.fanout {
55 return fanout.subscribe();
56 }
57
58 let stream = std::mem::replace(&mut self.rx, Box::pin(stream::empty::<Event>()));
59 let fanout = EventFanout::new(stream, DEFAULT_CHANNEL_BUFFER);
60 self.rx = fanout.subscribe();
61 let stream = fanout.subscribe();
62 self.fanout = Some(fanout);
63 stream
64 }
65}
66
67impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
68 #[allow(clippy::result_large_err)]
70 pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
71 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
72 "LLM provider is required".to_string(),
73 ))?;
74 let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
75 let agent: BaseAgent<T, DirectAgent> =
76 BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
77 let stream = receiver_into_stream(rx);
78 Ok(DirectAgentHandle::new(agent, stream))
79 }
80}
81
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
83 pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
85 where
86 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87 <T as AgentExecutor>::Error: Into<RunnableAgentError>,
88 {
89 let context = self.create_context();
90
91 let hook_outcome = self.inner.on_run_start(&task, &context).await;
93 match hook_outcome {
94 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
95 HookOutcome::Continue => {}
96 }
97
98 match self.inner().execute(&task, context.clone()).await {
100 Ok(output) => {
101 let output: <T as AgentExecutor>::Output = output;
102
103 let agent_out: <T as AgentDeriveT>::Output = output.into();
105
106 self.inner
108 .on_run_complete(&task, &agent_out, &context)
109 .await;
110 Ok(agent_out)
111 }
112 Err(e) => {
113 Err(e.into())
115 }
116 }
117 }
118
119 pub async fn run_stream(
122 &self,
123 task: Task,
124 ) -> Result<
125 std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
126 RunnableAgentError,
127 >
128 where
129 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
130 <T as AgentExecutor>::Error: Into<RunnableAgentError>,
131 {
132 let context = self.create_context();
133
134 let hook_outcome = self.inner.on_run_start(&task, &context).await;
136 match hook_outcome {
137 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
138 HookOutcome::Continue => {}
139 }
140
141 match self.inner().execute_stream(&task, context.clone()).await {
143 Ok(stream) => {
144 use futures::TryStreamExt;
145 let transformed_stream = stream
147 .map_ok(Into::into)
148 .map_err(Into::<RunnableAgentError>::into)
149 .map_err(Error::from);
150
151 Ok(Box::pin(transformed_stream))
152 }
153 Err(e) => {
154 Err(e.into())
156 }
157 }
158 }
159}
160
161#[cfg(test)]
162mod tests {
163 use super::*;
164 use crate::agent::hooks::HookOutcome;
165 use crate::agent::output::AgentOutputT;
166 use crate::agent::prebuilt::executor::{
167 BasicAgent as StableBasicAgent, BasicAgentOutput, ReActAgent as StableReActAgent,
168 ReActAgentOutput,
169 };
170 use crate::agent::task::Task;
171 use crate::agent::{Context, ExecutorConfig};
172 use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, TestAgentOutput, TestError};
173 use crate::tool::ToolT;
174 use async_trait::async_trait;
175 use futures::StreamExt;
176 use serde::{Deserialize, Serialize};
177 use serde_json::Value;
178 use std::sync::{
179 Arc,
180 atomic::{AtomicBool, AtomicUsize, Ordering},
181 };
182
183 #[tokio::test]
184 async fn test_direct_agent_build_requires_llm() {
185 let mock_agent = MockAgentImpl::new("direct", "direct agent");
186 let err = match AgentBuilder::<_, DirectAgent>::new(mock_agent)
187 .build()
188 .await
189 {
190 Ok(_) => panic!("expected missing llm error"),
191 Err(err) => err,
192 };
193
194 assert!(matches!(err, crate::error::Error::AgentBuildError(_)));
195 }
196
197 #[tokio::test]
198 async fn test_direct_agent_run_success() {
199 let mock_agent = MockAgentImpl::new("direct", "direct agent");
200 let llm = Arc::new(ConfigurableLLMProvider::default());
201 let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
202 .llm(llm)
203 .build()
204 .await
205 .expect("build should succeed");
206
207 let task = Task::new("hello");
208 let result = handle.agent.run(task).await.expect("run should succeed");
209 assert_eq!(result.result, "Processed: hello");
210 }
211
212 #[tokio::test]
213 async fn test_direct_agent_run_executor_error() {
214 let mock_agent = MockAgentImpl::new("direct", "direct agent").with_failure(true);
215 let llm = Arc::new(ConfigurableLLMProvider::default());
216 let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
217 .llm(llm)
218 .build()
219 .await
220 .expect("build should succeed");
221
222 let task = Task::new("fail");
223 let err = handle.agent.run(task).await.expect_err("expected error");
224 assert!(matches!(err, RunnableAgentError::ExecutorError(_)));
225 }
226
227 #[derive(Debug, Clone, Serialize, Deserialize)]
228 struct HookCountOutput {
229 result: String,
230 }
231
232 impl AgentOutputT for HookCountOutput {
233 fn output_schema() -> &'static str {
234 r#"{"type":"object","properties":{"result":{"type":"string"}},"required":["result"]}"#
235 }
236
237 fn structured_output_format() -> Value {
238 serde_json::json!({
239 "name": "HookCountOutput",
240 "description": "Hook count output",
241 "schema": {
242 "type": "object",
243 "properties": {
244 "result": {"type": "string"}
245 },
246 "required": ["result"]
247 },
248 "strict": true
249 })
250 }
251 }
252
253 impl From<BasicAgentOutput> for HookCountOutput {
254 fn from(output: BasicAgentOutput) -> Self {
255 Self {
256 result: output.response,
257 }
258 }
259 }
260
261 impl From<ReActAgentOutput> for HookCountOutput {
262 fn from(output: ReActAgentOutput) -> Self {
263 Self {
264 result: output.response,
265 }
266 }
267 }
268
269 #[derive(Debug, Clone)]
270 struct CountingHookAgent {
271 on_run_start_calls: Arc<AtomicUsize>,
272 }
273
274 #[async_trait]
275 impl AgentDeriveT for CountingHookAgent {
276 type Output = HookCountOutput;
277
278 fn description(&self) -> &'static str {
279 "counting hook agent"
280 }
281
282 fn output_schema(&self) -> Option<Value> {
283 Some(serde_json::json!({
284 "type": "object",
285 "properties": {"result": {"type": "string"}},
286 "required": ["result"]
287 }))
288 }
289
290 fn name(&self) -> &'static str {
291 "counting_hook_agent"
292 }
293
294 fn tools(&self) -> Vec<Box<dyn ToolT>> {
295 vec![]
296 }
297 }
298
299 #[async_trait]
300 impl AgentHooks for CountingHookAgent {
301 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
302 self.on_run_start_calls.fetch_add(1, Ordering::SeqCst);
303 HookOutcome::Continue
304 }
305 }
306
307 #[tokio::test]
308 async fn test_direct_basic_agent_run_calls_on_run_start_once() {
309 let calls = Arc::new(AtomicUsize::new(0));
310 let llm = Arc::new(ConfigurableLLMProvider::default());
311 let handle =
312 AgentBuilder::<_, DirectAgent>::new(StableBasicAgent::new(CountingHookAgent {
313 on_run_start_calls: Arc::clone(&calls),
314 }))
315 .llm(llm)
316 .build()
317 .await
318 .expect("build should succeed");
319
320 let task = Task::new("hello");
321 let result = handle.agent.run(task).await.expect("run should succeed");
322
323 assert_eq!(result.result, "Mock response");
324 assert_eq!(calls.load(Ordering::SeqCst), 1);
325 }
326
327 #[tokio::test]
328 async fn test_direct_react_agent_run_calls_on_run_start_once() {
329 let calls = Arc::new(AtomicUsize::new(0));
330 let llm = Arc::new(ConfigurableLLMProvider::default());
331 let handle =
332 AgentBuilder::<_, DirectAgent>::new(StableReActAgent::new(CountingHookAgent {
333 on_run_start_calls: Arc::clone(&calls),
334 }))
335 .llm(llm)
336 .build()
337 .await
338 .expect("build should succeed");
339
340 let task = Task::new("hello");
341 let result = handle.agent.run(task).await.expect("run should succeed");
342
343 assert_eq!(result.result, "Mock response");
344 assert_eq!(calls.load(Ordering::SeqCst), 1);
345 }
346
347 #[derive(Clone, Debug)]
348 struct StreamAgent;
349
350 #[async_trait]
351 impl AgentDeriveT for StreamAgent {
352 type Output = TestAgentOutput;
353
354 fn description(&self) -> &'static str {
355 "stream agent"
356 }
357
358 fn output_schema(&self) -> Option<Value> {
359 Some(TestAgentOutput::structured_output_format())
360 }
361
362 fn name(&self) -> &'static str {
363 "stream_agent"
364 }
365
366 fn tools(&self) -> Vec<Box<dyn ToolT>> {
367 vec![]
368 }
369 }
370
371 #[async_trait]
372 impl AgentExecutor for StreamAgent {
373 type Output = TestAgentOutput;
374 type Error = TestError;
375
376 fn config(&self) -> ExecutorConfig {
377 ExecutorConfig::default()
378 }
379
380 async fn execute(
381 &self,
382 task: &Task,
383 _context: Arc<Context>,
384 ) -> Result<Self::Output, Self::Error> {
385 Ok(TestAgentOutput {
386 result: format!("Streamed: {}", task.prompt),
387 })
388 }
389 }
390
391 impl AgentHooks for StreamAgent {}
392
393 #[tokio::test]
394 async fn test_direct_agent_run_stream_default_executes_once() {
395 let llm = Arc::new(ConfigurableLLMProvider::default());
396 let handle = AgentBuilder::<_, DirectAgent>::new(StreamAgent)
397 .llm(llm)
398 .build()
399 .await
400 .expect("build should succeed");
401
402 let task = Task::new("stream");
403 let stream = handle
404 .agent
405 .run_stream(task)
406 .await
407 .expect("stream should succeed");
408 let outputs: Vec<_> = stream.collect().await;
409 assert_eq!(outputs.len(), 1);
410 let output = outputs[0].as_ref().expect("expected Ok output");
411 assert_eq!(output.result, "Streamed: stream");
412 }
413
414 #[derive(Debug)]
415 struct AbortAgent {
416 executed: Arc<AtomicBool>,
417 }
418
419 #[async_trait]
420 impl AgentDeriveT for AbortAgent {
421 type Output = TestAgentOutput;
422
423 fn description(&self) -> &'static str {
424 "abort agent"
425 }
426
427 fn output_schema(&self) -> Option<Value> {
428 Some(TestAgentOutput::structured_output_format())
429 }
430
431 fn name(&self) -> &'static str {
432 "abort_agent"
433 }
434
435 fn tools(&self) -> Vec<Box<dyn ToolT>> {
436 vec![]
437 }
438 }
439
440 #[async_trait]
441 impl AgentExecutor for AbortAgent {
442 type Output = TestAgentOutput;
443 type Error = TestError;
444
445 fn config(&self) -> ExecutorConfig {
446 ExecutorConfig::default()
447 }
448
449 async fn execute(
450 &self,
451 _task: &Task,
452 _context: Arc<Context>,
453 ) -> Result<Self::Output, Self::Error> {
454 self.executed.store(true, Ordering::SeqCst);
455 Ok(TestAgentOutput {
456 result: "should-not-run".to_string(),
457 })
458 }
459 }
460
461 #[async_trait]
462 impl AgentHooks for AbortAgent {
463 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
464 HookOutcome::Abort
465 }
466 }
467
468 #[tokio::test]
469 async fn test_direct_agent_run_aborts_before_execute() {
470 let executed = Arc::new(AtomicBool::new(false));
471 let agent = AbortAgent {
472 executed: Arc::clone(&executed),
473 };
474 let llm = Arc::new(ConfigurableLLMProvider::default());
475 let handle = AgentBuilder::<_, DirectAgent>::new(agent)
476 .llm(llm)
477 .build()
478 .await
479 .expect("build should succeed");
480
481 let task = Task::new("abort");
482 let err = handle.agent.run(task).await.expect_err("expected abort");
483 assert!(matches!(err, RunnableAgentError::Abort));
484 assert!(!executed.load(Ordering::SeqCst));
485 }
486}