Skip to main content

autoagents_core/agent/
actor.rs

1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::hooks::AgentHooks;
6use crate::agent::state::AgentState;
7use crate::agent::task::Task;
8use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
9use crate::channel::Sender;
10use crate::error::Error;
11#[cfg(not(target_arch = "wasm32"))]
12use crate::runtime::TypedRuntime;
13use async_trait::async_trait;
14use autoagents_protocol::Event;
15#[cfg(target_arch = "wasm32")]
16use futures::SinkExt;
17use futures::Stream;
18#[cfg(not(target_arch = "wasm32"))]
19use ractor::Actor;
20#[cfg(not(target_arch = "wasm32"))]
21use ractor::{ActorProcessingErr, ActorRef};
22use serde_json::Value;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26/// Marker type for actor-based agents.
27///
28/// Actor agents run inside a runtime, can subscribe to topics, receive
29/// messages, and emit protocol `Event`s for streaming updates.
30pub struct ActorAgent {}
31
32impl AgentType for ActorAgent {
33    fn type_name() -> &'static str {
34        "protocol_agent"
35    }
36}
37
38/// Handle for an actor-based agent that contains both the agent and the
39/// address of its actor. Use `addr()` to send messages directly or publish
40/// `Task`s to subscribed `Topic<Task>`.
41#[cfg(not(target_arch = "wasm32"))]
42#[derive(Clone)]
43pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
44    pub agent: Arc<BaseAgent<T, ActorAgent>>,
45    pub actor_ref: ActorRef<Task>,
46}
47
48#[cfg(not(target_arch = "wasm32"))]
49impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
50    /// Get the actor reference (`ActorRef<Task>`) for direct messaging.
51    pub fn addr(&self) -> ActorRef<Task> {
52        self.actor_ref.clone()
53    }
54
55    /// Get a clone of the agent reference for querying metadata or invoking
56    /// methods that require `Arc<BaseAgent<..>>`.
57    pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
58        self.agent.clone()
59    }
60}
61
62#[cfg(not(target_arch = "wasm32"))]
63impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("AgentHandle")
66            .field("agent", &self.agent)
67            .finish()
68    }
69}
70
71#[cfg(not(target_arch = "wasm32"))]
72#[derive(Debug)]
73pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
74    pub Arc<BaseAgent<T, ActorAgent>>,
75);
76
77#[cfg(not(target_arch = "wasm32"))]
78impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
79
80#[cfg(not(target_arch = "wasm32"))]
81impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
82where
83    T: Send + Sync + 'static,
84    serde_json::Value: From<<T as AgentExecutor>::Output>,
85    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
86{
87    /// Build the BaseAgent and return a wrapper that includes the actor reference
88    pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
89        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
90            "LLM provider is required".to_string(),
91        ))?;
92        let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
93            "Runtime should be defined".into(),
94        ))?;
95        let tx = runtime.tx();
96
97        let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
98            BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
99        );
100
101        // Create agent actor
102        let agent_actor = AgentActor(agent.clone());
103        let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
104            .await
105            .map_err(AgentBuildError::SpawnError)?
106            .0;
107
108        // Subscribe to topics
109        for topic in self.subscribed_topics {
110            runtime.subscribe(&topic, actor_ref.clone()).await?;
111        }
112
113        Ok(ActorAgentHandle { agent, actor_ref })
114    }
115
116    pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
117        self.subscribed_topics.push(topic);
118        self
119    }
120}
121
122#[cfg(not(target_arch = "wasm32"))]
123impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
124    pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
125        self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
126    }
127
128    pub async fn run(
129        self: Arc<Self>,
130        task: Task,
131    ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
132    where
133        Value: From<<T as AgentExecutor>::Output>,
134        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
135    {
136        let submission_id = task.submission_id;
137        let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
138
139        let context = self.create_context();
140
141        //Run Hook
142        let hook_outcome = self.inner.on_run_start(&task, &context).await;
143        match hook_outcome {
144            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
145            HookOutcome::Continue => {}
146        }
147
148        // Execute the agent's logic using the executor
149        match self.inner().execute(&task, context.clone()).await {
150            Ok(output) => {
151                let value: Value = output.clone().into();
152
153                #[cfg(not(target_arch = "wasm32"))]
154                tx.send(Event::TaskComplete {
155                    sub_id: submission_id,
156                    actor_id: self.id,
157                    actor_name: self.name().to_string(),
158                    result: serde_json::to_string_pretty(&value)
159                        .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?,
160                })
161                .await
162                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
163
164                //Extract Agent output into the desired type
165                let agent_out: <T as AgentDeriveT>::Output = output.into();
166
167                //Run On complete Hook
168                self.inner
169                    .on_run_complete(&task, &agent_out, &context)
170                    .await;
171
172                Ok(agent_out)
173            }
174            Err(e) => {
175                #[cfg(not(target_arch = "wasm32"))]
176                tx.send(Event::TaskError {
177                    sub_id: submission_id,
178                    actor_id: self.id,
179                    error: e.to_string(),
180                })
181                .await
182                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
183                Err(RunnableAgentError::ExecutorError(e.to_string()))
184            }
185        }
186    }
187
188    pub async fn run_stream(
189        self: Arc<Self>,
190        task: Task,
191    ) -> Result<
192        std::pin::Pin<
193            Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
194        >,
195        RunnableAgentError,
196    >
197    where
198        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
199    {
200        // let submission_id = task.submission_id;
201        let context = self.create_context();
202
203        // Execute the agent's streaming logic using the executor
204        match self.inner().execute_stream(&task, context).await {
205            Ok(stream) => {
206                use futures::StreamExt;
207                // Transform the stream to convert agent output to TaskResult
208                let transformed_stream = stream.map(move |result| {
209                    match result {
210                        Ok(output) => Ok(output.into()),
211                        Err(e) => {
212                            // Handle error
213                            let error_msg = e.to_string();
214                            Err(RunnableAgentError::ExecutorError(error_msg))
215                        }
216                    }
217                });
218
219                Ok(Box::pin(transformed_stream))
220            }
221            Err(e) => {
222                // Send error event for stream creation failure
223                Err(RunnableAgentError::ExecutorError(e.to_string()))
224            }
225        }
226    }
227}
228
229#[cfg(not(target_arch = "wasm32"))]
230#[async_trait]
231impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
232where
233    T: Send + Sync + 'static,
234    serde_json::Value: From<<T as AgentExecutor>::Output>,
235    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
236{
237    type Msg = Task;
238    type State = AgentState;
239    type Arguments = ();
240
241    async fn pre_start(
242        &self,
243        _myself: ActorRef<Self::Msg>,
244        _args: Self::Arguments,
245    ) -> Result<Self::State, ActorProcessingErr> {
246        Ok(AgentState::new())
247    }
248
249    async fn post_stop(
250        &self,
251        _myself: ActorRef<Self::Msg>,
252        _state: &mut Self::State,
253    ) -> Result<(), ActorProcessingErr> {
254        //Run Hook
255        self.0.inner().on_agent_shutdown().await;
256        Ok(())
257    }
258
259    async fn handle(
260        &self,
261        _myself: ActorRef<Self::Msg>,
262        message: Self::Msg,
263        _state: &mut Self::State,
264    ) -> Result<(), ActorProcessingErr> {
265        let agent = self.0.clone();
266        let task = message;
267
268        //Run agent
269        if agent.stream() {
270            let _ = agent.run_stream(task).await?;
271            Ok(())
272        } else {
273            let _ = agent.run(task).await?;
274            Ok(())
275        }
276    }
277}
278
279#[cfg(test)]
280#[cfg(not(target_arch = "wasm32"))]
281mod tests {
282    use super::*;
283    use crate::actor::{LocalTransport, Topic, Transport};
284    use crate::runtime::{Runtime, RuntimeError};
285    use crate::tests::{MockAgentImpl, MockLLMProvider};
286    use crate::utils::BoxEventStream;
287    use async_trait::async_trait;
288    use futures::stream;
289    use std::any::{Any, TypeId};
290    use std::sync::Arc;
291    use tokio::sync::{Mutex, mpsc};
292
293    #[derive(Debug)]
294    struct TestRuntime {
295        subscribed: Arc<Mutex<Vec<(String, TypeId)>>>,
296        tx: mpsc::Sender<Event>,
297    }
298
299    impl TestRuntime {
300        fn new() -> Self {
301            let (tx, _rx) = mpsc::channel(4);
302            Self {
303                subscribed: Arc::new(Mutex::new(Vec::new())),
304                tx,
305            }
306        }
307    }
308
309    #[async_trait]
310    impl Runtime for TestRuntime {
311        fn id(&self) -> autoagents_protocol::RuntimeID {
312            autoagents_protocol::RuntimeID::new_v4()
313        }
314
315        async fn subscribe_any(
316            &self,
317            topic_name: &str,
318            topic_type: TypeId,
319            _actor: Arc<dyn crate::actor::AnyActor>,
320        ) -> Result<(), RuntimeError> {
321            let mut subscribed = self.subscribed.lock().await;
322            subscribed.push((topic_name.to_string(), topic_type));
323            Ok(())
324        }
325
326        async fn publish_any(
327            &self,
328            _topic_name: &str,
329            _topic_type: TypeId,
330            _message: Arc<dyn Any + Send + Sync>,
331        ) -> Result<(), RuntimeError> {
332            Ok(())
333        }
334
335        fn tx(&self) -> mpsc::Sender<Event> {
336            self.tx.clone()
337        }
338
339        async fn transport(&self) -> Arc<dyn Transport> {
340            Arc::new(LocalTransport)
341        }
342
343        async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
344            None
345        }
346
347        async fn subscribe_events(&self) -> BoxEventStream<Event> {
348            Box::pin(stream::empty())
349        }
350
351        async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
352            Ok(())
353        }
354
355        async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
356            Ok(())
357        }
358    }
359
360    #[tokio::test]
361    async fn test_actor_builder_requires_llm() {
362        let mock = MockAgentImpl::new("agent", "desc");
363        let runtime = Arc::new(TestRuntime::new());
364        let err = AgentBuilder::<_, ActorAgent>::new(mock)
365            .runtime(runtime)
366            .build()
367            .await
368            .unwrap_err();
369        assert!(matches!(err, Error::AgentBuildError(_)));
370    }
371
372    #[tokio::test]
373    async fn test_actor_builder_requires_runtime() {
374        let mock = MockAgentImpl::new("agent", "desc");
375        let llm = Arc::new(MockLLMProvider);
376        let err = AgentBuilder::<_, ActorAgent>::new(mock)
377            .llm(llm)
378            .build()
379            .await
380            .unwrap_err();
381        assert!(matches!(err, Error::AgentBuildError(_)));
382    }
383
384    #[tokio::test]
385    async fn test_actor_builder_subscribes_topics() {
386        let mock = MockAgentImpl::new("agent", "desc");
387        let llm = Arc::new(MockLLMProvider);
388        let runtime = Arc::new(TestRuntime::new());
389        let topic = Topic::<Task>::new("jobs");
390
391        let _handle = AgentBuilder::<_, ActorAgent>::new(mock)
392            .llm(llm)
393            .runtime(runtime.clone())
394            .subscribe(topic)
395            .build()
396            .await
397            .expect("build should succeed");
398
399        let subscribed = runtime.subscribed.lock().await;
400        assert_eq!(subscribed.len(), 1);
401        assert_eq!(subscribed[0].0, "jobs");
402    }
403
404    #[tokio::test]
405    async fn test_actor_agent_tx_missing_returns_error() {
406        let mock = MockAgentImpl::new("agent", "desc");
407        let llm = Arc::new(MockLLMProvider);
408        let (tx, _rx) = mpsc::channel(2);
409        let mut agent = BaseAgent::<_, ActorAgent>::new(mock, llm, None, tx, false)
410            .await
411            .unwrap();
412        agent.tx = None;
413        let err = agent.tx().unwrap_err();
414        assert!(matches!(err, RunnableAgentError::EmptyTx));
415    }
416}