Skip to main content

autoagents_core/agent/
direct.rs

1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use autoagents_protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{Receiver, Sender, channel};
12
13#[cfg(not(target_arch = "wasm32"))]
14use crate::event_fanout::EventFanout;
15use crate::utils::{BoxEventStream, receiver_into_stream};
16#[cfg(not(target_arch = "wasm32"))]
17use futures_util::stream;
18
19/// Marker type for direct (non-actor) agents.
20///
21/// Direct agents execute immediately within the caller's task without
22/// requiring a runtime or event wiring. Use this for simple one-shot
23/// invocations and unit tests.
24pub struct DirectAgent {}
25
26impl AgentType for DirectAgent {
27    fn type_name() -> &'static str {
28        "direct_agent"
29    }
30}
31
32/// Handle for a direct agent containing the agent instance and an event stream
33/// receiver. Use `agent.run(...)` for one-shot calls or `agent.run_stream(...)`
34/// to receive streaming outputs.
35pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
36    pub agent: BaseAgent<T, DirectAgent>,
37    pub rx: BoxEventStream<Event>,
38    #[cfg(not(target_arch = "wasm32"))]
39    fanout: Option<EventFanout>,
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
43    pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
44        Self {
45            agent,
46            rx,
47            #[cfg(not(target_arch = "wasm32"))]
48            fanout: None,
49        }
50    }
51
52    #[cfg(not(target_arch = "wasm32"))]
53    pub fn subscribe_events(&mut self) -> BoxEventStream<Event> {
54        if let Some(fanout) = &self.fanout {
55            return fanout.subscribe();
56        }
57
58        let stream = std::mem::replace(&mut self.rx, Box::pin(stream::empty::<Event>()));
59        let fanout = EventFanout::new(stream, DEFAULT_CHANNEL_BUFFER);
60        self.rx = fanout.subscribe();
61        let stream = fanout.subscribe();
62        self.fanout = Some(fanout);
63        stream
64    }
65}
66
67impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
68    /// Build the BaseAgent and return a wrapper
69    #[allow(clippy::result_large_err)]
70    pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
71        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
72            "LLM provider is required".to_string(),
73        ))?;
74        let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
75        let agent: BaseAgent<T, DirectAgent> =
76            BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
77        let stream = receiver_into_stream(rx);
78        Ok(DirectAgentHandle::new(agent, stream))
79    }
80}
81
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
83    /// Execute the agent for a single task and return the final agent output.
84    pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
85    where
86        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87    {
88        let context = self.create_context();
89
90        //Run Hook
91        let hook_outcome = self.inner.on_run_start(&task, &context).await;
92        match hook_outcome {
93            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
94            HookOutcome::Continue => {}
95        }
96
97        // Execute the agent's logic using the executor
98        match self.inner().execute(&task, context.clone()).await {
99            Ok(output) => {
100                let output: <T as AgentExecutor>::Output = output;
101
102                //Extract Agent output into the desired type
103                let agent_out: <T as AgentDeriveT>::Output = output.into();
104
105                //Run On complete Hook
106                self.inner
107                    .on_run_complete(&task, &agent_out, &context)
108                    .await;
109                Ok(agent_out)
110            }
111            Err(e) => {
112                // Send error event
113                Err(RunnableAgentError::ExecutorError(e.to_string()))
114            }
115        }
116    }
117
118    /// Execute the agent with streaming enabled and receive a stream of
119    /// partial outputs which culminate in a final chunk with `done=true`.
120    pub async fn run_stream(
121        &self,
122        task: Task,
123    ) -> Result<
124        std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
125        RunnableAgentError,
126    >
127    where
128        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
129    {
130        let context = self.create_context();
131
132        //Run Hook
133        let hook_outcome = self.inner.on_run_start(&task, &context).await;
134        match hook_outcome {
135            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
136            HookOutcome::Continue => {}
137        }
138
139        // Execute the agent's streaming logic using the executor
140        match self.inner().execute_stream(&task, context.clone()).await {
141            Ok(stream) => {
142                use futures::StreamExt;
143                // Convert the stream output
144                let transformed_stream = stream.map(move |result| match result {
145                    Ok(output) => Ok(output.into()),
146                    Err(e) => {
147                        let error_msg = e.to_string();
148                        Err(RunnableAgentError::ExecutorError(error_msg).into())
149                    }
150                });
151
152                Ok(Box::pin(transformed_stream))
153            }
154            Err(e) => {
155                // Send error event for stream creation failure
156                Err(RunnableAgentError::ExecutorError(e.to_string()))
157            }
158        }
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::agent::hooks::HookOutcome;
166    use crate::agent::output::AgentOutputT;
167    use crate::agent::task::Task;
168    use crate::agent::{Context, ExecutorConfig};
169    use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, TestAgentOutput, TestError};
170    use crate::tool::ToolT;
171    use async_trait::async_trait;
172    use futures::StreamExt;
173    use serde_json::Value;
174    use std::sync::{
175        Arc,
176        atomic::{AtomicBool, Ordering},
177    };
178
179    #[tokio::test]
180    async fn test_direct_agent_build_requires_llm() {
181        let mock_agent = MockAgentImpl::new("direct", "direct agent");
182        let err = match AgentBuilder::<_, DirectAgent>::new(mock_agent)
183            .build()
184            .await
185        {
186            Ok(_) => panic!("expected missing llm error"),
187            Err(err) => err,
188        };
189
190        assert!(matches!(err, crate::error::Error::AgentBuildError(_)));
191    }
192
193    #[tokio::test]
194    async fn test_direct_agent_run_success() {
195        let mock_agent = MockAgentImpl::new("direct", "direct agent");
196        let llm = Arc::new(ConfigurableLLMProvider::default());
197        let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
198            .llm(llm)
199            .build()
200            .await
201            .expect("build should succeed");
202
203        let task = Task::new("hello");
204        let result = handle.agent.run(task).await.expect("run should succeed");
205        assert_eq!(result.result, "Processed: hello");
206    }
207
208    #[tokio::test]
209    async fn test_direct_agent_run_executor_error() {
210        let mock_agent = MockAgentImpl::new("direct", "direct agent").with_failure(true);
211        let llm = Arc::new(ConfigurableLLMProvider::default());
212        let handle = AgentBuilder::<_, DirectAgent>::new(mock_agent)
213            .llm(llm)
214            .build()
215            .await
216            .expect("build should succeed");
217
218        let task = Task::new("fail");
219        let err = handle.agent.run(task).await.expect_err("expected error");
220        assert!(matches!(err, RunnableAgentError::ExecutorError(_)));
221    }
222
223    #[derive(Clone, Debug)]
224    struct StreamAgent;
225
226    #[async_trait]
227    impl AgentDeriveT for StreamAgent {
228        type Output = TestAgentOutput;
229
230        fn description(&self) -> &'static str {
231            "stream agent"
232        }
233
234        fn output_schema(&self) -> Option<Value> {
235            Some(TestAgentOutput::structured_output_format())
236        }
237
238        fn name(&self) -> &'static str {
239            "stream_agent"
240        }
241
242        fn tools(&self) -> Vec<Box<dyn ToolT>> {
243            vec![]
244        }
245    }
246
247    #[async_trait]
248    impl AgentExecutor for StreamAgent {
249        type Output = TestAgentOutput;
250        type Error = TestError;
251
252        fn config(&self) -> ExecutorConfig {
253            ExecutorConfig::default()
254        }
255
256        async fn execute(
257            &self,
258            task: &Task,
259            _context: Arc<Context>,
260        ) -> Result<Self::Output, Self::Error> {
261            Ok(TestAgentOutput {
262                result: format!("Streamed: {}", task.prompt),
263            })
264        }
265    }
266
267    impl AgentHooks for StreamAgent {}
268
269    #[tokio::test]
270    async fn test_direct_agent_run_stream_default_executes_once() {
271        let llm = Arc::new(ConfigurableLLMProvider::default());
272        let handle = AgentBuilder::<_, DirectAgent>::new(StreamAgent)
273            .llm(llm)
274            .build()
275            .await
276            .expect("build should succeed");
277
278        let task = Task::new("stream");
279        let stream = handle
280            .agent
281            .run_stream(task)
282            .await
283            .expect("stream should succeed");
284        let outputs: Vec<_> = stream.collect().await;
285        assert_eq!(outputs.len(), 1);
286        let output = outputs[0].as_ref().expect("expected Ok output");
287        assert_eq!(output.result, "Streamed: stream");
288    }
289
290    #[derive(Debug)]
291    struct AbortAgent {
292        executed: Arc<AtomicBool>,
293    }
294
295    #[async_trait]
296    impl AgentDeriveT for AbortAgent {
297        type Output = TestAgentOutput;
298
299        fn description(&self) -> &'static str {
300            "abort agent"
301        }
302
303        fn output_schema(&self) -> Option<Value> {
304            Some(TestAgentOutput::structured_output_format())
305        }
306
307        fn name(&self) -> &'static str {
308            "abort_agent"
309        }
310
311        fn tools(&self) -> Vec<Box<dyn ToolT>> {
312            vec![]
313        }
314    }
315
316    #[async_trait]
317    impl AgentExecutor for AbortAgent {
318        type Output = TestAgentOutput;
319        type Error = TestError;
320
321        fn config(&self) -> ExecutorConfig {
322            ExecutorConfig::default()
323        }
324
325        async fn execute(
326            &self,
327            _task: &Task,
328            _context: Arc<Context>,
329        ) -> Result<Self::Output, Self::Error> {
330            self.executed.store(true, Ordering::SeqCst);
331            Ok(TestAgentOutput {
332                result: "should-not-run".to_string(),
333            })
334        }
335    }
336
337    #[async_trait]
338    impl AgentHooks for AbortAgent {
339        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
340            HookOutcome::Abort
341        }
342    }
343
344    #[tokio::test]
345    async fn test_direct_agent_run_aborts_before_execute() {
346        let executed = Arc::new(AtomicBool::new(false));
347        let agent = AbortAgent {
348            executed: Arc::clone(&executed),
349        };
350        let llm = Arc::new(ConfigurableLLMProvider::default());
351        let handle = AgentBuilder::<_, DirectAgent>::new(agent)
352            .llm(llm)
353            .build()
354            .await
355            .expect("build should succeed");
356
357        let task = Task::new("abort");
358        let err = handle.agent.run(task).await.expect_err("expected abort");
359        assert!(matches!(err, RunnableAgentError::Abort));
360        assert!(!executed.load(Ordering::SeqCst));
361    }
362}