Skip to main content

autoagents_core/agent/executor/
mod.rs

1pub mod event_helper;
2pub mod memory_helper;
3pub mod memory_policy;
4pub mod tool_processor;
5pub mod turn_engine;
6
7use crate::agent::context::Context;
8use crate::agent::task::Task;
9use async_trait::async_trait;
10use futures::Stream;
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use std::error::Error;
14use std::fmt::Debug;
15use std::sync::Arc;
16
17/// Result of processing a single turn in the agent's execution
18#[derive(Debug)]
19pub enum TurnResult<T> {
20    /// Continue processing with optional intermediate data
21    Continue(Option<T>),
22    /// Final result obtained
23    Complete(T),
24}
25
26/// Configuration for executors
27#[derive(Debug, Clone)]
28pub struct ExecutorConfig {
29    pub max_turns: usize,
30}
31
32impl Default for ExecutorConfig {
33    fn default() -> Self {
34        Self { max_turns: 10 }
35    }
36}
37
38/// Base trait for agent execution strategies
39///
40/// Executors are responsible for implementing the specific execution logic
41/// for agents, such as ReAct loops, chain-of-thought, or custom patterns.
42#[async_trait]
43pub trait AgentExecutor: Send + Sync + 'static {
44    type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Debug;
45    type Error: Error + Send + Sync + 'static;
46
47    fn config(&self) -> ExecutorConfig;
48
49    async fn execute(
50        &self,
51        task: &Task,
52        context: Arc<Context>,
53    ) -> Result<Self::Output, Self::Error>;
54
55    async fn execute_stream(
56        &self,
57        task: &Task,
58        context: Arc<Context>,
59    ) -> Result<
60        std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
61        Self::Error,
62    > {
63        // Default fallback to self.execute with final result as a single-item stream
64        let context_clone = context.clone();
65        let result = self.execute(task, context_clone).await;
66        let stream = futures::stream::iter(vec![result]);
67        Ok(Box::pin(stream))
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use super::*;
74    use crate::agent::context::Context;
75    use crate::agent::task::Task;
76    use async_trait::async_trait;
77    use autoagents_llm::{
78        LLMProvider, ToolCall,
79        chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
80        completion::{CompletionProvider, CompletionRequest, CompletionResponse},
81        embedding::EmbeddingProvider,
82        error::LLMError,
83        models::ModelsProvider,
84    };
85    use futures::stream;
86    use serde::{Deserialize, Serialize};
87    use serde_json::Value;
88    use std::sync::Arc;
89    use tokio::sync::mpsc;
90
91    #[derive(Debug, Clone, Serialize, Deserialize)]
92    struct TestOutput {
93        message: String,
94    }
95
96    impl From<TestOutput> for Value {
97        fn from(output: TestOutput) -> Self {
98            serde_json::to_value(output).unwrap_or(Value::Null)
99        }
100    }
101
102    #[derive(Debug, thiserror::Error)]
103    enum TestError {
104        #[error("Test error: {0}")]
105        TestError(String),
106    }
107
108    struct MockExecutor {
109        should_fail: bool,
110        max_turns: usize,
111    }
112
113    impl MockExecutor {
114        fn new(should_fail: bool) -> Self {
115            Self {
116                should_fail,
117                max_turns: 5,
118            }
119        }
120
121        fn with_max_turns(max_turns: usize) -> Self {
122            Self {
123                should_fail: false,
124                max_turns,
125            }
126        }
127    }
128
129    #[async_trait]
130    impl AgentExecutor for MockExecutor {
131        type Output = TestOutput;
132        type Error = TestError;
133
134        fn config(&self) -> ExecutorConfig {
135            ExecutorConfig {
136                max_turns: self.max_turns,
137            }
138        }
139
140        async fn execute(
141            &self,
142            task: &Task,
143            _context: Arc<Context>,
144        ) -> Result<Self::Output, Self::Error> {
145            if self.should_fail {
146                return Err(TestError::TestError("Mock execution failed".to_string()));
147            }
148
149            Ok(TestOutput {
150                message: format!("Processed: {}", task.prompt),
151            })
152        }
153        async fn execute_stream(
154            &self,
155            task: &Task,
156            context: Arc<Context>,
157        ) -> Result<
158            std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
159            Self::Error,
160        > {
161            // Use the default implementation from the trait
162            let context_clone = context.clone();
163            let result = self.execute(task, context_clone).await;
164            let stream = stream::once(async move { result });
165            Ok(Box::pin(stream))
166        }
167    }
168
169    // Mock LLM Provider
170    struct MockLLMProvider;
171
172    #[async_trait]
173    impl ChatProvider for MockLLMProvider {
174        async fn chat(
175            &self,
176            _messages: &[ChatMessage],
177            _json_schema: Option<StructuredOutputFormat>,
178        ) -> Result<Box<dyn ChatResponse>, LLMError> {
179            Ok(Box::new(MockChatResponse {
180                text: Some("Mock response".to_string()),
181            }))
182        }
183        async fn chat_with_tools(
184            &self,
185            _messages: &[ChatMessage],
186            _tools: Option<&[autoagents_llm::chat::Tool]>,
187            _json_schema: Option<StructuredOutputFormat>,
188        ) -> Result<Box<dyn ChatResponse>, LLMError> {
189            Ok(Box::new(MockChatResponse {
190                text: Some("Mock response".to_string()),
191            }))
192        }
193    }
194
195    #[async_trait]
196    impl CompletionProvider for MockLLMProvider {
197        async fn complete(
198            &self,
199            _req: &CompletionRequest,
200            _json_schema: Option<StructuredOutputFormat>,
201        ) -> Result<CompletionResponse, LLMError> {
202            Ok(CompletionResponse {
203                text: "Mock completion".to_string(),
204            })
205        }
206    }
207
208    #[async_trait]
209    impl EmbeddingProvider for MockLLMProvider {
210        async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
211            Ok(vec![vec![0.1, 0.2, 0.3]])
212        }
213    }
214
215    #[async_trait]
216    impl ModelsProvider for MockLLMProvider {}
217
218    impl LLMProvider for MockLLMProvider {}
219
220    struct MockChatResponse {
221        text: Option<String>,
222    }
223
224    impl ChatResponse for MockChatResponse {
225        fn text(&self) -> Option<String> {
226            self.text.clone()
227        }
228
229        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
230            None
231        }
232    }
233
234    impl std::fmt::Debug for MockChatResponse {
235        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236            write!(f, "MockChatResponse")
237        }
238    }
239
240    impl std::fmt::Display for MockChatResponse {
241        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242            write!(f, "{}", self.text.as_deref().unwrap_or(""))
243        }
244    }
245
246    #[test]
247    fn test_executor_config_default() {
248        let config = ExecutorConfig::default();
249        assert_eq!(config.max_turns, 10);
250    }
251
252    #[test]
253    fn test_executor_config_custom() {
254        let config = ExecutorConfig { max_turns: 5 };
255        assert_eq!(config.max_turns, 5);
256    }
257
258    #[test]
259    fn test_executor_config_clone() {
260        let config = ExecutorConfig { max_turns: 15 };
261        let cloned = config.clone();
262        assert_eq!(config.max_turns, cloned.max_turns);
263    }
264
265    #[test]
266    fn test_executor_config_debug() {
267        let config = ExecutorConfig { max_turns: 20 };
268        let debug_str = format!("{config:?}");
269        assert!(debug_str.contains("ExecutorConfig"));
270        assert!(debug_str.contains("20"));
271    }
272
273    #[test]
274    fn test_turn_result_continue() {
275        let result = TurnResult::<String>::Continue(Some("partial".to_string()));
276        match result {
277            TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
278            _ => panic!("Expected Continue variant"),
279        }
280    }
281
282    #[test]
283    fn test_turn_result_continue_none() {
284        let result = TurnResult::<String>::Continue(None);
285        match result {
286            TurnResult::Continue(None) => {}
287            _ => panic!("Expected Continue(None) variant"),
288        }
289    }
290
291    #[test]
292    fn test_turn_result_complete() {
293        let result = TurnResult::Complete("final".to_string());
294        match result {
295            TurnResult::Complete(data) => assert_eq!(data, "final"),
296            _ => panic!("Expected Complete variant"),
297        }
298    }
299
300    #[test]
301    fn test_turn_result_debug() {
302        let result = TurnResult::Complete("test".to_string());
303        let debug_str = format!("{result:?}");
304        assert!(debug_str.contains("Complete"));
305        assert!(debug_str.contains("test"));
306    }
307
308    #[tokio::test]
309    async fn test_mock_executor_success() {
310        let executor = MockExecutor::new(false);
311        let llm = Arc::new(MockLLMProvider);
312        let task = Task::new("test task");
313        let (tx_event, _rx_event) = mpsc::channel(100);
314        let context = Context::new(llm, Some(tx_event));
315
316        let result = executor.execute(&task, Arc::new(context)).await;
317
318        assert!(result.is_ok());
319        let output = result.unwrap();
320        assert_eq!(output.message, "Processed: test task");
321    }
322
323    #[tokio::test]
324    async fn test_mock_executor_failure() {
325        let executor = MockExecutor::new(true);
326        let llm = Arc::new(MockLLMProvider);
327        let task = Task::new("test task");
328        let (tx_event, _rx_event) = mpsc::channel(100);
329        let context = Context::new(llm, Some(tx_event));
330
331        let result = executor.execute(&task, Arc::new(context)).await;
332
333        assert!(result.is_err());
334        let error = result.unwrap_err();
335        assert_eq!(error.to_string(), "Test error: Mock execution failed");
336    }
337
338    #[test]
339    fn test_mock_executor_config() {
340        let executor = MockExecutor::with_max_turns(3);
341        let config = executor.config();
342        assert_eq!(config.max_turns, 3);
343    }
344
345    #[test]
346    fn test_mock_executor_config_default() {
347        let executor = MockExecutor::new(false);
348        let config = executor.config();
349        assert_eq!(config.max_turns, 5);
350    }
351
352    #[test]
353    fn test_test_output_serialization() {
354        let output = TestOutput {
355            message: "test message".to_string(),
356        };
357        let serialized = serde_json::to_string(&output).unwrap();
358        assert!(serialized.contains("test message"));
359    }
360
361    #[test]
362    fn test_test_output_deserialization() {
363        let json = r#"{"message":"test message"}"#;
364        let output: TestOutput = serde_json::from_str(json).unwrap();
365        assert_eq!(output.message, "test message");
366    }
367
368    #[test]
369    fn test_test_output_clone() {
370        let output = TestOutput {
371            message: "original".to_string(),
372        };
373        let cloned = output.clone();
374        assert_eq!(output.message, cloned.message);
375    }
376
377    #[test]
378    fn test_test_output_debug() {
379        let output = TestOutput {
380            message: "debug test".to_string(),
381        };
382        let debug_str = format!("{output:?}");
383        assert!(debug_str.contains("TestOutput"));
384        assert!(debug_str.contains("debug test"));
385    }
386
387    #[test]
388    fn test_test_output_into_value() {
389        let output = TestOutput {
390            message: "value test".to_string(),
391        };
392        let value: Value = output.into();
393        assert_eq!(value["message"], "value test");
394    }
395
396    #[test]
397    fn test_test_error_display() {
398        let error = TestError::TestError("display test".to_string());
399        assert_eq!(error.to_string(), "Test error: display test");
400    }
401
402    #[test]
403    fn test_test_error_debug() {
404        let error = TestError::TestError("debug test".to_string());
405        let debug_str = format!("{error:?}");
406        assert!(debug_str.contains("TestError"));
407        assert!(debug_str.contains("debug test"));
408    }
409
410    #[test]
411    fn test_test_error_source() {
412        let error = TestError::TestError("source test".to_string());
413        assert!(error.source().is_none());
414    }
415}