1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::executor::event_helper::EventHelper;
6use crate::agent::hooks::AgentHooks;
7use crate::agent::state::AgentState;
8use crate::agent::task::Task;
9use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
10use crate::channel::Sender;
11use crate::error::Error;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15use autoagents_protocol::Event;
16#[cfg(target_arch = "wasm32")]
17use futures::SinkExt;
18use futures::Stream;
19#[cfg(not(target_arch = "wasm32"))]
20use ractor::Actor;
21#[cfg(not(target_arch = "wasm32"))]
22use ractor::{ActorProcessingErr, ActorRef};
23use serde_json::Value;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27pub struct ActorAgent {}
32
33impl AgentType for ActorAgent {
34 fn type_name() -> &'static str {
35 "protocol_agent"
36 }
37}
38
39#[cfg(not(target_arch = "wasm32"))]
43#[derive(Clone)]
44pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
45 pub agent: Arc<BaseAgent<T, ActorAgent>>,
46 pub actor_ref: ActorRef<Task>,
47}
48
49#[cfg(not(target_arch = "wasm32"))]
50impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
51 pub fn addr(&self) -> ActorRef<Task> {
53 self.actor_ref.clone()
54 }
55
56 pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
59 self.agent.clone()
60 }
61}
62
63#[cfg(not(target_arch = "wasm32"))]
64impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("AgentHandle")
67 .field("agent", &self.agent)
68 .finish()
69 }
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73#[derive(Debug)]
74pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
75 pub Arc<BaseAgent<T, ActorAgent>>,
76);
77
78#[cfg(not(target_arch = "wasm32"))]
79impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
80
81#[cfg(not(target_arch = "wasm32"))]
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
83where
84 T: Send + Sync + 'static,
85 serde_json::Value: From<<T as AgentExecutor>::Output>,
86 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87{
88 pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
90 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
91 "LLM provider is required".to_string(),
92 ))?;
93 let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
94 "Runtime should be defined".into(),
95 ))?;
96 let tx = runtime.tx();
97
98 let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
99 BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
100 );
101
102 let agent_actor = AgentActor(agent.clone());
104 let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
105 .await
106 .map_err(AgentBuildError::SpawnError)?
107 .0;
108
109 for topic in self.subscribed_topics {
111 runtime.subscribe(&topic, actor_ref.clone()).await?;
112 }
113
114 Ok(ActorAgentHandle { agent, actor_ref })
115 }
116
117 pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
118 self.subscribed_topics.push(topic);
119 self
120 }
121}
122
123#[cfg(not(target_arch = "wasm32"))]
124impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
125 pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
126 self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
127 }
128
129 pub async fn run(
130 self: Arc<Self>,
131 task: Task,
132 ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
133 where
134 Value: From<<T as AgentExecutor>::Output>,
135 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
136 {
137 let submission_id = task.submission_id;
138 let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
139 let tx_event = Some(tx.clone());
140
141 let context = self.create_context();
142
143 let hook_outcome = self.inner.on_run_start(&task, &context).await;
145 match hook_outcome {
146 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
147 HookOutcome::Continue => {}
148 }
149
150 match self.inner().execute(&task, context.clone()).await {
152 Ok(output) => {
153 let value: Value = output.clone().into();
154 #[cfg(not(target_arch = "wasm32"))]
155 EventHelper::send_task_completed_value(
156 &tx_event,
157 submission_id,
158 self.id,
159 self.name().to_string(),
160 &value,
161 )
162 .await
163 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
164
165 let agent_out: <T as AgentDeriveT>::Output = output.into();
167
168 self.inner
170 .on_run_complete(&task, &agent_out, &context)
171 .await;
172
173 Ok(agent_out)
174 }
175 Err(e) => {
176 #[cfg(not(target_arch = "wasm32"))]
177 EventHelper::send_task_error(&tx_event, submission_id, self.id, e.to_string())
178 .await;
179 Err(RunnableAgentError::ExecutorError(e.to_string()))
180 }
181 }
182 }
183
184 pub async fn run_stream(
185 self: Arc<Self>,
186 task: Task,
187 ) -> Result<
188 std::pin::Pin<
189 Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
190 >,
191 RunnableAgentError,
192 >
193 where
194 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
195 {
196 let context = self.create_context();
198
199 match self.inner().execute_stream(&task, context).await {
201 Ok(stream) => {
202 use futures::StreamExt;
203 let transformed_stream = stream.map(move |result| {
205 match result {
206 Ok(output) => Ok(output.into()),
207 Err(e) => {
208 let error_msg = e.to_string();
210 Err(RunnableAgentError::ExecutorError(error_msg))
211 }
212 }
213 });
214
215 Ok(Box::pin(transformed_stream))
216 }
217 Err(e) => {
218 Err(RunnableAgentError::ExecutorError(e.to_string()))
220 }
221 }
222 }
223}
224
225#[cfg(not(target_arch = "wasm32"))]
226#[async_trait]
227impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
228where
229 T: Send + Sync + 'static,
230 serde_json::Value: From<<T as AgentExecutor>::Output>,
231 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
232{
233 type Msg = Task;
234 type State = AgentState;
235 type Arguments = ();
236
237 async fn pre_start(
238 &self,
239 _myself: ActorRef<Self::Msg>,
240 _args: Self::Arguments,
241 ) -> Result<Self::State, ActorProcessingErr> {
242 Ok(AgentState::new())
243 }
244
245 async fn post_stop(
246 &self,
247 _myself: ActorRef<Self::Msg>,
248 _state: &mut Self::State,
249 ) -> Result<(), ActorProcessingErr> {
250 self.0.inner().on_agent_shutdown().await;
252 Ok(())
253 }
254
255 async fn handle(
256 &self,
257 _myself: ActorRef<Self::Msg>,
258 message: Self::Msg,
259 _state: &mut Self::State,
260 ) -> Result<(), ActorProcessingErr> {
261 let agent = self.0.clone();
262 let task = message;
263
264 if agent.stream() {
266 let _ = agent.run_stream(task).await?;
267 Ok(())
268 } else {
269 let _ = agent.run(task).await?;
270 Ok(())
271 }
272 }
273}
274
275#[cfg(test)]
276#[cfg(not(target_arch = "wasm32"))]
277mod tests {
278 use super::*;
279 use crate::actor::{LocalTransport, Topic, Transport};
280 use crate::runtime::{Runtime, RuntimeError};
281 use crate::tests::{MockAgentImpl, MockLLMProvider};
282 use crate::utils::BoxEventStream;
283 use async_trait::async_trait;
284 use futures::stream;
285 use std::any::{Any, TypeId};
286 use std::sync::Arc;
287 use tokio::sync::{Mutex, mpsc};
288
289 #[derive(Debug)]
290 struct TestRuntime {
291 subscribed: Arc<Mutex<Vec<(String, TypeId)>>>,
292 tx: mpsc::Sender<Event>,
293 }
294
295 impl TestRuntime {
296 fn new() -> Self {
297 let (tx, _rx) = mpsc::channel(4);
298 Self {
299 subscribed: Arc::new(Mutex::new(Vec::new())),
300 tx,
301 }
302 }
303 }
304
305 #[async_trait]
306 impl Runtime for TestRuntime {
307 fn id(&self) -> autoagents_protocol::RuntimeID {
308 autoagents_protocol::RuntimeID::new_v4()
309 }
310
311 async fn subscribe_any(
312 &self,
313 topic_name: &str,
314 topic_type: TypeId,
315 _actor: Arc<dyn crate::actor::AnyActor>,
316 ) -> Result<(), RuntimeError> {
317 let mut subscribed = self.subscribed.lock().await;
318 subscribed.push((topic_name.to_string(), topic_type));
319 Ok(())
320 }
321
322 async fn publish_any(
323 &self,
324 _topic_name: &str,
325 _topic_type: TypeId,
326 _message: Arc<dyn Any + Send + Sync>,
327 ) -> Result<(), RuntimeError> {
328 Ok(())
329 }
330
331 fn tx(&self) -> mpsc::Sender<Event> {
332 self.tx.clone()
333 }
334
335 async fn transport(&self) -> Arc<dyn Transport> {
336 Arc::new(LocalTransport)
337 }
338
339 async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
340 None
341 }
342
343 async fn subscribe_events(&self) -> BoxEventStream<Event> {
344 Box::pin(stream::empty())
345 }
346
347 async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
348 Ok(())
349 }
350
351 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
352 Ok(())
353 }
354 }
355
356 #[tokio::test]
357 async fn test_actor_builder_requires_llm() {
358 let mock = MockAgentImpl::new("agent", "desc");
359 let runtime = Arc::new(TestRuntime::new());
360 let err = AgentBuilder::<_, ActorAgent>::new(mock)
361 .runtime(runtime)
362 .build()
363 .await
364 .unwrap_err();
365 assert!(matches!(err, Error::AgentBuildError(_)));
366 }
367
368 #[tokio::test]
369 async fn test_actor_builder_requires_runtime() {
370 let mock = MockAgentImpl::new("agent", "desc");
371 let llm = Arc::new(MockLLMProvider);
372 let err = AgentBuilder::<_, ActorAgent>::new(mock)
373 .llm(llm)
374 .build()
375 .await
376 .unwrap_err();
377 assert!(matches!(err, Error::AgentBuildError(_)));
378 }
379
380 #[tokio::test]
381 async fn test_actor_builder_subscribes_topics() {
382 let mock = MockAgentImpl::new("agent", "desc");
383 let llm = Arc::new(MockLLMProvider);
384 let runtime = Arc::new(TestRuntime::new());
385 let topic = Topic::<Task>::new("jobs");
386
387 let _handle = AgentBuilder::<_, ActorAgent>::new(mock)
388 .llm(llm)
389 .runtime(runtime.clone())
390 .subscribe(topic)
391 .build()
392 .await
393 .expect("build should succeed");
394
395 let subscribed = runtime.subscribed.lock().await;
396 assert_eq!(subscribed.len(), 1);
397 assert_eq!(subscribed[0].0, "jobs");
398 }
399
400 #[tokio::test]
401 async fn test_actor_agent_tx_missing_returns_error() {
402 let mock = MockAgentImpl::new("agent", "desc");
403 let llm = Arc::new(MockLLMProvider);
404 let (tx, _rx) = mpsc::channel(2);
405 let mut agent = BaseAgent::<_, ActorAgent>::new(mock, llm, None, tx, false)
406 .await
407 .unwrap();
408 agent.tx = None;
409 let err = agent.tx().unwrap_err();
410 assert!(matches!(err, RunnableAgentError::EmptyTx));
411 }
412}