autoagents_core/agent/executor/
mod.rs

1pub mod event_helper;
2pub mod memory_helper;
3pub mod tool_processor;
4
5use crate::agent::context::Context;
6use crate::agent::task::Task;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::error::Error;
12use std::fmt::Debug;
13use std::sync::Arc;
14
15/// Result of processing a single turn in the agent's execution
16#[derive(Debug)]
17pub enum TurnResult<T> {
18    /// Continue processing with optional intermediate data
19    Continue(Option<T>),
20    /// Final result obtained
21    Complete(T),
22}
23
24/// Configuration for executors
25#[derive(Debug, Clone)]
26pub struct ExecutorConfig {
27    pub max_turns: usize,
28}
29
30impl Default for ExecutorConfig {
31    fn default() -> Self {
32        Self { max_turns: 10 }
33    }
34}
35
36/// Base trait for agent execution strategies
37///
38/// Executors are responsible for implementing the specific execution logic
39/// for agents, such as ReAct loops, chain-of-thought, or custom patterns.
40#[async_trait]
41pub trait AgentExecutor: Send + Sync + 'static {
42    type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Debug;
43    type Error: Error + Send + Sync + 'static;
44
45    fn config(&self) -> ExecutorConfig;
46
47    async fn execute(
48        &self,
49        task: &Task,
50        context: Arc<Context>,
51    ) -> Result<Self::Output, Self::Error>;
52
53    async fn execute_stream(
54        &self,
55        task: &Task,
56        context: Arc<Context>,
57    ) -> Result<
58        std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
59        Self::Error,
60    > {
61        // Default fallback to self.execute with final result as a single-item stream
62        let context_clone = context.clone();
63        let result = self.execute(task, context_clone).await;
64        let stream = futures::stream::iter(vec![result]);
65        Ok(Box::pin(stream))
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::agent::context::Context;
73    use crate::agent::task::Task;
74    use async_trait::async_trait;
75    use autoagents_llm::{
76        chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
77        completion::{CompletionProvider, CompletionRequest, CompletionResponse},
78        embedding::EmbeddingProvider,
79        error::LLMError,
80        models::ModelsProvider,
81        LLMProvider, ToolCall,
82    };
83    use futures::stream;
84    use serde::{Deserialize, Serialize};
85    use serde_json::Value;
86    use std::sync::Arc;
87    use tokio::sync::mpsc;
88
89    #[derive(Debug, Clone, Serialize, Deserialize)]
90    struct TestOutput {
91        message: String,
92    }
93
94    impl From<TestOutput> for Value {
95        fn from(output: TestOutput) -> Self {
96            serde_json::to_value(output).unwrap_or(Value::Null)
97        }
98    }
99
100    #[derive(Debug, thiserror::Error)]
101    enum TestError {
102        #[error("Test error: {0}")]
103        TestError(String),
104    }
105
106    struct MockExecutor {
107        should_fail: bool,
108        max_turns: usize,
109    }
110
111    impl MockExecutor {
112        fn new(should_fail: bool) -> Self {
113            Self {
114                should_fail,
115                max_turns: 5,
116            }
117        }
118
119        fn with_max_turns(max_turns: usize) -> Self {
120            Self {
121                should_fail: false,
122                max_turns,
123            }
124        }
125    }
126
127    #[async_trait]
128    impl AgentExecutor for MockExecutor {
129        type Output = TestOutput;
130        type Error = TestError;
131
132        fn config(&self) -> ExecutorConfig {
133            ExecutorConfig {
134                max_turns: self.max_turns,
135            }
136        }
137
138        async fn execute(
139            &self,
140            task: &Task,
141            _context: Arc<Context>,
142        ) -> Result<Self::Output, Self::Error> {
143            if self.should_fail {
144                return Err(TestError::TestError("Mock execution failed".to_string()));
145            }
146
147            Ok(TestOutput {
148                message: format!("Processed: {}", task.prompt),
149            })
150        }
151        async fn execute_stream(
152            &self,
153            task: &Task,
154            context: Arc<Context>,
155        ) -> Result<
156            std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
157            Self::Error,
158        > {
159            // Use the default implementation from the trait
160            let context_clone = context.clone();
161            let result = self.execute(task, context_clone).await;
162            let stream = stream::once(async move { result });
163            Ok(Box::pin(stream))
164        }
165    }
166
167    // Mock LLM Provider
168    struct MockLLMProvider;
169
170    #[async_trait]
171    impl ChatProvider for MockLLMProvider {
172        async fn chat(
173            &self,
174            _messages: &[ChatMessage],
175            _tools: Option<&[autoagents_llm::chat::Tool]>,
176            _json_schema: Option<StructuredOutputFormat>,
177        ) -> Result<Box<dyn ChatResponse>, LLMError> {
178            Ok(Box::new(MockChatResponse {
179                text: Some("Mock response".to_string()),
180            }))
181        }
182    }
183
184    #[async_trait]
185    impl CompletionProvider for MockLLMProvider {
186        async fn complete(
187            &self,
188            _req: &CompletionRequest,
189            _json_schema: Option<StructuredOutputFormat>,
190        ) -> Result<CompletionResponse, LLMError> {
191            Ok(CompletionResponse {
192                text: "Mock completion".to_string(),
193            })
194        }
195    }
196
197    #[async_trait]
198    impl EmbeddingProvider for MockLLMProvider {
199        async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
200            Ok(vec![vec![0.1, 0.2, 0.3]])
201        }
202    }
203
204    #[async_trait]
205    impl ModelsProvider for MockLLMProvider {}
206
207    impl LLMProvider for MockLLMProvider {}
208
209    struct MockChatResponse {
210        text: Option<String>,
211    }
212
213    impl ChatResponse for MockChatResponse {
214        fn text(&self) -> Option<String> {
215            self.text.clone()
216        }
217
218        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
219            None
220        }
221    }
222
223    impl std::fmt::Debug for MockChatResponse {
224        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225            write!(f, "MockChatResponse")
226        }
227    }
228
229    impl std::fmt::Display for MockChatResponse {
230        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231            write!(f, "{}", self.text.as_deref().unwrap_or(""))
232        }
233    }
234
235    #[test]
236    fn test_executor_config_default() {
237        let config = ExecutorConfig::default();
238        assert_eq!(config.max_turns, 10);
239    }
240
241    #[test]
242    fn test_executor_config_custom() {
243        let config = ExecutorConfig { max_turns: 5 };
244        assert_eq!(config.max_turns, 5);
245    }
246
247    #[test]
248    fn test_executor_config_clone() {
249        let config = ExecutorConfig { max_turns: 15 };
250        let cloned = config.clone();
251        assert_eq!(config.max_turns, cloned.max_turns);
252    }
253
254    #[test]
255    fn test_executor_config_debug() {
256        let config = ExecutorConfig { max_turns: 20 };
257        let debug_str = format!("{config:?}");
258        assert!(debug_str.contains("ExecutorConfig"));
259        assert!(debug_str.contains("20"));
260    }
261
262    #[test]
263    fn test_turn_result_continue() {
264        let result = TurnResult::<String>::Continue(Some("partial".to_string()));
265        match result {
266            TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
267            _ => panic!("Expected Continue variant"),
268        }
269    }
270
271    #[test]
272    fn test_turn_result_continue_none() {
273        let result = TurnResult::<String>::Continue(None);
274        match result {
275            TurnResult::Continue(None) => {}
276            _ => panic!("Expected Continue(None) variant"),
277        }
278    }
279
280    #[test]
281    fn test_turn_result_complete() {
282        let result = TurnResult::Complete("final".to_string());
283        match result {
284            TurnResult::Complete(data) => assert_eq!(data, "final"),
285            _ => panic!("Expected Complete variant"),
286        }
287    }
288
289    #[test]
290    fn test_turn_result_debug() {
291        let result = TurnResult::Complete("test".to_string());
292        let debug_str = format!("{result:?}");
293        assert!(debug_str.contains("Complete"));
294        assert!(debug_str.contains("test"));
295    }
296
297    #[tokio::test]
298    async fn test_mock_executor_success() {
299        let executor = MockExecutor::new(false);
300        let llm = Arc::new(MockLLMProvider);
301        let task = Task::new("test task");
302        let (tx_event, _rx_event) = mpsc::channel(100);
303        let context = Context::new(llm, Some(tx_event));
304
305        let result = executor.execute(&task, Arc::new(context)).await;
306
307        assert!(result.is_ok());
308        let output = result.unwrap();
309        assert_eq!(output.message, "Processed: test task");
310    }
311
312    #[tokio::test]
313    async fn test_mock_executor_failure() {
314        let executor = MockExecutor::new(true);
315        let llm = Arc::new(MockLLMProvider);
316        let task = Task::new("test task");
317        let (tx_event, _rx_event) = mpsc::channel(100);
318        let context = Context::new(llm, Some(tx_event));
319
320        let result = executor.execute(&task, Arc::new(context)).await;
321
322        assert!(result.is_err());
323        let error = result.unwrap_err();
324        assert_eq!(error.to_string(), "Test error: Mock execution failed");
325    }
326
327    #[test]
328    fn test_mock_executor_config() {
329        let executor = MockExecutor::with_max_turns(3);
330        let config = executor.config();
331        assert_eq!(config.max_turns, 3);
332    }
333
334    #[test]
335    fn test_mock_executor_config_default() {
336        let executor = MockExecutor::new(false);
337        let config = executor.config();
338        assert_eq!(config.max_turns, 5);
339    }
340
341    #[test]
342    fn test_test_output_serialization() {
343        let output = TestOutput {
344            message: "test message".to_string(),
345        };
346        let serialized = serde_json::to_string(&output).unwrap();
347        assert!(serialized.contains("test message"));
348    }
349
350    #[test]
351    fn test_test_output_deserialization() {
352        let json = r#"{"message":"test message"}"#;
353        let output: TestOutput = serde_json::from_str(json).unwrap();
354        assert_eq!(output.message, "test message");
355    }
356
357    #[test]
358    fn test_test_output_clone() {
359        let output = TestOutput {
360            message: "original".to_string(),
361        };
362        let cloned = output.clone();
363        assert_eq!(output.message, cloned.message);
364    }
365
366    #[test]
367    fn test_test_output_debug() {
368        let output = TestOutput {
369            message: "debug test".to_string(),
370        };
371        let debug_str = format!("{output:?}");
372        assert!(debug_str.contains("TestOutput"));
373        assert!(debug_str.contains("debug test"));
374    }
375
376    #[test]
377    fn test_test_output_into_value() {
378        let output = TestOutput {
379            message: "value test".to_string(),
380        };
381        let value: Value = output.into();
382        assert_eq!(value["message"], "value test");
383    }
384
385    #[test]
386    fn test_test_error_display() {
387        let error = TestError::TestError("display test".to_string());
388        assert_eq!(error.to_string(), "Test error: display test");
389    }
390
391    #[test]
392    fn test_test_error_debug() {
393        let error = TestError::TestError("debug test".to_string());
394        let debug_str = format!("{error:?}");
395        assert!(debug_str.contains("TestError"));
396        assert!(debug_str.contains("debug test"));
397    }
398
399    #[test]
400    fn test_test_error_source() {
401        let error = TestError::TestError("source test".to_string());
402        assert!(error.source().is_none());
403    }
404}