autoagents_core/agent/
actor.rs1#[cfg(not(target_arch = "wasm32"))]
2use crate::actor::Topic;
3use crate::agent::base::AgentType;
4use crate::agent::error::{AgentBuildError, RunnableAgentError};
5use crate::agent::hooks::AgentHooks;
6use crate::agent::state::AgentState;
7use crate::agent::task::Task;
8use crate::agent::{AgentBuilder, AgentDeriveT, AgentExecutor, BaseAgent, HookOutcome};
9use crate::channel::Sender;
10use crate::error::Error;
11use crate::protocol::Event;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::runtime::TypedRuntime;
14use async_trait::async_trait;
15#[cfg(target_arch = "wasm32")]
16use futures::SinkExt;
17use futures::Stream;
18#[cfg(not(target_arch = "wasm32"))]
19use ractor::Actor;
20#[cfg(not(target_arch = "wasm32"))]
21use ractor::{ActorProcessingErr, ActorRef};
22use serde_json::Value;
23use std::fmt::Debug;
24use std::sync::Arc;
25
26pub struct ActorAgent {}
27
28impl AgentType for ActorAgent {
29 fn type_name() -> &'static str {
30 "protocol_agent"
31 }
32}
33
34#[cfg(not(target_arch = "wasm32"))]
36#[derive(Clone)]
37pub struct ActorAgentHandle<T: AgentDeriveT + AgentExecutor + AgentHooks + Send + Sync> {
38 pub agent: Arc<BaseAgent<T, ActorAgent>>,
39 pub actor_ref: ActorRef<Task>,
40}
41
42#[cfg(not(target_arch = "wasm32"))]
43impl<T: AgentDeriveT + AgentExecutor + AgentHooks> ActorAgentHandle<T> {
44 pub fn addr(&self) -> ActorRef<Task> {
46 self.actor_ref.clone()
47 }
48
49 pub fn agent(&self) -> Arc<BaseAgent<T, ActorAgent>> {
51 self.agent.clone()
52 }
53}
54
55#[cfg(not(target_arch = "wasm32"))]
56impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Debug for ActorAgentHandle<T> {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("AgentHandle")
59 .field("agent", &self.agent)
60 .finish()
61 }
62}
63
64#[cfg(not(target_arch = "wasm32"))]
65#[derive(Debug)]
66pub struct AgentActor<T: AgentDeriveT + AgentExecutor + AgentHooks>(
67 pub Arc<BaseAgent<T, ActorAgent>>,
68);
69
70#[cfg(not(target_arch = "wasm32"))]
71impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentActor<T> {}
72
73#[cfg(not(target_arch = "wasm32"))]
74impl<T: AgentDeriveT + AgentExecutor + AgentHooks> AgentBuilder<T, ActorAgent>
75where
76 T: Send + Sync + 'static,
77 serde_json::Value: From<<T as AgentExecutor>::Output>,
78 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
79{
80 pub async fn build(self) -> Result<ActorAgentHandle<T>, Error> {
82 let llm = self.llm.ok_or(AgentBuildError::BuildFailure(
83 "LLM provider is required".to_string(),
84 ))?;
85 let runtime = self.runtime.ok_or(AgentBuildError::BuildFailure(
86 "Runtime should be defined".into(),
87 ))?;
88 let tx = runtime.tx();
89
90 let agent: Arc<BaseAgent<T, ActorAgent>> = Arc::new(
91 BaseAgent::<T, ActorAgent>::new(self.inner, llm, self.memory, tx, self.stream).await?,
92 );
93
94 let agent_actor = AgentActor(agent.clone());
96 let actor_ref = Actor::spawn(Some(agent_actor.0.name().into()), agent_actor, ())
97 .await
98 .map_err(AgentBuildError::SpawnError)?
99 .0;
100
101 for topic in self.subscribed_topics {
103 runtime.subscribe(&topic, actor_ref.clone()).await?;
104 }
105
106 Ok(ActorAgentHandle { agent, actor_ref })
107 }
108
109 pub fn subscribe(mut self, topic: Topic<Task>) -> Self {
110 self.subscribed_topics.push(topic);
111 self
112 }
113}
114
115#[cfg(not(target_arch = "wasm32"))]
116impl<T: AgentDeriveT + AgentExecutor + AgentHooks> BaseAgent<T, ActorAgent> {
117 pub fn tx(&self) -> Result<Sender<Event>, RunnableAgentError> {
118 self.tx.clone().ok_or(RunnableAgentError::EmptyTx)
119 }
120
121 pub async fn run(
122 self: Arc<Self>,
123 task: Task,
124 ) -> Result<<T as AgentDeriveT>::Output, RunnableAgentError>
125 where
126 Value: From<<T as AgentExecutor>::Output>,
127 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
128 {
129 let submission_id = task.submission_id;
130 let tx = self.tx().map_err(|_| RunnableAgentError::EmptyTx)?;
131
132 let context = self.create_context();
133
134 let hook_outcome = self.inner.on_run_start(&task, &context).await;
136 match hook_outcome {
137 HookOutcome::Abort => return Err(RunnableAgentError::Abort),
138 HookOutcome::Continue => {}
139 }
140
141 match self.inner().execute(&task, context.clone()).await {
143 Ok(output) => {
144 let value: Value = output.clone().into();
145
146 #[cfg(not(target_arch = "wasm32"))]
147 tx.send(Event::TaskComplete {
148 sub_id: submission_id,
149 actor_id: self.id,
150 actor_name: self.name().to_string(),
151 result: serde_json::to_string_pretty(&value)
152 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?,
153 })
154 .await
155 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
156
157 let agent_out: <T as AgentDeriveT>::Output = output.into();
159
160 self.inner
162 .on_run_complete(&task, &agent_out, &context)
163 .await;
164
165 Ok(agent_out)
166 }
167 Err(e) => {
168 #[cfg(not(target_arch = "wasm32"))]
169 tx.send(Event::TaskError {
170 sub_id: submission_id,
171 actor_id: self.id,
172 error: e.to_string(),
173 })
174 .await
175 .map_err(|e| RunnableAgentError::ExecutorError(e.to_string()))?;
176 Err(RunnableAgentError::ExecutorError(e.to_string()))
177 }
178 }
179 }
180
181 pub async fn run_stream(
182 self: Arc<Self>,
183 task: Task,
184 ) -> Result<
185 std::pin::Pin<
186 Box<dyn Stream<Item = Result<<T as AgentDeriveT>::Output, RunnableAgentError>> + Send>,
187 >,
188 RunnableAgentError,
189 >
190 where
191 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
192 {
193 let context = self.create_context();
195
196 match self.inner().execute_stream(&task, context).await {
198 Ok(stream) => {
199 use futures::StreamExt;
200 let transformed_stream = stream.map(move |result| {
202 match result {
203 Ok(output) => Ok(output.into()),
204 Err(e) => {
205 let error_msg = e.to_string();
207 Err(RunnableAgentError::ExecutorError(error_msg))
208 }
209 }
210 });
211
212 Ok(Box::pin(transformed_stream))
213 }
214 Err(e) => {
215 Err(RunnableAgentError::ExecutorError(e.to_string()))
217 }
218 }
219 }
220}
221
222#[cfg(not(target_arch = "wasm32"))]
223#[async_trait]
224impl<T: AgentDeriveT + AgentExecutor + AgentHooks> Actor for AgentActor<T>
225where
226 T: Send + Sync + 'static,
227 serde_json::Value: From<<T as AgentExecutor>::Output>,
228 <T as AgentDeriveT>::Output: From<<T as AgentExecutor>::Output>,
229{
230 type Msg = Task;
231 type State = AgentState;
232 type Arguments = ();
233
234 async fn pre_start(
235 &self,
236 _myself: ActorRef<Self::Msg>,
237 _args: Self::Arguments,
238 ) -> Result<Self::State, ActorProcessingErr> {
239 Ok(AgentState::new())
240 }
241
242 async fn post_stop(
243 &self,
244 _myself: ActorRef<Self::Msg>,
245 _state: &mut Self::State,
246 ) -> Result<(), ActorProcessingErr> {
247 self.0.inner().on_agent_shutdown().await;
249 Ok(())
250 }
251
252 async fn handle(
253 &self,
254 _myself: ActorRef<Self::Msg>,
255 message: Self::Msg,
256 _state: &mut Self::State,
257 ) -> Result<(), ActorProcessingErr> {
258 let agent = self.0.clone();
259 let task = message;
260
261 if agent.stream() {
263 let _ = agent.run_stream(task).await?;
264 Ok(())
265 } else {
266 let _ = agent.run(task).await?;
267 Ok(())
268 }
269 }
270}