1use crate::agent::executor::event_helper::EventHelper;
2use crate::agent::executor::turn_engine::{
3 TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, TurnEngineOutput, record_task_state,
4};
5use crate::agent::hooks::HookOutcome;
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use autoagents_llm::error::LLMError;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::ops::Deref;
18use std::pin::Pin;
19use std::sync::Arc;
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct BasicAgentOutput {
24 pub response: String,
25 pub done: bool,
26}
27
28impl From<BasicAgentOutput> for Value {
29 fn from(output: BasicAgentOutput) -> Self {
30 serde_json::to_value(output).unwrap_or(Value::Null)
31 }
32}
33impl From<BasicAgentOutput> for String {
34 fn from(output: BasicAgentOutput) -> Self {
35 output.response
36 }
37}
38
39impl BasicAgentOutput {
40 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
43 serde_json::from_str::<T>(&self.response)
44 }
45
46 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
49 where
50 T: for<'de> serde::Deserialize<'de>,
51 F: FnOnce(&str) -> T,
52 {
53 self.try_parse::<T>()
54 .unwrap_or_else(|_| fallback(&self.response))
55 }
56}
57
58#[derive(Debug, thiserror::Error)]
60pub enum BasicExecutorError {
61 #[error("LLM error: {0}")]
62 LLMError(
63 #[from]
64 #[source]
65 LLMError,
66 ),
67
68 #[error("Other error: {0}")]
69 Other(String),
70}
71
72impl From<TurnEngineError> for BasicExecutorError {
73 fn from(error: TurnEngineError) -> Self {
74 match error {
75 TurnEngineError::LLMError(err) => err.into(),
76 TurnEngineError::Aborted => {
77 BasicExecutorError::Other("Run aborted by hook".to_string())
78 }
79 TurnEngineError::Other(err) => BasicExecutorError::Other(err),
80 }
81 }
82}
83
84#[derive(Debug)]
89pub struct BasicAgent<T: AgentDeriveT> {
90 inner: Arc<T>,
91}
92
93impl<T: AgentDeriveT> Clone for BasicAgent<T> {
94 fn clone(&self) -> Self {
95 Self {
96 inner: Arc::clone(&self.inner),
97 }
98 }
99}
100
101impl<T: AgentDeriveT> BasicAgent<T> {
102 pub fn new(inner: T) -> Self {
103 Self {
104 inner: Arc::new(inner),
105 }
106 }
107}
108
109impl<T: AgentDeriveT> Deref for BasicAgent<T> {
110 type Target = T;
111
112 fn deref(&self) -> &Self::Target {
113 &self.inner
114 }
115}
116
117#[async_trait]
119impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
120 type Output = <T as AgentDeriveT>::Output;
121
122 fn description(&self) -> &str {
123 self.inner.description()
124 }
125
126 fn output_schema(&self) -> Option<Value> {
127 self.inner.output_schema()
128 }
129
130 fn name(&self) -> &str {
131 self.inner.name()
132 }
133
134 fn tools(&self) -> Vec<Box<dyn ToolT>> {
135 self.inner.tools()
136 }
137}
138
139#[async_trait]
140impl<T> AgentHooks for BasicAgent<T>
141where
142 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
143{
144 async fn on_agent_create(&self) {
145 self.inner.on_agent_create().await
146 }
147
148 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
149 self.inner.on_run_start(task, ctx).await
150 }
151
152 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
153 self.inner.on_run_complete(task, result, ctx).await
154 }
155
156 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
157 self.inner.on_turn_start(turn_index, ctx).await
158 }
159
160 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
161 self.inner.on_turn_complete(turn_index, ctx).await
162 }
163
164 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
165 self.inner.on_tool_call(tool_call, ctx).await
166 }
167
168 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
169 self.inner.on_tool_start(tool_call, ctx).await
170 }
171
172 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
173 self.inner.on_tool_result(tool_call, result, ctx).await
174 }
175
176 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
177 self.inner.on_tool_error(tool_call, err, ctx).await
178 }
179 async fn on_agent_shutdown(&self) {
180 self.inner.on_agent_shutdown().await
181 }
182}
183
184#[async_trait]
186impl<T: AgentDeriveT + AgentHooks> AgentExecutor for BasicAgent<T> {
187 type Output = BasicAgentOutput;
188 type Error = BasicExecutorError;
189
190 fn config(&self) -> ExecutorConfig {
191 ExecutorConfig { max_turns: 1 }
192 }
193
194 async fn execute(
195 &self,
196 task: &Task,
197 context: Arc<Context>,
198 ) -> Result<Self::Output, Self::Error> {
199 record_task_state(&context, task);
200 let tx_event = context.tx().ok();
201 EventHelper::send_task_started(
202 &tx_event,
203 task.submission_id,
204 context.config().id,
205 context.config().name.clone(),
206 task.prompt.clone(),
207 )
208 .await;
209
210 let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
211 let mut turn_state = engine.turn_state(&context);
212 let turn_result = engine
213 .run_turn(
214 self,
215 task,
216 &context,
217 &mut turn_state,
218 0,
219 self.config().max_turns,
220 )
221 .await?;
222
223 let output = extract_turn_output(turn_result);
224
225 Ok(BasicAgentOutput {
226 response: output.response,
227 done: true,
228 })
229 }
230
231 async fn execute_stream(
232 &self,
233 task: &Task,
234 context: Arc<Context>,
235 ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
236 {
237 record_task_state(&context, task);
238 let tx_event = context.tx().ok();
239 EventHelper::send_task_started(
240 &tx_event,
241 task.submission_id,
242 context.config().id,
243 context.config().name.clone(),
244 task.prompt.clone(),
245 )
246 .await;
247
248 let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
249 let mut turn_state = engine.turn_state(&context);
250 let context_clone = context.clone();
251 let task = task.clone();
252 let executor = self.clone();
253
254 let (tx, rx) = channel::<Result<BasicAgentOutput, BasicExecutorError>>(100);
255
256 spawn_future(async move {
257 let turn_stream = engine
258 .run_turn_stream(
259 executor,
260 &task,
261 context_clone.clone(),
262 &mut turn_state,
263 0,
264 1,
265 )
266 .await;
267
268 let mut final_response = String::default();
269 match turn_stream {
270 Ok(mut stream) => {
271 use futures::StreamExt;
272 while let Some(delta_result) = stream.next().await {
273 match delta_result {
274 Ok(TurnDelta::Text(content)) => {
275 let _ = tx
276 .send(Ok(BasicAgentOutput {
277 response: content,
278 done: false,
279 }))
280 .await;
281 }
282 Ok(TurnDelta::ReasoningContent(_)) => {}
283 Ok(TurnDelta::ToolResults(_)) => {}
284 Ok(TurnDelta::Done(result)) => {
285 let output = extract_turn_output(result);
286 final_response = output.response.clone();
287 let _ = tx
288 .send(Ok(BasicAgentOutput {
289 response: output.response,
290 done: true,
291 }))
292 .await;
293 break;
294 }
295 Err(err) => {
296 let _ = tx.send(Err(err.into())).await;
297 return;
298 }
299 }
300 }
301 }
302 Err(err) => {
303 let _ = tx.send(Err(err.into())).await;
304 return;
305 }
306 }
307
308 let tx_event = context_clone.tx().ok();
309 EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
310 let output = BasicAgentOutput {
311 response: final_response,
312 done: true,
313 };
314 let result =
315 serde_json::to_string_pretty(&output).unwrap_or_else(|_| output.response.clone());
316 EventHelper::send_task_completed(
317 &tx_event,
318 task.submission_id,
319 context_clone.config().id,
320 context_clone.config().name.clone(),
321 result,
322 )
323 .await;
324 });
325
326 Ok(receiver_into_stream(rx))
327 }
328}
329
330fn extract_turn_output(
331 result: crate::agent::executor::TurnResult<TurnEngineOutput>,
332) -> TurnEngineOutput {
333 match result {
334 crate::agent::executor::TurnResult::Complete(output) => output,
335 crate::agent::executor::TurnResult::Continue(Some(output)) => output,
336 crate::agent::executor::TurnResult::Continue(None) => TurnEngineOutput {
337 response: String::default(),
338 reasoning_content: String::default(),
339 tool_calls: Vec::default(),
340 },
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::agent::AgentDeriveT;
348 use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, MockLLMProvider};
349 use async_trait::async_trait;
350 use autoagents_llm::chat::{StreamChoice, StreamDelta, StreamResponse};
351 use std::sync::Arc;
352
353 #[derive(Debug, Clone)]
354 struct AbortAgent;
355
356 #[async_trait]
357 impl AgentDeriveT for AbortAgent {
358 type Output = String;
359
360 fn description(&self) -> &str {
361 "abort"
362 }
363
364 fn output_schema(&self) -> Option<Value> {
365 None
366 }
367
368 fn name(&self) -> &str {
369 "abort_agent"
370 }
371
372 fn tools(&self) -> Vec<Box<dyn ToolT>> {
373 vec![]
374 }
375 }
376
377 #[async_trait]
378 impl AgentHooks for AbortAgent {
379 async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
380 HookOutcome::Abort
381 }
382 }
383
384 #[tokio::test]
385 async fn test_basic_agent_execute() {
386 use crate::agent::task::Task;
387 use crate::agent::{AgentConfig, Context};
388 use autoagents_protocol::ActorID;
389
390 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
391 let basic_agent = BasicAgent::new(mock_agent);
392
393 let llm = Arc::new(MockLLMProvider {});
394 let config = AgentConfig {
395 id: ActorID::new_v4(),
396 name: "test_agent".to_string(),
397 description: "Test agent description".to_string(),
398 output_schema: None,
399 };
400
401 let context = Context::new(llm, None).with_config(config);
402
403 let context_arc = Arc::new(context);
404 let task = Task::new("Test task");
405 let result = basic_agent.execute(&task, context_arc).await;
406
407 assert!(result.is_ok());
408 let output = result.unwrap();
409 assert_eq!(output.response, "Mock response");
410 assert!(output.done);
411 }
412
413 #[test]
414 fn test_basic_agent_metadata_and_output_conversion() {
415 let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
416 let basic_agent = BasicAgent::new(mock_agent);
417
418 let config = basic_agent.config();
419 assert_eq!(config.max_turns, 1);
420
421 let cloned = basic_agent.clone();
422 assert_eq!(cloned.name(), "test_agent");
423 assert_eq!(cloned.description(), "Test agent description");
424
425 let output = BasicAgentOutput {
426 response: "Test response".to_string(),
427 done: true,
428 };
429 let value: Value = output.clone().into();
430 assert_eq!(value["response"], "Test response");
431 let string: String = output.into();
432 assert_eq!(string, "Test response");
433 }
434
435 #[test]
436 fn test_basic_agent_output_try_parse_success() {
437 let output = BasicAgentOutput {
438 response: r#"{"name":"test","value":42}"#.to_string(),
439 done: true,
440 };
441 #[derive(serde::Deserialize, PartialEq, Debug)]
442 struct Data {
443 name: String,
444 value: i32,
445 }
446 let parsed: Data = output.try_parse().unwrap();
447 assert_eq!(
448 parsed,
449 Data {
450 name: "test".to_string(),
451 value: 42
452 }
453 );
454 }
455
456 #[test]
457 fn test_basic_agent_output_try_parse_failure() {
458 let output = BasicAgentOutput {
459 response: "not json".to_string(),
460 done: true,
461 };
462 let result = output.try_parse::<serde_json::Value>();
463 assert!(result.is_err());
464 }
465
466 #[test]
467 fn test_basic_agent_output_parse_or_map_fallback() {
468 let output = BasicAgentOutput {
469 response: "plain text".to_string(),
470 done: true,
471 };
472 let result: String = output.parse_or_map(|s| s.to_uppercase());
473 assert_eq!(result, "PLAIN TEXT");
474 }
475
476 #[test]
477 fn test_basic_agent_output_parse_or_map_success() {
478 let output = BasicAgentOutput {
479 response: r#""hello""#.to_string(),
480 done: true,
481 };
482 let result: String = output.parse_or_map(|s| s.to_uppercase());
483 assert_eq!(result, "hello");
484 }
485
486 #[test]
487 fn test_error_from_turn_engine_llm() {
488 let err: BasicExecutorError =
489 TurnEngineError::LLMError(LLMError::Generic("bad".to_string())).into();
490 assert!(matches!(err, BasicExecutorError::LLMError(_)));
491 assert!(err.to_string().contains("bad"));
492 }
493
494 #[test]
495 fn test_error_from_turn_engine_aborted() {
496 let err: BasicExecutorError = TurnEngineError::Aborted.into();
497 assert!(matches!(err, BasicExecutorError::Other(_)));
498 assert!(err.to_string().contains("aborted"));
499 }
500
501 #[test]
502 fn test_error_from_turn_engine_other() {
503 let err: BasicExecutorError = TurnEngineError::Other("misc".to_string()).into();
504 assert!(matches!(err, BasicExecutorError::Other(_)));
505 assert!(err.to_string().contains("misc"));
506 }
507
508 #[test]
509 fn test_extract_turn_output_complete() {
510 let result = crate::agent::executor::TurnResult::Complete(
511 crate::agent::executor::turn_engine::TurnEngineOutput {
512 response: "done".to_string(),
513 reasoning_content: String::default(),
514 tool_calls: Vec::new(),
515 },
516 );
517 let output = extract_turn_output(result);
518 assert_eq!(output.response, "done");
519 }
520
521 #[test]
522 fn test_extract_turn_output_continue_some() {
523 let result = crate::agent::executor::TurnResult::Continue(Some(
524 crate::agent::executor::turn_engine::TurnEngineOutput {
525 response: "partial".to_string(),
526 reasoning_content: String::default(),
527 tool_calls: Vec::new(),
528 },
529 ));
530 let output = extract_turn_output(result);
531 assert_eq!(output.response, "partial");
532 }
533
534 #[test]
535 fn test_extract_turn_output_continue_none() {
536 let result = crate::agent::executor::TurnResult::Continue(None);
537 let output = extract_turn_output(result);
538 assert!(output.response.is_empty());
539 assert!(output.tool_calls.is_empty());
540 }
541
542 #[tokio::test]
543 async fn test_basic_agent_execute_stream_returns_output() {
544 use crate::agent::{AgentConfig, Context};
545 use autoagents_protocol::ActorID;
546 use futures::StreamExt;
547
548 let llm = Arc::new(ConfigurableLLMProvider {
549 structured_stream: vec![
550 StreamResponse {
551 choices: vec![StreamChoice {
552 delta: StreamDelta {
553 content: Some("Hello ".to_string()),
554 reasoning_content: None,
555 tool_calls: None,
556 },
557 }],
558 usage: None,
559 },
560 StreamResponse {
561 choices: vec![StreamChoice {
562 delta: StreamDelta {
563 content: Some("world".to_string()),
564 reasoning_content: None,
565 tool_calls: None,
566 },
567 }],
568 usage: None,
569 },
570 ],
571 ..ConfigurableLLMProvider::default()
572 });
573
574 let mock_agent = MockAgentImpl::new("stream_agent", "desc");
575 let basic_agent = BasicAgent::new(mock_agent);
576 let config = AgentConfig {
577 id: ActorID::new_v4(),
578 name: "stream_agent".to_string(),
579 description: "desc".to_string(),
580 output_schema: None,
581 };
582 let context = Arc::new(Context::new(llm, None).with_config(config));
583 let task = Task::new("Test task");
584
585 let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
586 let mut final_output = None;
587 while let Some(item) = stream.next().await {
588 let output = item.unwrap();
589 if output.done {
590 final_output = Some(output);
591 break;
592 }
593 }
594
595 let output = final_output.expect("final output");
596 assert_eq!(output.response, "Hello world");
597 assert!(output.done);
598 }
599
600 #[tokio::test]
601 async fn test_basic_agent_execute_stream_ignores_reasoning_output() {
602 use crate::agent::{AgentConfig, Context};
603 use autoagents_protocol::ActorID;
604 use futures::StreamExt;
605
606 let llm = Arc::new(ConfigurableLLMProvider {
607 structured_stream: vec![
608 StreamResponse {
609 choices: vec![StreamChoice {
610 delta: StreamDelta {
611 content: None,
612 reasoning_content: Some("plan".to_string()),
613 tool_calls: None,
614 },
615 }],
616 usage: None,
617 },
618 StreamResponse {
619 choices: vec![StreamChoice {
620 delta: StreamDelta {
621 content: Some("done".to_string()),
622 reasoning_content: None,
623 tool_calls: None,
624 },
625 }],
626 usage: None,
627 },
628 ],
629 ..ConfigurableLLMProvider::default()
630 });
631
632 let mock_agent = MockAgentImpl::new("stream_agent_reasoning", "desc");
633 let basic_agent = BasicAgent::new(mock_agent);
634 let config = AgentConfig {
635 id: ActorID::new_v4(),
636 name: "stream_agent_reasoning".to_string(),
637 description: "desc".to_string(),
638 output_schema: None,
639 };
640 let context = Arc::new(Context::new(llm, None).with_config(config));
641 let task = Task::new("Test task");
642
643 let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
644 let mut outputs = Vec::new();
645 while let Some(item) = stream.next().await {
646 outputs.push(item.unwrap());
647 }
648
649 assert_eq!(outputs.len(), 2);
650 assert_eq!(outputs[0].response, "done");
651 assert!(!outputs[0].done);
652 assert_eq!(outputs[1].response, "done");
653 assert!(outputs[1].done);
654 }
655
656 #[tokio::test]
657 async fn test_basic_agent_run_aborts_on_hook() {
658 use crate::agent::AgentBuilder;
659 use crate::agent::direct::DirectAgent;
660 use crate::agent::error::RunnableAgentError;
661
662 let agent = BasicAgent::new(AbortAgent);
663 let llm = Arc::new(MockLLMProvider {});
664 let handle = AgentBuilder::<_, DirectAgent>::new(agent)
665 .llm(llm)
666 .build()
667 .await
668 .expect("build should succeed");
669 let task = Task::new("abort");
670
671 let err = handle.agent.run(task).await.expect_err("expected abort");
672 assert!(matches!(err, RunnableAgentError::Abort));
673 }
674
675 #[tokio::test]
676 async fn test_basic_agent_run_stream_aborts_on_hook() {
677 use crate::agent::AgentBuilder;
678 use crate::agent::direct::DirectAgent;
679 use crate::agent::error::RunnableAgentError;
680
681 let agent = BasicAgent::new(AbortAgent);
682 let llm = Arc::new(MockLLMProvider {});
683 let handle = AgentBuilder::<_, DirectAgent>::new(agent)
684 .llm(llm)
685 .build()
686 .await
687 .expect("build should succeed");
688 let task = Task::new("abort");
689
690 let err = match handle.agent.run_stream(task).await {
691 Ok(_) => panic!("expected abort"),
692 Err(err) => err,
693 };
694 assert!(matches!(err, RunnableAgentError::Abort));
695 }
696}