autoagents_core/agent/executor/
mod.rs1pub mod event_helper;
2pub mod memory_helper;
3pub mod tool_processor;
4
5use crate::agent::context::Context;
6use crate::agent::task::Task;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use std::error::Error;
12use std::fmt::Debug;
13use std::sync::Arc;
14
15#[derive(Debug)]
17pub enum TurnResult<T> {
18 Continue(Option<T>),
20 Complete(T),
22}
23
24#[derive(Debug, Clone)]
26pub struct ExecutorConfig {
27 pub max_turns: usize,
28}
29
30impl Default for ExecutorConfig {
31 fn default() -> Self {
32 Self { max_turns: 10 }
33 }
34}
35
36#[async_trait]
41pub trait AgentExecutor: Send + Sync + 'static {
42 type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Debug;
43 type Error: Error + Send + Sync + 'static;
44
45 fn config(&self) -> ExecutorConfig;
46
47 async fn execute(
48 &self,
49 task: &Task,
50 context: Arc<Context>,
51 ) -> Result<Self::Output, Self::Error>;
52
53 async fn execute_stream(
54 &self,
55 task: &Task,
56 context: Arc<Context>,
57 ) -> Result<
58 std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
59 Self::Error,
60 > {
61 let context_clone = context.clone();
63 let result = self.execute(task, context_clone).await;
64 let stream = futures::stream::iter(vec![result]);
65 Ok(Box::pin(stream))
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::agent::context::Context;
73 use crate::agent::task::Task;
74 use async_trait::async_trait;
75 use autoagents_llm::{
76 LLMProvider, ToolCall,
77 chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
78 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
79 embedding::EmbeddingProvider,
80 error::LLMError,
81 models::ModelsProvider,
82 };
83 use futures::stream;
84 use serde::{Deserialize, Serialize};
85 use serde_json::Value;
86 use std::sync::Arc;
87 use tokio::sync::mpsc;
88
89 #[derive(Debug, Clone, Serialize, Deserialize)]
90 struct TestOutput {
91 message: String,
92 }
93
94 impl From<TestOutput> for Value {
95 fn from(output: TestOutput) -> Self {
96 serde_json::to_value(output).unwrap_or(Value::Null)
97 }
98 }
99
100 #[derive(Debug, thiserror::Error)]
101 enum TestError {
102 #[error("Test error: {0}")]
103 TestError(String),
104 }
105
106 struct MockExecutor {
107 should_fail: bool,
108 max_turns: usize,
109 }
110
111 impl MockExecutor {
112 fn new(should_fail: bool) -> Self {
113 Self {
114 should_fail,
115 max_turns: 5,
116 }
117 }
118
119 fn with_max_turns(max_turns: usize) -> Self {
120 Self {
121 should_fail: false,
122 max_turns,
123 }
124 }
125 }
126
127 #[async_trait]
128 impl AgentExecutor for MockExecutor {
129 type Output = TestOutput;
130 type Error = TestError;
131
132 fn config(&self) -> ExecutorConfig {
133 ExecutorConfig {
134 max_turns: self.max_turns,
135 }
136 }
137
138 async fn execute(
139 &self,
140 task: &Task,
141 _context: Arc<Context>,
142 ) -> Result<Self::Output, Self::Error> {
143 if self.should_fail {
144 return Err(TestError::TestError("Mock execution failed".to_string()));
145 }
146
147 Ok(TestOutput {
148 message: format!("Processed: {}", task.prompt),
149 })
150 }
151 async fn execute_stream(
152 &self,
153 task: &Task,
154 context: Arc<Context>,
155 ) -> Result<
156 std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
157 Self::Error,
158 > {
159 let context_clone = context.clone();
161 let result = self.execute(task, context_clone).await;
162 let stream = stream::once(async move { result });
163 Ok(Box::pin(stream))
164 }
165 }
166
167 struct MockLLMProvider;
169
170 #[async_trait]
171 impl ChatProvider for MockLLMProvider {
172 async fn chat(
173 &self,
174 _messages: &[ChatMessage],
175 _json_schema: Option<StructuredOutputFormat>,
176 ) -> Result<Box<dyn ChatResponse>, LLMError> {
177 Ok(Box::new(MockChatResponse {
178 text: Some("Mock response".to_string()),
179 }))
180 }
181 async fn chat_with_tools(
182 &self,
183 _messages: &[ChatMessage],
184 _tools: Option<&[autoagents_llm::chat::Tool]>,
185 _json_schema: Option<StructuredOutputFormat>,
186 ) -> Result<Box<dyn ChatResponse>, LLMError> {
187 Ok(Box::new(MockChatResponse {
188 text: Some("Mock response".to_string()),
189 }))
190 }
191 }
192
193 #[async_trait]
194 impl CompletionProvider for MockLLMProvider {
195 async fn complete(
196 &self,
197 _req: &CompletionRequest,
198 _json_schema: Option<StructuredOutputFormat>,
199 ) -> Result<CompletionResponse, LLMError> {
200 Ok(CompletionResponse {
201 text: "Mock completion".to_string(),
202 })
203 }
204 }
205
206 #[async_trait]
207 impl EmbeddingProvider for MockLLMProvider {
208 async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
209 Ok(vec![vec![0.1, 0.2, 0.3]])
210 }
211 }
212
213 #[async_trait]
214 impl ModelsProvider for MockLLMProvider {}
215
216 impl LLMProvider for MockLLMProvider {}
217
218 struct MockChatResponse {
219 text: Option<String>,
220 }
221
222 impl ChatResponse for MockChatResponse {
223 fn text(&self) -> Option<String> {
224 self.text.clone()
225 }
226
227 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
228 None
229 }
230 }
231
232 impl std::fmt::Debug for MockChatResponse {
233 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234 write!(f, "MockChatResponse")
235 }
236 }
237
238 impl std::fmt::Display for MockChatResponse {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 write!(f, "{}", self.text.as_deref().unwrap_or(""))
241 }
242 }
243
244 #[test]
245 fn test_executor_config_default() {
246 let config = ExecutorConfig::default();
247 assert_eq!(config.max_turns, 10);
248 }
249
250 #[test]
251 fn test_executor_config_custom() {
252 let config = ExecutorConfig { max_turns: 5 };
253 assert_eq!(config.max_turns, 5);
254 }
255
256 #[test]
257 fn test_executor_config_clone() {
258 let config = ExecutorConfig { max_turns: 15 };
259 let cloned = config.clone();
260 assert_eq!(config.max_turns, cloned.max_turns);
261 }
262
263 #[test]
264 fn test_executor_config_debug() {
265 let config = ExecutorConfig { max_turns: 20 };
266 let debug_str = format!("{config:?}");
267 assert!(debug_str.contains("ExecutorConfig"));
268 assert!(debug_str.contains("20"));
269 }
270
271 #[test]
272 fn test_turn_result_continue() {
273 let result = TurnResult::<String>::Continue(Some("partial".to_string()));
274 match result {
275 TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
276 _ => panic!("Expected Continue variant"),
277 }
278 }
279
280 #[test]
281 fn test_turn_result_continue_none() {
282 let result = TurnResult::<String>::Continue(None);
283 match result {
284 TurnResult::Continue(None) => {}
285 _ => panic!("Expected Continue(None) variant"),
286 }
287 }
288
289 #[test]
290 fn test_turn_result_complete() {
291 let result = TurnResult::Complete("final".to_string());
292 match result {
293 TurnResult::Complete(data) => assert_eq!(data, "final"),
294 _ => panic!("Expected Complete variant"),
295 }
296 }
297
298 #[test]
299 fn test_turn_result_debug() {
300 let result = TurnResult::Complete("test".to_string());
301 let debug_str = format!("{result:?}");
302 assert!(debug_str.contains("Complete"));
303 assert!(debug_str.contains("test"));
304 }
305
306 #[tokio::test]
307 async fn test_mock_executor_success() {
308 let executor = MockExecutor::new(false);
309 let llm = Arc::new(MockLLMProvider);
310 let task = Task::new("test task");
311 let (tx_event, _rx_event) = mpsc::channel(100);
312 let context = Context::new(llm, Some(tx_event));
313
314 let result = executor.execute(&task, Arc::new(context)).await;
315
316 assert!(result.is_ok());
317 let output = result.unwrap();
318 assert_eq!(output.message, "Processed: test task");
319 }
320
321 #[tokio::test]
322 async fn test_mock_executor_failure() {
323 let executor = MockExecutor::new(true);
324 let llm = Arc::new(MockLLMProvider);
325 let task = Task::new("test task");
326 let (tx_event, _rx_event) = mpsc::channel(100);
327 let context = Context::new(llm, Some(tx_event));
328
329 let result = executor.execute(&task, Arc::new(context)).await;
330
331 assert!(result.is_err());
332 let error = result.unwrap_err();
333 assert_eq!(error.to_string(), "Test error: Mock execution failed");
334 }
335
336 #[test]
337 fn test_mock_executor_config() {
338 let executor = MockExecutor::with_max_turns(3);
339 let config = executor.config();
340 assert_eq!(config.max_turns, 3);
341 }
342
343 #[test]
344 fn test_mock_executor_config_default() {
345 let executor = MockExecutor::new(false);
346 let config = executor.config();
347 assert_eq!(config.max_turns, 5);
348 }
349
350 #[test]
351 fn test_test_output_serialization() {
352 let output = TestOutput {
353 message: "test message".to_string(),
354 };
355 let serialized = serde_json::to_string(&output).unwrap();
356 assert!(serialized.contains("test message"));
357 }
358
359 #[test]
360 fn test_test_output_deserialization() {
361 let json = r#"{"message":"test message"}"#;
362 let output: TestOutput = serde_json::from_str(json).unwrap();
363 assert_eq!(output.message, "test message");
364 }
365
366 #[test]
367 fn test_test_output_clone() {
368 let output = TestOutput {
369 message: "original".to_string(),
370 };
371 let cloned = output.clone();
372 assert_eq!(output.message, cloned.message);
373 }
374
375 #[test]
376 fn test_test_output_debug() {
377 let output = TestOutput {
378 message: "debug test".to_string(),
379 };
380 let debug_str = format!("{output:?}");
381 assert!(debug_str.contains("TestOutput"));
382 assert!(debug_str.contains("debug test"));
383 }
384
385 #[test]
386 fn test_test_output_into_value() {
387 let output = TestOutput {
388 message: "value test".to_string(),
389 };
390 let value: Value = output.into();
391 assert_eq!(value["message"], "value test");
392 }
393
394 #[test]
395 fn test_test_error_display() {
396 let error = TestError::TestError("display test".to_string());
397 assert_eq!(error.to_string(), "Test error: display test");
398 }
399
400 #[test]
401 fn test_test_error_debug() {
402 let error = TestError::TestError("debug test".to_string());
403 let debug_str = format!("{error:?}");
404 assert!(debug_str.contains("TestError"));
405 assert!(debug_str.contains("debug test"));
406 }
407
408 #[test]
409 fn test_test_error_source() {
410 let error = TestError::TestError("source test".to_string());
411 assert!(error.source().is_none());
412 }
413}