1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::executor::event_helper::EventHelper;
6use crate::agent::hooks::AgentHooks;
7use crate::agent::state::AgentState;
8use crate::agent::task::Task;
9use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
10use crate::channel::Sender;
11use crate::error::Error;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15use autoagents_protocol::Event;
16#[cfg(target_arch = "wasm32")]
17use futures::SinkExt;
18use futures::Stream;
19#[cfg(not(target_arch = "wasm32"))]
20use ractor::Actor;
21#[cfg(not(target_arch = "wasm32"))]
22use ractor::{ActorProcessingErr, ActorRef};
23use serde_json::Value;
24use std::fmt::Debug;
25use std::sync::Arc;
26
27pub struct ActorAgent {}
32
33impl AgentType for ActorAgent {
34 fn type_name() -> &'static str {
35 "protocol_agent"
36 }
37}
38
39#[cfg(not(target_arch = "wasm32"))]
43#[derive(Clone)]
44pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
45 pub agent: Arc<BaseAgent<T, ActorAgent>>,
46 pub actor_ref: ActorRef<Task>,
47}
48
49#[cfg(not(target_arch = "wasm32"))]
50impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
51 pub fn addr(&self) -> ActorRef<Task> {
53 self.actor_ref.clone()
54 }
55
56 pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
59 self.agent.clone()
60 }
61}
62
63#[cfg(not(target_arch = "wasm32"))]
64impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("AgentHandle")
67 .field("agent", &self.agent)
68 .finish()
69 }
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73#[derive(Debug)]
74pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
75 pub Arc<BaseAgent<T, ActorAgent>>,
76);
77
78#[cfg(not(target_arch = "wasm32"))]
79impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
80
81#[cfg(not(target_arch = "wasm32"))]
82impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
83where
84 T: Send + Sync + 'static,
85 serde_json::Value: From<<T as AgentExecutor>::Output>,
86 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
87 <T as AgentExecutor>::Error: Into<RunnableAgentError>,
88{
89 pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
91 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
92 "LLM provider is required".to_string(),
93 ))?;
94 let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
95 "Runtime should be defined".into(),
96 ))?;
97 let tx = runtime.tx();
98
99 let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
100 BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
101 );
102
103 let agent_actor = AgentActor(agent.clone());
105 let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
106 .await
107 .map_err(AgentBuildError::SpawnError)?
108 .0;
109
110 for topic in self.subscribed_topics {
112 runtime.subscribe(&topic, actor_ref.clone()).await?;
113 }
114
115 Ok(ActorAgentHandle { agent, actor_ref })
116 }
117
118 pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
119 self.subscribed_topics.push(topic);
120 self
121 }
122}
123
124#[cfg(not(target_arch = "wasm32"))]
125impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
126 pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
127 self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
128 }
129
130 pub async fn run(
131 self: Arc<Self>,
132 task: Task,
133 ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
134 where
135 Value: From<<T as AgentExecutor>::Output>,
136 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
137 <T as AgentExecutor>::Error: Into<RunnableAgentError>,
138 {
139 let submission_id = task.submission_id;
140 let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
141 let tx_event = Some(tx.clone());
142
143 let context = self.create_context();
144
145 let hook_outcome = self.inner.on_run_start(&task, &context).await;
147 match hook_outcome {
148 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
149 HookOutcome::Continue => {}
150 }
151
152 match self.inner().execute(&task, context.clone()).await {
154 Ok(output) => {
155 let value: Value = output.clone().into();
156 #[cfg(not(target_arch = "wasm32"))]
157 EventHelper::send_task_completed_value(
158 &tx_event,
159 submission_id,
160 self.id,
161 self.name().to_string(),
162 &value,
163 )
164 .await
165 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
166
167 let agent_out: <T as AgentDeriveT>::Output = output.into();
169
170 self.inner
172 .on_run_complete(&task, &agent_out, &context)
173 .await;
174
175 Ok(agent_out)
176 }
177 Err(e) => {
178 #[cfg(not(target_arch = "wasm32"))]
179 EventHelper::send_task_error(&tx_event, submission_id, self.id, e.to_string())
180 .await;
181 Err(e.into())
182 }
183 }
184 }
185
186 pub async fn run_stream(
187 self: Arc<Self>,
188 task: Task,
189 ) -> Result<
190 std::pin::Pin<
191 Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
192 >,
193 RunnableAgentError,
194 >
195 where
196 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
197 <T as AgentExecutor>::Error: Into<RunnableAgentError>,
198 {
199 let context = self.create_context();
201
202 match self.inner().execute_stream(&task, context).await {
204 Ok(stream) => {
205 use futures::StreamExt;
206 let transformed_stream = stream.map(move |result| {
208 match result {
209 Ok(output) => Ok(output.into()),
210 Err(e) => {
211 Err(e.into())
213 }
214 }
215 });
216
217 Ok(Box::pin(transformed_stream))
218 }
219 Err(e) => {
220 Err(e.into())
222 }
223 }
224 }
225}
226
227#[cfg(not(target_arch = "wasm32"))]
228#[async_trait]
229impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
230where
231 T: Send + Sync + 'static,
232 serde_json::Value: From<<T as AgentExecutor>::Output>,
233 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
234 <T as AgentExecutor>::Error: Into<RunnableAgentError>,
235{
236 type Msg = Task;
237 type State = AgentState;
238 type Arguments = ();
239
240 async fn pre_start(
241 &self,
242 _myself: ActorRef<Self::Msg>,
243 _args: Self::Arguments,
244 ) -> Result<Self::State, ActorProcessingErr> {
245 Ok(AgentState::new())
246 }
247
248 async fn post_stop(
249 &self,
250 _myself: ActorRef<Self::Msg>,
251 _state: &mut Self::State,
252 ) -> Result<(), ActorProcessingErr> {
253 self.0.inner().on_agent_shutdown().await;
255 Ok(())
256 }
257
258 async fn handle(
259 &self,
260 _myself: ActorRef<Self::Msg>,
261 message: Self::Msg,
262 _state: &mut Self::State,
263 ) -> Result<(), ActorProcessingErr> {
264 let agent = self.0.clone();
265 let task = message;
266
267 if agent.stream() {
269 let _ = agent.run_stream(task).await?;
270 Ok(())
271 } else {
272 let _ = agent.run(task).await?;
273 Ok(())
274 }
275 }
276}
277
278#[cfg(test)]
279#[cfg(not(target_arch = "wasm32"))]
280mod tests {
281 use super::*;
282 use crate::actor::{LocalTransport, Topic, Transport};
283 use crate::runtime::{Runtime, RuntimeError};
284 use crate::tests::{MockAgentImpl, MockLLMProvider};
285 use crate::utils::BoxEventStream;
286 use async_trait::async_trait;
287 use futures::stream;
288 use std::any::{Any, TypeId};
289 use std::sync::Arc;
290 use tokio::sync::{Mutex, mpsc};
291
292 #[derive(Debug)]
293 struct TestRuntime {
294 subscribed: Arc<Mutex<Vec<(String, TypeId)>>>,
295 tx: mpsc::Sender<Event>,
296 }
297
298 impl TestRuntime {
299 fn new() -> Self {
300 let (tx, _rx) = mpsc::channel(4);
301 Self {
302 subscribed: Arc::new(Mutex::new(Vec::new())),
303 tx,
304 }
305 }
306 }
307
308 #[async_trait]
309 impl Runtime for TestRuntime {
310 fn id(&self) -> autoagents_protocol::RuntimeID {
311 autoagents_protocol::RuntimeID::new_v4()
312 }
313
314 async fn subscribe_any(
315 &self,
316 topic_name: &str,
317 topic_type: TypeId,
318 _actor: Arc<dyn crate::actor::AnyActor>,
319 ) -> Result<(), RuntimeError> {
320 let mut subscribed = self.subscribed.lock().await;
321 subscribed.push((topic_name.to_string(), topic_type));
322 Ok(())
323 }
324
325 async fn publish_any(
326 &self,
327 _topic_name: &str,
328 _topic_type: TypeId,
329 _message: Arc<dyn Any + Send + Sync>,
330 ) -> Result<(), RuntimeError> {
331 Ok(())
332 }
333
334 fn tx(&self) -> mpsc::Sender<Event> {
335 self.tx.clone()
336 }
337
338 async fn transport(&self) -> Arc<dyn Transport> {
339 Arc::new(LocalTransport)
340 }
341
342 async fn take_event_receiver(&self) -> Option<BoxEventStream<Event>> {
343 None
344 }
345
346 async fn subscribe_events(&self) -> BoxEventStream<Event> {
347 Box::pin(stream::empty())
348 }
349
350 async fn run(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
351 Ok(())
352 }
353
354 async fn stop(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
355 Ok(())
356 }
357 }
358
359 #[tokio::test]
360 async fn test_actor_builder_requires_llm() {
361 let mock = MockAgentImpl::new("agent", "desc");
362 let runtime = Arc::new(TestRuntime::new());
363 let err = AgentBuilder::<_, ActorAgent>::new(mock)
364 .runtime(runtime)
365 .build()
366 .await
367 .unwrap_err();
368 assert!(matches!(err, Error::AgentBuildError(_)));
369 }
370
371 #[tokio::test]
372 async fn test_actor_builder_requires_runtime() {
373 let mock = MockAgentImpl::new("agent", "desc");
374 let llm = Arc::new(MockLLMProvider);
375 let err = AgentBuilder::<_, ActorAgent>::new(mock)
376 .llm(llm)
377 .build()
378 .await
379 .unwrap_err();
380 assert!(matches!(err, Error::AgentBuildError(_)));
381 }
382
383 #[tokio::test]
384 async fn test_actor_builder_subscribes_topics() {
385 let mock = MockAgentImpl::new("agent", "desc");
386 let llm = Arc::new(MockLLMProvider);
387 let runtime = Arc::new(TestRuntime::new());
388 let topic = Topic::<Task>::new("jobs");
389
390 let _handle = AgentBuilder::<_, ActorAgent>::new(mock)
391 .llm(llm)
392 .runtime(runtime.clone())
393 .subscribe(topic)
394 .build()
395 .await
396 .expect("build should succeed");
397
398 let subscribed = runtime.subscribed.lock().await;
399 assert_eq!(subscribed.len(), 1);
400 assert_eq!(subscribed[0].0, "jobs");
401 }
402
403 #[tokio::test]
404 async fn test_actor_agent_tx_missing_returns_error() {
405 let mock = MockAgentImpl::new("agent", "desc");
406 let llm = Arc::new(MockLLMProvider);
407 let (tx, _rx) = mpsc::channel(2);
408 let mut agent = BaseAgent::<_, ActorAgent>::new(mock, llm, None, tx, false)
409 .await
410 .unwrap();
411 agent.tx = None;
412 let err = agent.tx().unwrap_err();
413 assert!(matches!(err, RunnableAgentError::EmptyTx));
414 }
415}