Skip to main content

autoagents_core/agent/prebuilt/executor/
basic.rs

1use crate::agent::executor::event_helper::EventHelper;
2use crate::agent::executor::turn_engine::{
3    TurnDelta, TurnEngine, TurnEngineConfig, TurnEngineError, TurnEngineOutput, record_task_state,
4};
5use crate::agent::hooks::HookOutcome;
6use crate::agent::task::Task;
7use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
8use crate::channel::channel;
9use crate::tool::{ToolCallResult, ToolT};
10use crate::utils::{receiver_into_stream, spawn_future};
11use async_trait::async_trait;
12use autoagents_llm::ToolCall;
13use futures::Stream;
14use serde::{Deserialize, Serialize};
15use serde_json::Value;
16use std::ops::Deref;
17use std::pin::Pin;
18use std::sync::Arc;
19
20/// Output of the Basic executor
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct BasicAgentOutput {
23    pub response: String,
24    pub done: bool,
25}
26
27impl From<BasicAgentOutput> for Value {
28    fn from(output: BasicAgentOutput) -> Self {
29        serde_json::to_value(output).unwrap_or(Value::Null)
30    }
31}
32impl From<BasicAgentOutput> for String {
33    fn from(output: BasicAgentOutput) -> Self {
34        output.response
35    }
36}
37
38impl BasicAgentOutput {
39    /// Try to parse the response string as structured JSON of type `T`.
40    /// Returns `serde_json::Error` if parsing fails.
41    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
42        serde_json::from_str::<T>(&self.response)
43    }
44
45    /// Parse the response string as structured JSON of type `T`, or map the raw
46    /// text into `T` using the provided fallback function if parsing fails.
47    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
48    where
49        T: for<'de> serde::Deserialize<'de>,
50        F: FnOnce(&str) -> T,
51    {
52        self.try_parse::<T>()
53            .unwrap_or_else(|_| fallback(&self.response))
54    }
55}
56
57/// Error type for Basic executor
58#[derive(Debug, thiserror::Error)]
59pub enum BasicExecutorError {
60    #[error("LLM error: {0}")]
61    LLMError(String),
62
63    #[error("Other error: {0}")]
64    Other(String),
65}
66
67impl From<TurnEngineError> for BasicExecutorError {
68    fn from(error: TurnEngineError) -> Self {
69        match error {
70            TurnEngineError::LLMError(err) => BasicExecutorError::LLMError(err),
71            TurnEngineError::Aborted => {
72                BasicExecutorError::Other("Run aborted by hook".to_string())
73            }
74            TurnEngineError::Other(err) => BasicExecutorError::Other(err),
75        }
76    }
77}
78
79/// Wrapper type for the single-turn Basic executor.
80///
81/// Use `BasicAgent<T>` when you want a single request/response interaction
82/// with optional streaming but without tool calling or multi-turn loops.
83#[derive(Debug)]
84pub struct BasicAgent<T: AgentDeriveT> {
85    inner: Arc<T>,
86}
87
88impl<T: AgentDeriveT> Clone for BasicAgent<T> {
89    fn clone(&self) -> Self {
90        Self {
91            inner: Arc::clone(&self.inner),
92        }
93    }
94}
95
96impl<T: AgentDeriveT> BasicAgent<T> {
97    pub fn new(inner: T) -> Self {
98        Self {
99            inner: Arc::new(inner),
100        }
101    }
102}
103
104impl<T: AgentDeriveT> Deref for BasicAgent<T> {
105    type Target = T;
106
107    fn deref(&self) -> &Self::Target {
108        &self.inner
109    }
110}
111
112/// Implement AgentDeriveT for the wrapper by delegating to the inner type
113#[async_trait]
114impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
115    type Output = <T as AgentDeriveT>::Output;
116
117    fn description(&self) -> &str {
118        self.inner.description()
119    }
120
121    fn output_schema(&self) -> Option<Value> {
122        self.inner.output_schema()
123    }
124
125    fn name(&self) -> &str {
126        self.inner.name()
127    }
128
129    fn tools(&self) -> Vec<Box<dyn ToolT>> {
130        self.inner.tools()
131    }
132}
133
134#[async_trait]
135impl<T> AgentHooks for BasicAgent<T>
136where
137    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
138{
139    async fn on_agent_create(&self) {
140        self.inner.on_agent_create().await
141    }
142
143    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
144        self.inner.on_run_start(task, ctx).await
145    }
146
147    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
148        self.inner.on_run_complete(task, result, ctx).await
149    }
150
151    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
152        self.inner.on_turn_start(turn_index, ctx).await
153    }
154
155    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
156        self.inner.on_turn_complete(turn_index, ctx).await
157    }
158
159    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
160        self.inner.on_tool_call(tool_call, ctx).await
161    }
162
163    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
164        self.inner.on_tool_start(tool_call, ctx).await
165    }
166
167    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
168        self.inner.on_tool_result(tool_call, result, ctx).await
169    }
170
171    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
172        self.inner.on_tool_error(tool_call, err, ctx).await
173    }
174    async fn on_agent_shutdown(&self) {
175        self.inner.on_agent_shutdown().await
176    }
177}
178
179/// Implementation of AgentExecutor for the BasicExecutorWrapper
180#[async_trait]
181impl<T: AgentDeriveT + AgentHooks> AgentExecutor for BasicAgent<T> {
182    type Output = BasicAgentOutput;
183    type Error = BasicExecutorError;
184
185    fn config(&self) -> ExecutorConfig {
186        ExecutorConfig { max_turns: 1 }
187    }
188
189    async fn execute(
190        &self,
191        task: &Task,
192        context: Arc<Context>,
193    ) -> Result<Self::Output, Self::Error> {
194        if self.on_run_start(task, &context).await == HookOutcome::Abort {
195            return Err(BasicExecutorError::Other("Run aborted by hook".to_string()));
196        }
197
198        record_task_state(&context, task);
199        let tx_event = context.tx().ok();
200        EventHelper::send_task_started(
201            &tx_event,
202            task.submission_id,
203            context.config().id,
204            context.config().name.clone(),
205            task.prompt.clone(),
206        )
207        .await;
208
209        let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
210        let mut turn_state = engine.turn_state(&context);
211        let turn_result = engine
212            .run_turn(
213                self,
214                task,
215                &context,
216                &mut turn_state,
217                0,
218                self.config().max_turns,
219            )
220            .await?;
221
222        let output = extract_turn_output(turn_result);
223
224        EventHelper::send_task_completed(
225            &tx_event,
226            task.submission_id,
227            context.config().id,
228            context.config().name.clone(),
229            output.response.clone(),
230        )
231        .await;
232
233        Ok(BasicAgentOutput {
234            response: output.response,
235            done: true,
236        })
237    }
238
239    async fn execute_stream(
240        &self,
241        task: &Task,
242        context: Arc<Context>,
243    ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
244    {
245        if self.on_run_start(task, &context).await == HookOutcome::Abort {
246            return Err(BasicExecutorError::Other("Run aborted by hook".to_string()));
247        }
248
249        record_task_state(&context, task);
250        let tx_event = context.tx().ok();
251        EventHelper::send_task_started(
252            &tx_event,
253            task.submission_id,
254            context.config().id,
255            context.config().name.clone(),
256            task.prompt.clone(),
257        )
258        .await;
259
260        let engine = TurnEngine::new(TurnEngineConfig::basic(self.config().max_turns));
261        let mut turn_state = engine.turn_state(&context);
262        let context_clone = context.clone();
263        let task = task.clone();
264        let executor = self.clone();
265
266        let (tx, rx) = channel::<Result<BasicAgentOutput, BasicExecutorError>>(100);
267
268        spawn_future(async move {
269            let turn_stream = engine
270                .run_turn_stream(
271                    executor,
272                    &task,
273                    context_clone.clone(),
274                    &mut turn_state,
275                    0,
276                    1,
277                )
278                .await;
279
280            let mut final_response = String::default();
281
282            match turn_stream {
283                Ok(mut stream) => {
284                    use futures::StreamExt;
285                    while let Some(delta_result) = stream.next().await {
286                        match delta_result {
287                            Ok(TurnDelta::Text(content)) => {
288                                let _ = tx
289                                    .send(Ok(BasicAgentOutput {
290                                        response: content,
291                                        done: false,
292                                    }))
293                                    .await;
294                            }
295                            Ok(TurnDelta::ToolResults(_)) => {}
296                            Ok(TurnDelta::Done(result)) => {
297                                let output = extract_turn_output(result);
298                                final_response = output.response.clone();
299                                let _ = tx
300                                    .send(Ok(BasicAgentOutput {
301                                        response: output.response,
302                                        done: true,
303                                    }))
304                                    .await;
305                                break;
306                            }
307                            Err(err) => {
308                                let _ = tx.send(Err(err.into())).await;
309                                return;
310                            }
311                        }
312                    }
313                }
314                Err(err) => {
315                    let _ = tx.send(Err(err.into())).await;
316                    return;
317                }
318            }
319
320            let tx_event = context_clone.tx().ok();
321            EventHelper::send_stream_complete(&tx_event, task.submission_id).await;
322            EventHelper::send_task_completed(
323                &tx_event,
324                task.submission_id,
325                context_clone.config().id,
326                context_clone.config().name.clone(),
327                final_response,
328            )
329            .await;
330        });
331
332        Ok(receiver_into_stream(rx))
333    }
334}
335
336fn extract_turn_output(
337    result: crate::agent::executor::TurnResult<TurnEngineOutput>,
338) -> TurnEngineOutput {
339    match result {
340        crate::agent::executor::TurnResult::Complete(output) => output,
341        crate::agent::executor::TurnResult::Continue(Some(output)) => output,
342        crate::agent::executor::TurnResult::Continue(None) => TurnEngineOutput {
343            response: String::default(),
344            tool_calls: Vec::default(),
345        },
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use crate::agent::AgentDeriveT;
353    use crate::tests::{ConfigurableLLMProvider, MockAgentImpl, MockLLMProvider, TestAgentOutput};
354    use async_trait::async_trait;
355    use autoagents_llm::chat::{StreamChoice, StreamDelta, StreamResponse};
356    use std::sync::Arc;
357
358    #[derive(Debug, Clone)]
359    struct AbortAgent;
360
361    #[async_trait]
362    impl AgentDeriveT for AbortAgent {
363        type Output = TestAgentOutput;
364
365        fn description(&self) -> &str {
366            "abort"
367        }
368
369        fn output_schema(&self) -> Option<Value> {
370            None
371        }
372
373        fn name(&self) -> &str {
374            "abort_agent"
375        }
376
377        fn tools(&self) -> Vec<Box<dyn ToolT>> {
378            vec![]
379        }
380    }
381
382    #[async_trait]
383    impl AgentHooks for AbortAgent {
384        async fn on_run_start(&self, _task: &Task, _ctx: &Context) -> HookOutcome {
385            HookOutcome::Abort
386        }
387    }
388
389    #[tokio::test]
390    async fn test_basic_agent_execute() {
391        use crate::agent::task::Task;
392        use crate::agent::{AgentConfig, Context};
393        use autoagents_protocol::ActorID;
394
395        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
396        let basic_agent = BasicAgent::new(mock_agent);
397
398        let llm = Arc::new(MockLLMProvider {});
399        let config = AgentConfig {
400            id: ActorID::new_v4(),
401            name: "test_agent".to_string(),
402            description: "Test agent description".to_string(),
403            output_schema: None,
404        };
405
406        let context = Context::new(llm, None).with_config(config);
407
408        let context_arc = Arc::new(context);
409        let task = Task::new("Test task");
410        let result = basic_agent.execute(&task, context_arc).await;
411
412        assert!(result.is_ok());
413        let output = result.unwrap();
414        assert_eq!(output.response, "Mock response");
415        assert!(output.done);
416    }
417
418    #[test]
419    fn test_basic_agent_metadata_and_output_conversion() {
420        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
421        let basic_agent = BasicAgent::new(mock_agent);
422
423        let config = basic_agent.config();
424        assert_eq!(config.max_turns, 1);
425
426        let cloned = basic_agent.clone();
427        assert_eq!(cloned.name(), "test_agent");
428        assert_eq!(cloned.description(), "Test agent description");
429
430        let output = BasicAgentOutput {
431            response: "Test response".to_string(),
432            done: true,
433        };
434        let value: Value = output.clone().into();
435        assert_eq!(value["response"], "Test response");
436        let string: String = output.into();
437        assert_eq!(string, "Test response");
438    }
439
440    #[test]
441    fn test_basic_agent_output_try_parse_success() {
442        let output = BasicAgentOutput {
443            response: r#"{"name":"test","value":42}"#.to_string(),
444            done: true,
445        };
446        #[derive(serde::Deserialize, PartialEq, Debug)]
447        struct Data {
448            name: String,
449            value: i32,
450        }
451        let parsed: Data = output.try_parse().unwrap();
452        assert_eq!(
453            parsed,
454            Data {
455                name: "test".to_string(),
456                value: 42
457            }
458        );
459    }
460
461    #[test]
462    fn test_basic_agent_output_try_parse_failure() {
463        let output = BasicAgentOutput {
464            response: "not json".to_string(),
465            done: true,
466        };
467        let result = output.try_parse::<serde_json::Value>();
468        assert!(result.is_err());
469    }
470
471    #[test]
472    fn test_basic_agent_output_parse_or_map_fallback() {
473        let output = BasicAgentOutput {
474            response: "plain text".to_string(),
475            done: true,
476        };
477        let result: String = output.parse_or_map(|s| s.to_uppercase());
478        assert_eq!(result, "PLAIN TEXT");
479    }
480
481    #[test]
482    fn test_basic_agent_output_parse_or_map_success() {
483        let output = BasicAgentOutput {
484            response: r#""hello""#.to_string(),
485            done: true,
486        };
487        let result: String = output.parse_or_map(|s| s.to_uppercase());
488        assert_eq!(result, "hello");
489    }
490
491    #[test]
492    fn test_error_from_turn_engine_llm() {
493        let err: BasicExecutorError = TurnEngineError::LLMError("bad".to_string()).into();
494        assert!(matches!(err, BasicExecutorError::LLMError(_)));
495        assert!(err.to_string().contains("bad"));
496    }
497
498    #[test]
499    fn test_error_from_turn_engine_aborted() {
500        let err: BasicExecutorError = TurnEngineError::Aborted.into();
501        assert!(matches!(err, BasicExecutorError::Other(_)));
502        assert!(err.to_string().contains("aborted"));
503    }
504
505    #[test]
506    fn test_error_from_turn_engine_other() {
507        let err: BasicExecutorError = TurnEngineError::Other("misc".to_string()).into();
508        assert!(matches!(err, BasicExecutorError::Other(_)));
509        assert!(err.to_string().contains("misc"));
510    }
511
512    #[test]
513    fn test_extract_turn_output_complete() {
514        let result = crate::agent::executor::TurnResult::Complete(
515            crate::agent::executor::turn_engine::TurnEngineOutput {
516                response: "done".to_string(),
517                tool_calls: Vec::new(),
518            },
519        );
520        let output = extract_turn_output(result);
521        assert_eq!(output.response, "done");
522    }
523
524    #[test]
525    fn test_extract_turn_output_continue_some() {
526        let result = crate::agent::executor::TurnResult::Continue(Some(
527            crate::agent::executor::turn_engine::TurnEngineOutput {
528                response: "partial".to_string(),
529                tool_calls: Vec::new(),
530            },
531        ));
532        let output = extract_turn_output(result);
533        assert_eq!(output.response, "partial");
534    }
535
536    #[test]
537    fn test_extract_turn_output_continue_none() {
538        let result = crate::agent::executor::TurnResult::Continue(None);
539        let output = extract_turn_output(result);
540        assert!(output.response.is_empty());
541        assert!(output.tool_calls.is_empty());
542    }
543
544    #[tokio::test]
545    async fn test_basic_agent_execute_stream_returns_output() {
546        use crate::agent::{AgentConfig, Context};
547        use autoagents_protocol::ActorID;
548        use futures::StreamExt;
549
550        let llm = Arc::new(ConfigurableLLMProvider {
551            structured_stream: vec![
552                StreamResponse {
553                    choices: vec![StreamChoice {
554                        delta: StreamDelta {
555                            content: Some("Hello ".to_string()),
556                            tool_calls: None,
557                        },
558                    }],
559                    usage: None,
560                },
561                StreamResponse {
562                    choices: vec![StreamChoice {
563                        delta: StreamDelta {
564                            content: Some("world".to_string()),
565                            tool_calls: None,
566                        },
567                    }],
568                    usage: None,
569                },
570            ],
571            ..ConfigurableLLMProvider::default()
572        });
573
574        let mock_agent = MockAgentImpl::new("stream_agent", "desc");
575        let basic_agent = BasicAgent::new(mock_agent);
576        let config = AgentConfig {
577            id: ActorID::new_v4(),
578            name: "stream_agent".to_string(),
579            description: "desc".to_string(),
580            output_schema: None,
581        };
582        let context = Arc::new(Context::new(llm, None).with_config(config));
583        let task = Task::new("Test task");
584
585        let mut stream = basic_agent.execute_stream(&task, context).await.unwrap();
586        let mut final_output = None;
587        while let Some(item) = stream.next().await {
588            let output = item.unwrap();
589            if output.done {
590                final_output = Some(output);
591                break;
592            }
593        }
594
595        let output = final_output.expect("final output");
596        assert_eq!(output.response, "Hello world");
597        assert!(output.done);
598    }
599
600    #[tokio::test]
601    async fn test_basic_agent_execute_aborts_on_hook() {
602        use crate::agent::{AgentConfig, Context};
603        use autoagents_protocol::ActorID;
604
605        let agent = BasicAgent::new(AbortAgent);
606        let llm = Arc::new(MockLLMProvider {});
607        let config = AgentConfig {
608            id: ActorID::new_v4(),
609            name: "abort_agent".to_string(),
610            description: "abort".to_string(),
611            output_schema: None,
612        };
613        let context = Arc::new(Context::new(llm, None).with_config(config));
614        let task = Task::new("abort");
615
616        let err = agent.execute(&task, context).await.unwrap_err();
617        assert!(err.to_string().contains("aborted"));
618    }
619
620    #[tokio::test]
621    async fn test_basic_agent_execute_stream_aborts_on_hook() {
622        use crate::agent::{AgentConfig, Context};
623        use autoagents_protocol::ActorID;
624
625        let agent = BasicAgent::new(AbortAgent);
626        let llm = Arc::new(MockLLMProvider {});
627        let config = AgentConfig {
628            id: ActorID::new_v4(),
629            name: "abort_agent".to_string(),
630            description: "abort".to_string(),
631            output_schema: None,
632        };
633        let context = Arc::new(Context::new(llm, None).with_config(config));
634        let task = Task::new("abort");
635
636        let err = agent.execute_stream(&task, context).await.err().unwrap();
637        assert!(err.to_string().contains("aborted"));
638    }
639}