Skip to main content

autoagents_core/agent/
direct.rs

1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use autoagents_protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{Receiver, Sender, channel};
12
13#[cfg(not(target_arch = "wasm32"))]
14use crate::event_fanout::EventFanout;
15use crate::utils::{BoxEventStream, receiver_into_stream};
16#[cfg(not(target_arch = "wasm32"))]
17use futures_util::stream;
18
19/// Marker type for direct (non-actor) agents.
20///
21/// Direct agents execute immediately within the caller's task without
22/// requiring a runtime or event wiring. Use this for simple one-shot
23/// invocations and unit tests.
24pub struct DirectAgent {}
25
26impl AgentType for DirectAgent {
27    fn type_name() -> &'static str {
28        "direct_agent"
29    }
30}
31
32/// Handle for a direct agent containing the agent instance and an event stream
33/// receiver. Use `agent.run(...)` for one-shot calls or `agent.run_stream(...)`
34/// to receive streaming outputs.
35pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
36    pub agent: BaseAgent<T, DirectAgent>,
37    pub rx: BoxEventStream<Event>,
38    #[cfg(not(target_arch = "wasm32"))]
39    fanout: Option<EventFanout>,
40}
41
42impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
43    pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
44        Self {
45            agent,
46            rx,
47            #[cfg(not(target_arch = "wasm32"))]
48            fanout: None,
49        }
50    }
51
52    #[cfg(not(target_arch = "wasm32"))]
53    pub fn subscribe_events(&mut self) -> BoxEventStream<Event> {
54        if let Some(fanout) = &self.fanout {
55            return fanout.subscribe();
56        }
57
58        let stream = std::mem::replace(&mut self.rx, Box::pin(stream::empty::<Event>()));
59        let fanout = EventFanout::new(stream, DEFAULT_CHANNEL_BUFFER);
60        self.rx = fanout.subscribe();
61        let stream = fanout.subscribe();
62        self.fanout = Some(fanout);
63        stream
64    }
65}
66
67impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
68    /// Build the BaseAgent and return a wrapper
69    #[allow(clippy::result_large_err)]
70    pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
71        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
72            "LLM provider is required".to_string(),
73        ))?;
74        let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
75        let agent: BaseAgent<T, DirectAgent> =
76            BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
77        let stream = receiver_into_stream(rx);
78        Ok(DirectAgentHandle::new(agent, stream))
79    }
80}
81
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
83    /// Execute the agent for a single task and return the final agent output.
84    pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
85    where
86        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87    {
88        let context = self.create_context();
89
90        //Run Hook
91        let hook_outcome = self.inner.on_run_start(&task, &context).await;
92        match hook_outcome {
93            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
94            HookOutcome::Continue => {}
95        }
96
97        // Execute the agent's logic using the executor
98        match self.inner().execute(&task, context.clone()).await {
99            Ok(output) => {
100                let output: <T as AgentExecutor>::Output = output;
101
102                //Extract Agent output into the desired type
103                let agent_out: <T as AgentDeriveT>::Output = output.into();
104
105                //Run On complete Hook
106                self.inner
107                    .on_run_complete(&task, &agent_out, &context)
108                    .await;
109                Ok(agent_out)
110            }
111            Err(e) => {
112                // Send error event
113                Err(RunnableAgentError::ExecutorError(e.to_string()))
114            }
115        }
116    }
117
118    /// Execute the agent with streaming enabled and receive a stream of
119    /// partial outputs which culminate in a final chunk with `done=true`.
120    pub async fn run_stream(
121        &self,
122        task: Task,
123    ) -> Result<
124        std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
125        RunnableAgentError,
126    >
127    where
128        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
129    {
130        let context = self.create_context();
131
132        //Run Hook
133        let hook_outcome = self.inner.on_run_start(&task, &context).await;
134        match hook_outcome {
135            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
136            HookOutcome::Continue => {}
137        }
138
139        // Execute the agent's streaming logic using the executor
140        match self.inner().execute_stream(&task, context.clone()).await {
141            Ok(stream) => {
142                use futures::StreamExt;
143                // Convert the stream output
144                let transformed_stream = stream.map(move |result| match result {
145                    Ok(output) => Ok(output.into()),
146                    Err(e) => {
147                        let error_msg = e.to_string();
148                        Err(RunnableAgentError::ExecutorError(error_msg).into())
149                    }
150                });
151
152                Ok(Box::pin(transformed_stream))
153            }
154            Err(e) => {
155                // Send error event for stream creation failure
156                Err(RunnableAgentError::ExecutorError(e.to_string()))
157            }
158        }
159    }
160}