autoagents_core/agent/
direct.rs

1use crate::agent::base::AgentType;
2use crate::agent::error::{AgentBuildError, RunnableAgentError};
3use crate::agent::task::Task;
4use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, AgentHooks, BaseAgent, HookOutcome};
5use crate::error::Error;
6use crate::protocol::Event;
7use futures::Stream;
8
9use crate::agent::constants::DEFAULT_CHANNEL_BUFFER;
10
11use crate::channel::{channel, Receiver, Sender};
12
13use crate::utils::{receiver_into_stream, BoxEventStream};
14
15pub struct DirectAgent {}
16
17impl AgentType for DirectAgent {
18    fn type_name() -> &'static str {
19        "direct_agent"
20    }
21}
22
23/// Handle for an agent that includes both the agent and its actor reference
24pub struct DirectAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
25    pub agent: BaseAgent<T, DirectAgent>,
26    pub rx: BoxEventStream<Event>,
27}
28
29impl<T: AgentDeriveT + AgentExecutor + AgentHooks> DirectAgentHandle<T> {
30    pub fn new(agent: BaseAgent<T, DirectAgent>, rx: BoxEventStream<Event>) -> Self {
31        Self { agent, rx }
32    }
33}
34
35impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, DirectAgent> {
36    /// Build the BaseAgent and return a wrapper
37    #[allow(clippy::result_large_err)]
38    pub async fn build(self) -> Result<DirectAgentHandle<T>, Error> {
39        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
40            "LLM provider is required".to_string(),
41        ))?;
42        let (tx, rx): (Sender<Event>, Receiver<Event>) = channel(DEFAULT_CHANNEL_BUFFER);
43        let agent: BaseAgent<T, DirectAgent> =
44            BaseAgent::<T, DirectAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?;
45        let stream = receiver_into_stream(rx);
46        Ok(DirectAgentHandle::new(agent, stream))
47    }
48}
49
50impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, DirectAgent> {
51    pub async fn run(&self, task: Task) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
52    where
53        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
54    {
55        let context = self.create_context();
56
57        //Run Hook
58        let hook_outcome = self.inner.on_run_start(&task, &context).await;
59        match hook_outcome {
60            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
61            HookOutcome::Continue => {}
62        }
63
64        // Execute the agent's logic using the executor
65        match self.inner().execute(&task, context.clone()).await {
66            Ok(output) => {
67                let output: <T as AgentExecutor>::Output = output;
68
69                //Extract Agent output into the desired type
70                let agent_out: <T as AgentDeriveT>::Output = output.into();
71
72                //Run On complete Hook
73                self.inner
74                    .on_run_complete(&task, &agent_out, &context)
75                    .await;
76                Ok(agent_out)
77            }
78            Err(e) => {
79                // Send error event
80                Err(RunnableAgentError::ExecutorError(e.to_string()))
81            }
82        }
83    }
84
85    pub async fn run_stream(
86        &self,
87        task: Task,
88    ) -> Result<
89        std::pin::Pin<Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, Error>> + Send>>,
90        RunnableAgentError,
91    >
92    where
93        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
94    {
95        let context = self.create_context();
96
97        //Run Hook
98        let hook_outcome = self.inner.on_run_start(&task, &context).await;
99        match hook_outcome {
100            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
101            HookOutcome::Continue => {}
102        }
103
104        // Execute the agent's streaming logic using the executor
105        match self.inner().execute_stream(&task, context.clone()).await {
106            Ok(stream) => {
107                use futures::StreamExt;
108                // Convert the stream output
109                let transformed_stream = stream.map(move |result| match result {
110                    Ok(output) => Ok(output.into()),
111                    Err(e) => {
112                        let error_msg = e.to_string();
113                        Err(RunnableAgentError::ExecutorError(error_msg).into())
114                    }
115                });
116
117                Ok(Box::pin(transformed_stream))
118            }
119            Err(e) => {
120                // Send error event for stream creation failure
121                Err(RunnableAgentError::ExecutorError(e.to_string()))
122            }
123        }
124    }
125}