autoagents_core/agent/prebuilt/executor/
basic.rs1use crate::agent::executor::event_helper::EventHelper;
2use crate::agent::executor::turn_engine::{
3 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, TurnEngineOutput, record_task_state,
4};
5use crate::agent::hooks::HookOutcome;
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BasicAgentOutput {
23 pub response: String,
24 pub done: bool,
25}
26
27impl From<BasicAgentOutput> for Value {
28 fn from(output: BasicAgentOutput) -> Self {
29 serde_json::to_value(output).unwrap_or(Value::Null)
30 }
31}
32impl From<BasicAgentOutput> for String {
33 fn from(output: BasicAgentOutput) -> Self {
34 output.response
35 }
36}
37
38impl BasicAgentOutput {
39 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
42 serde_json::from_str::<T>(&self.response)
43 }
44
45 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
48 where
49 T: for<'de> serde::Deserialize<'de>,
50 F: FnOnce(&str) -> T,
51 {
52 self.try_parse::<T>()
53 .unwrap_or_else(|_| fallback(&self.response))
54 }
55}
56
57#[derive(Debug, thiserror::Error)]
59pub enum BasicExecutorError {
60 #[error("LLM error: {0}")]
61 LLMError(String),
62
63 #[error("Other error: {0}")]
64 Other(String),
65}
66
67impl From<TurnEngineError> for BasicExecutorError {
68 fn from(error: TurnEngineError) -> Self {
69 match error {
70 TurnEngineError::LLMError(err) => BasicExecutorError::LLMError(err),
71 TurnEngineError::Aborted => {
72 BasicExecutorError::Other("Run aborted by hook".to_string())
73 }
74 TurnEngineError::Other(err) => BasicExecutorError::Other(err),
75 }
76 }
77}
78
79#[derive(Debug)]
84pub struct BasicAgent<T: AgentDeriveT> {
85 inner: Arc<T>,
86}
87
88impl<T: AgentDeriveT> Clone for BasicAgent<T> {
89 fn clone(&self) -> Self {
90 Self {
91 inner: Arc::clone(&self.inner),
92 }
93 }
94}
95
96impl<T: AgentDeriveT> BasicAgent<T> {
97 pub fn new(inner: T) -> Self {
98 Self {
99 inner: Arc::new(inner),
100 }
101 }
102}
103
104impl<T: AgentDeriveT> Deref for BasicAgent<T> {
105 type Target = T;
106
107 fn deref(&self) -> &Self::Target {
108 &self.inner
109 }
110}
111
112#[async_trait]
114impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
115 type Output = <T as AgentDeriveT>::Output;
116
117 fn description(&self) -> &str {
118 self.inner.description()
119 }
120
121 fn output_schema(&self) -> Option<Value> {
122 self.inner.output_schema()
123 }
124
125 fn name(&self) -> &str {
126 self.inner.name()
127 }
128
129 fn tools(&self) -> Vec<Box<dyn ToolT>> {
130 self.inner.tools()
131 }
132}
133
134#[async_trait]
135impl<T> AgentHooks for BasicAgent<T>
136where
137 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
138{
139 async fn on_agent_create(&self) {
140 self.inner.on_agent_create().await
141 }
142
143 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
144 self.inner.on_run_start(task, ctx).await
145 }
146
147 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
148 self.inner.on_run_complete(task, result, ctx).await
149 }
150
151 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
152 self.inner.on_turn_start(turn_index, ctx).await
153 }
154
155 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
156 self.inner.on_turn_complete(turn_index, ctx).await
157 }
158
159 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
160 self.inner.on_tool_call(tool_call, ctx).await
161 }
162
163 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
164 self.inner.on_tool_start(tool_call, ctx).await
165 }
166
167 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
168 self.inner.on_tool_result(tool_call, result, ctx).await
169 }
170
171 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
172 self.inner.on_tool_error(tool_call, err, ctx).await
173 }
174 async fn on_agent_shutdown(&self) {
175 self.inner.on_agent_shutdown().await
176 }
177}
178
179#[async_trait]
181impl<T: AgentDeriveT + AgentHooks> AgentExecutor for BasicAgent<T> {
182 type Output = BasicAgentOutput;
183 type Error = BasicExecutorError;
184
185 fn config(&self) -> ExecutorConfig {
186 ExecutorConfig { max_turns: 1 }
187 }
188
189 async fn execute(
190 &self,
191 task: &Task,
192 context: Arc<Context>,
193 ) -> Result<Self::Output, Self::Error> {
194 if self.on_run_start(task, &context).await == HookOutcome::Abort {
195 return Err(BasicExecutorError::Other("Run aborted by hook".to_string()));
196 }
197
198 record_task_state(&context, task);
199 let tx_event = context.tx().ok();
200 EventHelper::send_task_started(
201 &tx_event,
202 task.submission_id,
203 context.config().id,
204 context.config().name.clone(),
205 task.prompt.clone(),
206 )
207 .await;
208
209 let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
210 let mut turn_state = engine.turn_state(&context);
211 let turn_result = engine
212 .run_turn(
213 self,
214 task,
215 &context,
216 &mut turn_state,
217 0,
218 self.config().max_turns,
219 )
220 .await?;
221
222 let output = extract_turn_output(turn_result);
223
224 EventHelper::send_task_completed(
225 &tx_event,
226 task.submission_id,
227 context.config().id,
228 context.config().name.clone(),
229 output.response.clone(),
230 )
231 .await;
232
233 Ok(BasicAgentOutput {
234 response: output.response,
235 done: true,
236 })
237 }
238
239 async fn execute_stream(
240 &self,
241 task: &Task,
242 context: Arc<Context>,
243 ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
244 {
245 if self.on_run_start(task, &context).await == HookOutcome::Abort {
246 return Err(BasicExecutorError::Other("Run aborted by hook".to_string()));
247 }
248
249 record_task_state(&context, task);
250 let tx_event = context.tx().ok();
251 EventHelper::send_task_started(
252 &tx_event,
253 task.submission_id,
254 context.config().id,
255 context.config().name.clone(),
256 task.prompt.clone(),
257 )
258 .await;
259
260 let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
261 let mut turn_state = engine.turn_state(&context);
262 let context_clone = context.clone();
263 let task = task.clone();
264 let executor = self.clone();
265
266 let (tx, rx) = channel::<Result<BasicAgentOutput, BasicExecutorError>>(100);
267
268 spawn_future(async move {
269 let turn_stream = engine
270 .run_turn_stream(
271 executor,
272 &task,
273 context_clone.clone(),
274 &mut turn_state,
275 0,
276 1,
277 )
278 .await;
279
280 let mut final_response = String::new();
281
282 match turn_stream {
283 Ok(mut stream) => {
284 use futures::StreamExt;
285 while let Some(delta_result) = stream.next().await {
286 match delta_result {
287 Ok(TurnDelta::Text(content)) => {
288 let _ = tx
289 .send(Ok(BasicAgentOutput {
290 response: content,
291 done: false,
292 }))
293 .await;
294 }
295 Ok(TurnDelta::ToolResults(_)) => {}
296 Ok(TurnDelta::Done(result)) => {
297 let output = extract_turn_output(result);
298 final_response = output.response.clone();
299 let _ = tx
300 .send(Ok(BasicAgentOutput {
301 response: output.response,
302 done: true,
303 }))
304 .await;
305 break;
306 }
307 Err(err) => {
308 let _ = tx.send(Err(err.into())).await;
309 return;
310 }
311 }
312 }
313 }
314 Err(err) => {
315 let _ = tx.send(Err(err.into())).await;
316 return;
317 }
318 }
319
320 let tx_event = context_clone.tx().ok();
321 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
322 EventHelper::send_task_completed(
323 &tx_event,
324 task.submission_id,
325 context_clone.config().id,
326 context_clone.config().name.clone(),
327 final_response,
328 )
329 .await;
330 });
331
332 Ok(receiver_into_stream(rx))
333 }
334}
335
336fn extract_turn_output(
337 result: crate::agent::executor::TurnResult<TurnEngineOutput>,
338) -> TurnEngineOutput {
339 match result {
340 crate::agent::executor::TurnResult::Complete(output) => output,
341 crate::agent::executor::TurnResult::Continue(Some(output)) => output,
342 crate::agent::executor::TurnResult::Continue(None) => TurnEngineOutput {
343 response: String::new(),
344 tool_calls: Vec::new(),
345 },
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352 use crate::agent::AgentDeriveT;
353 use crate::tests::agent::MockAgentImpl;
354 use autoagents_test_utils::llm::MockLLMProvider;
355 use std::sync::Arc;
356
357 #[test]
358 fn test_basic_agent_creation() {
359 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
360 let basic_agent = BasicAgent::new(mock_agent);
361
362 assert_eq!(basic_agent.name(), "test_agent");
363 assert_eq!(basic_agent.description(), "Test agent description");
364 }
365
366 #[test]
367 fn test_basic_agent_clone() {
368 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
369 let basic_agent = BasicAgent::new(mock_agent);
370 let cloned_agent = basic_agent.clone();
371
372 assert_eq!(cloned_agent.name(), "test_agent");
373 assert_eq!(cloned_agent.description(), "Test agent description");
374 }
375
376 #[test]
377 fn test_basic_agent_output_conversions() {
378 let output = BasicAgentOutput {
379 response: "Test response".to_string(),
380 done: true,
381 };
382
383 let value: Value = output.clone().into();
385 assert!(value.is_object());
386
387 let string: String = output.into();
389 assert_eq!(string, "Test response");
390 }
391
392 #[tokio::test]
393 async fn test_basic_agent_execute() {
394 use crate::agent::task::Task;
395 use crate::agent::{AgentConfig, Context};
396 use autoagents_protocol::ActorID;
397
398 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
399 let basic_agent = BasicAgent::new(mock_agent);
400
401 let llm = Arc::new(MockLLMProvider {});
402 let config = AgentConfig {
403 id: ActorID::new_v4(),
404 name: "test_agent".to_string(),
405 description: "Test agent description".to_string(),
406 output_schema: None,
407 };
408
409 let context = Context::new(llm, None).with_config(config);
410
411 let context_arc = Arc::new(context);
412 let task = Task::new("Test task");
413 let result = basic_agent.execute(&task, context_arc).await;
414
415 assert!(result.is_ok());
416 let output = result.unwrap();
417 assert_eq!(output.response, "Mock response");
418 assert!(output.done);
419 }
420
421 #[test]
422 fn test_executor_config() {
423 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
424 let basic_agent = BasicAgent::new(mock_agent);
425
426 let config = basic_agent.config();
427 assert_eq!(config.max_turns, 1);
428 }
429}