Skip to main content

autoagents_core/agent/prebuilt/executor/
basic.rs

1use crate::agent::executor::event_helper::EventHelper;
2use crate::agent::executor::turn_engine::{
3    TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, TurnEngineOutput, record_task_state,
4};
5use crate::agent::hooks::HookOutcome;
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use autoagents_llm::error::LLMError;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17use std::ops::Deref;
18use std::pin::Pin;
19use std::sync::Arc;
20
21/// Output of the Basic executor
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct BasicAgentOutput {
24    pub response: String,
25    pub done: bool,
26}
27
28impl From<BasicAgentOutput> for Value {
29    fn from(output: BasicAgentOutput) -> Self {
30        serde_json::to_value(output).unwrap_or(Value::Null)
31    }
32}
33impl From<BasicAgentOutput> for String {
34    fn from(output: BasicAgentOutput) -> Self {
35        output.response
36    }
37}
38
39impl BasicAgentOutput {
40    /// Try to parse the response string as structured JSON of type `T`.
41    /// Returns `serde_json::Error` if parsing fails.
42    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
43        serde_json::from_str::<T>(&self.response)
44    }
45
46    /// Parse the response string as structured JSON of type `T`, or map the raw
47    /// text into `T` using the provided fallback function if parsing fails.
48    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
49    where
50        T: for<'de> serde::Deserialize<'de>,
51        F: FnOnce(&str) -> T,
52    {
53        self.try_parse::<T>()
54            .unwrap_or_else(|_| fallback(&self.response))
55    }
56}
57
58/// Error type for Basic executor
59#[derive(Debug, thiserror::Error)]
60pub enum BasicExecutorError {
61    #[error("LLM error: {0}")]
62    LLMError(
63        #[from]
64        #[source]
65        LLMError,
66    ),
67
68    #[error("Other error: {0}")]
69    Other(String),
70}
71
72impl From<TurnEngineError> for BasicExecutorError {
73    fn from(error: TurnEngineError) -> Self {
74        match error {
75            TurnEngineError::LLMError(err) => err.into(),
76            TurnEngineError::Aborted => {
77                BasicExecutorError::Other("Run aborted by hook".to_string())
78            }
79            TurnEngineError::Other(err) => BasicExecutorError::Other(err),
80        }
81    }
82}
83
84/// Wrapper type for the single-turn Basic executor.
85///
86/// Use `BasicAgent<T>` when you want a single request/response interaction
87/// with optional streaming but without tool calling or multi-turn loops.
88#[derive(Debug)]
89pub struct BasicAgent<T: AgentDeriveT> {
90    inner: Arc<T>,
91}
92
93impl<T: AgentDeriveT> Clone for BasicAgent<T> {
94    fn clone(&self) -> Self {
95        Self {
96            inner: Arc::clone(&self.inner),
97        }
98    }
99}
100
101impl<T: AgentDeriveT> BasicAgent<T> {
102    pub fn new(inner: T) -> Self {
103        Self {
104            inner: Arc::new(inner),
105        }
106    }
107}
108
109impl<T: AgentDeriveT> Deref for BasicAgent<T> {
110    type Target = T;
111
112    fn deref(&self) -> &Self::Target {
113        &self.inner
114    }
115}
116
117/// Implement AgentDeriveT for the wrapper by delegating to the inner type
118#[async_trait]
119impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
120    type Output = <T as AgentDeriveT>::Output;
121
122    fn description(&self) -> &str {
123        self.inner.description()
124    }
125
126    fn output_schema(&self) -> Option<Value> {
127        self.inner.output_schema()
128    }
129
130    fn name(&self) -> &str {
131        self.inner.name()
132    }
133
134    fn tools(&self) -> Vec<Box<dyn ToolT>> {
135        self.inner.tools()
136    }
137}
138
139#[async_trait]
140impl<T> AgentHooks for BasicAgent<T>
141where
142    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
143{
144    async fn on_agent_create(&self) {
145        self.inner.on_agent_create().await
146    }
147
148    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
149        self.inner.on_run_start(task, ctx).await
150    }
151
152    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
153        self.inner.on_run_complete(task, result, ctx).await
154    }
155
156    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
157        self.inner.on_turn_start(turn_index, ctx).await
158    }
159
160    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
161        self.inner.on_turn_complete(turn_index, ctx).await
162    }
163
164    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
165        self.inner.on_tool_call(tool_call, ctx).await
166    }
167
168    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
169        self.inner.on_tool_start(tool_call, ctx).await
170    }
171
172    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
173        self.inner.on_tool_result(tool_call, result, ctx).await
174    }
175
176    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
177        self.inner.on_tool_error(tool_call, err, ctx).await
178    }
179    async fn on_agent_shutdown(&self) {
180        self.inner.on_agent_shutdown().await
181    }
182}
183
184/// Implementation of AgentExecutor for the BasicExecutorWrapper
185#[async_trait]
186impl<T: AgentDeriveT + AgentHooks> AgentExecutor for BasicAgent<T> {
187    type Output = BasicAgentOutput;
188    type Error = BasicExecutorError;
189
190    fn config(&self) -> ExecutorConfig {
191        ExecutorConfig { max_turns: 1 }
192    }
193
194    async fn execute(
195        &self,
196        task: &Task,
197        context: Arc<Context>,
198    ) -> Result<Self::Output, Self::Error> {
199        record_task_state(&context, task);
200        let tx_event = context.tx().ok();
201        EventHelper::send_task_started(
202            &tx_event,
203            task.submission_id,
204            context.config().id,
205            context.config().name.clone(),
206            task.prompt.clone(),
207        )
208        .await;
209
210        let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
211        let mut turn_state = engine.turn_state(&context);
212        let turn_result = engine
213            .run_turn(
214                self,
215                task,
216                &context,
217                &mut turn_state,
218                0,
219                self.config().max_turns,
220            )
221            .await?;
222
223        let output = extract_turn_output(turn_result);
224
225        Ok(BasicAgentOutput {
226            response: output.response,
227            done: true,
228        })
229    }
230
231    async fn execute_stream(
232        &self,
233        task: &Task,
234        context: Arc<Context>,
235    ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
236    {
237        record_task_state(&context, task);
238        let tx_event = context.tx().ok();
239        EventHelper::send_task_started(
240            &tx_event,
241            task.submission_id,
242            context.config().id,
243            context.config().name.clone(),
244            task.prompt.clone(),
245        )
246        .await;
247
248        let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
249        let mut turn_state = engine.turn_state(&context);
250        let context_clone = context.clone();
251        let task = task.clone();
252        let executor = self.clone();
253
254        let (tx, rx) = channel::<Result<BasicAgentOutput, BasicExecutorError>>(100);
255
256        spawn_future(async move {
257            let turn_stream = engine
258                .run_turn_stream(
259                    executor,
260                    &task,
261                    context_clone.clone(),
262                    &mut turn_state,
263                    0,
264                    1,
265                )
266                .await;
267
268            let mut final_response = String::default();
269            match turn_stream {
270                Ok(mut stream) => {
271                    use futures::StreamExt;
272                    while let Some(delta_result) = stream.next().await {
273                        match delta_result {
274                            Ok(TurnDelta::Text(content)) => {
275                                let _ = tx
276                                    .send(Ok(BasicAgentOutput {
277                                        response: content,
278                                        done: false,
279                                    }))
280                                    .await;
281                            }
282                            Ok(TurnDelta::ReasoningContent(_)) => {}
283                            Ok(TurnDelta::ToolResults(_)) => {}
284                            Ok(TurnDelta::Done(result)) => {
285                                let output = extract_turn_output(result);
286                                final_response = output.response.clone();
287                                let _ = tx
288                                    .send(Ok(BasicAgentOutput {
289                                        response: output.response,
290                                        done: true,
291                                    }))
292                                    .await;
293                                break;
294                            }
295                            Err(err) => {
296                                let _ = tx.send(Err(err.into())).await;
297                                return;
298                            }
299                        }
300                    }
301                }
302                Err(err) => {
303                    let _ = tx.send(Err(err.into())).await;
304                    return;
305                }
306            }
307
308            let tx_event = context_clone.tx().ok();
309            EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
310            let output = BasicAgentOutput {
311                response: final_response,
312                done: true,
313            };
314            let result =
315                serde_json::to_string_pretty(&output).unwrap_or_else(|_| output.response.clone());
316            EventHelper::send_task_completed(
317                &tx_event,
318                task.submission_id,
319                context_clone.config().id,
320                context_clone.config().name.clone(),
321                result,
322            )
323            .await;
324        });
325
326        Ok(receiver_into_stream(rx))
327    }
328}
329
330fn extract_turn_output(
331    result: crate::agent::executor::TurnResult<TurnEngineOutput>,
332) -> TurnEngineOutput {
333    match result {
334        crate::agent::executor::TurnResult::Complete(output) => output,
335        crate::agent::executor::TurnResult::Continue(Some(output)) => output,
336        crate::agent::executor::TurnResult::Continue(None) => TurnEngineOutput {
337            response: String::default(),
338            reasoning_content: String::default(),
339            tool_calls: Vec::default(),
340        },
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::agent::AgentDeriveT;
348    use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, MockLLMProvider};
349    use async_trait::async_trait;
350    use autoagents_llm::chat::{StreamChoice, StreamDelta, StreamResponse};
351    use std::sync::Arc;
352
353    #[derive(Debug, Clone)]
354    struct AbortAgent;
355
356    #[async_trait]
357    impl AgentDeriveT for AbortAgent {
358        type Output = String;
359
360        fn description(&self) -> &str {
361            "abort"
362        }
363
364        fn output_schema(&self) -> Option<Value> {
365            None
366        }
367
368        fn name(&self) -> &str {
369            "abort_agent"
370        }
371
372        fn tools(&self) -> Vec<Box<dyn ToolT>> {
373            vec![]
374        }
375    }
376
377    #[async_trait]
378    impl AgentHooks for AbortAgent {
379        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
380            HookOutcome::Abort
381        }
382    }
383
384    #[tokio::test]
385    async fn test_basic_agent_execute() {
386        use crate::agent::task::Task;
387        use crate::agent::{AgentConfig, Context};
388        use autoagents_protocol::ActorID;
389
390        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
391        let basic_agent = BasicAgent::new(mock_agent);
392
393        let llm = Arc::new(MockLLMProvider {});
394        let config = AgentConfig {
395            id: ActorID::new_v4(),
396            name: "test_agent".to_string(),
397            description: "Test agent description".to_string(),
398            output_schema: None,
399        };
400
401        let context = Context::new(llm, None).with_config(config);
402
403        let context_arc = Arc::new(context);
404        let task = Task::new("Test task");
405        let result = basic_agent.execute(&task, context_arc).await;
406
407        assert!(result.is_ok());
408        let output = result.unwrap();
409        assert_eq!(output.response, "Mock response");
410        assert!(output.done);
411    }
412
413    #[test]
414    fn test_basic_agent_metadata_and_output_conversion() {
415        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
416        let basic_agent = BasicAgent::new(mock_agent);
417
418        let config = basic_agent.config();
419        assert_eq!(config.max_turns, 1);
420
421        let cloned = basic_agent.clone();
422        assert_eq!(cloned.name(), "test_agent");
423        assert_eq!(cloned.description(), "Test agent description");
424
425        let output = BasicAgentOutput {
426            response: "Test response".to_string(),
427            done: true,
428        };
429        let value: Value = output.clone().into();
430        assert_eq!(value["response"], "Test response");
431        let string: String = output.into();
432        assert_eq!(string, "Test response");
433    }
434
435    #[test]
436    fn test_basic_agent_output_try_parse_success() {
437        let output = BasicAgentOutput {
438            response: r#"{"name":"test","value":42}"#.to_string(),
439            done: true,
440        };
441        #[derive(serde::Deserialize, PartialEq, Debug)]
442        struct Data {
443            name: String,
444            value: i32,
445        }
446        let parsed: Data = output.try_parse().unwrap();
447        assert_eq!(
448            parsed,
449            Data {
450                name: "test".to_string(),
451                value: 42
452            }
453        );
454    }
455
456    #[test]
457    fn test_basic_agent_output_try_parse_failure() {
458        let output = BasicAgentOutput {
459            response: "not json".to_string(),
460            done: true,
461        };
462        let result = output.try_parse::<serde_json::Value>();
463        assert!(result.is_err());
464    }
465
466    #[test]
467    fn test_basic_agent_output_parse_or_map_fallback() {
468        let output = BasicAgentOutput {
469            response: "plain text".to_string(),
470            done: true,
471        };
472        let result: String = output.parse_or_map(|s| s.to_uppercase());
473        assert_eq!(result, "PLAIN TEXT");
474    }
475
476    #[test]
477    fn test_basic_agent_output_parse_or_map_success() {
478        let output = BasicAgentOutput {
479            response: r#""hello""#.to_string(),
480            done: true,
481        };
482        let result: String = output.parse_or_map(|s| s.to_uppercase());
483        assert_eq!(result, "hello");
484    }
485
486    #[test]
487    fn test_error_from_turn_engine_llm() {
488        let err: BasicExecutorError =
489            TurnEngineError::LLMError(LLMError::Generic("bad".to_string())).into();
490        assert!(matches!(err, BasicExecutorError::LLMError(_)));
491        assert!(err.to_string().contains("bad"));
492    }
493
494    #[test]
495    fn test_error_from_turn_engine_aborted() {
496        let err: BasicExecutorError = TurnEngineError::Aborted.into();
497        assert!(matches!(err, BasicExecutorError::Other(_)));
498        assert!(err.to_string().contains("aborted"));
499    }
500
501    #[test]
502    fn test_error_from_turn_engine_other() {
503        let err: BasicExecutorError = TurnEngineError::Other("misc".to_string()).into();
504        assert!(matches!(err, BasicExecutorError::Other(_)));
505        assert!(err.to_string().contains("misc"));
506    }
507
508    #[test]
509    fn test_extract_turn_output_complete() {
510        let result = crate::agent::executor::TurnResult::Complete(
511            crate::agent::executor::turn_engine::TurnEngineOutput {
512                response: "done".to_string(),
513                reasoning_content: String::default(),
514                tool_calls: Vec::new(),
515            },
516        );
517        let output = extract_turn_output(result);
518        assert_eq!(output.response, "done");
519    }
520
521    #[test]
522    fn test_extract_turn_output_continue_some() {
523        let result = crate::agent::executor::TurnResult::Continue(Some(
524            crate::agent::executor::turn_engine::TurnEngineOutput {
525                response: "partial".to_string(),
526                reasoning_content: String::default(),
527                tool_calls: Vec::new(),
528            },
529        ));
530        let output = extract_turn_output(result);
531        assert_eq!(output.response, "partial");
532    }
533
534    #[test]
535    fn test_extract_turn_output_continue_none() {
536        let result = crate::agent::executor::TurnResult::Continue(None);
537        let output = extract_turn_output(result);
538        assert!(output.response.is_empty());
539        assert!(output.tool_calls.is_empty());
540    }
541
542    #[tokio::test]
543    async fn test_basic_agent_execute_stream_returns_output() {
544        use crate::agent::{AgentConfig, Context};
545        use autoagents_protocol::ActorID;
546        use futures::StreamExt;
547
548        let llm = Arc::new(ConfigurableLLMProvider {
549            structured_stream: vec![
550                StreamResponse {
551                    choices: vec![StreamChoice {
552                        delta: StreamDelta {
553                            content: Some("Hello ".to_string()),
554                            reasoning_content: None,
555                            tool_calls: None,
556                        },
557                    }],
558                    usage: None,
559                },
560                StreamResponse {
561                    choices: vec![StreamChoice {
562                        delta: StreamDelta {
563                            content: Some("world".to_string()),
564                            reasoning_content: None,
565                            tool_calls: None,
566                        },
567                    }],
568                    usage: None,
569                },
570            ],
571            ..ConfigurableLLMProvider::default()
572        });
573
574        let mock_agent = MockAgentImpl::new("stream_agent", "desc");
575        let basic_agent = BasicAgent::new(mock_agent);
576        let config = AgentConfig {
577            id: ActorID::new_v4(),
578            name: "stream_agent".to_string(),
579            description: "desc".to_string(),
580            output_schema: None,
581        };
582        let context = Arc::new(Context::new(llm, None).with_config(config));
583        let task = Task::new("Test task");
584
585        let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
586        let mut final_output = None;
587        while let Some(item) = stream.next().await {
588            let output = item.unwrap();
589            if output.done {
590                final_output = Some(output);
591                break;
592            }
593        }
594
595        let output = final_output.expect("final output");
596        assert_eq!(output.response, "Hello world");
597        assert!(output.done);
598    }
599
600    #[tokio::test]
601    async fn test_basic_agent_execute_stream_ignores_reasoning_output() {
602        use crate::agent::{AgentConfig, Context};
603        use autoagents_protocol::ActorID;
604        use futures::StreamExt;
605
606        let llm = Arc::new(ConfigurableLLMProvider {
607            structured_stream: vec![
608                StreamResponse {
609                    choices: vec![StreamChoice {
610                        delta: StreamDelta {
611                            content: None,
612                            reasoning_content: Some("plan".to_string()),
613                            tool_calls: None,
614                        },
615                    }],
616                    usage: None,
617                },
618                StreamResponse {
619                    choices: vec![StreamChoice {
620                        delta: StreamDelta {
621                            content: Some("done".to_string()),
622                            reasoning_content: None,
623                            tool_calls: None,
624                        },
625                    }],
626                    usage: None,
627                },
628            ],
629            ..ConfigurableLLMProvider::default()
630        });
631
632        let mock_agent = MockAgentImpl::new("stream_agent_reasoning", "desc");
633        let basic_agent = BasicAgent::new(mock_agent);
634        let config = AgentConfig {
635            id: ActorID::new_v4(),
636            name: "stream_agent_reasoning".to_string(),
637            description: "desc".to_string(),
638            output_schema: None,
639        };
640        let context = Arc::new(Context::new(llm, None).with_config(config));
641        let task = Task::new("Test task");
642
643        let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
644        let mut outputs = Vec::new();
645        while let Some(item) = stream.next().await {
646            outputs.push(item.unwrap());
647        }
648
649        assert_eq!(outputs.len(), 2);
650        assert_eq!(outputs[0].response, "done");
651        assert!(!outputs[0].done);
652        assert_eq!(outputs[1].response, "done");
653        assert!(outputs[1].done);
654    }
655
656    #[tokio::test]
657    async fn test_basic_agent_run_aborts_on_hook() {
658        use crate::agent::AgentBuilder;
659        use crate::agent::direct::DirectAgent;
660        use crate::agent::error::RunnableAgentError;
661
662        let agent = BasicAgent::new(AbortAgent);
663        let llm = Arc::new(MockLLMProvider {});
664        let handle = AgentBuilder::<_, DirectAgent>::new(agent)
665            .llm(llm)
666            .build()
667            .await
668            .expect("build should succeed");
669        let task = Task::new("abort");
670
671        let err = handle.agent.run(task).await.expect_err("expected abort");
672        assert!(matches!(err, RunnableAgentError::Abort));
673    }
674
675    #[tokio::test]
676    async fn test_basic_agent_run_stream_aborts_on_hook() {
677        use crate::agent::AgentBuilder;
678        use crate::agent::direct::DirectAgent;
679        use crate::agent::error::RunnableAgentError;
680
681        let agent = BasicAgent::new(AbortAgent);
682        let llm = Arc::new(MockLLMProvider {});
683        let handle = AgentBuilder::<_, DirectAgent>::new(agent)
684            .llm(llm)
685            .build()
686            .await
687            .expect("build should succeed");
688        let task = Task::new("abort");
689
690        let err = match handle.agent.run_stream(task).await {
691            Ok(_) => panic!("expected abort"),
692            Err(err) => err,
693        };
694        assert!(matches!(err, RunnableAgentError::Abort));
695    }
696}