autoagents_core/agent/prebuilt/executor/
basic.rs

1use crate::agent::hooks::HookOutcome;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, AgentExecutor, AgentHooks, Context, EventHelper, ExecutorConfig};
4use crate::tool::{ToolCallResult, ToolT};
5use async_trait::async_trait;
6use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType};
7use autoagents_llm::ToolCall;
8use futures::Stream;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::ops::Deref;
12use std::pin::Pin;
13use std::sync::Arc;
14
15/// Output of the Basic executor
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct BasicAgentOutput {
18    pub response: String,
19    pub done: bool,
20}
21
22impl From<BasicAgentOutput> for Value {
23    fn from(output: BasicAgentOutput) -> Self {
24        serde_json::to_value(output).unwrap_or(Value::Null)
25    }
26}
27impl From<BasicAgentOutput> for String {
28    fn from(output: BasicAgentOutput) -> Self {
29        output.response
30    }
31}
32
33/// Error type for Basic executor
34#[derive(Debug, thiserror::Error)]
35pub enum BasicExecutorError {
36    #[error("LLM error: {0}")]
37    LLMError(String),
38
39    #[error("Other error: {0}")]
40    Other(String),
41}
42
43/// Wrapper type for Basic executor
44#[derive(Debug)]
45pub struct BasicAgent<T: AgentDeriveT> {
46    inner: Arc<T>,
47}
48
49impl<T: AgentDeriveT> Clone for BasicAgent<T> {
50    fn clone(&self) -> Self {
51        Self {
52            inner: Arc::clone(&self.inner),
53        }
54    }
55}
56
57impl<T: AgentDeriveT> BasicAgent<T> {
58    pub fn new(inner: T) -> Self {
59        Self {
60            inner: Arc::new(inner),
61        }
62    }
63}
64
65impl<T: AgentDeriveT> Deref for BasicAgent<T> {
66    type Target = T;
67
68    fn deref(&self) -> &Self::Target {
69        &self.inner
70    }
71}
72
73/// Implement AgentDeriveT for the wrapper by delegating to the inner type
74#[async_trait]
75impl<T: AgentDeriveT> AgentDeriveT for BasicAgent<T> {
76    type Output = <T as AgentDeriveT>::Output;
77
78    fn description(&self) -> &'static str {
79        self.inner.description()
80    }
81
82    fn output_schema(&self) -> Option<Value> {
83        self.inner.output_schema()
84    }
85
86    fn name(&self) -> &'static str {
87        self.inner.name()
88    }
89
90    fn tools(&self) -> Vec<Box<dyn ToolT>> {
91        self.inner.tools()
92    }
93}
94
95#[async_trait]
96impl<T> AgentHooks for BasicAgent<T>
97where
98    T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
99{
100    async fn on_agent_create(&self) {
101        self.inner.on_agent_create().await
102    }
103
104    async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
105        self.inner.on_run_start(task, ctx).await
106    }
107
108    async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
109        self.inner.on_run_complete(task, result, ctx).await
110    }
111
112    async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
113        self.inner.on_turn_start(turn_index, ctx).await
114    }
115
116    async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
117        self.inner.on_turn_complete(turn_index, ctx).await
118    }
119
120    async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
121        self.inner.on_tool_call(tool_call, ctx).await
122    }
123
124    async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
125        self.inner.on_tool_start(tool_call, ctx).await
126    }
127
128    async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
129        self.inner.on_tool_result(tool_call, result, ctx).await
130    }
131
132    async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
133        self.inner.on_tool_error(tool_call, err, ctx).await
134    }
135    async fn on_agent_shutdown(&self) {
136        self.inner.on_agent_shutdown().await
137    }
138}
139
140/// Implementation of AgentExecutor for the BasicExecutorWrapper
141#[async_trait]
142impl<T: AgentDeriveT> AgentExecutor for BasicAgent<T> {
143    type Output = BasicAgentOutput;
144    type Error = BasicExecutorError;
145
146    fn config(&self) -> ExecutorConfig {
147        ExecutorConfig { max_turns: 1 }
148    }
149
150    async fn execute(
151        &self,
152        task: &Task,
153        context: Arc<Context>,
154    ) -> Result<Self::Output, Self::Error> {
155        let tx_event = context.tx().ok();
156        EventHelper::send_task_started(
157            &tx_event,
158            task.submission_id,
159            context.config().id,
160            task.prompt.clone(),
161            context.config().name.clone(),
162        )
163        .await;
164
165        let mut messages = vec![ChatMessage {
166            role: ChatRole::System,
167            message_type: MessageType::Text,
168            content: context.config().description.clone(),
169        }];
170
171        let chat_msg = ChatMessage {
172            role: ChatRole::User,
173            message_type: MessageType::Text,
174            content: task.prompt.clone(),
175        };
176        messages.push(chat_msg);
177        let response = context
178            .llm()
179            .chat(&messages, None, context.config().output_schema.clone())
180            .await
181            .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
182        let response_text = response.text().unwrap_or_default();
183        Ok(BasicAgentOutput {
184            response: response_text,
185            done: true,
186        })
187    }
188
189    async fn execute_stream(
190        &self,
191        task: &Task,
192        context: Arc<Context>,
193    ) -> Result<Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>, Self::Error>
194    {
195        use futures::StreamExt;
196
197        let tx_event = context.tx().ok();
198        EventHelper::send_task_started(
199            &tx_event,
200            task.submission_id,
201            context.config().id,
202            task.prompt.clone(),
203            context.config().name.clone(),
204        )
205        .await;
206
207        let mut messages = vec![ChatMessage {
208            role: ChatRole::System,
209            message_type: MessageType::Text,
210            content: context.config().description.clone(),
211        }];
212
213        let chat_msg = ChatMessage {
214            role: ChatRole::User,
215            message_type: MessageType::Text,
216            content: task.prompt.clone(),
217        };
218        messages.push(chat_msg);
219
220        let stream = context
221            .llm()
222            .chat_stream_struct(&messages, None, context.config().output_schema.clone())
223            .await
224            .map_err(|e| BasicExecutorError::LLMError(e.to_string()))?;
225
226        let mapped_stream = stream.map(|chunk_result| match chunk_result {
227            Ok(chunk) => {
228                let content = chunk
229                    .choices
230                    .first()
231                    .and_then(|choice| choice.delta.content.as_ref())
232                    .map_or("", |v| v)
233                    .to_string();
234
235                Ok(BasicAgentOutput {
236                    response: content,
237                    done: false,
238                })
239            }
240            Err(e) => Err(BasicExecutorError::LLMError(e.to_string())),
241        });
242
243        Ok(Box::pin(mapped_stream))
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use crate::agent::AgentDeriveT;
251    use crate::tests::agent::MockAgentImpl;
252    use autoagents_test_utils::llm::MockLLMProvider;
253    use std::sync::Arc;
254
255    #[test]
256    fn test_basic_agent_creation() {
257        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
258        let basic_agent = BasicAgent::new(mock_agent);
259
260        assert_eq!(basic_agent.name(), "test_agent");
261        assert_eq!(basic_agent.description(), "Test agent description");
262    }
263
264    #[test]
265    fn test_basic_agent_clone() {
266        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
267        let basic_agent = BasicAgent::new(mock_agent);
268        let cloned_agent = basic_agent.clone();
269
270        assert_eq!(cloned_agent.name(), "test_agent");
271        assert_eq!(cloned_agent.description(), "Test agent description");
272    }
273
274    #[test]
275    fn test_basic_agent_output_conversions() {
276        let output = BasicAgentOutput {
277            response: "Test response".to_string(),
278            done: true,
279        };
280
281        // Test conversion to Value
282        let value: Value = output.clone().into();
283        assert!(value.is_object());
284
285        // Test conversion to String
286        let string: String = output.into();
287        assert_eq!(string, "Test response");
288    }
289
290    #[tokio::test]
291    async fn test_basic_agent_execute() {
292        use crate::agent::task::Task;
293        use crate::agent::{AgentConfig, Context};
294        use crate::protocol::ActorID;
295
296        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
297        let basic_agent = BasicAgent::new(mock_agent);
298
299        let llm = Arc::new(MockLLMProvider {});
300        let config = AgentConfig {
301            id: ActorID::new_v4(),
302            name: "test_agent".to_string(),
303            description: "Test agent description".to_string(),
304            output_schema: None,
305        };
306
307        let context = Context::new(llm, None).with_config(config);
308
309        let context_arc = Arc::new(context);
310        let task = Task::new("Test task");
311        let result = basic_agent.execute(&task, context_arc).await;
312
313        assert!(result.is_ok());
314        let output = result.unwrap();
315        assert_eq!(output.response, "Mock response");
316        assert!(output.done);
317    }
318
319    #[test]
320    fn test_executor_config() {
321        let mock_agent = MockAgentImpl::new("test_agent", "Test agent description");
322        let basic_agent = BasicAgent::new(mock_agent);
323
324        let config = basic_agent.config();
325        assert_eq!(config.max_turns, 1);
326    }
327}