autoagents_core/agent/
executor.rs

1use crate::agent::runnable::AgentState;
2use crate::memory::MemoryProvider;
3use crate::protocol::Event;
4use crate::runtime::Task;
5use crate::{agent::base::AgentConfig, tool::ToolT};
6use async_trait::async_trait;
7use autoagents_llm::LLMProvider;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use serde_json::Value;
11use std::error::Error;
12use std::fmt::Debug;
13use std::sync::Arc;
14use tokio::sync::{mpsc, RwLock};
15
16/// Result of processing a single turn in the agent's execution
17#[derive(Debug)]
18pub enum TurnResult<T> {
19    /// Continue processing with optional intermediate data
20    Continue(Option<T>),
21    /// Final result obtained
22    Complete(T),
23}
24
25/// Configuration for executors
26#[derive(Debug, Clone)]
27pub struct ExecutorConfig {
28    pub max_turns: usize,
29}
30
31impl Default for ExecutorConfig {
32    fn default() -> Self {
33        Self { max_turns: 10 }
34    }
35}
36
37/// Base trait for agent execution strategies
38///
39/// Executors are responsible for implementing the specific execution logic
40/// for agents, such as ReAct loops, chain-of-thought, or custom patterns.
41#[async_trait]
42pub trait AgentExecutor: Send + Sync + 'static {
43    type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Into<Value>;
44    type Error: Error + Send + Sync + 'static;
45
46    /// Get the configuration for this executor
47    fn config(&self) -> ExecutorConfig;
48
49    /// Execute the agent with the given task
50    #[allow(clippy::too_many_arguments)]
51    async fn execute(
52        &self,
53        llm: Arc<dyn LLMProvider>,
54        memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
55        tools: Vec<Box<dyn ToolT>>,
56        agent_config: &AgentConfig,
57        task: Task,
58        state: Arc<RwLock<AgentState>>,
59        tx_event: mpsc::Sender<Event>,
60    ) -> Result<Self::Output, Self::Error>;
61}
62
63#[cfg(test)]
64mod tests {
65    use super::*;
66    use crate::agent::base::AgentConfig;
67    use crate::agent::runnable::AgentState;
68    use crate::memory::MemoryProvider;
69    use crate::protocol::Event;
70    use crate::runtime::Task;
71    use async_trait::async_trait;
72    use autoagents_llm::{
73        chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
74        completion::{CompletionProvider, CompletionRequest, CompletionResponse},
75        embedding::EmbeddingProvider,
76        error::LLMError,
77        models::ModelsProvider,
78        LLMProvider, ToolCall,
79    };
80    use serde::{Deserialize, Serialize};
81    use std::sync::Arc;
82    use tokio::sync::{mpsc, RwLock};
83    use uuid::Uuid;
84
85    #[derive(Debug, Clone, Serialize, Deserialize)]
86    struct TestOutput {
87        message: String,
88    }
89
90    impl From<TestOutput> for Value {
91        fn from(output: TestOutput) -> Self {
92            serde_json::to_value(output).unwrap_or(Value::Null)
93        }
94    }
95
96    #[derive(Debug, thiserror::Error)]
97    enum TestError {
98        #[error("Test error: {0}")]
99        TestError(String),
100    }
101
102    struct MockExecutor {
103        should_fail: bool,
104        max_turns: usize,
105    }
106
107    impl MockExecutor {
108        fn new(should_fail: bool) -> Self {
109            Self {
110                should_fail,
111                max_turns: 5,
112            }
113        }
114
115        fn with_max_turns(max_turns: usize) -> Self {
116            Self {
117                should_fail: false,
118                max_turns,
119            }
120        }
121    }
122
123    #[async_trait]
124    impl AgentExecutor for MockExecutor {
125        type Output = TestOutput;
126        type Error = TestError;
127
128        fn config(&self) -> ExecutorConfig {
129            ExecutorConfig {
130                max_turns: self.max_turns,
131            }
132        }
133
134        async fn execute(
135            &self,
136            _llm: Arc<dyn LLMProvider>,
137            _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
138            _tools: Vec<Box<dyn ToolT>>,
139            _agent_config: &AgentConfig,
140            task: Task,
141            _state: Arc<RwLock<AgentState>>,
142            _tx_event: mpsc::Sender<Event>,
143        ) -> Result<Self::Output, Self::Error> {
144            if self.should_fail {
145                return Err(TestError::TestError("Mock execution failed".to_string()));
146            }
147
148            Ok(TestOutput {
149                message: format!("Processed: {}", task.prompt),
150            })
151        }
152    }
153
154    // Mock LLM Provider
155    struct MockLLMProvider;
156
157    #[async_trait]
158    impl ChatProvider for MockLLMProvider {
159        async fn chat_with_tools(
160            &self,
161            _messages: &[ChatMessage],
162            _tools: Option<&[autoagents_llm::chat::Tool]>,
163            _json_schema: Option<StructuredOutputFormat>,
164        ) -> Result<Box<dyn ChatResponse>, LLMError> {
165            Ok(Box::new(MockChatResponse {
166                text: Some("Mock response".to_string()),
167            }))
168        }
169    }
170
171    #[async_trait]
172    impl CompletionProvider for MockLLMProvider {
173        async fn complete(
174            &self,
175            _req: &CompletionRequest,
176            _json_schema: Option<StructuredOutputFormat>,
177        ) -> Result<CompletionResponse, LLMError> {
178            Ok(CompletionResponse {
179                text: "Mock completion".to_string(),
180            })
181        }
182    }
183
184    #[async_trait]
185    impl EmbeddingProvider for MockLLMProvider {
186        async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
187            Ok(vec![vec![0.1, 0.2, 0.3]])
188        }
189    }
190
191    #[async_trait]
192    impl ModelsProvider for MockLLMProvider {}
193
194    impl LLMProvider for MockLLMProvider {}
195
196    struct MockChatResponse {
197        text: Option<String>,
198    }
199
200    impl ChatResponse for MockChatResponse {
201        fn text(&self) -> Option<String> {
202            self.text.clone()
203        }
204
205        fn tool_calls(&self) -> Option<Vec<ToolCall>> {
206            None
207        }
208    }
209
210    impl std::fmt::Debug for MockChatResponse {
211        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212            write!(f, "MockChatResponse")
213        }
214    }
215
216    impl std::fmt::Display for MockChatResponse {
217        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218            write!(f, "{}", self.text.as_deref().unwrap_or(""))
219        }
220    }
221
222    #[test]
223    fn test_executor_config_default() {
224        let config = ExecutorConfig::default();
225        assert_eq!(config.max_turns, 10);
226    }
227
228    #[test]
229    fn test_executor_config_custom() {
230        let config = ExecutorConfig { max_turns: 5 };
231        assert_eq!(config.max_turns, 5);
232    }
233
234    #[test]
235    fn test_executor_config_clone() {
236        let config = ExecutorConfig { max_turns: 15 };
237        let cloned = config.clone();
238        assert_eq!(config.max_turns, cloned.max_turns);
239    }
240
241    #[test]
242    fn test_executor_config_debug() {
243        let config = ExecutorConfig { max_turns: 20 };
244        let debug_str = format!("{config:?}");
245        assert!(debug_str.contains("ExecutorConfig"));
246        assert!(debug_str.contains("20"));
247    }
248
249    #[test]
250    fn test_turn_result_continue() {
251        let result = TurnResult::<String>::Continue(Some("partial".to_string()));
252        match result {
253            TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
254            _ => panic!("Expected Continue variant"),
255        }
256    }
257
258    #[test]
259    fn test_turn_result_continue_none() {
260        let result = TurnResult::<String>::Continue(None);
261        match result {
262            TurnResult::Continue(None) => {}
263            _ => panic!("Expected Continue(None) variant"),
264        }
265    }
266
267    #[test]
268    fn test_turn_result_complete() {
269        let result = TurnResult::Complete("final".to_string());
270        match result {
271            TurnResult::Complete(data) => assert_eq!(data, "final"),
272            _ => panic!("Expected Complete variant"),
273        }
274    }
275
276    #[test]
277    fn test_turn_result_debug() {
278        let result = TurnResult::Complete("test".to_string());
279        let debug_str = format!("{result:?}");
280        assert!(debug_str.contains("Complete"));
281        assert!(debug_str.contains("test"));
282    }
283
284    #[tokio::test]
285    async fn test_mock_executor_success() {
286        let executor = MockExecutor::new(false);
287        let llm = Arc::new(MockLLMProvider);
288        let tools = vec![];
289        let agent_config = AgentConfig {
290            name: "test".to_string(),
291            id: Uuid::new_v4(),
292            description: "test agent".to_string(),
293            output_schema: None,
294        };
295        let task = Task::new("test task", None);
296        let state = Arc::new(RwLock::new(AgentState::new()));
297        let (tx_event, _rx_event) = mpsc::channel(100);
298
299        let result = executor
300            .execute(llm, None, tools, &agent_config, task, state, tx_event)
301            .await;
302
303        assert!(result.is_ok());
304        let output = result.unwrap();
305        assert_eq!(output.message, "Processed: test task");
306    }
307
308    #[tokio::test]
309    async fn test_mock_executor_failure() {
310        let executor = MockExecutor::new(true);
311        let llm = Arc::new(MockLLMProvider);
312        let tools = vec![];
313        let agent_config = AgentConfig {
314            name: "test".to_string(),
315            id: Uuid::new_v4(),
316            description: "test agent".to_string(),
317            output_schema: None,
318        };
319        let task = Task::new("test task", None);
320        let state = Arc::new(RwLock::new(AgentState::new()));
321        let (tx_event, _rx_event) = mpsc::channel(100);
322
323        let result = executor
324            .execute(llm, None, tools, &agent_config, task, state, tx_event)
325            .await;
326
327        assert!(result.is_err());
328        let error = result.unwrap_err();
329        assert_eq!(error.to_string(), "Test error: Mock execution failed");
330    }
331
332    #[test]
333    fn test_mock_executor_config() {
334        let executor = MockExecutor::with_max_turns(3);
335        let config = executor.config();
336        assert_eq!(config.max_turns, 3);
337    }
338
339    #[test]
340    fn test_mock_executor_config_default() {
341        let executor = MockExecutor::new(false);
342        let config = executor.config();
343        assert_eq!(config.max_turns, 5);
344    }
345
346    #[test]
347    fn test_test_output_serialization() {
348        let output = TestOutput {
349            message: "test message".to_string(),
350        };
351        let serialized = serde_json::to_string(&output).unwrap();
352        assert!(serialized.contains("test message"));
353    }
354
355    #[test]
356    fn test_test_output_deserialization() {
357        let json = r#"{"message":"test message"}"#;
358        let output: TestOutput = serde_json::from_str(json).unwrap();
359        assert_eq!(output.message, "test message");
360    }
361
362    #[test]
363    fn test_test_output_clone() {
364        let output = TestOutput {
365            message: "original".to_string(),
366        };
367        let cloned = output.clone();
368        assert_eq!(output.message, cloned.message);
369    }
370
371    #[test]
372    fn test_test_output_debug() {
373        let output = TestOutput {
374            message: "debug test".to_string(),
375        };
376        let debug_str = format!("{output:?}");
377        assert!(debug_str.contains("TestOutput"));
378        assert!(debug_str.contains("debug test"));
379    }
380
381    #[test]
382    fn test_test_output_into_value() {
383        let output = TestOutput {
384            message: "value test".to_string(),
385        };
386        let value: Value = output.into();
387        assert_eq!(value["message"], "value test");
388    }
389
390    #[test]
391    fn test_test_error_display() {
392        let error = TestError::TestError("display test".to_string());
393        assert_eq!(error.to_string(), "Test error: display test");
394    }
395
396    #[test]
397    fn test_test_error_debug() {
398        let error = TestError::TestError("debug test".to_string());
399        let debug_str = format!("{error:?}");
400        assert!(debug_str.contains("TestError"));
401        assert!(debug_str.contains("debug test"));
402    }
403
404    #[test]
405    fn test_test_error_source() {
406        let error = TestError::TestError("source test".to_string());
407        assert!(error.source().is_none());
408    }
409}