Skip to main content

autoagents_core/agent/
direct.rs

1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use autoagents_protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{Receiver, Sender, channel};
12
13#[cfg(not(target_arch = "wasm32"))]
14use crate::event_fanout::EventFanout;
15use crate::utils::{BoxEventStream, receiver_into_stream};
16#[cfg(not(target_arch = "wasm32"))]
17use futures_util::stream;
18
19/// Marker type for direct (non-actor) agents.
20///
21/// Direct agents execute immediately within the caller's task without
22/// requiring a runtime or event wiring. Use this for simple one-shot
23/// invocations and unit tests.
24pub struct DirectAgent {}
25
26impl AgentType for DirectAgent {
27    fn type_name() -> &'static str {
28        "direct_agent"
29    }
30}
31
32/// Handle for a direct agent containing the agent instance and an event stream
33/// receiver. Use `agent.run(...)` for one-shot calls or `agent.run_stream(...)`
34/// to receive streaming outputs.
35pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
36    pub agent: BaseAgent<T, DirectAgent>,
37    pub rx: BoxEventStream<Event>,
38    #[cfg(not(target_arch = "wasm32"))]
39    fanout: Option<EventFanout>,
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
43    pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
44        Self {
45            agent,
46            rx,
47            #[cfg(not(target_arch = "wasm32"))]
48            fanout: None,
49        }
50    }
51
52    #[cfg(not(target_arch = "wasm32"))]
53    pub fn subscribe_events(&mut self) -> BoxEventStream<Event> {
54        if let Some(fanout) = &self.fanout {
55            return fanout.subscribe();
56        }
57
58        let stream = std::mem::replace(&mut self.rx, Box::pin(stream::empty::<Event>()));
59        let fanout = EventFanout::new(stream, DEFAULT_CHANNEL_BUFFER);
60        self.rx = fanout.subscribe();
61        let stream = fanout.subscribe();
62        self.fanout = Some(fanout);
63        stream
64    }
65}
66
67impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
68    /// Build the BaseAgent and return a wrapper
69    #[allow(clippy::result_large_err)]
70    pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
71        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
72            "LLM provider is required".to_string(),
73        ))?;
74        let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
75        let agent: BaseAgent<T, DirectAgent> =
76            BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
77        let stream = receiver_into_stream(rx);
78        Ok(DirectAgentHandle::new(agent, stream))
79    }
80}
81
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
83    /// Execute the agent for a single task and return the final agent output.
84    pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
85    where
86        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87        <T as AgentExecutor>::Error: Into<RunnableAgentError>,
88    {
89        let context = self.create_context();
90
91        //Run Hook
92        let hook_outcome = self.inner.on_run_start(&task, &context).await;
93        match hook_outcome {
94            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
95            HookOutcome::Continue => {}
96        }
97
98        // Execute the agent's logic using the executor
99        match self.inner().execute(&task, context.clone()).await {
100            Ok(output) => {
101                let output: <T as AgentExecutor>::Output = output;
102
103                //Extract Agent output into the desired type
104                let agent_out: <T as AgentDeriveT>::Output = output.into();
105
106                //Run On complete Hook
107                self.inner
108                    .on_run_complete(&task, &agent_out, &context)
109                    .await;
110                Ok(agent_out)
111            }
112            Err(e) => {
113                // Send error event
114                Err(e.into())
115            }
116        }
117    }
118
119    /// Execute the agent with streaming enabled and receive a stream of
120    /// partial outputs which culminate in a final chunk with `done=true`.
121    pub async fn run_stream(
122        &self,
123        task: Task,
124    ) -> Result<
125        std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
126        RunnableAgentError,
127    >
128    where
129        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
130        <T as AgentExecutor>::Error: Into<RunnableAgentError>,
131    {
132        let context = self.create_context();
133
134        //Run Hook
135        let hook_outcome = self.inner.on_run_start(&task, &context).await;
136        match hook_outcome {
137            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
138            HookOutcome::Continue => {}
139        }
140
141        // Execute the agent's streaming logic using the executor
142        match self.inner().execute_stream(&task, context.clone()).await {
143            Ok(stream) => {
144                use futures::TryStreamExt;
145                // Convert stream output/error without returning large Result err types from closures.
146                let transformed_stream = stream
147                    .map_ok(Into::into)
148                    .map_err(Into::<RunnableAgentError>::into)
149                    .map_err(Error::from);
150
151                Ok(Box::pin(transformed_stream))
152            }
153            Err(e) => {
154                // Send error event for stream creation failure
155                Err(e.into())
156            }
157        }
158    }
159}
160
161#[cfg(test)]
162mod tests {
163    use super::*;
164    use crate::agent::hooks::HookOutcome;
165    use crate::agent::output::AgentOutputT;
166    use crate::agent::prebuilt::executor::{
167        BasicAgent as StableBasicAgent, BasicAgentOutput, ReActAgent as StableReActAgent,
168        ReActAgentOutput,
169    };
170    use crate::agent::task::Task;
171    use crate::agent::{Context, ExecutorConfig};
172    use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, TestAgentOutput, TestError};
173    use crate::tool::ToolT;
174    use async_trait::async_trait;
175    use futures::StreamExt;
176    use serde::{Deserialize, Serialize};
177    use serde_json::Value;
178    use std::sync::{
179        Arc,
180        atomic::{AtomicBool, AtomicUsize, Ordering},
181    };
182
183    #[tokio::test]
184    async fn test_direct_agent_build_requires_llm() {
185        let mock_agent = MockAgentImpl::new("direct", "direct agent");
186        let err = match AgentBuilder::<_, DirectAgent>::new(mock_agent)
187            .build()
188            .await
189        {
190            Ok(_) => panic!("expected missing llm error"),
191            Err(err) => err,
192        };
193
194        assert!(matches!(err, crate::error::Error::AgentBuildError(_)));
195    }
196
197    #[tokio::test]
198    async fn test_direct_agent_run_success() {
199        let mock_agent = MockAgentImpl::new("direct", "direct agent");
200        let llm = Arc::new(ConfigurableLLMProvider::default());
201        let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
202            .llm(llm)
203            .build()
204            .await
205            .expect("build should succeed");
206
207        let task = Task::new("hello");
208        let result = handle.agent.run(task).await.expect("run should succeed");
209        assert_eq!(result.result, "Processed: hello");
210    }
211
212    #[tokio::test]
213    async fn test_direct_agent_run_executor_error() {
214        let mock_agent = MockAgentImpl::new("direct", "direct agent").with_failure(true);
215        let llm = Arc::new(ConfigurableLLMProvider::default());
216        let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
217            .llm(llm)
218            .build()
219            .await
220            .expect("build should succeed");
221
222        let task = Task::new("fail");
223        let err = handle.agent.run(task).await.expect_err("expected error");
224        assert!(matches!(err, RunnableAgentError::ExecutorError(_)));
225    }
226
227    #[derive(Debug, Clone, Serialize, Deserialize)]
228    struct HookCountOutput {
229        result: String,
230    }
231
232    impl AgentOutputT for HookCountOutput {
233        fn output_schema() -> &'static str {
234            r#"{"type":"object","properties":{"result":{"type":"string"}},"required":["result"]}"#
235        }
236
237        fn structured_output_format() -> Value {
238            serde_json::json!({
239                "name": "HookCountOutput",
240                "description": "Hook count output",
241                "schema": {
242                    "type": "object",
243                    "properties": {
244                        "result": {"type": "string"}
245                    },
246                    "required": ["result"]
247                },
248                "strict": true
249            })
250        }
251    }
252
253    impl From<BasicAgentOutput> for HookCountOutput {
254        fn from(output: BasicAgentOutput) -> Self {
255            Self {
256                result: output.response,
257            }
258        }
259    }
260
261    impl From<ReActAgentOutput> for HookCountOutput {
262        fn from(output: ReActAgentOutput) -> Self {
263            Self {
264                result: output.response,
265            }
266        }
267    }
268
269    #[derive(Debug, Clone)]
270    struct CountingHookAgent {
271        on_run_start_calls: Arc<AtomicUsize>,
272    }
273
274    #[async_trait]
275    impl AgentDeriveT for CountingHookAgent {
276        type Output = HookCountOutput;
277
278        fn description(&self) -> &'static str {
279            "counting hook agent"
280        }
281
282        fn output_schema(&self) -> Option<Value> {
283            Some(serde_json::json!({
284                "type": "object",
285                "properties": {"result": {"type": "string"}},
286                "required": ["result"]
287            }))
288        }
289
290        fn name(&self) -> &'static str {
291            "counting_hook_agent"
292        }
293
294        fn tools(&self) -> Vec<Box<dyn ToolT>> {
295            vec![]
296        }
297    }
298
299    #[async_trait]
300    impl AgentHooks for CountingHookAgent {
301        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
302            self.on_run_start_calls.fetch_add(1, Ordering::SeqCst);
303            HookOutcome::Continue
304        }
305    }
306
307    #[tokio::test]
308    async fn test_direct_basic_agent_run_calls_on_run_start_once() {
309        let calls = Arc::new(AtomicUsize::new(0));
310        let llm = Arc::new(ConfigurableLLMProvider::default());
311        let handle =
312            AgentBuilder::<_, DirectAgent>::new(StableBasicAgent::new(CountingHookAgent {
313                on_run_start_calls: Arc::clone(&calls),
314            }))
315            .llm(llm)
316            .build()
317            .await
318            .expect("build should succeed");
319
320        let task = Task::new("hello");
321        let result = handle.agent.run(task).await.expect("run should succeed");
322
323        assert_eq!(result.result, "Mock response");
324        assert_eq!(calls.load(Ordering::SeqCst), 1);
325    }
326
327    #[tokio::test]
328    async fn test_direct_react_agent_run_calls_on_run_start_once() {
329        let calls = Arc::new(AtomicUsize::new(0));
330        let llm = Arc::new(ConfigurableLLMProvider::default());
331        let handle =
332            AgentBuilder::<_, DirectAgent>::new(StableReActAgent::new(CountingHookAgent {
333                on_run_start_calls: Arc::clone(&calls),
334            }))
335            .llm(llm)
336            .build()
337            .await
338            .expect("build should succeed");
339
340        let task = Task::new("hello");
341        let result = handle.agent.run(task).await.expect("run should succeed");
342
343        assert_eq!(result.result, "Mock response");
344        assert_eq!(calls.load(Ordering::SeqCst), 1);
345    }
346
347    #[derive(Clone, Debug)]
348    struct StreamAgent;
349
350    #[async_trait]
351    impl AgentDeriveT for StreamAgent {
352        type Output = TestAgentOutput;
353
354        fn description(&self) -> &'static str {
355            "stream agent"
356        }
357
358        fn output_schema(&self) -> Option<Value> {
359            Some(TestAgentOutput::structured_output_format())
360        }
361
362        fn name(&self) -> &'static str {
363            "stream_agent"
364        }
365
366        fn tools(&self) -> Vec<Box<dyn ToolT>> {
367            vec![]
368        }
369    }
370
371    #[async_trait]
372    impl AgentExecutor for StreamAgent {
373        type Output = TestAgentOutput;
374        type Error = TestError;
375
376        fn config(&self) -> ExecutorConfig {
377            ExecutorConfig::default()
378        }
379
380        async fn execute(
381            &self,
382            task: &Task,
383            _context: Arc<Context>,
384        ) -> Result<Self::Output, Self::Error> {
385            Ok(TestAgentOutput {
386                result: format!("Streamed: {}", task.prompt),
387            })
388        }
389    }
390
391    impl AgentHooks for StreamAgent {}
392
393    #[tokio::test]
394    async fn test_direct_agent_run_stream_default_executes_once() {
395        let llm = Arc::new(ConfigurableLLMProvider::default());
396        let handle = AgentBuilder::<_, DirectAgent>::new(StreamAgent)
397            .llm(llm)
398            .build()
399            .await
400            .expect("build should succeed");
401
402        let task = Task::new("stream");
403        let stream = handle
404            .agent
405            .run_stream(task)
406            .await
407            .expect("stream should succeed");
408        let outputs: Vec<_> = stream.collect().await;
409        assert_eq!(outputs.len(), 1);
410        let output = outputs[0].as_ref().expect("expected Ok output");
411        assert_eq!(output.result, "Streamed: stream");
412    }
413
414    #[derive(Debug)]
415    struct AbortAgent {
416        executed: Arc<AtomicBool>,
417    }
418
419    #[async_trait]
420    impl AgentDeriveT for AbortAgent {
421        type Output = TestAgentOutput;
422
423        fn description(&self) -> &'static str {
424            "abort agent"
425        }
426
427        fn output_schema(&self) -> Option<Value> {
428            Some(TestAgentOutput::structured_output_format())
429        }
430
431        fn name(&self) -> &'static str {
432            "abort_agent"
433        }
434
435        fn tools(&self) -> Vec<Box<dyn ToolT>> {
436            vec![]
437        }
438    }
439
440    #[async_trait]
441    impl AgentExecutor for AbortAgent {
442        type Output = TestAgentOutput;
443        type Error = TestError;
444
445        fn config(&self) -> ExecutorConfig {
446            ExecutorConfig::default()
447        }
448
449        async fn execute(
450            &self,
451            _task: &Task,
452            _context: Arc<Context>,
453        ) -> Result<Self::Output, Self::Error> {
454            self.executed.store(true, Ordering::SeqCst);
455            Ok(TestAgentOutput {
456                result: "should-not-run".to_string(),
457            })
458        }
459    }
460
461    #[async_trait]
462    impl AgentHooks for AbortAgent {
463        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
464            HookOutcome::Abort
465        }
466    }
467
468    #[tokio::test]
469    async fn test_direct_agent_run_aborts_before_execute() {
470        let executed = Arc::new(AtomicBool::new(false));
471        let agent = AbortAgent {
472            executed: Arc::clone(&executed),
473        };
474        let llm = Arc::new(ConfigurableLLMProvider::default());
475        let handle = AgentBuilder::<_, DirectAgent>::new(agent)
476            .llm(llm)
477            .build()
478            .await
479            .expect("build should succeed");
480
481        let task = Task::new("abort");
482        let err = handle.agent.run(task).await.expect_err("expected abort");
483        assert!(matches!(err, RunnableAgentError::Abort));
484        assert!(!executed.load(Ordering::SeqCst));
485    }
486}