1use crate::agent::base::AgentConfig;
2use crate::agent::runnable::AgentState;
3use crate::memory::MemoryProvider;
4use crate::protocol::Event;
5use crate::runtime::Task;
6use async_trait::async_trait;
7use autoagents_llm::{LLMProvider, ToolT};
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use serde_json::Value;
11use std::error::Error;
12use std::fmt::Debug;
13use std::sync::Arc;
14use tokio::sync::{mpsc, RwLock};
15
16#[derive(Debug)]
18pub enum TurnResult<T> {
19 Continue(Option<T>),
21 Complete(T),
23}
24
25#[derive(Debug, Clone)]
27pub struct ExecutorConfig {
28 pub max_turns: usize,
29}
30
31impl Default for ExecutorConfig {
32 fn default() -> Self {
33 Self { max_turns: 10 }
34 }
35}
36
37#[async_trait]
42pub trait AgentExecutor: Send + Sync + 'static {
43 type Output: Serialize + DeserializeOwned + Clone + Send + Sync + Into<Value>;
44 type Error: Error + Send + Sync + 'static;
45
46 fn config(&self) -> ExecutorConfig;
48
49 #[allow(clippy::too_many_arguments)]
51 async fn execute(
52 &self,
53 llm: Arc<dyn LLMProvider>,
54 memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
55 tools: Vec<Box<dyn ToolT>>,
56 agent_config: &AgentConfig,
57 task: Task,
58 state: Arc<RwLock<AgentState>>,
59 tx_event: mpsc::Sender<Event>,
60 ) -> Result<Self::Output, Self::Error>;
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66 use crate::agent::base::AgentConfig;
67 use crate::agent::runnable::AgentState;
68 use crate::memory::MemoryProvider;
69 use crate::protocol::Event;
70 use crate::runtime::Task;
71 use async_trait::async_trait;
72 use autoagents_llm::{
73 chat::{ChatMessage, ChatProvider, ChatResponse, StructuredOutputFormat},
74 completion::{CompletionProvider, CompletionRequest, CompletionResponse},
75 embedding::EmbeddingProvider,
76 error::LLMError,
77 models::ModelsProvider,
78 LLMProvider, ToolCall, ToolT,
79 };
80 use serde::{Deserialize, Serialize};
81 use std::sync::Arc;
82 use tokio::sync::{mpsc, RwLock};
83 use uuid::Uuid;
84
85 #[derive(Debug, Clone, Serialize, Deserialize)]
86 struct TestOutput {
87 message: String,
88 }
89
90 impl From<TestOutput> for Value {
91 fn from(output: TestOutput) -> Self {
92 serde_json::to_value(output).unwrap_or(Value::Null)
93 }
94 }
95
96 #[derive(Debug, thiserror::Error)]
97 enum TestError {
98 #[error("Test error: {0}")]
99 TestError(String),
100 }
101
102 struct MockExecutor {
103 should_fail: bool,
104 max_turns: usize,
105 }
106
107 impl MockExecutor {
108 fn new(should_fail: bool) -> Self {
109 Self {
110 should_fail,
111 max_turns: 5,
112 }
113 }
114
115 fn with_max_turns(max_turns: usize) -> Self {
116 Self {
117 should_fail: false,
118 max_turns,
119 }
120 }
121 }
122
123 #[async_trait]
124 impl AgentExecutor for MockExecutor {
125 type Output = TestOutput;
126 type Error = TestError;
127
128 fn config(&self) -> ExecutorConfig {
129 ExecutorConfig {
130 max_turns: self.max_turns,
131 }
132 }
133
134 async fn execute(
135 &self,
136 _llm: Arc<dyn LLMProvider>,
137 _memory: Option<Arc<RwLock<Box<dyn MemoryProvider>>>>,
138 _tools: Vec<Box<dyn ToolT>>,
139 _agent_config: &AgentConfig,
140 task: Task,
141 _state: Arc<RwLock<AgentState>>,
142 _tx_event: mpsc::Sender<Event>,
143 ) -> Result<Self::Output, Self::Error> {
144 if self.should_fail {
145 return Err(TestError::TestError("Mock execution failed".to_string()));
146 }
147
148 Ok(TestOutput {
149 message: format!("Processed: {}", task.prompt),
150 })
151 }
152 }
153
154 struct MockLLMProvider;
156
157 #[async_trait]
158 impl ChatProvider for MockLLMProvider {
159 async fn chat_with_tools(
160 &self,
161 _messages: &[ChatMessage],
162 _tools: Option<&[autoagents_llm::chat::Tool]>,
163 _json_schema: Option<StructuredOutputFormat>,
164 ) -> Result<Box<dyn ChatResponse>, LLMError> {
165 Ok(Box::new(MockChatResponse {
166 text: Some("Mock response".to_string()),
167 }))
168 }
169 }
170
171 #[async_trait]
172 impl CompletionProvider for MockLLMProvider {
173 async fn complete(
174 &self,
175 _req: &CompletionRequest,
176 _json_schema: Option<StructuredOutputFormat>,
177 ) -> Result<CompletionResponse, LLMError> {
178 Ok(CompletionResponse {
179 text: "Mock completion".to_string(),
180 })
181 }
182 }
183
184 #[async_trait]
185 impl EmbeddingProvider for MockLLMProvider {
186 async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
187 Ok(vec![vec![0.1, 0.2, 0.3]])
188 }
189 }
190
191 #[async_trait]
192 impl ModelsProvider for MockLLMProvider {}
193
194 impl LLMProvider for MockLLMProvider {}
195
196 struct MockChatResponse {
197 text: Option<String>,
198 }
199
200 impl ChatResponse for MockChatResponse {
201 fn text(&self) -> Option<String> {
202 self.text.clone()
203 }
204
205 fn tool_calls(&self) -> Option<Vec<ToolCall>> {
206 None
207 }
208 }
209
210 impl std::fmt::Debug for MockChatResponse {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 write!(f, "MockChatResponse")
213 }
214 }
215
216 impl std::fmt::Display for MockChatResponse {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 write!(f, "{}", self.text.as_deref().unwrap_or(""))
219 }
220 }
221
222 #[test]
223 fn test_executor_config_default() {
224 let config = ExecutorConfig::default();
225 assert_eq!(config.max_turns, 10);
226 }
227
228 #[test]
229 fn test_executor_config_custom() {
230 let config = ExecutorConfig { max_turns: 5 };
231 assert_eq!(config.max_turns, 5);
232 }
233
234 #[test]
235 fn test_executor_config_clone() {
236 let config = ExecutorConfig { max_turns: 15 };
237 let cloned = config.clone();
238 assert_eq!(config.max_turns, cloned.max_turns);
239 }
240
241 #[test]
242 fn test_executor_config_debug() {
243 let config = ExecutorConfig { max_turns: 20 };
244 let debug_str = format!("{config:?}");
245 assert!(debug_str.contains("ExecutorConfig"));
246 assert!(debug_str.contains("20"));
247 }
248
249 #[test]
250 fn test_turn_result_continue() {
251 let result = TurnResult::<String>::Continue(Some("partial".to_string()));
252 match result {
253 TurnResult::Continue(Some(data)) => assert_eq!(data, "partial"),
254 _ => panic!("Expected Continue variant"),
255 }
256 }
257
258 #[test]
259 fn test_turn_result_continue_none() {
260 let result = TurnResult::<String>::Continue(None);
261 match result {
262 TurnResult::Continue(None) => {}
263 _ => panic!("Expected Continue(None) variant"),
264 }
265 }
266
267 #[test]
268 fn test_turn_result_complete() {
269 let result = TurnResult::Complete("final".to_string());
270 match result {
271 TurnResult::Complete(data) => assert_eq!(data, "final"),
272 _ => panic!("Expected Complete variant"),
273 }
274 }
275
276 #[test]
277 fn test_turn_result_debug() {
278 let result = TurnResult::Complete("test".to_string());
279 let debug_str = format!("{result:?}");
280 assert!(debug_str.contains("Complete"));
281 assert!(debug_str.contains("test"));
282 }
283
284 #[tokio::test]
285 async fn test_mock_executor_success() {
286 let executor = MockExecutor::new(false);
287 let llm = Arc::new(MockLLMProvider);
288 let tools = vec![];
289 let agent_config = AgentConfig {
290 name: "test".to_string(),
291 id: Uuid::new_v4(),
292 description: "test agent".to_string(),
293 output_schema: None,
294 };
295 let task = Task::new("test task", None);
296 let state = Arc::new(RwLock::new(AgentState::new()));
297 let (tx_event, _rx_event) = mpsc::channel(100);
298
299 let result = executor
300 .execute(llm, None, tools, &agent_config, task, state, tx_event)
301 .await;
302
303 assert!(result.is_ok());
304 let output = result.unwrap();
305 assert_eq!(output.message, "Processed: test task");
306 }
307
308 #[tokio::test]
309 async fn test_mock_executor_failure() {
310 let executor = MockExecutor::new(true);
311 let llm = Arc::new(MockLLMProvider);
312 let tools = vec![];
313 let agent_config = AgentConfig {
314 name: "test".to_string(),
315 id: Uuid::new_v4(),
316 description: "test agent".to_string(),
317 output_schema: None,
318 };
319 let task = Task::new("test task", None);
320 let state = Arc::new(RwLock::new(AgentState::new()));
321 let (tx_event, _rx_event) = mpsc::channel(100);
322
323 let result = executor
324 .execute(llm, None, tools, &agent_config, task, state, tx_event)
325 .await;
326
327 assert!(result.is_err());
328 let error = result.unwrap_err();
329 assert_eq!(error.to_string(), "Test error: Mock execution failed");
330 }
331
332 #[test]
333 fn test_mock_executor_config() {
334 let executor = MockExecutor::with_max_turns(3);
335 let config = executor.config();
336 assert_eq!(config.max_turns, 3);
337 }
338
339 #[test]
340 fn test_mock_executor_config_default() {
341 let executor = MockExecutor::new(false);
342 let config = executor.config();
343 assert_eq!(config.max_turns, 5);
344 }
345
346 #[test]
347 fn test_test_output_serialization() {
348 let output = TestOutput {
349 message: "test message".to_string(),
350 };
351 let serialized = serde_json::to_string(&output).unwrap();
352 assert!(serialized.contains("test message"));
353 }
354
355 #[test]
356 fn test_test_output_deserialization() {
357 let json = r#"{"message":"test message"}"#;
358 let output: TestOutput = serde_json::from_str(json).unwrap();
359 assert_eq!(output.message, "test message");
360 }
361
362 #[test]
363 fn test_test_output_clone() {
364 let output = TestOutput {
365 message: "original".to_string(),
366 };
367 let cloned = output.clone();
368 assert_eq!(output.message, cloned.message);
369 }
370
371 #[test]
372 fn test_test_output_debug() {
373 let output = TestOutput {
374 message: "debug test".to_string(),
375 };
376 let debug_str = format!("{output:?}");
377 assert!(debug_str.contains("TestOutput"));
378 assert!(debug_str.contains("debug test"));
379 }
380
381 #[test]
382 fn test_test_output_into_value() {
383 let output = TestOutput {
384 message: "value test".to_string(),
385 };
386 let value: Value = output.into();
387 assert_eq!(value["message"], "value test");
388 }
389
390 #[test]
391 fn test_test_error_display() {
392 let error = TestError::TestError("display test".to_string());
393 assert_eq!(error.to_string(), "Test error: display test");
394 }
395
396 #[test]
397 fn test_test_error_debug() {
398 let error = TestError::TestError("debug test".to_string());
399 let debug_str = format!("{error:?}");
400 assert!(debug_str.contains("TestError"));
401 assert!(debug_str.contains("debug test"));
402 }
403
404 #[test]
405 fn test_test_error_source() {
406 let error = TestError::TestError("source test".to_string());
407 assert!(error.source().is_none());
408 }
409}