autoagents_core/agent/executor/
mod.rs1pub mod event_helper;
2pub mod memory_helper;
3pub mod tool_processor;
4
5use crate::agent::context::Context;
6use crate::agent::task::Task;
7use async_trait::async_trait;
8use futures::Stream;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::error::Error;
12use std::fmt::Debug;
13use std::sync::Arc;
14
15#[derive(Debug)]
17pub enum TurnResult<T> {
18 Continue(Option<T>),
20 Complete(T),
22}
23
24#[derive(Debug, Clone)]
26pub struct ExecutorConfig {
27 pub max_turns: usize,
28}
29
30impl Default for ExecutorConfig {
31 fn default() -> Self {
32 Self { max_turns: 10 }
33 }
34}
35
36#[async_trait]
41pub trait AgentExecutor: Send + Sync + 'static {
42 type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Debug;
43 type Error: Error + Send + Sync + 'static;
44
45 fn config(&self) -> ExecutorConfig;
46
47 async fn execute(
48 &self,
49 task: &Task,
50 context: Arc<Context>,
51 ) -> Result<Self::Output, Self::Error>;
52
53 async fn execute_stream(
54 &self,
55 task: &Task,
56 context: Arc<Context>,
57 ) -> Result<
58 std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
59 Self::Error,
60 > {
61 let context_clone = context.clone();
63 let result = self.execute(task, context_clone).await;
64 let stream = futures::stream::iter(vec![result]);
65 Ok(Box::pin(stream))
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::agent::context::Context;
73 use crate::agent::task::Task;
74 use async_trait::async_trait;
75 use autoagents_llm::{
76 chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
77 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
78 embedding::EmbeddingProvider,
79 error::LLMError,
80 models::ModelsProvider,
81 LLMProvider, ToolCall,
82 };
83 use futures::stream;
84 use serde::{Deserialize, Serialize};
85 use serde_json::Value;
86 use std::sync::Arc;
87 use tokio::sync::mpsc;
88
89 #[derive(Debug, Clone, Serialize, Deserialize)]
90 struct TestOutput {
91 message: String,
92 }
93
94 impl From<TestOutput> for Value {
95 fn from(output: TestOutput) -> Self {
96 serde_json::to_value(output).unwrap_or(Value::Null)
97 }
98 }
99
100 #[derive(Debug, thiserror::Error)]
101 enum TestError {
102 #[error("Test error: {0}")]
103 TestError(String),
104 }
105
106 struct MockExecutor {
107 should_fail: bool,
108 max_turns: usize,
109 }
110
111 impl MockExecutor {
112 fn new(should_fail: bool) -> Self {
113 Self {
114 should_fail,
115 max_turns: 5,
116 }
117 }
118
119 fn with_max_turns(max_turns: usize) -> Self {
120 Self {
121 should_fail: false,
122 max_turns,
123 }
124 }
125 }
126
127 #[async_trait]
128 impl AgentExecutor for MockExecutor {
129 type Output = TestOutput;
130 type Error = TestError;
131
132 fn config(&self) -> ExecutorConfig {
133 ExecutorConfig {
134 max_turns: self.max_turns,
135 }
136 }
137
138 async fn execute(
139 &self,
140 task: &Task,
141 _context: Arc<Context>,
142 ) -> Result<Self::Output, Self::Error> {
143 if self.should_fail {
144 return Err(TestError::TestError("Mock execution failed".to_string()));
145 }
146
147 Ok(TestOutput {
148 message: format!("Processed: {}", task.prompt),
149 })
150 }
151 async fn execute_stream(
152 &self,
153 task: &Task,
154 context: Arc<Context>,
155 ) -> Result<
156 std::pin::Pin<Box<dyn Stream<Item = Result<Self::Output, Self::Error>> + Send>>,
157 Self::Error,
158 > {
159 let context_clone = context.clone();
161 let result = self.execute(task, context_clone).await;
162 let stream = stream::once(async move { result });
163 Ok(Box::pin(stream))
164 }
165 }
166
167 struct MockLLMProvider;
169
170 #[async_trait]
171 impl ChatProvider for MockLLMProvider {
172 async fn chat(
173 &self,
174 _messages: &[ChatMessage],
175 _tools: Option<&[autoagents_llm::chat::Tool]>,
176 _json_schema: Option<StructuredOutputFormat>,
177 ) -> Result<Box<dyn ChatResponse>, LLMError> {
178 Ok(Box::new(MockChatResponse {
179 text: Some("Mock response".to_string()),
180 }))
181 }
182 }
183
184 #[async_trait]
185 impl CompletionProvider for MockLLMProvider {
186 async fn complete(
187 &self,
188 _req: &CompletionRequest,
189 _json_schema: Option<StructuredOutputFormat>,
190 ) -> Result<CompletionResponse, LLMError> {
191 Ok(CompletionResponse {
192 text: "Mock completion".to_string(),
193 })
194 }
195 }
196
197 #[async_trait]
198 impl EmbeddingProvider for MockLLMProvider {
199 async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
200 Ok(vec![vec![0.1, 0.2, 0.3]])
201 }
202 }
203
204 #[async_trait]
205 impl ModelsProvider for MockLLMProvider {}
206
207 impl LLMProvider for MockLLMProvider {}
208
209 struct MockChatResponse {
210 text: Option<String>,
211 }
212
213 impl ChatResponse for MockChatResponse {
214 fn text(&self) -> Option<String> {
215 self.text.clone()
216 }
217
218 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
219 None
220 }
221 }
222
223 impl std::fmt::Debug for MockChatResponse {
224 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225 write!(f, "MockChatResponse")
226 }
227 }
228
229 impl std::fmt::Display for MockChatResponse {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 write!(f, "{}", self.text.as_deref().unwrap_or(""))
232 }
233 }
234
235 #[test]
236 fn test_executor_config_default() {
237 let config = ExecutorConfig::default();
238 assert_eq!(config.max_turns, 10);
239 }
240
241 #[test]
242 fn test_executor_config_custom() {
243 let config = ExecutorConfig { max_turns: 5 };
244 assert_eq!(config.max_turns, 5);
245 }
246
247 #[test]
248 fn test_executor_config_clone() {
249 let config = ExecutorConfig { max_turns: 15 };
250 let cloned = config.clone();
251 assert_eq!(config.max_turns, cloned.max_turns);
252 }
253
254 #[test]
255 fn test_executor_config_debug() {
256 let config = ExecutorConfig { max_turns: 20 };
257 let debug_str = format!("{config:?}");
258 assert!(debug_str.contains("ExecutorConfig"));
259 assert!(debug_str.contains("20"));
260 }
261
262 #[test]
263 fn test_turn_result_continue() {
264 let result = TurnResult::<String>::Continue(Some("partial".to_string()));
265 match result {
266 TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
267 _ => panic!("Expected Continue variant"),
268 }
269 }
270
271 #[test]
272 fn test_turn_result_continue_none() {
273 let result = TurnResult::<String>::Continue(None);
274 match result {
275 TurnResult::Continue(None) => {}
276 _ => panic!("Expected Continue(None) variant"),
277 }
278 }
279
280 #[test]
281 fn test_turn_result_complete() {
282 let result = TurnResult::Complete("final".to_string());
283 match result {
284 TurnResult::Complete(data) => assert_eq!(data, "final"),
285 _ => panic!("Expected Complete variant"),
286 }
287 }
288
289 #[test]
290 fn test_turn_result_debug() {
291 let result = TurnResult::Complete("test".to_string());
292 let debug_str = format!("{result:?}");
293 assert!(debug_str.contains("Complete"));
294 assert!(debug_str.contains("test"));
295 }
296
297 #[tokio::test]
298 async fn test_mock_executor_success() {
299 let executor = MockExecutor::new(false);
300 let llm = Arc::new(MockLLMProvider);
301 let task = Task::new("test task");
302 let (tx_event, _rx_event) = mpsc::channel(100);
303 let context = Context::new(llm, Some(tx_event));
304
305 let result = executor.execute(&task, Arc::new(context)).await;
306
307 assert!(result.is_ok());
308 let output = result.unwrap();
309 assert_eq!(output.message, "Processed: test task");
310 }
311
312 #[tokio::test]
313 async fn test_mock_executor_failure() {
314 let executor = MockExecutor::new(true);
315 let llm = Arc::new(MockLLMProvider);
316 let task = Task::new("test task");
317 let (tx_event, _rx_event) = mpsc::channel(100);
318 let context = Context::new(llm, Some(tx_event));
319
320 let result = executor.execute(&task, Arc::new(context)).await;
321
322 assert!(result.is_err());
323 let error = result.unwrap_err();
324 assert_eq!(error.to_string(), "Test error: Mock execution failed");
325 }
326
327 #[test]
328 fn test_mock_executor_config() {
329 let executor = MockExecutor::with_max_turns(3);
330 let config = executor.config();
331 assert_eq!(config.max_turns, 3);
332 }
333
334 #[test]
335 fn test_mock_executor_config_default() {
336 let executor = MockExecutor::new(false);
337 let config = executor.config();
338 assert_eq!(config.max_turns, 5);
339 }
340
341 #[test]
342 fn test_test_output_serialization() {
343 let output = TestOutput {
344 message: "test message".to_string(),
345 };
346 let serialized = serde_json::to_string(&output).unwrap();
347 assert!(serialized.contains("test message"));
348 }
349
350 #[test]
351 fn test_test_output_deserialization() {
352 let json = r#"{"message":"test message"}"#;
353 let output: TestOutput = serde_json::from_str(json).unwrap();
354 assert_eq!(output.message, "test message");
355 }
356
357 #[test]
358 fn test_test_output_clone() {
359 let output = TestOutput {
360 message: "original".to_string(),
361 };
362 let cloned = output.clone();
363 assert_eq!(output.message, cloned.message);
364 }
365
366 #[test]
367 fn test_test_output_debug() {
368 let output = TestOutput {
369 message: "debug test".to_string(),
370 };
371 let debug_str = format!("{output:?}");
372 assert!(debug_str.contains("TestOutput"));
373 assert!(debug_str.contains("debug test"));
374 }
375
376 #[test]
377 fn test_test_output_into_value() {
378 let output = TestOutput {
379 message: "value test".to_string(),
380 };
381 let value: Value = output.into();
382 assert_eq!(value["message"], "value test");
383 }
384
385 #[test]
386 fn test_test_error_display() {
387 let error = TestError::TestError("display test".to_string());
388 assert_eq!(error.to_string(), "Test error: display test");
389 }
390
391 #[test]
392 fn test_test_error_debug() {
393 let error = TestError::TestError("debug test".to_string());
394 let debug_str = format!("{error:?}");
395 assert!(debug_str.contains("TestError"));
396 assert!(debug_str.contains("debug test"));
397 }
398
399 #[test]
400 fn test_test_error_source() {
401 let error = TestError::TestError("source test".to_string());
402 assert!(error.source().is_none());
403 }
404}