Skip to main content

autoagents_core/agent/executor/
mod.rs

1pub mod event_helper;
2pub mod memory_helper;
3pub mod tool_processor;
4
5use crate::agent::context::Context;
6use crate::agent::task::Task;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use std::error::Error;
12use std::fmt::Debug;
13use std::sync::Arc;
14
15/// Result of processing a single turn in the agent's execution
16#[derive(Debug)]
17pub enum TurnResult<T> {
18    /// Continue processing with optional intermediate data
19    Continue(Option<T>),
20    /// Final result obtained
21    Complete(T),
22}
23
24/// Configuration for executors
25#[derive(Debug, Clone)]
26pub struct ExecutorConfig {
27    pub max_turns: usize,
28}
29
30impl Default for ExecutorConfig {
31    fn default() -> Self {
32        Self { max_turns: 10 }
33    }
34}
35
36/// Base trait for agent execution strategies
37///
38/// Executors are responsible for implementing the specific execution logic
39/// for agents, such as ReAct loops, chain-of-thought, or custom patterns.
40#[async_trait]
41pub trait AgentExecutor: Send + Sync + 'static {
42    type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Debug;
43    type Error: Error + Send + Sync + 'static;
44
45    fn config(&self) -> ExecutorConfig;
46
47    async fn execute(
48        &self,
49        task: &Task,
50        context: Arc<Context>,
51    ) -> Result<Self::Output, Self::Error>;
52
53    async fn execute_stream(
54        &self,
55        task: &Task,
56        context: Arc<Context>,
57    ) -> Result<
58        std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
59        Self::Error,
60    > {
61        // Default fallback to self.execute with final result as a single-item stream
62        let context_clone = context.clone();
63        let result = self.execute(task, context_clone).await;
64        let stream = futures::stream::iter(vec![result]);
65        Ok(Box::pin(stream))
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::agent::context::Context;
73    use crate::agent::task::Task;
74    use async_trait::async_trait;
75    use autoagents_llm::{
76        LLMProvider, ToolCall,
77        chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
78        completion::{CompletionProvider, CompletionRequest, CompletionResponse},
79        embedding::EmbeddingProvider,
80        error::LLMError,
81        models::ModelsProvider,
82    };
83    use futures::stream;
84    use serde::{Deserialize, Serialize};
85    use serde_json::Value;
86    use std::sync::Arc;
87    use tokio::sync::mpsc;
88
89    #[derive(Debug, Clone, Serialize, Deserialize)]
90    struct TestOutput {
91        message: String,
92    }
93
94    impl From<TestOutput> for Value {
95        fn from(output: TestOutput) -> Self {
96            serde_json::to_value(output).unwrap_or(Value::Null)
97        }
98    }
99
100    #[derive(Debug, thiserror::Error)]
101    enum TestError {
102        #[error("Test error: {0}")]
103        TestError(String),
104    }
105
106    struct MockExecutor {
107        should_fail: bool,
108        max_turns: usize,
109    }
110
111    impl MockExecutor {
112        fn new(should_fail: bool) -> Self {
113            Self {
114                should_fail,
115                max_turns: 5,
116            }
117        }
118
119        fn with_max_turns(max_turns: usize) -> Self {
120            Self {
121                should_fail: false,
122                max_turns,
123            }
124        }
125    }
126
127    #[async_trait]
128    impl AgentExecutor for MockExecutor {
129        type Output = TestOutput;
130        type Error = TestError;
131
132        fn config(&self) -> ExecutorConfig {
133            ExecutorConfig {
134                max_turns: self.max_turns,
135            }
136        }
137
138        async fn execute(
139            &self,
140            task: &Task,
141            _context: Arc<Context>,
142        ) -> Result<Self::Output, Self::Error> {
143            if self.should_fail {
144                return Err(TestError::TestError("Mock execution failed".to_string()));
145            }
146
147            Ok(TestOutput {
148                message: format!("Processed: {}", task.prompt),
149            })
150        }
151        async fn execute_stream(
152            &self,
153            task: &Task,
154            context: Arc<Context>,
155        ) -> Result<
156            std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
157            Self::Error,
158        > {
159            // Use the default implementation from the trait
160            let context_clone = context.clone();
161            let result = self.execute(task, context_clone).await;
162            let stream = stream::once(async move { result });
163            Ok(Box::pin(stream))
164        }
165    }
166
167    // Mock LLM Provider
168    struct MockLLMProvider;
169
170    #[async_trait]
171    impl ChatProvider for MockLLMProvider {
172        async fn chat(
173            &self,
174            _messages: &[ChatMessage],
175            _json_schema: Option<StructuredOutputFormat>,
176        ) -> Result<Box<dyn ChatResponse>, LLMError> {
177            Ok(Box::new(MockChatResponse {
178                text: Some("Mock response".to_string()),
179            }))
180        }
181        async fn chat_with_tools(
182            &self,
183            _messages: &[ChatMessage],
184            _tools: Option<&[autoagents_llm::chat::Tool]>,
185            _json_schema: Option<StructuredOutputFormat>,
186        ) -> Result<Box<dyn ChatResponse>, LLMError> {
187            Ok(Box::new(MockChatResponse {
188                text: Some("Mock response".to_string()),
189            }))
190        }
191    }
192
193    #[async_trait]
194    impl CompletionProvider for MockLLMProvider {
195        async fn complete(
196            &self,
197            _req: &CompletionRequest,
198            _json_schema: Option<StructuredOutputFormat>,
199        ) -> Result<CompletionResponse, LLMError> {
200            Ok(CompletionResponse {
201                text: "Mock completion".to_string(),
202            })
203        }
204    }
205
206    #[async_trait]
207    impl EmbeddingProvider for MockLLMProvider {
208        async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
209            Ok(vec![vec![0.1, 0.2, 0.3]])
210        }
211    }
212
213    #[async_trait]
214    impl ModelsProvider for MockLLMProvider {}
215
216    impl LLMProvider for MockLLMProvider {}
217
218    struct MockChatResponse {
219        text: Option<String>,
220    }
221
222    impl ChatResponse for MockChatResponse {
223        fn text(&self) -> Option<String> {
224            self.text.clone()
225        }
226
227        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
228            None
229        }
230    }
231
232    impl std::fmt::Debug for MockChatResponse {
233        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234            write!(f, "MockChatResponse")
235        }
236    }
237
238    impl std::fmt::Display for MockChatResponse {
239        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240            write!(f, "{}", self.text.as_deref().unwrap_or(""))
241        }
242    }
243
244    #[test]
245    fn test_executor_config_default() {
246        let config = ExecutorConfig::default();
247        assert_eq!(config.max_turns, 10);
248    }
249
250    #[test]
251    fn test_executor_config_custom() {
252        let config = ExecutorConfig { max_turns: 5 };
253        assert_eq!(config.max_turns, 5);
254    }
255
256    #[test]
257    fn test_executor_config_clone() {
258        let config = ExecutorConfig { max_turns: 15 };
259        let cloned = config.clone();
260        assert_eq!(config.max_turns, cloned.max_turns);
261    }
262
263    #[test]
264    fn test_executor_config_debug() {
265        let config = ExecutorConfig { max_turns: 20 };
266        let debug_str = format!("{config:?}");
267        assert!(debug_str.contains("ExecutorConfig"));
268        assert!(debug_str.contains("20"));
269    }
270
271    #[test]
272    fn test_turn_result_continue() {
273        let result = TurnResult::<String>::Continue(Some("partial".to_string()));
274        match result {
275            TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
276            _ => panic!("Expected Continue variant"),
277        }
278    }
279
280    #[test]
281    fn test_turn_result_continue_none() {
282        let result = TurnResult::<String>::Continue(None);
283        match result {
284            TurnResult::Continue(None) => {}
285            _ => panic!("Expected Continue(None) variant"),
286        }
287    }
288
289    #[test]
290    fn test_turn_result_complete() {
291        let result = TurnResult::Complete("final".to_string());
292        match result {
293            TurnResult::Complete(data) => assert_eq!(data, "final"),
294            _ => panic!("Expected Complete variant"),
295        }
296    }
297
298    #[test]
299    fn test_turn_result_debug() {
300        let result = TurnResult::Complete("test".to_string());
301        let debug_str = format!("{result:?}");
302        assert!(debug_str.contains("Complete"));
303        assert!(debug_str.contains("test"));
304    }
305
306    #[tokio::test]
307    async fn test_mock_executor_success() {
308        let executor = MockExecutor::new(false);
309        let llm = Arc::new(MockLLMProvider);
310        let task = Task::new("test task");
311        let (tx_event, _rx_event) = mpsc::channel(100);
312        let context = Context::new(llm, Some(tx_event));
313
314        let result = executor.execute(&task, Arc::new(context)).await;
315
316        assert!(result.is_ok());
317        let output = result.unwrap();
318        assert_eq!(output.message, "Processed: test task");
319    }
320
321    #[tokio::test]
322    async fn test_mock_executor_failure() {
323        let executor = MockExecutor::new(true);
324        let llm = Arc::new(MockLLMProvider);
325        let task = Task::new("test task");
326        let (tx_event, _rx_event) = mpsc::channel(100);
327        let context = Context::new(llm, Some(tx_event));
328
329        let result = executor.execute(&task, Arc::new(context)).await;
330
331        assert!(result.is_err());
332        let error = result.unwrap_err();
333        assert_eq!(error.to_string(), "Test error: Mock execution failed");
334    }
335
336    #[test]
337    fn test_mock_executor_config() {
338        let executor = MockExecutor::with_max_turns(3);
339        let config = executor.config();
340        assert_eq!(config.max_turns, 3);
341    }
342
343    #[test]
344    fn test_mock_executor_config_default() {
345        let executor = MockExecutor::new(false);
346        let config = executor.config();
347        assert_eq!(config.max_turns, 5);
348    }
349
350    #[test]
351    fn test_test_output_serialization() {
352        let output = TestOutput {
353            message: "test message".to_string(),
354        };
355        let serialized = serde_json::to_string(&output).unwrap();
356        assert!(serialized.contains("test message"));
357    }
358
359    #[test]
360    fn test_test_output_deserialization() {
361        let json = r#"{"message":"test message"}"#;
362        let output: TestOutput = serde_json::from_str(json).unwrap();
363        assert_eq!(output.message, "test message");
364    }
365
366    #[test]
367    fn test_test_output_clone() {
368        let output = TestOutput {
369            message: "original".to_string(),
370        };
371        let cloned = output.clone();
372        assert_eq!(output.message, cloned.message);
373    }
374
375    #[test]
376    fn test_test_output_debug() {
377        let output = TestOutput {
378            message: "debug test".to_string(),
379        };
380        let debug_str = format!("{output:?}");
381        assert!(debug_str.contains("TestOutput"));
382        assert!(debug_str.contains("debug test"));
383    }
384
385    #[test]
386    fn test_test_output_into_value() {
387        let output = TestOutput {
388            message: "value test".to_string(),
389        };
390        let value: Value = output.into();
391        assert_eq!(value["message"], "value test");
392    }
393
394    #[test]
395    fn test_test_error_display() {
396        let error = TestError::TestError("display test".to_string());
397        assert_eq!(error.to_string(), "Test error: display test");
398    }
399
400    #[test]
401    fn test_test_error_debug() {
402        let error = TestError::TestError("debug test".to_string());
403        let debug_str = format!("{error:?}");
404        assert!(debug_str.contains("TestError"));
405        assert!(debug_str.contains("debug test"));
406    }
407
408    #[test]
409    fn test_test_error_source() {
410        let error = TestError::TestError("source test".to_string());
411        assert!(error.source().is_none());
412    }
413}