autoagents_core/agent/
actor.rs

1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::hooks::AgentHooks;
6use crate::agent::state::AgentState;
7use crate::agent::task::Task;
8use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
9use crate::channel::Sender;
10use crate::error::Error;
11use crate::protocol::Event;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15#[cfg(target_arch = "wasm32")]
16use futures::SinkExt;
17use futures::Stream;
18#[cfg(not(target_arch = "wasm32"))]
19use ractor::Actor;
20#[cfg(not(target_arch = "wasm32"))]
21use ractor::{ActorProcessingErr, ActorRef};
22use serde_json::Value;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26pub struct ActorAgent {}
27
28impl AgentType for ActorAgent {
29    fn type_name() -> &'static str {
30        "protocol_agent"
31    }
32}
33
34/// Handle for an agent that includes both the agent and its actor reference
35#[cfg(not(target_arch = "wasm32"))]
36#[derive(Clone)]
37pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
38    pub agent: Arc<BaseAgent<T, ActorAgent>>,
39    pub actor_ref: ActorRef<Task>,
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
44    /// Get the actor reference for direct messaging
45    pub fn addr(&self) -> ActorRef<Task> {
46        self.actor_ref.clone()
47    }
48
49    /// Get the agent reference
50    pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
51        self.agent.clone()
52    }
53}
54
55#[cfg(not(target_arch = "wasm32"))]
56impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("AgentHandle")
59            .field("agent", &self.agent)
60            .finish()
61    }
62}
63
64#[cfg(not(target_arch = "wasm32"))]
65#[derive(Debug)]
66pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
67    pub Arc<BaseAgent<T, ActorAgent>>,
68);
69
70#[cfg(not(target_arch = "wasm32"))]
71impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
72
73#[cfg(not(target_arch = "wasm32"))]
74impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
75where
76    T: Send + Sync + 'static,
77    serde_json::Value: From<<T as AgentExecutor>::Output>,
78    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
79{
80    /// Build the BaseAgent and return a wrapper that includes the actor reference
81    pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
82        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
83            "LLM provider is required".to_string(),
84        ))?;
85        let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
86            "Runtime should be defined".into(),
87        ))?;
88        let tx = runtime.tx();
89
90        let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
91            BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
92        );
93
94        // Create agent actor
95        let agent_actor = AgentActor(agent.clone());
96        let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
97            .await
98            .map_err(AgentBuildError::SpawnError)?
99            .0;
100
101        // Subscribe to topics
102        for topic in self.subscribed_topics {
103            runtime.subscribe(&topic, actor_ref.clone()).await?;
104        }
105
106        Ok(ActorAgentHandle { agent, actor_ref })
107    }
108
109    pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
110        self.subscribed_topics.push(topic);
111        self
112    }
113}
114
115#[cfg(not(target_arch = "wasm32"))]
116impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
117    pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
118        self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
119    }
120
121    pub async fn run(
122        self: Arc<Self>,
123        task: Task,
124    ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
125    where
126        Value: From<<T as AgentExecutor>::Output>,
127        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
128    {
129        let submission_id = task.submission_id;
130        let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
131
132        let context = self.create_context();
133
134        //Run Hook
135        let hook_outcome = self.inner.on_run_start(&task, &context).await;
136        match hook_outcome {
137            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
138            HookOutcome::Continue => {}
139        }
140
141        // Execute the agent's logic using the executor
142        match self.inner().execute(&task, context.clone()).await {
143            Ok(output) => {
144                let value: Value = output.clone().into();
145
146                #[cfg(not(target_arch = "wasm32"))]
147                tx.send(Event::TaskComplete {
148                    sub_id: submission_id,
149                    actor_id: self.id,
150                    actor_name: self.name().to_string(),
151                    result: serde_json::to_string_pretty(&value)
152                        .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?,
153                })
154                .await
155                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
156
157                //Extract Agent output into the desired type
158                let agent_out: <T as AgentDeriveT>::Output = output.into();
159
160                //Run On complete Hook
161                self.inner
162                    .on_run_complete(&task, &agent_out, &context)
163                    .await;
164
165                Ok(agent_out)
166            }
167            Err(e) => {
168                #[cfg(not(target_arch = "wasm32"))]
169                tx.send(Event::TaskError {
170                    sub_id: submission_id,
171                    actor_id: self.id,
172                    error: e.to_string(),
173                })
174                .await
175                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
176                Err(RunnableAgentError::ExecutorError(e.to_string()))
177            }
178        }
179    }
180
181    pub async fn run_stream(
182        self: Arc<Self>,
183        task: Task,
184    ) -> Result<
185        std::pin::Pin<
186            Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
187        >,
188        RunnableAgentError,
189    >
190    where
191        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
192    {
193        // let submission_id = task.submission_id;
194        let context = self.create_context();
195
196        // Execute the agent's streaming logic using the executor
197        match self.inner().execute_stream(&task, context).await {
198            Ok(stream) => {
199                use futures::StreamExt;
200                // Transform the stream to convert agent output to TaskResult
201                let transformed_stream = stream.map(move |result| {
202                    match result {
203                        Ok(output) => Ok(output.into()),
204                        Err(e) => {
205                            // Handle error
206                            let error_msg = e.to_string();
207                            Err(RunnableAgentError::ExecutorError(error_msg))
208                        }
209                    }
210                });
211
212                Ok(Box::pin(transformed_stream))
213            }
214            Err(e) => {
215                // Send error event for stream creation failure
216                Err(RunnableAgentError::ExecutorError(e.to_string()))
217            }
218        }
219    }
220}
221
222#[cfg(not(target_arch = "wasm32"))]
223#[async_trait]
224impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
225where
226    T: Send + Sync + 'static,
227    serde_json::Value: From<<T as AgentExecutor>::Output>,
228    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
229{
230    type Msg = Task;
231    type State = AgentState;
232    type Arguments = ();
233
234    async fn pre_start(
235        &self,
236        _myself: ActorRef<Self::Msg>,
237        _args: Self::Arguments,
238    ) -> Result<Self::State, ActorProcessingErr> {
239        Ok(AgentState::new())
240    }
241
242    async fn post_stop(
243        &self,
244        _myself: ActorRef<Self::Msg>,
245        _state: &mut Self::State,
246    ) -> Result<(), ActorProcessingErr> {
247        //Run Hook
248        self.0.inner().on_agent_shutdown().await;
249        Ok(())
250    }
251
252    async fn handle(
253        &self,
254        _myself: ActorRef<Self::Msg>,
255        message: Self::Msg,
256        _state: &mut Self::State,
257    ) -> Result<(), ActorProcessingErr> {
258        let agent = self.0.clone();
259        let task = message;
260
261        //Run agent
262        if agent.stream() {
263            let _ = agent.run_stream(task).await?;
264            Ok(())
265        } else {
266            let _ = agent.run(task).await?;
267            Ok(())
268        }
269    }
270}