autoagents_core/agent/prebuilt/executor/
basic.rs

1use crate::agent::hooks::HookOutcome;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, EventHelper, ExecutorConfig};
4use crate::tool::{ToolCallResult, ToolT};
5use async_trait::async_trait;
6use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
7use autoagents_llm::ToolCall;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15/// Output of the Basic executor
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BasicAgentOutput {
18    pub response: String,
19    pub done: bool,
20}
21
22impl From<BasicAgentOutput> for Value {
23    fn from(output: BasicAgentOutput) -> Self {
24        serde_json::to_value(output).unwrap_or(Value::Null)
25    }
26}
27impl From<BasicAgentOutput> for String {
28    fn from(output: BasicAgentOutput) -> Self {
29        output.response
30    }
31}
32
33impl BasicAgentOutput {
34    /// Try to parse the response string as structured JSON of type `T`.
35    /// Returns `serde_json::Error` if parsing fails.
36    pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
37        serde_json::from_str::<T>(&self.response)
38    }
39
40    /// Parse the response string as structured JSON of type `T`, or map the raw
41    /// text into `T` using the provided fallback function if parsing fails.
42    pub fn parse_or_map<T, F>(&self, fallback: F) -> T
43    where
44        T: for<'de> serde::Deserialize<'de>,
45        F: FnOnce(&str) -> T,
46    {
47        self.try_parse::<T>()
48            .unwrap_or_else(|_| fallback(&self.response))
49    }
50}
51
52/// Error type for Basic executor
53#[derive(Debug, thiserror::Error)]
54pub enum BasicExecutorError {
55    #[error("LLM error: {0}")]
56    LLMError(String),
57
58    #[error("Other error: {0}")]
59    Other(String),
60}
61
62/// Wrapper type for the single-turn Basic executor.
63///
64/// Use `BasicAgent<T>` when you want a single request/response interaction
65/// with optional streaming but without tool calling or multi-turn loops.
66#[derive(Debug)]
67pub struct BasicAgent<T: AgentDeriveT> {
68    inner: Arc<T>,
69}
70
71impl<T: AgentDeriveT> Clone for BasicAgent<T> {
72    fn clone(&self) -> Self {
73        Self {
74            inner: Arc::clone(&self.inner),
75        }
76    }
77}
78
79impl<T: AgentDeriveT> BasicAgent<T> {
80    pub fn new(inner: T) -> Self {
81        Self {
82            inner: Arc::new(inner),
83        }
84    }
85}
86
87impl<T: AgentDeriveT> Deref for BasicAgent<T> {
88    type Target = T;
89
90    fn deref(&self) -> &Self::Target {
91        &self.inner
92    }
93}
94
95/// Implement AgentDeriveT for the wrapper by delegating to the inner type
96#[async_trait]
97impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
98    type Output = <T as AgentDeriveT>::Output;
99
100    fn description(&self) -> &'static str {
101        self.inner.description()
102    }
103
104    fn output_schema(&self) -> Option<Value> {
105        self.inner.output_schema()
106    }
107
108    fn name(&self) -> &'static str {
109        self.inner.name()
110    }
111
112    fn tools(&self) -> Vec<Box<dyn ToolT>> {
113        self.inner.tools()
114    }
115}
116
117#[async_trait]
118impl<T> AgentHooks for BasicAgent<T>
119where
120    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
121{
122    async fn on_agent_create(&self) {
123        self.inner.on_agent_create().await
124    }
125
126    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
127        self.inner.on_run_start(task, ctx).await
128    }
129
130    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
131        self.inner.on_run_complete(task, result, ctx).await
132    }
133
134    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
135        self.inner.on_turn_start(turn_index, ctx).await
136    }
137
138    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
139        self.inner.on_turn_complete(turn_index, ctx).await
140    }
141
142    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
143        self.inner.on_tool_call(tool_call, ctx).await
144    }
145
146    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
147        self.inner.on_tool_start(tool_call, ctx).await
148    }
149
150    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
151        self.inner.on_tool_result(tool_call, result, ctx).await
152    }
153
154    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
155        self.inner.on_tool_error(tool_call, err, ctx).await
156    }
157    async fn on_agent_shutdown(&self) {
158        self.inner.on_agent_shutdown().await
159    }
160}
161
162/// Implementation of AgentExecutor for the BasicExecutorWrapper
163#[async_trait]
164impl<T: AgentDeriveT> AgentExecutor for BasicAgent<T> {
165    type Output = BasicAgentOutput;
166    type Error = BasicExecutorError;
167
168    fn config(&self) -> ExecutorConfig {
169        ExecutorConfig { max_turns: 1 }
170    }
171
172    async fn execute(
173        &self,
174        task: &Task,
175        context: Arc<Context>,
176    ) -> Result<Self::Output, Self::Error> {
177        let tx_event = context.tx().ok();
178        EventHelper::send_task_started(
179            &tx_event,
180            task.submission_id,
181            context.config().id,
182            task.prompt.clone(),
183            context.config().name.clone(),
184        )
185        .await;
186
187        let mut messages = vec![ChatMessage {
188            role: ChatRole::System,
189            message_type: MessageType::Text,
190            content: context.config().description.clone(),
191        }];
192
193        let chat_msg = if let Some((mime, image_data)) = &task.image {
194            // Task has an image, create an Image message
195            ChatMessage {
196                role: ChatRole::User,
197                message_type: MessageType::Image((*mime, image_data.clone())),
198                content: task.prompt.clone(),
199            }
200        } else {
201            // Text-only task
202            ChatMessage {
203                role: ChatRole::User,
204                message_type: MessageType::Text,
205                content: task.prompt.clone(),
206            }
207        };
208        messages.push(chat_msg);
209        let response = context
210            .llm()
211            .chat(&messages, None, context.config().output_schema.clone())
212            .await
213            .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
214        let response_text = response.text().unwrap_or_default();
215        Ok(BasicAgentOutput {
216            response: response_text,
217            done: true,
218        })
219    }
220
221    async fn execute_stream(
222        &self,
223        task: &Task,
224        context: Arc<Context>,
225    ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
226    {
227        use futures::StreamExt;
228
229        let tx_event = context.tx().ok();
230        EventHelper::send_task_started(
231            &tx_event,
232            task.submission_id,
233            context.config().id,
234            task.prompt.clone(),
235            context.config().name.clone(),
236        )
237        .await;
238
239        let mut messages = vec![ChatMessage {
240            role: ChatRole::System,
241            message_type: MessageType::Text,
242            content: context.config().description.clone(),
243        }];
244
245        let chat_msg = if let Some((mime, image_data)) = &task.image {
246            // Task has an image, create an Image message
247            ChatMessage {
248                role: ChatRole::User,
249                message_type: MessageType::Image((*mime, image_data.clone())),
250                content: task.prompt.clone(),
251            }
252        } else {
253            // Text-only task
254            ChatMessage {
255                role: ChatRole::User,
256                message_type: MessageType::Text,
257                content: task.prompt.clone(),
258            }
259        };
260        messages.push(chat_msg);
261
262        let stream = context
263            .llm()
264            .chat_stream_struct(&messages, None, context.config().output_schema.clone())
265            .await
266            .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
267
268        let mapped_stream = stream.map(|chunk_result| match chunk_result {
269            Ok(chunk) => {
270                let content = chunk
271                    .choices
272                    .first()
273                    .and_then(|choice| choice.delta.content.as_ref())
274                    .map_or("", |v| v)
275                    .to_string();
276
277                Ok(BasicAgentOutput {
278                    response: content,
279                    done: false,
280                })
281            }
282            Err(e) => Err(BasicExecutorError::LLMError(e.to_string())),
283        });
284
285        Ok(Box::pin(mapped_stream))
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use crate::agent::AgentDeriveT;
293    use crate::tests::agent::MockAgentImpl;
294    use autoagents_test_utils::llm::MockLLMProvider;
295    use std::sync::Arc;
296
297    #[test]
298    fn test_basic_agent_creation() {
299        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
300        let basic_agent = BasicAgent::new(mock_agent);
301
302        assert_eq!(basic_agent.name(), "test_agent");
303        assert_eq!(basic_agent.description(), "Test agent description");
304    }
305
306    #[test]
307    fn test_basic_agent_clone() {
308        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
309        let basic_agent = BasicAgent::new(mock_agent);
310        let cloned_agent = basic_agent.clone();
311
312        assert_eq!(cloned_agent.name(), "test_agent");
313        assert_eq!(cloned_agent.description(), "Test agent description");
314    }
315
316    #[test]
317    fn test_basic_agent_output_conversions() {
318        let output = BasicAgentOutput {
319            response: "Test response".to_string(),
320            done: true,
321        };
322
323        // Test conversion to Value
324        let value: Value = output.clone().into();
325        assert!(value.is_object());
326
327        // Test conversion to String
328        let string: String = output.into();
329        assert_eq!(string, "Test response");
330    }
331
332    #[tokio::test]
333    async fn test_basic_agent_execute() {
334        use crate::agent::task::Task;
335        use crate::agent::{AgentConfig, Context};
336        use crate::protocol::ActorID;
337
338        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
339        let basic_agent = BasicAgent::new(mock_agent);
340
341        let llm = Arc::new(MockLLMProvider {});
342        let config = AgentConfig {
343            id: ActorID::new_v4(),
344            name: "test_agent".to_string(),
345            description: "Test agent description".to_string(),
346            output_schema: None,
347        };
348
349        let context = Context::new(llm, None).with_config(config);
350
351        let context_arc = Arc::new(context);
352        let task = Task::new("Test task");
353        let result = basic_agent.execute(&task, context_arc).await;
354
355        assert!(result.is_ok());
356        let output = result.unwrap();
357        assert_eq!(output.response, "Mock response");
358        assert!(output.done);
359    }
360
361    #[test]
362    fn test_executor_config() {
363        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
364        let basic_agent = BasicAgent::new(mock_agent);
365
366        let config = basic_agent.config();
367        assert_eq!(config.max_turns, 1);
368    }
369}