1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::hooks::AgentHooks;
6use crate::agent::state::AgentState;
7use crate::agent::task::Task;
8use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
9use crate::channel::Sender;
10use crate::error::Error;
11#[cfg(not(target_arch = "wasm32"))]
12use crate::runtime::TypedRuntime;
13use async_trait::async_trait;
14use autoagents_protocol::Event;
15#[cfg(target_arch = "wasm32")]
16use futures::SinkExt;
17use futures::Stream;
18#[cfg(not(target_arch = "wasm32"))]
19use ractor::Actor;
20#[cfg(not(target_arch = "wasm32"))]
21use ractor::{ActorProcessingErr, ActorRef};
22use serde_json::Value;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26pub struct ActorAgent {}
31
32impl AgentType for ActorAgent {
33 fn type_name() -> &'static str {
34 "protocol_agent"
35 }
36}
37
38#[cfg(not(target_arch = "wasm32"))]
42#[derive(Clone)]
43pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
44 pub agent: Arc<BaseAgent<T, ActorAgent>>,
45 pub actor_ref: ActorRef<Task>,
46}
47
48#[cfg(not(target_arch = "wasm32"))]
49impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
50 pub fn addr(&self) -> ActorRef<Task> {
52 self.actor_ref.clone()
53 }
54
55 pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
58 self.agent.clone()
59 }
60}
61
62#[cfg(not(target_arch = "wasm32"))]
63impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("AgentHandle")
66 .field("agent", &self.agent)
67 .finish()
68 }
69}
70
71#[cfg(not(target_arch = "wasm32"))]
72#[derive(Debug)]
73pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
74 pub Arc<BaseAgent<T, ActorAgent>>,
75);
76
77#[cfg(not(target_arch = "wasm32"))]
78impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
79
80#[cfg(not(target_arch = "wasm32"))]
81impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
82where
83 T: Send + Sync + 'static,
84 serde_json::Value: From<<T as AgentExecutor>::Output>,
85 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
86{
87 pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
89 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
90 "LLM provider is required".to_string(),
91 ))?;
92 let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
93 "Runtime should be defined".into(),
94 ))?;
95 let tx = runtime.tx();
96
97 let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
98 BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
99 );
100
101 let agent_actor = AgentActor(agent.clone());
103 let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
104 .await
105 .map_err(AgentBuildError::SpawnError)?
106 .0;
107
108 for topic in self.subscribed_topics {
110 runtime.subscribe(&topic, actor_ref.clone()).await?;
111 }
112
113 Ok(ActorAgentHandle { agent, actor_ref })
114 }
115
116 pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
117 self.subscribed_topics.push(topic);
118 self
119 }
120}
121
122#[cfg(not(target_arch = "wasm32"))]
123impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
124 pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
125 self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
126 }
127
128 pub async fn run(
129 self: Arc<Self>,
130 task: Task,
131 ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
132 where
133 Value: From<<T as AgentExecutor>::Output>,
134 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
135 {
136 let submission_id = task.submission_id;
137 let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
138
139 let context = self.create_context();
140
141 let hook_outcome = self.inner.on_run_start(&task, &context).await;
143 match hook_outcome {
144 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
145 HookOutcome::Continue => {}
146 }
147
148 match self.inner().execute(&task, context.clone()).await {
150 Ok(output) => {
151 let value: Value = output.clone().into();
152
153 #[cfg(not(target_arch = "wasm32"))]
154 tx.send(Event::TaskComplete {
155 sub_id: submission_id,
156 actor_id: self.id,
157 actor_name: self.name().to_string(),
158 result: serde_json::to_string_pretty(&value)
159 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?,
160 })
161 .await
162 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
163
164 let agent_out: <T as AgentDeriveT>::Output = output.into();
166
167 self.inner
169 .on_run_complete(&task, &agent_out, &context)
170 .await;
171
172 Ok(agent_out)
173 }
174 Err(e) => {
175 #[cfg(not(target_arch = "wasm32"))]
176 tx.send(Event::TaskError {
177 sub_id: submission_id,
178 actor_id: self.id,
179 error: e.to_string(),
180 })
181 .await
182 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
183 Err(RunnableAgentError::ExecutorError(e.to_string()))
184 }
185 }
186 }
187
188 pub async fn run_stream(
189 self: Arc<Self>,
190 task: Task,
191 ) -> Result<
192 std::pin::Pin<
193 Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
194 >,
195 RunnableAgentError,
196 >
197 where
198 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
199 {
200 let context = self.create_context();
202
203 match self.inner().execute_stream(&task, context).await {
205 Ok(stream) => {
206 use futures::StreamExt;
207 let transformed_stream = stream.map(move |result| {
209 match result {
210 Ok(output) => Ok(output.into()),
211 Err(e) => {
212 let error_msg = e.to_string();
214 Err(RunnableAgentError::ExecutorError(error_msg))
215 }
216 }
217 });
218
219 Ok(Box::pin(transformed_stream))
220 }
221 Err(e) => {
222 Err(RunnableAgentError::ExecutorError(e.to_string()))
224 }
225 }
226 }
227}
228
229#[cfg(not(target_arch = "wasm32"))]
230#[async_trait]
231impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
232where
233 T: Send + Sync + 'static,
234 serde_json::Value: From<<T as AgentExecutor>::Output>,
235 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
236{
237 type Msg = Task;
238 type State = AgentState;
239 type Arguments = ();
240
241 async fn pre_start(
242 &self,
243 _myself: ActorRef<Self::Msg>,
244 _args: Self::Arguments,
245 ) -> Result<Self::State, ActorProcessingErr> {
246 Ok(AgentState::new())
247 }
248
249 async fn post_stop(
250 &self,
251 _myself: ActorRef<Self::Msg>,
252 _state: &mut Self::State,
253 ) -> Result<(), ActorProcessingErr> {
254 self.0.inner().on_agent_shutdown().await;
256 Ok(())
257 }
258
259 async fn handle(
260 &self,
261 _myself: ActorRef<Self::Msg>,
262 message: Self::Msg,
263 _state: &mut Self::State,
264 ) -> Result<(), ActorProcessingErr> {
265 let agent = self.0.clone();
266 let task = message;
267
268 if agent.stream() {
270 let _ = agent.run_stream(task).await?;
271 Ok(())
272 } else {
273 let _ = agent.run(task).await?;
274 Ok(())
275 }
276 }
277}
278
279#[cfg(test)]
280#[cfg(not(target_arch = "wasm32"))]
281mod tests {
282 use super::*;
283 use crate::actor::{LocalTransport, Topic, Transport};
284 use crate::runtime::{Runtime, RuntimeError};
285 use crate::tests::{MockAgentImpl, MockLLMProvider};
286 use crate::utils::BoxEventStream;
287 use async_trait::async_trait;
288 use futures::stream;
289 use std::any::{Any, TypeId};
290 use std::sync::Arc;
291 use tokio::sync::{Mutex, mpsc};
292
293 #[derive(Debug)]
294 struct TestRuntime {
295 subscribed: Arc<Mutex<Vec<(String, TypeId)>>>,
296 tx: mpsc::Sender<Event>,
297 }
298
299 impl TestRuntime {
300 fn new() -> Self {
301 let (tx, _rx) = mpsc::channel(4);
302 Self {
303 subscribed: Arc::new(Mutex::new(Vec::new())),
304 tx,
305 }
306 }
307 }
308
309 #[async_trait]
310 impl Runtime for TestRuntime {
311 fn id(&self) -> autoagents_protocol::RuntimeID {
312 autoagents_protocol::RuntimeID::new_v4()
313 }
314
315 async fn subscribe_any(
316 &self,
317 topic_name: &str,
318 topic_type: TypeId,
319 _actor: Arc<dyn crate::actor::AnyActor>,
320 ) -> Result<(), RuntimeError> {
321 let mut subscribed = self.subscribed.lock().await;
322 subscribed.push((topic_name.to_string(), topic_type));
323 Ok(())
324 }
325
326 async fn publish_any(
327 &self,
328 _topic_name: &str,
329 _topic_type: TypeId,
330 _message: Arc<dyn Any + Send + Sync>,
331 ) -> Result<(), RuntimeError> {
332 Ok(())
333 }
334
335 fn tx(&self) -> mpsc::Sender<Event> {
336 self.tx.clone()
337 }
338
339 async fn transport(&self) -> Arc<dyn Transport> {
340 Arc::new(LocalTransport)
341 }
342
343 async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
344 None
345 }
346
347 async fn subscribe_events(&self) -> BoxEventStream<Event> {
348 Box::pin(stream::empty())
349 }
350
351 async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
352 Ok(())
353 }
354
355 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
356 Ok(())
357 }
358 }
359
360 #[tokio::test]
361 async fn test_actor_builder_requires_llm() {
362 let mock = MockAgentImpl::new("agent", "desc");
363 let runtime = Arc::new(TestRuntime::new());
364 let err = AgentBuilder::<_, ActorAgent>::new(mock)
365 .runtime(runtime)
366 .build()
367 .await
368 .unwrap_err();
369 assert!(matches!(err, Error::AgentBuildError(_)));
370 }
371
372 #[tokio::test]
373 async fn test_actor_builder_requires_runtime() {
374 let mock = MockAgentImpl::new("agent", "desc");
375 let llm = Arc::new(MockLLMProvider);
376 let err = AgentBuilder::<_, ActorAgent>::new(mock)
377 .llm(llm)
378 .build()
379 .await
380 .unwrap_err();
381 assert!(matches!(err, Error::AgentBuildError(_)));
382 }
383
384 #[tokio::test]
385 async fn test_actor_builder_subscribes_topics() {
386 let mock = MockAgentImpl::new("agent", "desc");
387 let llm = Arc::new(MockLLMProvider);
388 let runtime = Arc::new(TestRuntime::new());
389 let topic = Topic::<Task>::new("jobs");
390
391 let _handle = AgentBuilder::<_, ActorAgent>::new(mock)
392 .llm(llm)
393 .runtime(runtime.clone())
394 .subscribe(topic)
395 .build()
396 .await
397 .expect("build should succeed");
398
399 let subscribed = runtime.subscribed.lock().await;
400 assert_eq!(subscribed.len(), 1);
401 assert_eq!(subscribed[0].0, "jobs");
402 }
403
404 #[tokio::test]
405 async fn test_actor_agent_tx_missing_returns_error() {
406 let mock = MockAgentImpl::new("agent", "desc");
407 let llm = Arc::new(MockLLMProvider);
408 let (tx, _rx) = mpsc::channel(2);
409 let mut agent = BaseAgent::<_, ActorAgent>::new(mock, llm, None, tx, false)
410 .await
411 .unwrap();
412 agent.tx = None;
413 let err = agent.tx().unwrap_err();
414 assert!(matches!(err, RunnableAgentError::EmptyTx));
415 }
416}