Skip to main content

autoagents_core/agent/
actor.rs

1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::executor::event_helper::EventHelper;
6use crate::agent::hooks::AgentHooks;
7use crate::agent::state::AgentState;
8use crate::agent::task::Task;
9use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
10use crate::channel::Sender;
11use crate::error::Error;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15use autoagents_protocol::Event;
16#[cfg(target_arch = "wasm32")]
17use futures::SinkExt;
18use futures::Stream;
19#[cfg(not(target_arch = "wasm32"))]
20use ractor::Actor;
21#[cfg(not(target_arch = "wasm32"))]
22use ractor::{ActorProcessingErr, ActorRef};
23use serde_json::Value;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27/// Marker type for actor-based agents.
28///
29/// Actor agents run inside a runtime, can subscribe to topics, receive
30/// messages, and emit protocol `Event`s for streaming updates.
31pub struct ActorAgent {}
32
33impl AgentType for ActorAgent {
34    fn type_name() -> &'static str {
35        "protocol_agent"
36    }
37}
38
39/// Handle for an actor-based agent that contains both the agent and the
40/// address of its actor. Use `addr()` to send messages directly or publish
41/// `Task`s to subscribed `Topic<Task>`.
42#[cfg(not(target_arch = "wasm32"))]
43#[derive(Clone)]
44pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
45    pub agent: Arc<BaseAgent<T, ActorAgent>>,
46    pub actor_ref: ActorRef<Task>,
47}
48
49#[cfg(not(target_arch = "wasm32"))]
50impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
51    /// Get the actor reference (`ActorRef<Task>`) for direct messaging.
52    pub fn addr(&self) -> ActorRef<Task> {
53        self.actor_ref.clone()
54    }
55
56    /// Get a clone of the agent reference for querying metadata or invoking
57    /// methods that require `Arc<BaseAgent<..>>`.
58    pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
59        self.agent.clone()
60    }
61}
62
63#[cfg(not(target_arch = "wasm32"))]
64impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("AgentHandle")
67            .field("agent", &self.agent)
68            .finish()
69    }
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73#[derive(Debug)]
74pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
75    pub Arc<BaseAgent<T, ActorAgent>>,
76);
77
78#[cfg(not(target_arch = "wasm32"))]
79impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
80
81#[cfg(not(target_arch = "wasm32"))]
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
83where
84    T: Send + Sync + 'static,
85    serde_json::Value: From<<T as AgentExecutor>::Output>,
86    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87    <T as AgentExecutor>::Error: Into<RunnableAgentError>,
88{
89    /// Build the BaseAgent and return a wrapper that includes the actor reference
90    pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
91        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
92            "LLM provider is required".to_string(),
93        ))?;
94        let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
95            "Runtime should be defined".into(),
96        ))?;
97        let tx = runtime.tx();
98
99        let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
100            BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
101        );
102
103        // Create agent actor
104        let agent_actor = AgentActor(agent.clone());
105        let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
106            .await
107            .map_err(AgentBuildError::SpawnError)?
108            .0;
109
110        // Subscribe to topics
111        for topic in self.subscribed_topics {
112            runtime.subscribe(&topic, actor_ref.clone()).await?;
113        }
114
115        Ok(ActorAgentHandle { agent, actor_ref })
116    }
117
118    pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
119        self.subscribed_topics.push(topic);
120        self
121    }
122}
123
124#[cfg(not(target_arch = "wasm32"))]
125impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
126    pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
127        self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
128    }
129
130    pub async fn run(
131        self: Arc<Self>,
132        task: Task,
133    ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
134    where
135        Value: From<<T as AgentExecutor>::Output>,
136        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
137        <T as AgentExecutor>::Error: Into<RunnableAgentError>,
138    {
139        let submission_id = task.submission_id;
140        let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
141        let tx_event = Some(tx.clone());
142
143        let context = self.create_context();
144
145        //Run Hook
146        let hook_outcome = self.inner.on_run_start(&task, &context).await;
147        match hook_outcome {
148            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
149            HookOutcome::Continue => {}
150        }
151
152        // Execute the agent's logic using the executor
153        match self.inner().execute(&task, context.clone()).await {
154            Ok(output) => {
155                let value: Value = output.clone().into();
156                #[cfg(not(target_arch = "wasm32"))]
157                EventHelper::send_task_completed_value(
158                    &tx_event,
159                    submission_id,
160                    self.id,
161                    self.name().to_string(),
162                    &value,
163                )
164                .await
165                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
166
167                //Extract Agent output into the desired type
168                let agent_out: <T as AgentDeriveT>::Output = output.into();
169
170                //Run On complete Hook
171                self.inner
172                    .on_run_complete(&task, &agent_out, &context)
173                    .await;
174
175                Ok(agent_out)
176            }
177            Err(e) => {
178                #[cfg(not(target_arch = "wasm32"))]
179                EventHelper::send_task_error(&tx_event, submission_id, self.id, e.to_string())
180                    .await;
181                Err(e.into())
182            }
183        }
184    }
185
186    pub async fn run_stream(
187        self: Arc<Self>,
188        task: Task,
189    ) -> Result<
190        std::pin::Pin<
191            Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
192        >,
193        RunnableAgentError,
194    >
195    where
196        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
197        <T as AgentExecutor>::Error: Into<RunnableAgentError>,
198    {
199        // let submission_id = task.submission_id;
200        let context = self.create_context();
201
202        // Execute the agent's streaming logic using the executor
203        match self.inner().execute_stream(&task, context).await {
204            Ok(stream) => {
205                use futures::StreamExt;
206                // Transform the stream to convert agent output to TaskResult
207                let transformed_stream = stream.map(move |result| {
208                    match result {
209                        Ok(output) => Ok(output.into()),
210                        Err(e) => {
211                            // Handle error
212                            Err(e.into())
213                        }
214                    }
215                });
216
217                Ok(Box::pin(transformed_stream))
218            }
219            Err(e) => {
220                // Send error event for stream creation failure
221                Err(e.into())
222            }
223        }
224    }
225}
226
227#[cfg(not(target_arch = "wasm32"))]
228#[async_trait]
229impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
230where
231    T: Send + Sync + 'static,
232    serde_json::Value: From<<T as AgentExecutor>::Output>,
233    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
234    <T as AgentExecutor>::Error: Into<RunnableAgentError>,
235{
236    type Msg = Task;
237    type State = AgentState;
238    type Arguments = ();
239
240    async fn pre_start(
241        &self,
242        _myself: ActorRef<Self::Msg>,
243        _args: Self::Arguments,
244    ) -> Result<Self::State, ActorProcessingErr> {
245        Ok(AgentState::new())
246    }
247
248    async fn post_stop(
249        &self,
250        _myself: ActorRef<Self::Msg>,
251        _state: &mut Self::State,
252    ) -> Result<(), ActorProcessingErr> {
253        //Run Hook
254        self.0.inner().on_agent_shutdown().await;
255        Ok(())
256    }
257
258    async fn handle(
259        &self,
260        _myself: ActorRef<Self::Msg>,
261        message: Self::Msg,
262        _state: &mut Self::State,
263    ) -> Result<(), ActorProcessingErr> {
264        let agent = self.0.clone();
265        let task = message;
266
267        //Run agent
268        if agent.stream() {
269            let _ = agent.run_stream(task).await?;
270            Ok(())
271        } else {
272            let _ = agent.run(task).await?;
273            Ok(())
274        }
275    }
276}
277
278#[cfg(test)]
279#[cfg(not(target_arch = "wasm32"))]
280mod tests {
281    use super::*;
282    use crate::actor::{LocalTransport, Topic, Transport};
283    use crate::runtime::{Runtime, RuntimeError};
284    use crate::tests::{MockAgentImpl, MockLLMProvider};
285    use crate::utils::BoxEventStream;
286    use async_trait::async_trait;
287    use futures::stream;
288    use std::any::{Any, TypeId};
289    use std::sync::Arc;
290    use tokio::sync::{Mutex, mpsc};
291
292    #[derive(Debug)]
293    struct TestRuntime {
294        subscribed: Arc<Mutex<Vec<(String, TypeId)>>>,
295        tx: mpsc::Sender<Event>,
296    }
297
298    impl TestRuntime {
299        fn new() -> Self {
300            let (tx, _rx) = mpsc::channel(4);
301            Self {
302                subscribed: Arc::new(Mutex::new(Vec::new())),
303                tx,
304            }
305        }
306    }
307
308    #[async_trait]
309    impl Runtime for TestRuntime {
310        fn id(&self) -> autoagents_protocol::RuntimeID {
311            autoagents_protocol::RuntimeID::new_v4()
312        }
313
314        async fn subscribe_any(
315            &self,
316            topic_name: &str,
317            topic_type: TypeId,
318            _actor: Arc<dyn crate::actor::AnyActor>,
319        ) -> Result<(), RuntimeError> {
320            let mut subscribed = self.subscribed.lock().await;
321            subscribed.push((topic_name.to_string(), topic_type));
322            Ok(())
323        }
324
325        async fn publish_any(
326            &self,
327            _topic_name: &str,
328            _topic_type: TypeId,
329            _message: Arc<dyn Any + Send + Sync>,
330        ) -> Result<(), RuntimeError> {
331            Ok(())
332        }
333
334        fn tx(&self) -> mpsc::Sender<Event> {
335            self.tx.clone()
336        }
337
338        async fn transport(&self) -> Arc<dyn Transport> {
339            Arc::new(LocalTransport)
340        }
341
342        async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
343            None
344        }
345
346        async fn subscribe_events(&self) -> BoxEventStream<Event> {
347            Box::pin(stream::empty())
348        }
349
350        async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
351            Ok(())
352        }
353
354        async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
355            Ok(())
356        }
357    }
358
359    #[tokio::test]
360    async fn test_actor_builder_requires_llm() {
361        let mock = MockAgentImpl::new("agent", "desc");
362        let runtime = Arc::new(TestRuntime::new());
363        let err = AgentBuilder::<_, ActorAgent>::new(mock)
364            .runtime(runtime)
365            .build()
366            .await
367            .unwrap_err();
368        assert!(matches!(err, Error::AgentBuildError(_)));
369    }
370
371    #[tokio::test]
372    async fn test_actor_builder_requires_runtime() {
373        let mock = MockAgentImpl::new("agent", "desc");
374        let llm = Arc::new(MockLLMProvider);
375        let err = AgentBuilder::<_, ActorAgent>::new(mock)
376            .llm(llm)
377            .build()
378            .await
379            .unwrap_err();
380        assert!(matches!(err, Error::AgentBuildError(_)));
381    }
382
383    #[tokio::test]
384    async fn test_actor_builder_subscribes_topics() {
385        let mock = MockAgentImpl::new("agent", "desc");
386        let llm = Arc::new(MockLLMProvider);
387        let runtime = Arc::new(TestRuntime::new());
388        let topic = Topic::<Task>::new("jobs");
389
390        let _handle = AgentBuilder::<_, ActorAgent>::new(mock)
391            .llm(llm)
392            .runtime(runtime.clone())
393            .subscribe(topic)
394            .build()
395            .await
396            .expect("build should succeed");
397
398        let subscribed = runtime.subscribed.lock().await;
399        assert_eq!(subscribed.len(), 1);
400        assert_eq!(subscribed[0].0, "jobs");
401    }
402
403    #[tokio::test]
404    async fn test_actor_agent_tx_missing_returns_error() {
405        let mock = MockAgentImpl::new("agent", "desc");
406        let llm = Arc::new(MockLLMProvider);
407        let (tx, _rx) = mpsc::channel(2);
408        let mut agent = BaseAgent::<_, ActorAgent>::new(mock, llm, None, tx, false)
409            .await
410            .unwrap();
411        agent.tx = None;
412        let err = agent.tx().unwrap_err();
413        assert!(matches!(err, RunnableAgentError::EmptyTx));
414    }
415}