Skip to main content

autoagents_core/agent/
actor.rs

1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::executor::event_helper::EventHelper;
6use crate::agent::hooks::AgentHooks;
7use crate::agent::state::AgentState;
8use crate::agent::task::Task;
9use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
10use crate::channel::Sender;
11use crate::error::Error;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15use autoagents_protocol::Event;
16#[cfg(target_arch = "wasm32")]
17use futures::SinkExt;
18use futures::Stream;
19#[cfg(not(target_arch = "wasm32"))]
20use ractor::Actor;
21#[cfg(not(target_arch = "wasm32"))]
22use ractor::{ActorProcessingErr, ActorRef};
23use serde_json::Value;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27/// Marker type for actor-based agents.
28///
29/// Actor agents run inside a runtime, can subscribe to topics, receive
30/// messages, and emit protocol `Event`s for streaming updates.
31pub struct ActorAgent {}
32
33impl AgentType for ActorAgent {
34    fn type_name() -> &'static str {
35        "protocol_agent"
36    }
37}
38
39/// Handle for an actor-based agent that contains both the agent and the
40/// address of its actor. Use `addr()` to send messages directly or publish
41/// `Task`s to subscribed `Topic<Task>`.
42#[cfg(not(target_arch = "wasm32"))]
43#[derive(Clone)]
44pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
45    pub agent: Arc<BaseAgent<T, ActorAgent>>,
46    pub actor_ref: ActorRef<Task>,
47}
48
49#[cfg(not(target_arch = "wasm32"))]
50impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
51    /// Get the actor reference (`ActorRef<Task>`) for direct messaging.
52    pub fn addr(&self) -> ActorRef<Task> {
53        self.actor_ref.clone()
54    }
55
56    /// Get a clone of the agent reference for querying metadata or invoking
57    /// methods that require `Arc<BaseAgent<..>>`.
58    pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
59        self.agent.clone()
60    }
61}
62
63#[cfg(not(target_arch = "wasm32"))]
64impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("AgentHandle")
67            .field("agent", &self.agent)
68            .finish()
69    }
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73#[derive(Debug)]
74pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
75    pub Arc<BaseAgent<T, ActorAgent>>,
76);
77
78#[cfg(not(target_arch = "wasm32"))]
79impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
80
81#[cfg(not(target_arch = "wasm32"))]
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
83where
84    T: Send + Sync + 'static,
85    serde_json::Value: From<<T as AgentExecutor>::Output>,
86    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87{
88    /// Build the BaseAgent and return a wrapper that includes the actor reference
89    pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
90        let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
91            "LLM provider is required".to_string(),
92        ))?;
93        let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
94            "Runtime should be defined".into(),
95        ))?;
96        let tx = runtime.tx();
97
98        let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
99            BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
100        );
101
102        // Create agent actor
103        let agent_actor = AgentActor(agent.clone());
104        let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
105            .await
106            .map_err(AgentBuildError::SpawnError)?
107            .0;
108
109        // Subscribe to topics
110        for topic in self.subscribed_topics {
111            runtime.subscribe(&topic, actor_ref.clone()).await?;
112        }
113
114        Ok(ActorAgentHandle { agent, actor_ref })
115    }
116
117    pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
118        self.subscribed_topics.push(topic);
119        self
120    }
121}
122
123#[cfg(not(target_arch = "wasm32"))]
124impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
125    pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
126        self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
127    }
128
129    pub async fn run(
130        self: Arc<Self>,
131        task: Task,
132    ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
133    where
134        Value: From<<T as AgentExecutor>::Output>,
135        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
136    {
137        let submission_id = task.submission_id;
138        let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
139        let tx_event = Some(tx.clone());
140
141        let context = self.create_context();
142
143        //Run Hook
144        let hook_outcome = self.inner.on_run_start(&task, &context).await;
145        match hook_outcome {
146            HookOutcome::Abort => return Err(RunnableAgentError::Abort),
147            HookOutcome::Continue => {}
148        }
149
150        // Execute the agent's logic using the executor
151        match self.inner().execute(&task, context.clone()).await {
152            Ok(output) => {
153                let value: Value = output.clone().into();
154                #[cfg(not(target_arch = "wasm32"))]
155                EventHelper::send_task_completed_value(
156                    &tx_event,
157                    submission_id,
158                    self.id,
159                    self.name().to_string(),
160                    &value,
161                )
162                .await
163                .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
164
165                //Extract Agent output into the desired type
166                let agent_out: <T as AgentDeriveT>::Output = output.into();
167
168                //Run On complete Hook
169                self.inner
170                    .on_run_complete(&task, &agent_out, &context)
171                    .await;
172
173                Ok(agent_out)
174            }
175            Err(e) => {
176                #[cfg(not(target_arch = "wasm32"))]
177                EventHelper::send_task_error(&tx_event, submission_id, self.id, e.to_string())
178                    .await;
179                Err(RunnableAgentError::ExecutorError(e.to_string()))
180            }
181        }
182    }
183
184    pub async fn run_stream(
185        self: Arc<Self>,
186        task: Task,
187    ) -> Result<
188        std::pin::Pin<
189            Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
190        >,
191        RunnableAgentError,
192    >
193    where
194        <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
195    {
196        // let submission_id = task.submission_id;
197        let context = self.create_context();
198
199        // Execute the agent's streaming logic using the executor
200        match self.inner().execute_stream(&task, context).await {
201            Ok(stream) => {
202                use futures::StreamExt;
203                // Transform the stream to convert agent output to TaskResult
204                let transformed_stream = stream.map(move |result| {
205                    match result {
206                        Ok(output) => Ok(output.into()),
207                        Err(e) => {
208                            // Handle error
209                            let error_msg = e.to_string();
210                            Err(RunnableAgentError::ExecutorError(error_msg))
211                        }
212                    }
213                });
214
215                Ok(Box::pin(transformed_stream))
216            }
217            Err(e) => {
218                // Send error event for stream creation failure
219                Err(RunnableAgentError::ExecutorError(e.to_string()))
220            }
221        }
222    }
223}
224
225#[cfg(not(target_arch = "wasm32"))]
226#[async_trait]
227impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
228where
229    T: Send + Sync + 'static,
230    serde_json::Value: From<<T as AgentExecutor>::Output>,
231    <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
232{
233    type Msg = Task;
234    type State = AgentState;
235    type Arguments = ();
236
237    async fn pre_start(
238        &self,
239        _myself: ActorRef<Self::Msg>,
240        _args: Self::Arguments,
241    ) -> Result<Self::State, ActorProcessingErr> {
242        Ok(AgentState::new())
243    }
244
245    async fn post_stop(
246        &self,
247        _myself: ActorRef<Self::Msg>,
248        _state: &mut Self::State,
249    ) -> Result<(), ActorProcessingErr> {
250        //Run Hook
251        self.0.inner().on_agent_shutdown().await;
252        Ok(())
253    }
254
255    async fn handle(
256        &self,
257        _myself: ActorRef<Self::Msg>,
258        message: Self::Msg,
259        _state: &mut Self::State,
260    ) -> Result<(), ActorProcessingErr> {
261        let agent = self.0.clone();
262        let task = message;
263
264        //Run agent
265        if agent.stream() {
266            let _ = agent.run_stream(task).await?;
267            Ok(())
268        } else {
269            let _ = agent.run(task).await?;
270            Ok(())
271        }
272    }
273}
274
275#[cfg(test)]
276#[cfg(not(target_arch = "wasm32"))]
277mod tests {
278    use super::*;
279    use crate::actor::{LocalTransport, Topic, Transport};
280    use crate::runtime::{Runtime, RuntimeError};
281    use crate::tests::{MockAgentImpl, MockLLMProvider};
282    use crate::utils::BoxEventStream;
283    use async_trait::async_trait;
284    use futures::stream;
285    use std::any::{Any, TypeId};
286    use std::sync::Arc;
287    use tokio::sync::{Mutex, mpsc};
288
289    #[derive(Debug)]
290    struct TestRuntime {
291        subscribed: Arc<Mutex<Vec<(String, TypeId)>>>,
292        tx: mpsc::Sender<Event>,
293    }
294
295    impl TestRuntime {
296        fn new() -> Self {
297            let (tx, _rx) = mpsc::channel(4);
298            Self {
299                subscribed: Arc::new(Mutex::new(Vec::new())),
300                tx,
301            }
302        }
303    }
304
305    #[async_trait]
306    impl Runtime for TestRuntime {
307        fn id(&self) -> autoagents_protocol::RuntimeID {
308            autoagents_protocol::RuntimeID::new_v4()
309        }
310
311        async fn subscribe_any(
312            &self,
313            topic_name: &str,
314            topic_type: TypeId,
315            _actor: Arc<dyn crate::actor::AnyActor>,
316        ) -> Result<(), RuntimeError> {
317            let mut subscribed = self.subscribed.lock().await;
318            subscribed.push((topic_name.to_string(), topic_type));
319            Ok(())
320        }
321
322        async fn publish_any(
323            &self,
324            _topic_name: &str,
325            _topic_type: TypeId,
326            _message: Arc<dyn Any + Send + Sync>,
327        ) -> Result<(), RuntimeError> {
328            Ok(())
329        }
330
331        fn tx(&self) -> mpsc::Sender<Event> {
332            self.tx.clone()
333        }
334
335        async fn transport(&self) -> Arc<dyn Transport> {
336            Arc::new(LocalTransport)
337        }
338
339        async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
340            None
341        }
342
343        async fn subscribe_events(&self) -> BoxEventStream<Event> {
344            Box::pin(stream::empty())
345        }
346
347        async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
348            Ok(())
349        }
350
351        async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
352            Ok(())
353        }
354    }
355
356    #[tokio::test]
357    async fn test_actor_builder_requires_llm() {
358        let mock = MockAgentImpl::new("agent", "desc");
359        let runtime = Arc::new(TestRuntime::new());
360        let err = AgentBuilder::<_, ActorAgent>::new(mock)
361            .runtime(runtime)
362            .build()
363            .await
364            .unwrap_err();
365        assert!(matches!(err, Error::AgentBuildError(_)));
366    }
367
368    #[tokio::test]
369    async fn test_actor_builder_requires_runtime() {
370        let mock = MockAgentImpl::new("agent", "desc");
371        let llm = Arc::new(MockLLMProvider);
372        let err = AgentBuilder::<_, ActorAgent>::new(mock)
373            .llm(llm)
374            .build()
375            .await
376            .unwrap_err();
377        assert!(matches!(err, Error::AgentBuildError(_)));
378    }
379
380    #[tokio::test]
381    async fn test_actor_builder_subscribes_topics() {
382        let mock = MockAgentImpl::new("agent", "desc");
383        let llm = Arc::new(MockLLMProvider);
384        let runtime = Arc::new(TestRuntime::new());
385        let topic = Topic::<Task>::new("jobs");
386
387        let _handle = AgentBuilder::<_, ActorAgent>::new(mock)
388            .llm(llm)
389            .runtime(runtime.clone())
390            .subscribe(topic)
391            .build()
392            .await
393            .expect("build should succeed");
394
395        let subscribed = runtime.subscribed.lock().await;
396        assert_eq!(subscribed.len(), 1);
397        assert_eq!(subscribed[0].0, "jobs");
398    }
399
400    #[tokio::test]
401    async fn test_actor_agent_tx_missing_returns_error() {
402        let mock = MockAgentImpl::new("agent", "desc");
403        let llm = Arc::new(MockLLMProvider);
404        let (tx, _rx) = mpsc::channel(2);
405        let mut agent = BaseAgent::<_, ActorAgent>::new(mock, llm, None, tx, false)
406            .await
407            .unwrap();
408        agent.tx = None;
409        let err = agent.tx().unwrap_err();
410        assert!(matches!(err, RunnableAgentError::EmptyTx));
411    }
412}