autoagents_core/agent/
actor.rs

1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::hooks::AgentHooks;
6use crate::agent::state::AgentState;
7use crate::agent::task::Task;
8use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
9use crate::channel::Sender;
10use crate::error::Error;
11use crate::protocol::Event;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15#[cfg(target_arch = "wasm32")]
16use futures::SinkExt;
17use futures::Stream;
18#[cfg(not(target_arch = "wasm32"))]
19use ractor::Actor;
20#[cfg(not(target_arch = "wasm32"))]
21use ractor::{ActorProcessingErr, ActorRef};
22use serde_json::Value;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26pub struct ActorAgent {}
27
28impl AgentType for ActorAgent {
29    fn type_name() -> &'static str {
30        "protocol_agent"
31    }
32}
33
34/// Handle for an agent that includes both the agent and its actor reference
35#[cfg(not(target_arch = "wasm32"))]
36#[derive(Clone)]
37pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks> {
38    pub agent: Arc<BaseAgent<T, ActorAgent>>,
39    pub actor_ref: ActorRef<Task>,
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
44    /// Get the actor reference for direct messaging
45    pub fn addr(&self) -> ActorRef<Task> {
46        self.actor_ref.clone()
47    }
48
49    /// Get the agent reference
50    pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
51        self.agent.clone()
52    }
53}
54
55#[cfg(not(target_arch = "wasm32"))]
56impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("AgentHandle")
59            .field("agent", &self.agent)
60            .finish()
61    }
62}
63
64#[cfg(not(target_arch = "wasm32"))]
65#[derive(Debug)]
66pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
67    pub Arc<BaseAgent<T, ActorAgent>>,
68);
69
70#[cfg(not(target_arch = "wasm32"))]
71impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
72
73#[cfg(not(target_arch = "wasm32"))]
74impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
75where
76    T: Send + Sync + 'static,
77    serde_json::Value: From<<T as AgentExecutor>::Output>,
78    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
79{
80    /// Build the BaseAgent and return a wrapper that includes the actor reference
81    pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
82        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
83            "LLM provider is required".to_string(),
84        ))?;
85        let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
86            "Runtime should be defined".into(),
87        ))?;
88        let tx = runtime.tx();
89
90        let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
91            BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
92        );
93
94        // Create agent actor
95        let agent_actor = AgentActor(agent.clone());
96        let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
97            .await
98            .map_err(AgentBuildError::SpawnError)?
99            .0;
100
101        // Subscribe to topics
102        for topic in self.subscribed_topics {
103            runtime.subscribe(&topic, actor_ref.clone()).await?;
104        }
105
106        Ok(ActorAgentHandle { agent, actor_ref })
107    }
108
109    pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
110        self.subscribed_topics.push(topic);
111        self
112    }
113}
114
115#[cfg(not(target_arch = "wasm32"))]
116impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
117    pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
118        self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
119    }
120
121    pub async fn run(
122        self: Arc<Self>,
123        task: Task,
124    ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
125    where
126        Value: From<<T as AgentExecutor>::Output>,
127        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
128    {
129        let submission_id = task.submission_id;
130        let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
131
132        let context = self.create_context();
133
134        //Run Hook
135        let hook_outcome = self.inner.on_run_start(&task, &context).await;
136        match hook_outcome {
137            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
138            HookOutcome::Continue => {}
139        }
140
141        // Execute the agent's logic using the executor
142        match self.inner().execute(&task, context.clone()).await {
143            Ok(output) => {
144                let value: Value = output.clone().into();
145
146                #[cfg(not(target_arch = "wasm32"))]
147                tx.send(Event::TaskComplete {
148                    sub_id: submission_id,
149                    result: serde_json::to_string_pretty(&value)
150                        .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?,
151                })
152                .await
153                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
154
155                //Extract Agent output into the desired type
156                let agent_out: <T as AgentDeriveT>::Output = output.into();
157
158                //Run On complete Hook
159                self.inner
160                    .on_run_complete(&task, &agent_out, &context)
161                    .await;
162
163                Ok(agent_out)
164            }
165            Err(e) => {
166                #[cfg(not(target_arch = "wasm32"))]
167                tx.send(Event::TaskError {
168                    sub_id: submission_id,
169                    error: e.to_string(),
170                })
171                .await
172                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
173                Err(RunnableAgentError::ExecutorError(e.to_string()))
174            }
175        }
176    }
177
178    pub async fn run_stream(
179        self: Arc<Self>,
180        task: Task,
181    ) -> Result<
182        std::pin::Pin<
183            Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
184        >,
185        RunnableAgentError,
186    >
187    where
188        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
189    {
190        // let submission_id = task.submission_id;
191        let context = self.create_context();
192
193        // Execute the agent's streaming logic using the executor
194        match self.inner().execute_stream(&task, context).await {
195            Ok(stream) => {
196                use futures::StreamExt;
197                // Transform the stream to convert agent output to TaskResult
198                let transformed_stream = stream.map(move |result| {
199                    match result {
200                        Ok(output) => Ok(output.into()),
201                        Err(e) => {
202                            // Handle error
203                            let error_msg = e.to_string();
204                            Err(RunnableAgentError::ExecutorError(error_msg))
205                        }
206                    }
207                });
208
209                Ok(Box::pin(transformed_stream))
210            }
211            Err(e) => {
212                // Send error event for stream creation failure
213                Err(RunnableAgentError::ExecutorError(e.to_string()))
214            }
215        }
216    }
217}
218
219#[cfg(not(target_arch = "wasm32"))]
220#[async_trait]
221impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
222where
223    T: Send + Sync + 'static,
224    serde_json::Value: From<<T as AgentExecutor>::Output>,
225    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
226{
227    type Msg = Task;
228    type State = AgentState;
229    type Arguments = ();
230
231    async fn pre_start(
232        &self,
233        _myself: ActorRef<Self::Msg>,
234        _args: Self::Arguments,
235    ) -> Result<Self::State, ActorProcessingErr> {
236        Ok(AgentState::new())
237    }
238
239    async fn post_stop(
240        &self,
241        _myself: ActorRef<Self::Msg>,
242        _state: &mut Self::State,
243    ) -> Result<(), ActorProcessingErr> {
244        //Run Hook
245        self.0.inner().on_agent_shutdown().await;
246        Ok(())
247    }
248
249    async fn handle(
250        &self,
251        _myself: ActorRef<Self::Msg>,
252        message: Self::Msg,
253        _state: &mut Self::State,
254    ) -> Result<(), ActorProcessingErr> {
255        let agent = self.0.clone();
256        let task = message;
257
258        //Run agent
259        if agent.stream() {
260            let _ = agent.run_stream(task).await?;
261            Ok(())
262        } else {
263            let _ = agent.run(task).await?;
264            Ok(())
265        }
266    }
267}