1use std::collections::HashSet;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use grafeo_common::types::EpochId;
52use grafeo_common::utils::hash::FxHashMap;
53use parking_lot::{Mutex, RwLock};
54use rayon::prelude::*;
55
56use super::EntityId;
57
58const MAX_REEXECUTION_ROUNDS: usize = 10;
60
61const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
63
64const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum ExecutionStatus {
70 Success,
72 NeedsRevalidation,
74 Reexecuted,
76 Failed,
78}
79
80#[derive(Debug)]
82pub struct ExecutionResult {
83 pub batch_index: usize,
85 pub status: ExecutionStatus,
87 pub read_set: HashSet<(EntityId, EpochId)>,
89 pub write_set: HashSet<EntityId>,
91 pub dependencies: Vec<usize>,
93 pub reexecution_count: usize,
95 pub error: Option<String>,
97}
98
99impl ExecutionResult {
100 fn new(batch_index: usize) -> Self {
102 Self {
103 batch_index,
104 status: ExecutionStatus::Success,
105 read_set: HashSet::new(),
106 write_set: HashSet::new(),
107 dependencies: Vec::new(),
108 reexecution_count: 0,
109 error: None,
110 }
111 }
112
113 pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
115 self.read_set.insert((entity, epoch));
116 }
117
118 pub fn record_write(&mut self, entity: EntityId) {
120 self.write_set.insert(entity);
121 }
122
123 pub fn mark_needs_revalidation(&mut self) {
125 self.status = ExecutionStatus::NeedsRevalidation;
126 }
127
128 pub fn mark_reexecuted(&mut self) {
130 self.status = ExecutionStatus::Reexecuted;
131 self.reexecution_count += 1;
132 }
133
134 pub fn mark_failed(&mut self, error: String) {
136 self.status = ExecutionStatus::Failed;
137 self.error = Some(error);
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct BatchRequest {
144 pub operations: Vec<String>,
146}
147
148impl BatchRequest {
149 pub fn new(operations: Vec<impl Into<String>>) -> Self {
151 Self {
152 operations: operations.into_iter().map(Into::into).collect(),
153 }
154 }
155
156 #[must_use]
158 pub fn len(&self) -> usize {
159 self.operations.len()
160 }
161
162 #[must_use]
164 pub fn is_empty(&self) -> bool {
165 self.operations.is_empty()
166 }
167}
168
169#[derive(Debug)]
171pub struct BatchResult {
172 pub results: Vec<ExecutionResult>,
174 pub success_count: usize,
176 pub failure_count: usize,
178 pub reexecution_count: usize,
180 pub parallel_executed: bool,
182}
183
184impl BatchResult {
185 #[must_use]
187 pub fn all_succeeded(&self) -> bool {
188 self.failure_count == 0
189 }
190
191 pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
193 self.results
194 .iter()
195 .filter(|r| r.status == ExecutionStatus::Failed)
196 .map(|r| r.batch_index)
197 }
198}
199
200#[derive(Debug, Default)]
202struct WriteTracker {
203 writes: RwLock<FxHashMap<EntityId, usize>>,
205}
206
207impl WriteTracker {
208 fn record_write(&self, entity: EntityId, batch_index: usize) {
211 let mut writes = self.writes.write();
212 writes
213 .entry(entity)
214 .and_modify(|existing| *existing = (*existing).min(batch_index))
215 .or_insert(batch_index);
216 }
217
218 fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
220 let writes = self.writes.read();
221 if let Some(&writer) = writes.get(entity)
222 && writer < batch_index
223 {
224 return Some(writer);
225 }
226 None
227 }
228}
229
230pub struct ParallelExecutor {
234 num_workers: usize,
236 pool: rayon::ThreadPool,
238}
239
240impl ParallelExecutor {
241 pub fn new(num_workers: usize) -> Self {
247 assert!(num_workers > 0, "num_workers must be positive");
248
249 let pool = rayon::ThreadPoolBuilder::new()
250 .num_threads(num_workers)
251 .build()
252 .expect("Failed to build thread pool");
253
254 Self { num_workers, pool }
255 }
256
257 #[must_use]
259 pub fn default_workers() -> Self {
260 Self::new(rayon::current_num_threads().max(1))
262 }
263
264 #[must_use]
266 pub fn num_workers(&self) -> usize {
267 self.num_workers
268 }
269
270 pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
275 where
276 F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
277 {
278 let n = batch.len();
279
280 if n == 0 {
282 return BatchResult {
283 results: Vec::new(),
284 success_count: 0,
285 failure_count: 0,
286 reexecution_count: 0,
287 parallel_executed: false,
288 };
289 }
290
291 if n < MIN_BATCH_SIZE_FOR_PARALLEL {
292 return self.execute_sequential(batch, execute_fn);
293 }
294
295 let write_tracker = Arc::new(WriteTracker::default());
297 let results: Vec<Mutex<ExecutionResult>> = (0..n)
298 .map(|i| Mutex::new(ExecutionResult::new(i)))
299 .collect();
300
301 self.pool.install(|| {
302 batch
303 .operations
304 .par_iter()
305 .enumerate()
306 .for_each(|(idx, op)| {
307 let mut result = results[idx].lock();
308 execute_fn(idx, op, &mut result);
309
310 for entity in &result.write_set {
312 write_tracker.record_write(*entity, idx);
313 }
314 });
315 });
316
317 let mut invalid_indices = Vec::new();
319
320 for (idx, result_mutex) in results.iter().enumerate() {
321 let mut result = result_mutex.lock();
322
323 let read_entities: Vec<EntityId> =
325 result.read_set.iter().map(|(entity, _)| *entity).collect();
326
327 for entity in read_entities {
329 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
330 result.mark_needs_revalidation();
331 result.dependencies.push(writer);
332 }
333 }
334
335 if result.status == ExecutionStatus::NeedsRevalidation {
336 invalid_indices.push(idx);
337 }
338 }
339
340 let conflict_rate = invalid_indices.len() as f64 / n as f64;
342 if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
343 return self.execute_sequential(batch, execute_fn);
345 }
346
347 let total_reexecutions = AtomicUsize::new(0);
349
350 for round in 0..MAX_REEXECUTION_ROUNDS {
351 if invalid_indices.is_empty() {
352 break;
353 }
354
355 let still_invalid: Vec<usize> = self.pool.install(|| {
357 invalid_indices
358 .par_iter()
359 .filter_map(|&idx| {
360 let mut result = results[idx].lock();
361
362 result.read_set.clear();
364 result.write_set.clear();
365 result.dependencies.clear();
366
367 execute_fn(idx, &batch.operations[idx], &mut result);
369 result.mark_reexecuted();
370 total_reexecutions.fetch_add(1, Ordering::Relaxed);
371
372 let read_entities: Vec<EntityId> =
374 result.read_set.iter().map(|(entity, _)| *entity).collect();
375
376 for entity in read_entities {
378 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
379 {
380 result.mark_needs_revalidation();
381 result.dependencies.push(writer);
382 return Some(idx);
383 }
384 }
385
386 result.status = ExecutionStatus::Success;
387 None
388 })
389 .collect()
390 });
391
392 invalid_indices = still_invalid;
393
394 if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
395 for idx in &invalid_indices {
397 let mut result = results[*idx].lock();
398 result.mark_failed("Max re-execution rounds reached".to_string());
399 }
400 }
401 }
402
403 let mut final_results: Vec<ExecutionResult> =
405 results.into_iter().map(|m| m.into_inner()).collect();
406
407 final_results.sort_by_key(|r| r.batch_index);
409
410 let success_count = final_results
411 .iter()
412 .filter(|r| r.status != ExecutionStatus::Failed)
413 .count();
414
415 BatchResult {
416 failure_count: n - success_count,
417 success_count,
418 reexecution_count: total_reexecutions.load(Ordering::Relaxed),
419 parallel_executed: true,
420 results: final_results,
421 }
422 }
423
424 fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
426 where
427 F: Fn(usize, &str, &mut ExecutionResult),
428 {
429 let mut results = Vec::with_capacity(batch.len());
430
431 for (idx, op) in batch.operations.iter().enumerate() {
432 let mut result = ExecutionResult::new(idx);
433 execute_fn(idx, op, &mut result);
434 results.push(result);
435 }
436
437 let success_count = results
438 .iter()
439 .filter(|r| r.status != ExecutionStatus::Failed)
440 .count();
441
442 BatchResult {
443 failure_count: results.len() - success_count,
444 success_count,
445 reexecution_count: 0,
446 parallel_executed: false,
447 results,
448 }
449 }
450}
451
452impl Default for ParallelExecutor {
453 fn default() -> Self {
454 Self::default_workers()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use grafeo_common::types::NodeId;
462 use std::sync::atomic::AtomicU64;
463 use std::thread;
464 use std::time::Duration;
465
466 #[test]
467 fn test_empty_batch() {
468 let executor = ParallelExecutor::new(4);
469 let batch = BatchRequest::new(Vec::<String>::new());
470
471 let result = executor.execute_batch(batch, |_, _, _| {});
472
473 assert!(result.all_succeeded());
474 assert_eq!(result.results.len(), 0);
475 }
476
477 #[test]
478 fn test_single_operation() {
479 let executor = ParallelExecutor::new(4);
480 let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
481
482 let result = executor.execute_batch(batch, |_, _, result| {
483 result.record_write(EntityId::Node(NodeId::new(1)));
484 });
485
486 assert!(result.all_succeeded());
487 assert_eq!(result.results.len(), 1);
488 assert!(!result.parallel_executed);
490 }
491
492 #[test]
493 fn test_independent_operations() {
494 let executor = ParallelExecutor::new(4);
495 let batch = BatchRequest::new(vec![
496 "CREATE (n1:Test {id: 1})",
497 "CREATE (n2:Test {id: 2})",
498 "CREATE (n3:Test {id: 3})",
499 "CREATE (n4:Test {id: 4})",
500 "CREATE (n5:Test {id: 5})",
501 ]);
502
503 let counter = AtomicU64::new(0);
504
505 let result = executor.execute_batch(batch, |idx, _, result| {
506 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
508 counter.fetch_add(1, Ordering::Relaxed);
509 });
510
511 assert!(result.all_succeeded());
512 assert_eq!(result.results.len(), 5);
513 assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
515 assert_eq!(counter.load(Ordering::Relaxed), 5);
516 }
517
518 #[test]
519 fn test_conflicting_operations() {
520 let executor = ParallelExecutor::new(4);
521 let batch = BatchRequest::new(vec![
522 "UPDATE (n:Test) SET n.value = 1",
523 "UPDATE (n:Test) SET n.value = 2",
524 "UPDATE (n:Test) SET n.value = 3",
525 "UPDATE (n:Test) SET n.value = 4",
526 "UPDATE (n:Test) SET n.value = 5",
527 ]);
528
529 let shared_entity = EntityId::Node(NodeId::new(100));
530
531 let result = executor.execute_batch(batch, |_idx, _, result| {
532 result.record_read(shared_entity, EpochId::new(0));
534 result.record_write(shared_entity);
535
536 thread::sleep(Duration::from_micros(10));
538 });
539
540 assert!(result.all_succeeded());
541 assert_eq!(result.results.len(), 5);
542 assert!(result.reexecution_count > 0 || !result.parallel_executed);
544 }
545
546 #[test]
547 fn test_partial_conflicts() {
548 let executor = ParallelExecutor::new(4);
549 let batch = BatchRequest::new(vec![
550 "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
551 ]);
552
553 let result = executor.execute_batch(batch, |idx, _, result| {
557 let entity = EntityId::Node(NodeId::new(idx as u64));
559 result.record_write(entity);
560 });
561
562 assert!(result.all_succeeded());
563 assert_eq!(result.results.len(), 10);
564 assert!(result.parallel_executed);
566 assert_eq!(result.reexecution_count, 0);
567 }
568
569 #[test]
570 fn test_execution_order_preserved() {
571 let executor = ParallelExecutor::new(4);
572 let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
573
574 let result = executor.execute_batch(batch, |idx, _, result| {
575 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
576 });
577
578 for (i, r) in result.results.iter().enumerate() {
580 assert_eq!(
581 r.batch_index, i,
582 "Result at position {} has wrong batch_index",
583 i
584 );
585 }
586 }
587
588 #[test]
589 fn test_failure_handling() {
590 let executor = ParallelExecutor::new(4);
591 let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
592
593 let result = executor.execute_batch(batch, |idx, op, result| {
594 if op == "fail" {
595 result.mark_failed("Intentional failure".to_string());
596 } else {
597 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
598 }
599 });
600
601 assert!(!result.all_succeeded());
602 assert_eq!(result.failure_count, 1);
603 assert_eq!(result.success_count, 4);
604
605 let failed: Vec<usize> = result.failed_indices().collect();
606 assert_eq!(failed, vec![1]);
607 }
608
609 #[test]
610 fn test_write_tracker() {
611 let tracker = WriteTracker::default();
612
613 tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
614 tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
615 tracker.record_write(EntityId::Node(NodeId::new(1)), 2); assert_eq!(
619 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
620 Some(0)
621 );
622
623 assert_eq!(
625 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
626 Some(1)
627 );
628
629 assert_eq!(
631 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
632 None
633 );
634 }
635
636 #[test]
637 fn test_batch_request() {
638 let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
639 assert_eq!(batch.len(), 3);
640 assert!(!batch.is_empty());
641
642 let empty_batch = BatchRequest::new(Vec::<String>::new());
643 assert!(empty_batch.is_empty());
644 }
645
646 #[test]
647 fn test_execution_result() {
648 let mut result = ExecutionResult::new(5);
649
650 assert_eq!(result.batch_index, 5);
651 assert_eq!(result.status, ExecutionStatus::Success);
652 assert!(result.read_set.is_empty());
653 assert!(result.write_set.is_empty());
654
655 result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
656 result.record_write(EntityId::Node(NodeId::new(2)));
657
658 assert_eq!(result.read_set.len(), 1);
659 assert_eq!(result.write_set.len(), 1);
660
661 result.mark_needs_revalidation();
662 assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
663
664 result.mark_reexecuted();
665 assert_eq!(result.status, ExecutionStatus::Reexecuted);
666 assert_eq!(result.reexecution_count, 1);
667 }
668}