autoagents_core/agent/prebuilt/executor/
basic.rs

1use crate::agent::hooks::HookOutcome;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, ExecutorConfig};
4use crate::tool::{ToolCallResult, ToolT};
5use async_trait::async_trait;
6use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
7use autoagents_llm::ToolCall;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15/// Output of the Basic executor
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BasicAgentOutput {
18    pub response: String,
19    pub done: bool,
20}
21
22impl From<BasicAgentOutput> for Value {
23    fn from(output: BasicAgentOutput) -> Self {
24        serde_json::to_value(output).unwrap_or(Value::Null)
25    }
26}
27impl From<BasicAgentOutput> for String {
28    fn from(output: BasicAgentOutput) -> Self {
29        output.response
30    }
31}
32
33/// Error type for Basic executor
34#[derive(Debug, thiserror::Error)]
35pub enum BasicExecutorError {
36    #[error("LLM error: {0}")]
37    LLMError(String),
38
39    #[error("Other error: {0}")]
40    Other(String),
41}
42
43/// Wrapper type for Basic executor
44#[derive(Debug)]
45pub struct BasicAgent<T: AgentDeriveT> {
46    inner: Arc<T>,
47}
48
49impl<T: AgentDeriveT> Clone for BasicAgent<T> {
50    fn clone(&self) -> Self {
51        Self {
52            inner: Arc::clone(&self.inner),
53        }
54    }
55}
56
57impl<T: AgentDeriveT> BasicAgent<T> {
58    pub fn new(inner: T) -> Self {
59        Self {
60            inner: Arc::new(inner),
61        }
62    }
63}
64
65impl<T: AgentDeriveT> Deref for BasicAgent<T> {
66    type Target = T;
67
68    fn deref(&self) -> &Self::Target {
69        &self.inner
70    }
71}
72
73/// Implement AgentDeriveT for the wrapper by delegating to the inner type
74#[async_trait]
75impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
76    type Output = <T as AgentDeriveT>::Output;
77
78    fn description(&self) -> &'static str {
79        self.inner.description()
80    }
81
82    fn output_schema(&self) -> Option<Value> {
83        self.inner.output_schema()
84    }
85
86    fn name(&self) -> &'static str {
87        self.inner.name()
88    }
89
90    fn tools(&self) -> Vec<Box<dyn ToolT>> {
91        self.inner.tools()
92    }
93}
94
95#[async_trait]
96impl<T> AgentHooks for BasicAgent<T>
97where
98    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
99{
100    async fn on_agent_create(&self) {
101        self.inner.on_agent_create().await
102    }
103
104    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
105        self.inner.on_run_start(task, ctx).await
106    }
107
108    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
109        self.inner.on_run_complete(task, result, ctx).await
110    }
111
112    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
113        self.inner.on_turn_start(turn_index, ctx).await
114    }
115
116    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
117        self.inner.on_turn_complete(turn_index, ctx).await
118    }
119
120    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
121        self.inner.on_tool_call(tool_call, ctx).await
122    }
123
124    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
125        self.inner.on_tool_start(tool_call, ctx).await
126    }
127
128    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
129        self.inner.on_tool_result(tool_call, result, ctx).await
130    }
131
132    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
133        self.inner.on_tool_error(tool_call, err, ctx).await
134    }
135    async fn on_agent_shutdown(&self) {
136        self.inner.on_agent_shutdown().await
137    }
138}
139
140/// Implementation of AgentExecutor for the BasicExecutorWrapper
141#[async_trait]
142impl<T: AgentDeriveT> AgentExecutor for BasicAgent<T> {
143    type Output = BasicAgentOutput;
144    type Error = BasicExecutorError;
145
146    fn config(&self) -> ExecutorConfig {
147        ExecutorConfig { max_turns: 1 }
148    }
149
150    async fn execute(
151        &self,
152        task: &Task,
153        context: Arc<Context>,
154    ) -> Result<Self::Output, Self::Error> {
155        let mut messages = vec![ChatMessage {
156            role: ChatRole::System,
157            message_type: MessageType::Text,
158            content: context.config().description.clone(),
159        }];
160
161        let chat_msg = ChatMessage {
162            role: ChatRole::User,
163            message_type: MessageType::Text,
164            content: task.prompt.clone(),
165        };
166        messages.push(chat_msg);
167        let response = context
168            .llm()
169            .chat(&messages, None, context.config().output_schema.clone())
170            .await
171            .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
172        let response_text = response.text().unwrap_or_default();
173        Ok(BasicAgentOutput {
174            response: response_text,
175            done: true,
176        })
177    }
178
179    async fn execute_stream(
180        &self,
181        task: &Task,
182        context: Arc<Context>,
183    ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
184    {
185        use futures::StreamExt;
186
187        let mut messages = vec![ChatMessage {
188            role: ChatRole::System,
189            message_type: MessageType::Text,
190            content: context.config().description.clone(),
191        }];
192
193        let chat_msg = ChatMessage {
194            role: ChatRole::User,
195            message_type: MessageType::Text,
196            content: task.prompt.clone(),
197        };
198        messages.push(chat_msg);
199
200        let stream = context
201            .llm()
202            .chat_stream_struct(&messages, None, context.config().output_schema.clone())
203            .await
204            .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
205
206        let mapped_stream = stream.map(|chunk_result| match chunk_result {
207            Ok(chunk) => {
208                let content = chunk
209                    .choices
210                    .first()
211                    .and_then(|choice| choice.delta.content.as_ref())
212                    .map_or("", |v| v)
213                    .to_string();
214
215                Ok(BasicAgentOutput {
216                    response: content,
217                    done: false,
218                })
219            }
220            Err(e) => Err(BasicExecutorError::LLMError(e.to_string())),
221        });
222
223        Ok(Box::pin(mapped_stream))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::agent::AgentDeriveT;
231    use crate::tests::agent::MockAgentImpl;
232    use autoagents_test_utils::llm::MockLLMProvider;
233    use std::sync::Arc;
234
235    #[test]
236    fn test_basic_agent_creation() {
237        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
238        let basic_agent = BasicAgent::new(mock_agent);
239
240        assert_eq!(basic_agent.name(), "test_agent");
241        assert_eq!(basic_agent.description(), "Test agent description");
242    }
243
244    #[test]
245    fn test_basic_agent_clone() {
246        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
247        let basic_agent = BasicAgent::new(mock_agent);
248        let cloned_agent = basic_agent.clone();
249
250        assert_eq!(cloned_agent.name(), "test_agent");
251        assert_eq!(cloned_agent.description(), "Test agent description");
252    }
253
254    #[test]
255    fn test_basic_agent_output_conversions() {
256        let output = BasicAgentOutput {
257            response: "Test response".to_string(),
258            done: true,
259        };
260
261        // Test conversion to Value
262        let value: Value = output.clone().into();
263        assert!(value.is_object());
264
265        // Test conversion to String
266        let string: String = output.into();
267        assert_eq!(string, "Test response");
268    }
269
270    #[tokio::test]
271    async fn test_basic_agent_execute() {
272        use crate::agent::task::Task;
273        use crate::agent::{AgentConfig, Context};
274        use crate::protocol::ActorID;
275
276        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
277        let basic_agent = BasicAgent::new(mock_agent);
278
279        let llm = Arc::new(MockLLMProvider {});
280        let config = AgentConfig {
281            id: ActorID::new_v4(),
282            name: "test_agent".to_string(),
283            description: "Test agent description".to_string(),
284            output_schema: None,
285        };
286
287        let context = Context::new(llm, None).with_config(config);
288
289        let context_arc = Arc::new(context);
290        let task = Task::new("Test task");
291        let result = basic_agent.execute(&task, context_arc).await;
292
293        assert!(result.is_ok());
294        let output = result.unwrap();
295        assert_eq!(output.response, "Mock response");
296        assert!(output.done);
297    }
298
299    #[test]
300    fn test_executor_config() {
301        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
302        let basic_agent = BasicAgent::new(mock_agent);
303
304        let config = basic_agent.config();
305        assert_eq!(config.max_turns, 1);
306    }
307}