Skip to main content

do_memory_mcp/batch/
executor.rs

1//! Batch executor for managing parallel execution
2
3use super::dependency_graph::DependencyGraph;
4use super::types::{BatchMode, BatchOperation, BatchRequest, BatchResponse, OperationResult};
5use serde_json::Value;
6use std::collections::HashSet;
7
8/// Batch executor for managing parallel execution
9pub struct BatchExecutor;
10
11impl BatchExecutor {
12    /// Create a new batch executor
13    pub fn new() -> Self {
14        Self
15    }
16
17    /// Execute a batch of operations with dependency management
18    pub async fn execute<F, Fut>(
19        &self,
20        request: BatchRequest,
21        executor_fn: F,
22    ) -> Result<BatchResponse, String>
23    where
24        F: Fn(String, Value) -> Fut + Send + Sync + Clone + 'static,
25        Fut: std::future::Future<Output = Result<Value, (i32, String)>> + Send,
26    {
27        let start_time = std::time::Instant::now();
28
29        // Build dependency graph
30        let graph = DependencyGraph::new(request.operations)?;
31        let total_operations = graph.len();
32
33        let mut completed = HashSet::new();
34        let mut results = Vec::new();
35        let mut parallel_count = 0;
36        let mut sequential_count = 0;
37
38        // Execute operations based on mode
39        match request.mode {
40            BatchMode::Sequential => {
41                // Execute all operations sequentially in insertion order
42                for op in graph.operations_in_order() {
43                    let result = self.execute_operation(&op, &executor_fn).await;
44                    completed.insert(op.id.clone());
45                    results.push(result);
46                    sequential_count += 1;
47                }
48            }
49            BatchMode::FailFast => {
50                // Execute operations sequentially in insertion order, stop on first failure
51                for op in graph.operations_in_order() {
52                    let result = self.execute_operation(&op, &executor_fn).await;
53                    let success = result.success;
54                    completed.insert(op.id.clone());
55                    results.push(result);
56                    sequential_count += 1;
57
58                    if !success {
59                        break;
60                    }
61                }
62            }
63            BatchMode::Parallel => {
64                // Execute operations respecting dependencies
65                while completed.len() < total_operations {
66                    let ready = graph.get_ready_operations(&completed);
67
68                    if ready.is_empty() {
69                        break; // No more operations can be executed
70                    }
71
72                    // Execute ready operations in parallel (up to max_parallel)
73                    let batch_size = ready.len().min(request.max_parallel);
74                    let batch: Vec<_> = ready.into_iter().take(batch_size).collect();
75
76                    let mut handles = Vec::new();
77                    for op in batch {
78                        let op_clone = op.clone();
79                        let executor_fn_clone = executor_fn.clone();
80                        let handle = tokio::spawn(async move {
81                            Self::execute_single_operation(&op_clone, executor_fn_clone).await
82                        });
83                        handles.push((op.id.clone(), handle));
84                    }
85
86                    // Wait for all operations in this batch to complete
87                    for (id, handle) in handles {
88                        match handle.await {
89                            Ok(result) => {
90                                completed.insert(id);
91                                results.push(result);
92                                parallel_count += 1;
93                            }
94                            Err(e) => {
95                                // Task panicked
96                                results.push(OperationResult {
97                                    id: id.clone(),
98                                    success: false,
99                                    result: None,
100                                    error: Some(super::types::OperationError {
101                                        code: -32603,
102                                        message: format!("Operation panicked: {}", e),
103                                        details: None,
104                                    }),
105                                    duration_ms: 0,
106                                });
107                                completed.insert(id);
108                            }
109                        }
110                    }
111                }
112            }
113        }
114
115        let total_duration_ms = start_time.elapsed().as_millis() as u64;
116        let success_count = results.iter().filter(|r| r.success).count();
117        let failure_count = results.len() - success_count;
118
119        let avg_duration_ms = if !results.is_empty() {
120            results.iter().map(|r| r.duration_ms).sum::<u64>() as f64 / results.len() as f64
121        } else {
122            0.0
123        };
124
125        Ok(BatchResponse {
126            results,
127            total_duration_ms,
128            success_count,
129            failure_count,
130            stats: super::types::BatchStats {
131                total_operations,
132                parallel_executed: parallel_count,
133                sequential_executed: sequential_count,
134                avg_duration_ms,
135            },
136        })
137    }
138
139    /// Execute a single operation
140    async fn execute_single_operation<F, Fut>(
141        op: &BatchOperation,
142        executor_fn: F,
143    ) -> OperationResult
144    where
145        F: Fn(String, Value) -> Fut,
146        Fut: std::future::Future<Output = Result<Value, (i32, String)>>,
147    {
148        let start = std::time::Instant::now();
149
150        match executor_fn(op.tool.clone(), op.arguments.clone()).await {
151            Ok(result) => OperationResult {
152                id: op.id.clone(),
153                success: true,
154                result: Some(result),
155                error: None,
156                duration_ms: start.elapsed().as_millis() as u64,
157            },
158            Err((code, message)) => OperationResult {
159                id: op.id.clone(),
160                success: false,
161                result: None,
162                error: Some(super::types::OperationError {
163                    code,
164                    message,
165                    details: None,
166                }),
167                duration_ms: start.elapsed().as_millis() as u64,
168            },
169        }
170    }
171
172    /// Execute a single operation (instance method for backward compatibility)
173    async fn execute_operation<F, Fut>(
174        &self,
175        op: &BatchOperation,
176        executor_fn: F,
177    ) -> OperationResult
178    where
179        F: Fn(String, Value) -> Fut,
180        Fut: std::future::Future<Output = Result<Value, (i32, String)>>,
181    {
182        Self::execute_single_operation(op, executor_fn).await
183    }
184}
185
186impl Default for BatchExecutor {
187    fn default() -> Self {
188        Self
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use serde_json::Value;
196    use std::sync::Arc;
197
198    #[tokio::test]
199    async fn test_execute_empty_batch() {
200        let executor = BatchExecutor::new();
201        let request = BatchRequest {
202            operations: vec![],
203            mode: BatchMode::Parallel,
204            max_parallel: 10,
205        };
206
207        let result = executor
208            .execute(request, |_, _| async { Ok(Value::Null) })
209            .await
210            .unwrap();
211
212        assert_eq!(result.success_count, 0);
213        assert_eq!(result.failure_count, 0);
214    }
215
216    #[tokio::test]
217    async fn test_execute_sequential_batch() {
218        let executor = BatchExecutor::new();
219        let operations = vec![
220            BatchOperation {
221                id: "op1".to_string(),
222                tool: "tool1".to_string(),
223                arguments: Value::Null,
224                depends_on: vec![],
225            },
226            BatchOperation {
227                id: "op2".to_string(),
228                tool: "tool2".to_string(),
229                arguments: Value::Null,
230                depends_on: vec![],
231            },
232        ];
233
234        let request = BatchRequest {
235            operations,
236            mode: BatchMode::Sequential,
237            max_parallel: 10,
238        };
239
240        let call_count = Arc::new(std::sync::Mutex::new(0));
241        let call_count_clone = Arc::clone(&call_count);
242
243        let result = executor
244            .execute(request, move |_, _| {
245                let count = *call_count_clone.lock().unwrap();
246                *call_count_clone.lock().unwrap() = count + 1;
247                async move { Ok(Value::Null) }
248            })
249            .await
250            .unwrap();
251
252        assert_eq!(result.success_count, 2);
253        assert_eq!(*call_count.lock().unwrap(), 2);
254    }
255
256    #[tokio::test]
257    async fn test_execute_parallel_batch() {
258        let executor = BatchExecutor::new();
259        let operations = vec![
260            BatchOperation {
261                id: "op1".to_string(),
262                tool: "tool1".to_string(),
263                arguments: Value::Null,
264                depends_on: vec![],
265            },
266            BatchOperation {
267                id: "op2".to_string(),
268                tool: "tool2".to_string(),
269                arguments: Value::Null,
270                depends_on: vec![],
271            },
272        ];
273
274        let request = BatchRequest {
275            operations,
276            mode: BatchMode::Parallel,
277            max_parallel: 10,
278        };
279
280        let start = std::time::Instant::now();
281        let result = executor
282            .execute(request, |_, _| async { Ok(Value::Null) })
283            .await
284            .unwrap();
285        let duration = start.elapsed();
286
287        assert_eq!(result.success_count, 2);
288        // Should complete in roughly 0ms since parallel
289        assert!(duration.as_millis() < 100);
290    }
291
292    #[tokio::test]
293    async fn test_execute_with_dependency() {
294        let executor = BatchExecutor::new();
295        let operations = vec![
296            BatchOperation {
297                id: "op1".to_string(),
298                tool: "tool1".to_string(),
299                arguments: Value::Null,
300                depends_on: vec![],
301            },
302            BatchOperation {
303                id: "op2".to_string(),
304                tool: "tool2".to_string(),
305                arguments: Value::Null,
306                depends_on: vec!["op1".to_string()],
307            },
308        ];
309
310        let request = BatchRequest {
311            operations,
312            mode: BatchMode::Parallel,
313            max_parallel: 10,
314        };
315
316        let result = executor
317            .execute(request, |id, _| async move {
318                if id == "op1" {
319                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
320                }
321                Ok(Value::Null)
322            })
323            .await
324            .unwrap();
325
326        // Both operations should succeed
327        assert_eq!(result.success_count, 2);
328        // Verify total matches
329        assert_eq!(result.success_count + result.failure_count, 2);
330    }
331
332    #[tokio::test]
333    async fn test_fail_fast_mode() {
334        let executor = BatchExecutor::new();
335        let operations = vec![
336            BatchOperation {
337                id: "op1".to_string(),
338                tool: "tool1".to_string(),
339                arguments: Value::Null,
340                depends_on: vec![],
341            },
342            BatchOperation {
343                id: "op2".to_string(),
344                tool: "tool2".to_string(),
345                arguments: Value::Null,
346                depends_on: vec![],
347            },
348        ];
349
350        let request = BatchRequest {
351            operations,
352            mode: BatchMode::FailFast,
353            max_parallel: 10,
354        };
355
356        let result = executor
357            .execute(request, |id, _| async move {
358                Err((-32600, format!("Operation {} failed", id)))
359            })
360            .await
361            .unwrap();
362
363        // Only first operation should be executed in fail_fast mode
364        assert_eq!(result.results.len(), 1);
365        assert!(!result.results[0].success);
366    }
367}