1use crate::agent::executor::event_helper::EventHelper;
2use crate::agent::executor::turn_engine::{
3 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, TurnEngineOutput, record_task_state,
4};
5use crate::agent::hooks::HookOutcome;
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BasicAgentOutput {
23 pub response: String,
24 pub done: bool,
25}
26
27impl From<BasicAgentOutput> for Value {
28 fn from(output: BasicAgentOutput) -> Self {
29 serde_json::to_value(output).unwrap_or(Value::Null)
30 }
31}
32impl From<BasicAgentOutput> for String {
33 fn from(output: BasicAgentOutput) -> Self {
34 output.response
35 }
36}
37
38impl BasicAgentOutput {
39 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
42 serde_json::from_str::<T>(&self.response)
43 }
44
45 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
48 where
49 T: for<'de> serde::Deserialize<'de>,
50 F: FnOnce(&str) -> T,
51 {
52 self.try_parse::<T>()
53 .unwrap_or_else(|_| fallback(&self.response))
54 }
55}
56
57#[derive(Debug, thiserror::Error)]
59pub enum BasicExecutorError {
60 #[error("LLM error: {0}")]
61 LLMError(String),
62
63 #[error("Other error: {0}")]
64 Other(String),
65}
66
67impl From<TurnEngineError> for BasicExecutorError {
68 fn from(error: TurnEngineError) -> Self {
69 match error {
70 TurnEngineError::LLMError(err) => BasicExecutorError::LLMError(err),
71 TurnEngineError::Aborted => {
72 BasicExecutorError::Other("Run aborted by hook".to_string())
73 }
74 TurnEngineError::Other(err) => BasicExecutorError::Other(err),
75 }
76 }
77}
78
79#[derive(Debug)]
84pub struct BasicAgent<T: AgentDeriveT> {
85 inner: Arc<T>,
86}
87
88impl<T: AgentDeriveT> Clone for BasicAgent<T> {
89 fn clone(&self) -> Self {
90 Self {
91 inner: Arc::clone(&self.inner),
92 }
93 }
94}
95
96impl<T: AgentDeriveT> BasicAgent<T> {
97 pub fn new(inner: T) -> Self {
98 Self {
99 inner: Arc::new(inner),
100 }
101 }
102}
103
104impl<T: AgentDeriveT> Deref for BasicAgent<T> {
105 type Target = T;
106
107 fn deref(&self) -> &Self::Target {
108 &self.inner
109 }
110}
111
112#[async_trait]
114impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
115 type Output = <T as AgentDeriveT>::Output;
116
117 fn description(&self) -> &str {
118 self.inner.description()
119 }
120
121 fn output_schema(&self) -> Option<Value> {
122 self.inner.output_schema()
123 }
124
125 fn name(&self) -> &str {
126 self.inner.name()
127 }
128
129 fn tools(&self) -> Vec<Box<dyn ToolT>> {
130 self.inner.tools()
131 }
132}
133
134#[async_trait]
135impl<T> AgentHooks for BasicAgent<T>
136where
137 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
138{
139 async fn on_agent_create(&self) {
140 self.inner.on_agent_create().await
141 }
142
143 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
144 self.inner.on_run_start(task, ctx).await
145 }
146
147 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
148 self.inner.on_run_complete(task, result, ctx).await
149 }
150
151 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
152 self.inner.on_turn_start(turn_index, ctx).await
153 }
154
155 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
156 self.inner.on_turn_complete(turn_index, ctx).await
157 }
158
159 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
160 self.inner.on_tool_call(tool_call, ctx).await
161 }
162
163 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
164 self.inner.on_tool_start(tool_call, ctx).await
165 }
166
167 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
168 self.inner.on_tool_result(tool_call, result, ctx).await
169 }
170
171 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
172 self.inner.on_tool_error(tool_call, err, ctx).await
173 }
174 async fn on_agent_shutdown(&self) {
175 self.inner.on_agent_shutdown().await
176 }
177}
178
179#[async_trait]
181impl<T: AgentDeriveT + AgentHooks> AgentExecutor for BasicAgent<T> {
182 type Output = BasicAgentOutput;
183 type Error = BasicExecutorError;
184
185 fn config(&self) -> ExecutorConfig {
186 ExecutorConfig { max_turns: 1 }
187 }
188
189 async fn execute(
190 &self,
191 task: &Task,
192 context: Arc<Context>,
193 ) -> Result<Self::Output, Self::Error> {
194 if self.on_run_start(task, &context).await == HookOutcome::Abort {
195 return Err(BasicExecutorError::Other("Run aborted by hook".to_string()));
196 }
197
198 record_task_state(&context, task);
199 let tx_event = context.tx().ok();
200 EventHelper::send_task_started(
201 &tx_event,
202 task.submission_id,
203 context.config().id,
204 context.config().name.clone(),
205 task.prompt.clone(),
206 )
207 .await;
208
209 let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
210 let mut turn_state = engine.turn_state(&context);
211 let turn_result = engine
212 .run_turn(
213 self,
214 task,
215 &context,
216 &mut turn_state,
217 0,
218 self.config().max_turns,
219 )
220 .await?;
221
222 let output = extract_turn_output(turn_result);
223
224 Ok(BasicAgentOutput {
225 response: output.response,
226 done: true,
227 })
228 }
229
230 async fn execute_stream(
231 &self,
232 task: &Task,
233 context: Arc<Context>,
234 ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
235 {
236 if self.on_run_start(task, &context).await == HookOutcome::Abort {
237 return Err(BasicExecutorError::Other("Run aborted by hook".to_string()));
238 }
239
240 record_task_state(&context, task);
241 let tx_event = context.tx().ok();
242 EventHelper::send_task_started(
243 &tx_event,
244 task.submission_id,
245 context.config().id,
246 context.config().name.clone(),
247 task.prompt.clone(),
248 )
249 .await;
250
251 let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
252 let mut turn_state = engine.turn_state(&context);
253 let context_clone = context.clone();
254 let task = task.clone();
255 let executor = self.clone();
256
257 let (tx, rx) = channel::<Result<BasicAgentOutput, BasicExecutorError>>(100);
258
259 spawn_future(async move {
260 let turn_stream = engine
261 .run_turn_stream(
262 executor,
263 &task,
264 context_clone.clone(),
265 &mut turn_state,
266 0,
267 1,
268 )
269 .await;
270
271 let mut final_response = String::default();
272
273 match turn_stream {
274 Ok(mut stream) => {
275 use futures::StreamExt;
276 while let Some(delta_result) = stream.next().await {
277 match delta_result {
278 Ok(TurnDelta::Text(content)) => {
279 let _ = tx
280 .send(Ok(BasicAgentOutput {
281 response: content,
282 done: false,
283 }))
284 .await;
285 }
286 Ok(TurnDelta::ToolResults(_)) => {}
287 Ok(TurnDelta::Done(result)) => {
288 let output = extract_turn_output(result);
289 final_response = output.response.clone();
290 let _ = tx
291 .send(Ok(BasicAgentOutput {
292 response: output.response,
293 done: true,
294 }))
295 .await;
296 break;
297 }
298 Err(err) => {
299 let _ = tx.send(Err(err.into())).await;
300 return;
301 }
302 }
303 }
304 }
305 Err(err) => {
306 let _ = tx.send(Err(err.into())).await;
307 return;
308 }
309 }
310
311 let tx_event = context_clone.tx().ok();
312 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
313 let output = BasicAgentOutput {
314 response: final_response,
315 done: true,
316 };
317 let result =
318 serde_json::to_string_pretty(&output).unwrap_or_else(|_| output.response.clone());
319 EventHelper::send_task_completed(
320 &tx_event,
321 task.submission_id,
322 context_clone.config().id,
323 context_clone.config().name.clone(),
324 result,
325 )
326 .await;
327 });
328
329 Ok(receiver_into_stream(rx))
330 }
331}
332
333fn extract_turn_output(
334 result: crate::agent::executor::TurnResult<TurnEngineOutput>,
335) -> TurnEngineOutput {
336 match result {
337 crate::agent::executor::TurnResult::Complete(output) => output,
338 crate::agent::executor::TurnResult::Continue(Some(output)) => output,
339 crate::agent::executor::TurnResult::Continue(None) => TurnEngineOutput {
340 response: String::default(),
341 tool_calls: Vec::default(),
342 },
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349 use crate::agent::AgentDeriveT;
350 use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, MockLLMProvider, TestAgentOutput};
351 use async_trait::async_trait;
352 use autoagents_llm::chat::{StreamChoice, StreamDelta, StreamResponse};
353 use std::sync::Arc;
354
355 #[derive(Debug, Clone)]
356 struct AbortAgent;
357
358 #[async_trait]
359 impl AgentDeriveT for AbortAgent {
360 type Output = TestAgentOutput;
361
362 fn description(&self) -> &str {
363 "abort"
364 }
365
366 fn output_schema(&self) -> Option<Value> {
367 None
368 }
369
370 fn name(&self) -> &str {
371 "abort_agent"
372 }
373
374 fn tools(&self) -> Vec<Box<dyn ToolT>> {
375 vec![]
376 }
377 }
378
379 #[async_trait]
380 impl AgentHooks for AbortAgent {
381 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
382 HookOutcome::Abort
383 }
384 }
385
386 #[tokio::test]
387 async fn test_basic_agent_execute() {
388 use crate::agent::task::Task;
389 use crate::agent::{AgentConfig, Context};
390 use autoagents_protocol::ActorID;
391
392 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
393 let basic_agent = BasicAgent::new(mock_agent);
394
395 let llm = Arc::new(MockLLMProvider {});
396 let config = AgentConfig {
397 id: ActorID::new_v4(),
398 name: "test_agent".to_string(),
399 description: "Test agent description".to_string(),
400 output_schema: None,
401 };
402
403 let context = Context::new(llm, None).with_config(config);
404
405 let context_arc = Arc::new(context);
406 let task = Task::new("Test task");
407 let result = basic_agent.execute(&task, context_arc).await;
408
409 assert!(result.is_ok());
410 let output = result.unwrap();
411 assert_eq!(output.response, "Mock response");
412 assert!(output.done);
413 }
414
415 #[test]
416 fn test_basic_agent_metadata_and_output_conversion() {
417 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
418 let basic_agent = BasicAgent::new(mock_agent);
419
420 let config = basic_agent.config();
421 assert_eq!(config.max_turns, 1);
422
423 let cloned = basic_agent.clone();
424 assert_eq!(cloned.name(), "test_agent");
425 assert_eq!(cloned.description(), "Test agent description");
426
427 let output = BasicAgentOutput {
428 response: "Test response".to_string(),
429 done: true,
430 };
431 let value: Value = output.clone().into();
432 assert_eq!(value["response"], "Test response");
433 let string: String = output.into();
434 assert_eq!(string, "Test response");
435 }
436
437 #[test]
438 fn test_basic_agent_output_try_parse_success() {
439 let output = BasicAgentOutput {
440 response: r#"{"name":"test","value":42}"#.to_string(),
441 done: true,
442 };
443 #[derive(serde::Deserialize, PartialEq, Debug)]
444 struct Data {
445 name: String,
446 value: i32,
447 }
448 let parsed: Data = output.try_parse().unwrap();
449 assert_eq!(
450 parsed,
451 Data {
452 name: "test".to_string(),
453 value: 42
454 }
455 );
456 }
457
458 #[test]
459 fn test_basic_agent_output_try_parse_failure() {
460 let output = BasicAgentOutput {
461 response: "not json".to_string(),
462 done: true,
463 };
464 let result = output.try_parse::<serde_json::Value>();
465 assert!(result.is_err());
466 }
467
468 #[test]
469 fn test_basic_agent_output_parse_or_map_fallback() {
470 let output = BasicAgentOutput {
471 response: "plain text".to_string(),
472 done: true,
473 };
474 let result: String = output.parse_or_map(|s| s.to_uppercase());
475 assert_eq!(result, "PLAIN TEXT");
476 }
477
478 #[test]
479 fn test_basic_agent_output_parse_or_map_success() {
480 let output = BasicAgentOutput {
481 response: r#""hello""#.to_string(),
482 done: true,
483 };
484 let result: String = output.parse_or_map(|s| s.to_uppercase());
485 assert_eq!(result, "hello");
486 }
487
488 #[test]
489 fn test_error_from_turn_engine_llm() {
490 let err: BasicExecutorError = TurnEngineError::LLMError("bad".to_string()).into();
491 assert!(matches!(err, BasicExecutorError::LLMError(_)));
492 assert!(err.to_string().contains("bad"));
493 }
494
495 #[test]
496 fn test_error_from_turn_engine_aborted() {
497 let err: BasicExecutorError = TurnEngineError::Aborted.into();
498 assert!(matches!(err, BasicExecutorError::Other(_)));
499 assert!(err.to_string().contains("aborted"));
500 }
501
502 #[test]
503 fn test_error_from_turn_engine_other() {
504 let err: BasicExecutorError = TurnEngineError::Other("misc".to_string()).into();
505 assert!(matches!(err, BasicExecutorError::Other(_)));
506 assert!(err.to_string().contains("misc"));
507 }
508
509 #[test]
510 fn test_extract_turn_output_complete() {
511 let result = crate::agent::executor::TurnResult::Complete(
512 crate::agent::executor::turn_engine::TurnEngineOutput {
513 response: "done".to_string(),
514 tool_calls: Vec::new(),
515 },
516 );
517 let output = extract_turn_output(result);
518 assert_eq!(output.response, "done");
519 }
520
521 #[test]
522 fn test_extract_turn_output_continue_some() {
523 let result = crate::agent::executor::TurnResult::Continue(Some(
524 crate::agent::executor::turn_engine::TurnEngineOutput {
525 response: "partial".to_string(),
526 tool_calls: Vec::new(),
527 },
528 ));
529 let output = extract_turn_output(result);
530 assert_eq!(output.response, "partial");
531 }
532
533 #[test]
534 fn test_extract_turn_output_continue_none() {
535 let result = crate::agent::executor::TurnResult::Continue(None);
536 let output = extract_turn_output(result);
537 assert!(output.response.is_empty());
538 assert!(output.tool_calls.is_empty());
539 }
540
541 #[tokio::test]
542 async fn test_basic_agent_execute_stream_returns_output() {
543 use crate::agent::{AgentConfig, Context};
544 use autoagents_protocol::ActorID;
545 use futures::StreamExt;
546
547 let llm = Arc::new(ConfigurableLLMProvider {
548 structured_stream: vec![
549 StreamResponse {
550 choices: vec![StreamChoice {
551 delta: StreamDelta {
552 content: Some("Hello ".to_string()),
553 tool_calls: None,
554 },
555 }],
556 usage: None,
557 },
558 StreamResponse {
559 choices: vec![StreamChoice {
560 delta: StreamDelta {
561 content: Some("world".to_string()),
562 tool_calls: None,
563 },
564 }],
565 usage: None,
566 },
567 ],
568 ..ConfigurableLLMProvider::default()
569 });
570
571 let mock_agent = MockAgentImpl::new("stream_agent", "desc");
572 let basic_agent = BasicAgent::new(mock_agent);
573 let config = AgentConfig {
574 id: ActorID::new_v4(),
575 name: "stream_agent".to_string(),
576 description: "desc".to_string(),
577 output_schema: None,
578 };
579 let context = Arc::new(Context::new(llm, None).with_config(config));
580 let task = Task::new("Test task");
581
582 let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
583 let mut final_output = None;
584 while let Some(item) = stream.next().await {
585 let output = item.unwrap();
586 if output.done {
587 final_output = Some(output);
588 break;
589 }
590 }
591
592 let output = final_output.expect("final output");
593 assert_eq!(output.response, "Hello world");
594 assert!(output.done);
595 }
596
597 #[tokio::test]
598 async fn test_basic_agent_execute_aborts_on_hook() {
599 use crate::agent::{AgentConfig, Context};
600 use autoagents_protocol::ActorID;
601
602 let agent = BasicAgent::new(AbortAgent);
603 let llm = Arc::new(MockLLMProvider {});
604 let config = AgentConfig {
605 id: ActorID::new_v4(),
606 name: "abort_agent".to_string(),
607 description: "abort".to_string(),
608 output_schema: None,
609 };
610 let context = Arc::new(Context::new(llm, None).with_config(config));
611 let task = Task::new("abort");
612
613 let err = agent.execute(&task, context).await.unwrap_err();
614 assert!(err.to_string().contains("aborted"));
615 }
616
617 #[tokio::test]
618 async fn test_basic_agent_execute_stream_aborts_on_hook() {
619 use crate::agent::{AgentConfig, Context};
620 use autoagents_protocol::ActorID;
621
622 let agent = BasicAgent::new(AbortAgent);
623 let llm = Arc::new(MockLLMProvider {});
624 let config = AgentConfig {
625 id: ActorID::new_v4(),
626 name: "abort_agent".to_string(),
627 description: "abort".to_string(),
628 output_schema: None,
629 };
630 let context = Arc::new(Context::new(llm, None).with_config(config));
631 let task = Task::new("abort");
632
633 let err = agent.execute_stream(&task, context).await.err().unwrap();
634 assert!(err.to_string().contains("aborted"));
635 }
636}