1use super::dependency_graph::DependencyGraph;
4use super::types::{BatchMode, BatchOperation, BatchRequest, BatchResponse, OperationResult};
5use serde_json::Value;
6use std::collections::HashSet;
7
8pub struct BatchExecutor;
10
11impl BatchExecutor {
12 pub fn new() -> Self {
14 Self
15 }
16
17 pub async fn execute<F, Fut>(
19 &self,
20 request: BatchRequest,
21 executor_fn: F,
22 ) -> Result<BatchResponse, String>
23 where
24 F: Fn(String, Value) -> Fut + Send + Sync + Clone + 'static,
25 Fut: std::future::Future<Output = Result<Value, (i32, String)>> + Send,
26 {
27 let start_time = std::time::Instant::now();
28
29 let graph = DependencyGraph::new(request.operations)?;
31 let total_operations = graph.len();
32
33 let mut completed = HashSet::new();
34 let mut results = Vec::new();
35 let mut parallel_count = 0;
36 let mut sequential_count = 0;
37
38 match request.mode {
40 BatchMode::Sequential => {
41 for op in graph.operations_in_order() {
43 let result = self.execute_operation(&op, &executor_fn).await;
44 completed.insert(op.id.clone());
45 results.push(result);
46 sequential_count += 1;
47 }
48 }
49 BatchMode::FailFast => {
50 for op in graph.operations_in_order() {
52 let result = self.execute_operation(&op, &executor_fn).await;
53 let success = result.success;
54 completed.insert(op.id.clone());
55 results.push(result);
56 sequential_count += 1;
57
58 if !success {
59 break;
60 }
61 }
62 }
63 BatchMode::Parallel => {
64 while completed.len() < total_operations {
66 let ready = graph.get_ready_operations(&completed);
67
68 if ready.is_empty() {
69 break; }
71
72 let batch_size = ready.len().min(request.max_parallel);
74 let batch: Vec<_> = ready.into_iter().take(batch_size).collect();
75
76 let mut handles = Vec::new();
77 for op in batch {
78 let op_clone = op.clone();
79 let executor_fn_clone = executor_fn.clone();
80 let handle = tokio::spawn(async move {
81 Self::execute_single_operation(&op_clone, executor_fn_clone).await
82 });
83 handles.push((op.id.clone(), handle));
84 }
85
86 for (id, handle) in handles {
88 match handle.await {
89 Ok(result) => {
90 completed.insert(id);
91 results.push(result);
92 parallel_count += 1;
93 }
94 Err(e) => {
95 results.push(OperationResult {
97 id: id.clone(),
98 success: false,
99 result: None,
100 error: Some(super::types::OperationError {
101 code: -32603,
102 message: format!("Operation panicked: {}", e),
103 details: None,
104 }),
105 duration_ms: 0,
106 });
107 completed.insert(id);
108 }
109 }
110 }
111 }
112 }
113 }
114
115 let total_duration_ms = start_time.elapsed().as_millis() as u64;
116 let success_count = results.iter().filter(|r| r.success).count();
117 let failure_count = results.len() - success_count;
118
119 let avg_duration_ms = if !results.is_empty() {
120 results.iter().map(|r| r.duration_ms).sum::<u64>() as f64 / results.len() as f64
121 } else {
122 0.0
123 };
124
125 Ok(BatchResponse {
126 results,
127 total_duration_ms,
128 success_count,
129 failure_count,
130 stats: super::types::BatchStats {
131 total_operations,
132 parallel_executed: parallel_count,
133 sequential_executed: sequential_count,
134 avg_duration_ms,
135 },
136 })
137 }
138
139 async fn execute_single_operation<F, Fut>(
141 op: &BatchOperation,
142 executor_fn: F,
143 ) -> OperationResult
144 where
145 F: Fn(String, Value) -> Fut,
146 Fut: std::future::Future<Output = Result<Value, (i32, String)>>,
147 {
148 let start = std::time::Instant::now();
149
150 match executor_fn(op.tool.clone(), op.arguments.clone()).await {
151 Ok(result) => OperationResult {
152 id: op.id.clone(),
153 success: true,
154 result: Some(result),
155 error: None,
156 duration_ms: start.elapsed().as_millis() as u64,
157 },
158 Err((code, message)) => OperationResult {
159 id: op.id.clone(),
160 success: false,
161 result: None,
162 error: Some(super::types::OperationError {
163 code,
164 message,
165 details: None,
166 }),
167 duration_ms: start.elapsed().as_millis() as u64,
168 },
169 }
170 }
171
172 async fn execute_operation<F, Fut>(
174 &self,
175 op: &BatchOperation,
176 executor_fn: F,
177 ) -> OperationResult
178 where
179 F: Fn(String, Value) -> Fut,
180 Fut: std::future::Future<Output = Result<Value, (i32, String)>>,
181 {
182 Self::execute_single_operation(op, executor_fn).await
183 }
184}
185
186impl Default for BatchExecutor {
187 fn default() -> Self {
188 Self
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use serde_json::Value;
196 use std::sync::Arc;
197
198 #[tokio::test]
199 async fn test_execute_empty_batch() {
200 let executor = BatchExecutor::new();
201 let request = BatchRequest {
202 operations: vec![],
203 mode: BatchMode::Parallel,
204 max_parallel: 10,
205 };
206
207 let result = executor
208 .execute(request, |_, _| async { Ok(Value::Null) })
209 .await
210 .unwrap();
211
212 assert_eq!(result.success_count, 0);
213 assert_eq!(result.failure_count, 0);
214 }
215
216 #[tokio::test]
217 async fn test_execute_sequential_batch() {
218 let executor = BatchExecutor::new();
219 let operations = vec![
220 BatchOperation {
221 id: "op1".to_string(),
222 tool: "tool1".to_string(),
223 arguments: Value::Null,
224 depends_on: vec![],
225 },
226 BatchOperation {
227 id: "op2".to_string(),
228 tool: "tool2".to_string(),
229 arguments: Value::Null,
230 depends_on: vec![],
231 },
232 ];
233
234 let request = BatchRequest {
235 operations,
236 mode: BatchMode::Sequential,
237 max_parallel: 10,
238 };
239
240 let call_count = Arc::new(std::sync::Mutex::new(0));
241 let call_count_clone = Arc::clone(&call_count);
242
243 let result = executor
244 .execute(request, move |_, _| {
245 let count = *call_count_clone.lock().unwrap();
246 *call_count_clone.lock().unwrap() = count + 1;
247 async move { Ok(Value::Null) }
248 })
249 .await
250 .unwrap();
251
252 assert_eq!(result.success_count, 2);
253 assert_eq!(*call_count.lock().unwrap(), 2);
254 }
255
256 #[tokio::test]
257 async fn test_execute_parallel_batch() {
258 let executor = BatchExecutor::new();
259 let operations = vec![
260 BatchOperation {
261 id: "op1".to_string(),
262 tool: "tool1".to_string(),
263 arguments: Value::Null,
264 depends_on: vec![],
265 },
266 BatchOperation {
267 id: "op2".to_string(),
268 tool: "tool2".to_string(),
269 arguments: Value::Null,
270 depends_on: vec![],
271 },
272 ];
273
274 let request = BatchRequest {
275 operations,
276 mode: BatchMode::Parallel,
277 max_parallel: 10,
278 };
279
280 let start = std::time::Instant::now();
281 let result = executor
282 .execute(request, |_, _| async { Ok(Value::Null) })
283 .await
284 .unwrap();
285 let duration = start.elapsed();
286
287 assert_eq!(result.success_count, 2);
288 assert!(duration.as_millis() < 100);
290 }
291
292 #[tokio::test]
293 async fn test_execute_with_dependency() {
294 let executor = BatchExecutor::new();
295 let operations = vec![
296 BatchOperation {
297 id: "op1".to_string(),
298 tool: "tool1".to_string(),
299 arguments: Value::Null,
300 depends_on: vec![],
301 },
302 BatchOperation {
303 id: "op2".to_string(),
304 tool: "tool2".to_string(),
305 arguments: Value::Null,
306 depends_on: vec!["op1".to_string()],
307 },
308 ];
309
310 let request = BatchRequest {
311 operations,
312 mode: BatchMode::Parallel,
313 max_parallel: 10,
314 };
315
316 let result = executor
317 .execute(request, |id, _| async move {
318 if id == "op1" {
319 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
320 }
321 Ok(Value::Null)
322 })
323 .await
324 .unwrap();
325
326 assert_eq!(result.success_count, 2);
328 assert_eq!(result.success_count + result.failure_count, 2);
330 }
331
332 #[tokio::test]
333 async fn test_fail_fast_mode() {
334 let executor = BatchExecutor::new();
335 let operations = vec![
336 BatchOperation {
337 id: "op1".to_string(),
338 tool: "tool1".to_string(),
339 arguments: Value::Null,
340 depends_on: vec![],
341 },
342 BatchOperation {
343 id: "op2".to_string(),
344 tool: "tool2".to_string(),
345 arguments: Value::Null,
346 depends_on: vec![],
347 },
348 ];
349
350 let request = BatchRequest {
351 operations,
352 mode: BatchMode::FailFast,
353 max_parallel: 10,
354 };
355
356 let result = executor
357 .execute(request, |id, _| async move {
358 Err((-32600, format!("Operation {} failed", id)))
359 })
360 .await
361 .unwrap();
362
363 assert_eq!(result.results.len(), 1);
365 assert!(!result.results[0].success);
366 }
367}