autoagents_core/agent/executor/
mod.rs1pub mod event_helper;
2pub mod memory_helper;
3pub mod memory_policy;
4pub mod tool_processor;
5pub mod turn_engine;
6
7use crate::agent::context::Context;
8use crate::agent::task::Task;
9use async_trait::async_trait;
10use futures::Stream;
11use serde::Serialize;
12use serde::de::DeserializeOwned;
13use std::error::Error;
14use std::fmt::Debug;
15use std::sync::Arc;
16
17#[derive(Debug)]
19pub enum TurnResult<T> {
20 Continue(Option<T>),
22 Complete(T),
24}
25
26#[derive(Debug, Clone)]
28pub struct ExecutorConfig {
29 pub max_turns: usize,
30}
31
32impl Default for ExecutorConfig {
33 fn default() -> Self {
34 Self { max_turns: 10 }
35 }
36}
37
38#[async_trait]
43pub trait AgentExecutor: Send + Sync + 'static {
44 type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Debug;
45 type Error: Error + Send + Sync + 'static;
46
47 fn config(&self) -> ExecutorConfig;
48
49 async fn execute(
50 &self,
51 task: &Task,
52 context: Arc<Context>,
53 ) -> Result<Self::Output, Self::Error>;
54
55 async fn execute_stream(
56 &self,
57 task: &Task,
58 context: Arc<Context>,
59 ) -> Result<
60 std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
61 Self::Error,
62 > {
63 let context_clone = context.clone();
65 let result = self.execute(task, context_clone).await;
66 let stream = futures::stream::iter(vec![result]);
67 Ok(Box::pin(stream))
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use crate::agent::context::Context;
75 use crate::agent::task::Task;
76 use async_trait::async_trait;
77 use autoagents_llm::{
78 LLMProvider, ToolCall,
79 chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
80 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
81 embedding::EmbeddingProvider,
82 error::LLMError,
83 models::ModelsProvider,
84 };
85 use futures::stream;
86 use serde::{Deserialize, Serialize};
87 use serde_json::Value;
88 use std::sync::Arc;
89 use tokio::sync::mpsc;
90
91 #[derive(Debug, Clone, Serialize, Deserialize)]
92 struct TestOutput {
93 message: String,
94 }
95
96 impl From<TestOutput> for Value {
97 fn from(output: TestOutput) -> Self {
98 serde_json::to_value(output).unwrap_or(Value::Null)
99 }
100 }
101
102 #[derive(Debug, thiserror::Error)]
103 enum TestError {
104 #[error("Test error: {0}")]
105 TestError(String),
106 }
107
108 struct MockExecutor {
109 should_fail: bool,
110 max_turns: usize,
111 }
112
113 impl MockExecutor {
114 fn new(should_fail: bool) -> Self {
115 Self {
116 should_fail,
117 max_turns: 5,
118 }
119 }
120
121 fn with_max_turns(max_turns: usize) -> Self {
122 Self {
123 should_fail: false,
124 max_turns,
125 }
126 }
127 }
128
129 #[async_trait]
130 impl AgentExecutor for MockExecutor {
131 type Output = TestOutput;
132 type Error = TestError;
133
134 fn config(&self) -> ExecutorConfig {
135 ExecutorConfig {
136 max_turns: self.max_turns,
137 }
138 }
139
140 async fn execute(
141 &self,
142 task: &Task,
143 _context: Arc<Context>,
144 ) -> Result<Self::Output, Self::Error> {
145 if self.should_fail {
146 return Err(TestError::TestError("Mock execution failed".to_string()));
147 }
148
149 Ok(TestOutput {
150 message: format!("Processed: {}", task.prompt),
151 })
152 }
153 async fn execute_stream(
154 &self,
155 task: &Task,
156 context: Arc<Context>,
157 ) -> Result<
158 std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
159 Self::Error,
160 > {
161 let context_clone = context.clone();
163 let result = self.execute(task, context_clone).await;
164 let stream = stream::once(async move { result });
165 Ok(Box::pin(stream))
166 }
167 }
168
169 struct MockLLMProvider;
171
172 #[async_trait]
173 impl ChatProvider for MockLLMProvider {
174 async fn chat(
175 &self,
176 _messages: &[ChatMessage],
177 _json_schema: Option<StructuredOutputFormat>,
178 ) -> Result<Box<dyn ChatResponse>, LLMError> {
179 Ok(Box::new(MockChatResponse {
180 text: Some("Mock response".to_string()),
181 }))
182 }
183 async fn chat_with_tools(
184 &self,
185 _messages: &[ChatMessage],
186 _tools: Option<&[autoagents_llm::chat::Tool]>,
187 _json_schema: Option<StructuredOutputFormat>,
188 ) -> Result<Box<dyn ChatResponse>, LLMError> {
189 Ok(Box::new(MockChatResponse {
190 text: Some("Mock response".to_string()),
191 }))
192 }
193 }
194
195 #[async_trait]
196 impl CompletionProvider for MockLLMProvider {
197 async fn complete(
198 &self,
199 _req: &CompletionRequest,
200 _json_schema: Option<StructuredOutputFormat>,
201 ) -> Result<CompletionResponse, LLMError> {
202 Ok(CompletionResponse {
203 text: "Mock completion".to_string(),
204 })
205 }
206 }
207
208 #[async_trait]
209 impl EmbeddingProvider for MockLLMProvider {
210 async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
211 Ok(vec![vec![0.1, 0.2, 0.3]])
212 }
213 }
214
215 #[async_trait]
216 impl ModelsProvider for MockLLMProvider {}
217
218 impl LLMProvider for MockLLMProvider {}
219
220 struct MockChatResponse {
221 text: Option<String>,
222 }
223
224 impl ChatResponse for MockChatResponse {
225 fn text(&self) -> Option<String> {
226 self.text.clone()
227 }
228
229 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
230 None
231 }
232 }
233
234 impl std::fmt::Debug for MockChatResponse {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 write!(f, "MockChatResponse")
237 }
238 }
239
240 impl std::fmt::Display for MockChatResponse {
241 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
242 write!(f, "{}", self.text.as_deref().unwrap_or(""))
243 }
244 }
245
246 #[test]
247 fn test_executor_config_default() {
248 let config = ExecutorConfig::default();
249 assert_eq!(config.max_turns, 10);
250 }
251
252 #[test]
253 fn test_executor_config_custom() {
254 let config = ExecutorConfig { max_turns: 5 };
255 assert_eq!(config.max_turns, 5);
256 }
257
258 #[test]
259 fn test_executor_config_clone() {
260 let config = ExecutorConfig { max_turns: 15 };
261 let cloned = config.clone();
262 assert_eq!(config.max_turns, cloned.max_turns);
263 }
264
265 #[test]
266 fn test_executor_config_debug() {
267 let config = ExecutorConfig { max_turns: 20 };
268 let debug_str = format!("{config:?}");
269 assert!(debug_str.contains("ExecutorConfig"));
270 assert!(debug_str.contains("20"));
271 }
272
273 #[test]
274 fn test_turn_result_continue() {
275 let result = TurnResult::<String>::Continue(Some("partial".to_string()));
276 match result {
277 TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
278 _ => panic!("Expected Continue variant"),
279 }
280 }
281
282 #[test]
283 fn test_turn_result_continue_none() {
284 let result = TurnResult::<String>::Continue(None);
285 match result {
286 TurnResult::Continue(None) => {}
287 _ => panic!("Expected Continue(None) variant"),
288 }
289 }
290
291 #[test]
292 fn test_turn_result_complete() {
293 let result = TurnResult::Complete("final".to_string());
294 match result {
295 TurnResult::Complete(data) => assert_eq!(data, "final"),
296 _ => panic!("Expected Complete variant"),
297 }
298 }
299
300 #[test]
301 fn test_turn_result_debug() {
302 let result = TurnResult::Complete("test".to_string());
303 let debug_str = format!("{result:?}");
304 assert!(debug_str.contains("Complete"));
305 assert!(debug_str.contains("test"));
306 }
307
308 #[tokio::test]
309 async fn test_mock_executor_success() {
310 let executor = MockExecutor::new(false);
311 let llm = Arc::new(MockLLMProvider);
312 let task = Task::new("test task");
313 let (tx_event, _rx_event) = mpsc::channel(100);
314 let context = Context::new(llm, Some(tx_event));
315
316 let result = executor.execute(&task, Arc::new(context)).await;
317
318 assert!(result.is_ok());
319 let output = result.unwrap();
320 assert_eq!(output.message, "Processed: test task");
321 }
322
323 #[tokio::test]
324 async fn test_mock_executor_failure() {
325 let executor = MockExecutor::new(true);
326 let llm = Arc::new(MockLLMProvider);
327 let task = Task::new("test task");
328 let (tx_event, _rx_event) = mpsc::channel(100);
329 let context = Context::new(llm, Some(tx_event));
330
331 let result = executor.execute(&task, Arc::new(context)).await;
332
333 assert!(result.is_err());
334 let error = result.unwrap_err();
335 assert_eq!(error.to_string(), "Test error: Mock execution failed");
336 }
337
338 #[test]
339 fn test_mock_executor_config() {
340 let executor = MockExecutor::with_max_turns(3);
341 let config = executor.config();
342 assert_eq!(config.max_turns, 3);
343 }
344
345 #[test]
346 fn test_mock_executor_config_default() {
347 let executor = MockExecutor::new(false);
348 let config = executor.config();
349 assert_eq!(config.max_turns, 5);
350 }
351
352 #[test]
353 fn test_test_output_serialization() {
354 let output = TestOutput {
355 message: "test message".to_string(),
356 };
357 let serialized = serde_json::to_string(&output).unwrap();
358 assert!(serialized.contains("test message"));
359 }
360
361 #[test]
362 fn test_test_output_deserialization() {
363 let json = r#"{"message":"test message"}"#;
364 let output: TestOutput = serde_json::from_str(json).unwrap();
365 assert_eq!(output.message, "test message");
366 }
367
368 #[test]
369 fn test_test_output_clone() {
370 let output = TestOutput {
371 message: "original".to_string(),
372 };
373 let cloned = output.clone();
374 assert_eq!(output.message, cloned.message);
375 }
376
377 #[test]
378 fn test_test_output_debug() {
379 let output = TestOutput {
380 message: "debug test".to_string(),
381 };
382 let debug_str = format!("{output:?}");
383 assert!(debug_str.contains("TestOutput"));
384 assert!(debug_str.contains("debug test"));
385 }
386
387 #[test]
388 fn test_test_output_into_value() {
389 let output = TestOutput {
390 message: "value test".to_string(),
391 };
392 let value: Value = output.into();
393 assert_eq!(value["message"], "value test");
394 }
395
396 #[test]
397 fn test_test_error_display() {
398 let error = TestError::TestError("display test".to_string());
399 assert_eq!(error.to_string(), "Test error: display test");
400 }
401
402 #[test]
403 fn test_test_error_debug() {
404 let error = TestError::TestError("debug test".to_string());
405 let debug_str = format!("{error:?}");
406 assert!(debug_str.contains("TestError"));
407 assert!(debug_str.contains("debug test"));
408 }
409
410 #[test]
411 fn test_test_error_source() {
412 let error = TestError::TestError("source test".to_string());
413 assert!(error.source().is_none());
414 }
415}