1use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69const CLUSTER_SKIP_THRESHOLD: f64 = 0.8;
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76#[non_exhaustive]
77pub enum ExecutionStatus {
78 Success,
80 NeedsRevalidation,
82 Reexecuted,
84 Failed,
86}
87
88#[derive(Debug)]
90pub struct ExecutionResult {
91 pub batch_index: usize,
93 pub status: ExecutionStatus,
95 pub read_set: HashSet<(EntityId, EpochId)>,
97 pub write_set: HashSet<EntityId>,
99 pub dependencies: Vec<usize>,
101 pub reexecution_count: usize,
103 pub error: Option<String>,
105}
106
107impl ExecutionResult {
108 fn new(batch_index: usize) -> Self {
110 Self {
111 batch_index,
112 status: ExecutionStatus::Success,
113 read_set: HashSet::new(),
114 write_set: HashSet::new(),
115 dependencies: Vec::new(),
116 reexecution_count: 0,
117 error: None,
118 }
119 }
120
121 pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
123 self.read_set.insert((entity, epoch));
124 }
125
126 pub fn record_write(&mut self, entity: EntityId) {
128 self.write_set.insert(entity);
129 }
130
131 pub fn mark_needs_revalidation(&mut self) {
133 self.status = ExecutionStatus::NeedsRevalidation;
134 }
135
136 pub fn mark_reexecuted(&mut self) {
138 self.status = ExecutionStatus::Reexecuted;
139 self.reexecution_count += 1;
140 }
141
142 pub fn mark_failed(&mut self, error: String) {
144 self.status = ExecutionStatus::Failed;
145 self.error = Some(error);
146 }
147}
148
149#[derive(Debug, Clone)]
151pub struct BatchRequest {
152 pub operations: Vec<String>,
154}
155
156impl BatchRequest {
157 pub fn new(operations: Vec<impl Into<String>>) -> Self {
159 Self {
160 operations: operations.into_iter().map(Into::into).collect(),
161 }
162 }
163
164 #[must_use]
166 pub fn len(&self) -> usize {
167 self.operations.len()
168 }
169
170 #[must_use]
172 pub fn is_empty(&self) -> bool {
173 self.operations.is_empty()
174 }
175}
176
177#[derive(Debug)]
179pub struct BatchResult {
180 pub results: Vec<ExecutionResult>,
182 pub success_count: usize,
184 pub failure_count: usize,
186 pub reexecution_count: usize,
188 pub parallel_executed: bool,
190 pub conflict_cluster_count: usize,
192 pub largest_cluster_size: usize,
194}
195
196impl BatchResult {
197 #[must_use]
199 pub fn all_succeeded(&self) -> bool {
200 self.failure_count == 0
201 }
202
203 pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
205 self.results
206 .iter()
207 .filter(|r| r.status == ExecutionStatus::Failed)
208 .map(|r| r.batch_index)
209 }
210}
211
212#[derive(Debug, Default)]
214struct WriteTracker {
215 writes: RwLock<FxHashMap<EntityId, usize>>,
217}
218
219impl WriteTracker {
220 fn record_write(&self, entity: EntityId, batch_index: usize) {
223 let mut writes = self.writes.write();
224 writes
225 .entry(entity)
226 .and_modify(|existing| *existing = (*existing).min(batch_index))
227 .or_insert(batch_index);
228 }
229
230 fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
232 let writes = self.writes.read();
233 if let Some(&writer) = writes.get(entity)
234 && writer < batch_index
235 {
236 return Some(writer);
237 }
238 None
239 }
240}
241
242struct ConflictPartitioner {
248 parent: Vec<usize>,
249 rank: Vec<usize>,
250}
251
252impl ConflictPartitioner {
253 fn new(n: usize) -> Self {
255 Self {
256 parent: (0..n).collect(),
257 rank: vec![0; n],
258 }
259 }
260
261 fn find(&mut self, x: usize) -> usize {
263 if self.parent[x] != x {
264 self.parent[x] = self.find(self.parent[x]);
265 }
266 self.parent[x]
267 }
268
269 fn union(&mut self, a: usize, b: usize) {
271 let ra = self.find(a);
272 let rb = self.find(b);
273 if ra == rb {
274 return;
275 }
276 match self.rank[ra].cmp(&self.rank[rb]) {
277 std::cmp::Ordering::Less => self.parent[ra] = rb,
278 std::cmp::Ordering::Greater => self.parent[rb] = ra,
279 std::cmp::Ordering::Equal => {
280 self.parent[rb] = ra;
281 self.rank[ra] += 1;
282 }
283 }
284 }
285
286 fn partition(
294 read_sets: &[HashSet<(EntityId, EpochId)>],
295 write_sets: &[HashSet<EntityId>],
296 invalid_indices: &[usize],
297 ) -> (Vec<Vec<usize>>, usize) {
298 if invalid_indices.is_empty() {
299 return (Vec::new(), 0);
300 }
301
302 let index_to_compact: FxHashMap<usize, usize> = invalid_indices
304 .iter()
305 .enumerate()
306 .map(|(compact, &orig)| (orig, compact))
307 .collect();
308
309 let n = invalid_indices.len();
310 let mut uf = ConflictPartitioner::new(n);
311
312 let mut entity_writers: FxHashMap<EntityId, Vec<usize>> = FxHashMap::default();
314
315 for &orig_idx in invalid_indices {
316 let compact = index_to_compact[&orig_idx];
317 for entity in &write_sets[orig_idx] {
318 entity_writers.entry(*entity).or_default().push(compact);
319 }
320 }
321
322 for &orig_idx in invalid_indices {
324 let compact = index_to_compact[&orig_idx];
325
326 for (entity, _epoch) in &read_sets[orig_idx] {
328 if let Some(writers) = entity_writers.get(entity) {
329 for &writer_compact in writers {
330 if writer_compact != compact {
331 uf.union(compact, writer_compact);
332 }
333 }
334 }
335 }
336
337 for entity in &write_sets[orig_idx] {
339 if let Some(writers) = entity_writers.get(entity) {
340 for &writer_compact in writers {
341 if writer_compact != compact {
342 uf.union(compact, writer_compact);
343 }
344 }
345 }
346 }
347 }
348
349 let mut cluster_map: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
351 for (compact, &orig_idx) in invalid_indices.iter().enumerate() {
352 let root = uf.find(compact);
353 cluster_map.entry(root).or_default().push(orig_idx);
354 }
355
356 let mut clusters: Vec<Vec<usize>> = cluster_map.into_values().collect();
357
358 for cluster in &mut clusters {
360 cluster.sort_unstable();
361 }
362
363 let largest = clusters.iter().map(Vec::len).max().unwrap_or(0);
364 (clusters, largest)
365 }
366}
367
368pub struct ParallelExecutor {
372 num_workers: usize,
374 pool: rayon::ThreadPool,
376}
377
378impl ParallelExecutor {
379 #[must_use]
385 pub fn new(num_workers: usize) -> Self {
386 assert!(num_workers > 0, "num_workers must be positive");
387
388 let pool = rayon::ThreadPoolBuilder::new()
389 .num_threads(num_workers)
390 .build()
391 .expect("failed to build thread pool");
392
393 Self { num_workers, pool }
394 }
395
396 #[must_use]
398 pub fn default_workers() -> Self {
399 Self::new(rayon::current_num_threads().max(1))
401 }
402
403 #[must_use]
405 pub fn num_workers(&self) -> usize {
406 self.num_workers
407 }
408
409 pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
414 where
415 F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
416 {
417 let n = batch.len();
418
419 if n == 0 {
421 return BatchResult {
422 results: Vec::new(),
423 success_count: 0,
424 failure_count: 0,
425 reexecution_count: 0,
426 parallel_executed: false,
427 conflict_cluster_count: 0,
428 largest_cluster_size: 0,
429 };
430 }
431
432 if n < MIN_BATCH_SIZE_FOR_PARALLEL {
433 return self.execute_sequential(batch, execute_fn);
434 }
435
436 let write_tracker = Arc::new(WriteTracker::default());
438 let results: Vec<Mutex<ExecutionResult>> = (0..n)
439 .map(|i| Mutex::new(ExecutionResult::new(i)))
440 .collect();
441
442 self.pool.install(|| {
443 batch
444 .operations
445 .par_iter()
446 .enumerate()
447 .for_each(|(idx, op)| {
448 let mut result = results[idx].lock();
449 execute_fn(idx, op, &mut result);
450
451 for entity in &result.write_set {
453 write_tracker.record_write(*entity, idx);
454 }
455 });
456 });
457
458 let mut invalid_indices = Vec::new();
460
461 for (idx, result_mutex) in results.iter().enumerate() {
462 let mut result = result_mutex.lock();
463
464 let read_entities: Vec<EntityId> =
466 result.read_set.iter().map(|(entity, _)| *entity).collect();
467
468 for entity in read_entities {
470 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
471 result.mark_needs_revalidation();
472 result.dependencies.push(writer);
473 }
474 }
475
476 if result.status == ExecutionStatus::NeedsRevalidation {
477 invalid_indices.push(idx);
478 }
479 }
480
481 let conflict_rate = invalid_indices.len() as f64 / n as f64;
483 if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
484 return self.execute_sequential(batch, execute_fn);
486 }
487
488 let total_reexecutions = AtomicUsize::new(0);
494
495 let all_read_sets: Vec<HashSet<(EntityId, EpochId)>> =
497 results.iter().map(|r| r.lock().read_set.clone()).collect();
498 let all_write_sets: Vec<HashSet<EntityId>> =
499 results.iter().map(|r| r.lock().write_set.clone()).collect();
500
501 let (clusters, largest_cluster) =
502 ConflictPartitioner::partition(&all_read_sets, &all_write_sets, &invalid_indices);
503
504 let use_clusters = !clusters.is_empty()
506 && (largest_cluster as f64 / invalid_indices.len().max(1) as f64)
507 <= CLUSTER_SKIP_THRESHOLD;
508
509 if use_clusters {
510 self.pool.install(|| {
513 clusters.par_iter().for_each(|cluster| {
514 for &idx in cluster {
515 let mut result = results[idx].lock();
516
517 result.read_set.clear();
519 result.write_set.clear();
520 result.dependencies.clear();
521
522 execute_fn(idx, &batch.operations[idx], &mut result);
524 result.mark_reexecuted();
525 total_reexecutions.fetch_add(1, Ordering::Relaxed);
526
527 for entity in &result.write_set {
529 write_tracker.record_write(*entity, idx);
530 }
531
532 result.status = ExecutionStatus::Success;
533 }
534 });
535 });
536 } else {
537 for round in 0..MAX_REEXECUTION_ROUNDS {
539 if invalid_indices.is_empty() {
540 break;
541 }
542
543 let still_invalid: Vec<usize> = self.pool.install(|| {
544 invalid_indices
545 .par_iter()
546 .filter_map(|&idx| {
547 let mut result = results[idx].lock();
548
549 result.read_set.clear();
550 result.write_set.clear();
551 result.dependencies.clear();
552
553 execute_fn(idx, &batch.operations[idx], &mut result);
554 result.mark_reexecuted();
555 total_reexecutions.fetch_add(1, Ordering::Relaxed);
556
557 let read_entities: Vec<EntityId> =
558 result.read_set.iter().map(|(entity, _)| *entity).collect();
559
560 for entity in read_entities {
561 if let Some(writer) =
562 write_tracker.was_written_by_earlier(&entity, idx)
563 {
564 result.mark_needs_revalidation();
565 result.dependencies.push(writer);
566 return Some(idx);
567 }
568 }
569
570 result.status = ExecutionStatus::Success;
571 None
572 })
573 .collect()
574 });
575
576 invalid_indices = still_invalid;
577
578 if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
579 for idx in &invalid_indices {
580 let mut result = results[*idx].lock();
581 result.mark_failed("Max re-execution rounds reached".to_string());
582 }
583 }
584 }
585 }
586
587 let mut final_results: Vec<ExecutionResult> =
589 results.into_iter().map(|m| m.into_inner()).collect();
590
591 final_results.sort_by_key(|r| r.batch_index);
593
594 let success_count = final_results
595 .iter()
596 .filter(|r| r.status != ExecutionStatus::Failed)
597 .count();
598
599 BatchResult {
600 failure_count: n - success_count,
601 success_count,
602 reexecution_count: total_reexecutions.load(Ordering::Relaxed),
603 parallel_executed: true,
604 conflict_cluster_count: clusters.len(),
605 largest_cluster_size: largest_cluster,
606 results: final_results,
607 }
608 }
609
610 fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
612 where
613 F: Fn(usize, &str, &mut ExecutionResult),
614 {
615 let mut results = Vec::with_capacity(batch.len());
616
617 for (idx, op) in batch.operations.iter().enumerate() {
618 let mut result = ExecutionResult::new(idx);
619 execute_fn(idx, op, &mut result);
620 results.push(result);
621 }
622
623 let success_count = results
624 .iter()
625 .filter(|r| r.status != ExecutionStatus::Failed)
626 .count();
627
628 BatchResult {
629 failure_count: results.len() - success_count,
630 success_count,
631 reexecution_count: 0,
632 parallel_executed: false,
633 conflict_cluster_count: 0,
634 largest_cluster_size: 0,
635 results,
636 }
637 }
638}
639
640impl Default for ParallelExecutor {
641 fn default() -> Self {
642 Self::default_workers()
643 }
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use grafeo_common::types::NodeId;
650 use std::sync::atomic::AtomicU64;
651 use std::thread;
652 use std::time::Duration;
653
654 #[test]
655 fn test_empty_batch() {
656 let executor = ParallelExecutor::new(4);
657 let batch = BatchRequest::new(Vec::<String>::new());
658
659 let result = executor.execute_batch(batch, |_, _, _| {});
660
661 assert!(result.all_succeeded());
662 assert_eq!(result.results.len(), 0);
663 }
664
665 #[test]
666 fn test_single_operation() {
667 let executor = ParallelExecutor::new(4);
668 let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
669
670 let result = executor.execute_batch(batch, |_, _, result| {
671 result.record_write(EntityId::Node(NodeId::new(1)));
672 });
673
674 assert!(result.all_succeeded());
675 assert_eq!(result.results.len(), 1);
676 assert!(!result.parallel_executed);
678 }
679
680 #[test]
681 fn test_independent_operations() {
682 let executor = ParallelExecutor::new(4);
683 let batch = BatchRequest::new(vec![
684 "CREATE (n1:Test {id: 1})",
685 "CREATE (n2:Test {id: 2})",
686 "CREATE (n3:Test {id: 3})",
687 "CREATE (n4:Test {id: 4})",
688 "CREATE (n5:Test {id: 5})",
689 ]);
690
691 let counter = AtomicU64::new(0);
692
693 let result = executor.execute_batch(batch, |idx, _, result| {
694 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
696 counter.fetch_add(1, Ordering::Relaxed);
697 });
698
699 assert!(result.all_succeeded());
700 assert_eq!(result.results.len(), 5);
701 assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
703 assert_eq!(counter.load(Ordering::Relaxed), 5);
704 }
705
706 #[test]
707 fn test_conflicting_operations() {
708 let executor = ParallelExecutor::new(4);
709 let batch = BatchRequest::new(vec![
710 "UPDATE (n:Test) SET n.value = 1",
711 "UPDATE (n:Test) SET n.value = 2",
712 "UPDATE (n:Test) SET n.value = 3",
713 "UPDATE (n:Test) SET n.value = 4",
714 "UPDATE (n:Test) SET n.value = 5",
715 ]);
716
717 let shared_entity = EntityId::Node(NodeId::new(100));
718
719 let result = executor.execute_batch(batch, |_idx, _, result| {
720 result.record_read(shared_entity, EpochId::new(0));
722 result.record_write(shared_entity);
723
724 thread::sleep(Duration::from_micros(10));
726 });
727
728 assert!(result.all_succeeded());
729 assert_eq!(result.results.len(), 5);
730 assert!(result.reexecution_count > 0 || !result.parallel_executed);
732 }
733
734 #[test]
735 fn test_partial_conflicts() {
736 let executor = ParallelExecutor::new(4);
737 let batch = BatchRequest::new(vec![
738 "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
739 ]);
740
741 let result = executor.execute_batch(batch, |idx, _, result| {
745 let entity = EntityId::Node(NodeId::new(idx as u64));
747 result.record_write(entity);
748 });
749
750 assert!(result.all_succeeded());
751 assert_eq!(result.results.len(), 10);
752 assert!(result.parallel_executed);
754 assert_eq!(result.reexecution_count, 0);
755 }
756
757 #[test]
758 fn test_execution_order_preserved() {
759 let executor = ParallelExecutor::new(4);
760 let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
761
762 let result = executor.execute_batch(batch, |idx, _, result| {
763 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
764 });
765
766 for (i, r) in result.results.iter().enumerate() {
768 assert_eq!(
769 r.batch_index, i,
770 "Result at position {} has wrong batch_index",
771 i
772 );
773 }
774 }
775
776 #[test]
777 fn test_failure_handling() {
778 let executor = ParallelExecutor::new(4);
779 let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
780
781 let result = executor.execute_batch(batch, |idx, op, result| {
782 if op == "fail" {
783 result.mark_failed("Intentional failure".to_string());
784 } else {
785 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
786 }
787 });
788
789 assert!(!result.all_succeeded());
790 assert_eq!(result.failure_count, 1);
791 assert_eq!(result.success_count, 4);
792
793 let failed: Vec<usize> = result.failed_indices().collect();
794 assert_eq!(failed, vec![1]);
795 }
796
797 #[test]
798 fn test_write_tracker() {
799 let tracker = WriteTracker::default();
800
801 tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
802 tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
803 tracker.record_write(EntityId::Node(NodeId::new(1)), 2); assert_eq!(
807 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
808 Some(0)
809 );
810
811 assert_eq!(
813 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
814 Some(1)
815 );
816
817 assert_eq!(
819 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
820 None
821 );
822 }
823
824 #[test]
825 fn test_batch_request() {
826 let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
827 assert_eq!(batch.len(), 3);
828 assert!(!batch.is_empty());
829
830 let empty_batch = BatchRequest::new(Vec::<String>::new());
831 assert!(empty_batch.is_empty());
832 }
833
834 #[test]
835 fn test_execution_result() {
836 let mut result = ExecutionResult::new(5);
837
838 assert_eq!(result.batch_index, 5);
839 assert_eq!(result.status, ExecutionStatus::Success);
840 assert!(result.read_set.is_empty());
841 assert!(result.write_set.is_empty());
842
843 result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
844 result.record_write(EntityId::Node(NodeId::new(2)));
845
846 assert_eq!(result.read_set.len(), 1);
847 assert_eq!(result.write_set.len(), 1);
848
849 result.mark_needs_revalidation();
850 assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
851
852 result.mark_reexecuted();
853 assert_eq!(result.status, ExecutionStatus::Reexecuted);
854 assert_eq!(result.reexecution_count, 1);
855 }
856
857 #[test]
860 fn test_partitioner_empty() {
861 let (clusters, largest) = ConflictPartitioner::partition(&[], &[], &[]);
862 assert!(clusters.is_empty());
863 assert_eq!(largest, 0);
864 }
865
866 #[test]
867 fn test_partitioner_disjoint_clusters() {
868 let entity_a = EntityId::Node(NodeId::new(100));
870 let entity_b = EntityId::Node(NodeId::new(200));
871
872 let read_sets = vec![
873 HashSet::from([(entity_a, EpochId::new(0))]),
874 HashSet::new(),
875 HashSet::from([(entity_b, EpochId::new(0))]),
876 HashSet::new(),
877 ];
878 let write_sets = vec![
879 HashSet::from([entity_a]),
880 HashSet::from([entity_a]),
881 HashSet::from([entity_b]),
882 HashSet::from([entity_b]),
883 ];
884
885 let invalid = vec![0, 1, 2, 3];
886 let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
887
888 assert_eq!(clusters.len(), 2, "should produce 2 disjoint clusters");
889 assert_eq!(largest, 2, "each cluster has 2 transactions");
890
891 let all: HashSet<usize> = clusters.iter().flat_map(|c| c.iter().copied()).collect();
893 assert_eq!(all, HashSet::from([0, 1, 2, 3]));
894 }
895
896 #[test]
897 fn test_partitioner_single_cluster() {
898 let entity_a = EntityId::Node(NodeId::new(42));
900
901 let read_sets = vec![
902 HashSet::from([(entity_a, EpochId::new(0))]),
903 HashSet::from([(entity_a, EpochId::new(0))]),
904 HashSet::from([(entity_a, EpochId::new(0))]),
905 ];
906 let write_sets = vec![
907 HashSet::from([entity_a]),
908 HashSet::from([entity_a]),
909 HashSet::from([entity_a]),
910 ];
911
912 let invalid = vec![0, 1, 2];
913 let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
914
915 assert_eq!(clusters.len(), 1, "all share the same entity");
916 assert_eq!(largest, 3);
917 assert_eq!(clusters[0], vec![0, 1, 2]);
918 }
919
920 #[test]
921 fn test_partitioner_chain_merges() {
922 let entity_a = EntityId::Node(NodeId::new(10));
925 let entity_b = EntityId::Node(NodeId::new(20));
926
927 let read_sets = vec![HashSet::new(), HashSet::new(), HashSet::new()];
928 let write_sets = vec![
929 HashSet::from([entity_a]),
930 HashSet::from([entity_a, entity_b]),
931 HashSet::from([entity_b]),
932 ];
933
934 let invalid = vec![0, 1, 2];
935 let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
936
937 assert_eq!(clusters.len(), 1, "chain should merge into one cluster");
938 assert_eq!(largest, 3);
939 }
940
941 #[test]
942 fn test_partitioner_read_write_conflict() {
943 let entity_a = EntityId::Node(NodeId::new(50));
945
946 let read_sets = vec![HashSet::new(), HashSet::from([(entity_a, EpochId::new(0))])];
947 let write_sets = vec![HashSet::from([entity_a]), HashSet::new()];
948
949 let invalid = vec![0, 1];
950 let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
951
952 assert_eq!(clusters.len(), 1, "read-write overlap merges clusters");
953 assert_eq!(largest, 2);
954 }
955
956 #[test]
957 fn test_partitioner_subset_of_transactions() {
958 let entity_a = EntityId::Node(NodeId::new(1));
961 let entity_b = EntityId::Node(NodeId::new(2));
962
963 let read_sets = vec![
964 HashSet::new(),
965 HashSet::new(),
966 HashSet::from([(entity_a, EpochId::new(0))]),
967 HashSet::new(),
968 HashSet::new(),
969 HashSet::from([(entity_b, EpochId::new(0))]),
970 ];
971 let write_sets = vec![
972 HashSet::new(),
973 HashSet::new(),
974 HashSet::from([entity_a]),
975 HashSet::new(),
976 HashSet::new(),
977 HashSet::from([entity_b]),
978 ];
979
980 let invalid = vec![2, 5];
981 let (clusters, _) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
982
983 assert_eq!(
984 clusters.len(),
985 2,
986 "non-overlapping invalid txns form separate clusters"
987 );
988 }
989
990 #[test]
991 fn test_cluster_based_reexecution() {
992 let executor = ParallelExecutor::new(4);
997 let batch = BatchRequest::new(vec![
998 "g1_op1", "g1_op2", "g2_op1", "g2_op2", "ind1", "ind2", "ind3", "ind4",
999 ]);
1000
1001 let entity_a = EntityId::Node(NodeId::new(100));
1002 let entity_b = EntityId::Node(NodeId::new(200));
1003
1004 let result = executor.execute_batch(batch, |idx, _, result| {
1005 match idx {
1006 0 | 1 => {
1007 result.record_read(entity_a, EpochId::new(0));
1008 result.record_write(entity_a);
1009 }
1010 2 | 3 => {
1011 result.record_read(entity_b, EpochId::new(0));
1012 result.record_write(entity_b);
1013 }
1014 _ => {
1015 result.record_write(EntityId::Node(NodeId::new(idx as u64 + 1000)));
1017 }
1018 }
1019 });
1020
1021 assert!(result.all_succeeded());
1022 assert_eq!(result.results.len(), 8);
1023 assert!(result.parallel_executed);
1024 }
1027
1028 #[test]
1029 fn test_cluster_metrics_reported() {
1030 let executor = ParallelExecutor::new(4);
1031 let batch = BatchRequest::new(vec!["a", "b", "c", "d", "e", "f", "g", "h"]);
1032
1033 let result = executor.execute_batch(batch, |idx, _, result| {
1035 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
1036 });
1037
1038 assert_eq!(result.conflict_cluster_count, 0);
1039 assert_eq!(result.largest_cluster_size, 0);
1040 assert_eq!(result.reexecution_count, 0);
1041 }
1042
1043 #[test]
1044 fn test_union_find_correctness() {
1045 let mut uf = ConflictPartitioner::new(6);
1046
1047 uf.union(0, 1);
1049 uf.union(2, 3);
1050 uf.union(4, 5);
1051
1052 assert_eq!(uf.find(0), uf.find(1));
1053 assert_eq!(uf.find(2), uf.find(3));
1054 assert_eq!(uf.find(4), uf.find(5));
1055 assert_ne!(uf.find(0), uf.find(2));
1056 assert_ne!(uf.find(0), uf.find(4));
1057
1058 uf.union(1, 3);
1060 assert_eq!(uf.find(0), uf.find(2));
1061 assert_eq!(uf.find(0), uf.find(3));
1062 assert_ne!(uf.find(0), uf.find(4));
1063 }
1064
1065 #[test]
1066 fn test_cluster_reexecution_resolves_conflicts() {
1067 let executor = ParallelExecutor::new(4);
1071 let ops: Vec<String> = (0..8).map(|i| format!("op{i}")).collect();
1072 let batch = BatchRequest::new(ops);
1073
1074 let entity_a = EntityId::Node(NodeId::new(100));
1075 let entity_b = EntityId::Node(NodeId::new(200));
1076
1077 let result = executor.execute_batch(batch, |idx, _, result| match idx {
1078 0 | 1 => {
1079 result.record_read(entity_a, EpochId::new(0));
1080 result.record_write(entity_a);
1081 }
1082 2 | 3 => {
1083 result.record_read(entity_b, EpochId::new(0));
1084 result.record_write(entity_b);
1085 }
1086 _ => {
1087 result.record_write(EntityId::Node(NodeId::new(idx as u64 + 1000)));
1088 }
1089 });
1090
1091 assert!(result.all_succeeded(), "all operations should succeed");
1092 assert!(result.parallel_executed, "should use parallel execution");
1093 assert!(
1095 result.conflict_cluster_count > 0,
1096 "should detect conflict clusters for shared entities"
1097 );
1098 assert!(
1099 result.reexecution_count > 0,
1100 "should re-execute conflicting operations"
1101 );
1102 }
1103
1104 #[test]
1105 fn test_large_single_cluster_falls_back() {
1106 let executor = ParallelExecutor::new(4);
1109 let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
1110 let batch = BatchRequest::new(ops);
1111
1112 let shared = EntityId::Node(NodeId::new(999));
1113
1114 let result = executor.execute_batch(batch, |_idx, _, result| {
1115 result.record_read(shared, EpochId::new(0));
1116 result.record_write(shared);
1117 });
1118
1119 assert!(result.all_succeeded(), "all operations should succeed");
1120 if result.parallel_executed {
1123 assert!(
1124 result.largest_cluster_size >= 8,
1125 "largest cluster should exceed 80% threshold, got {}",
1126 result.largest_cluster_size
1127 );
1128 }
1129 }
1130
1131 #[test]
1132 fn test_sequential_fallback_high_conflict_rate() {
1133 let executor = ParallelExecutor::new(4);
1135 let ops: Vec<String> = (0..5).map(|i| format!("op{i}")).collect();
1136 let batch = BatchRequest::new(ops);
1137
1138 let shared = EntityId::Node(NodeId::new(42));
1139
1140 let result = executor.execute_batch(batch, |_idx, _, result| {
1141 result.record_read(shared, EpochId::new(0));
1142 result.record_write(shared);
1143 });
1144
1145 assert!(result.all_succeeded(), "all operations should succeed");
1148 if result.parallel_executed {
1149 assert!(
1150 result.reexecution_count > 0,
1151 "parallel path with 100% conflicts must trigger re-execution"
1152 );
1153 }
1154 }
1155
1156 #[test]
1157 fn test_cluster_skip_threshold_large_cluster() {
1158 let executor = ParallelExecutor::new(4);
1161 let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
1162 let batch = BatchRequest::new(ops);
1163
1164 let shared = EntityId::Node(NodeId::new(1));
1166 let result = executor.execute_batch(batch, |_idx, _, result| {
1167 result.record_read(shared, EpochId::new(0));
1168 result.record_write(shared);
1169 });
1170
1171 assert!(result.all_succeeded());
1172 assert!(
1175 result.largest_cluster_size >= result.conflict_cluster_count,
1176 "largest cluster should dominate"
1177 );
1178 }
1179
1180 #[test]
1181 fn test_multiple_disjoint_clusters() {
1182 let executor = ParallelExecutor::new(4);
1184 let ops: Vec<String> = (0..8).map(|i| format!("op{i}")).collect();
1185 let batch = BatchRequest::new(ops);
1186
1187 let entity_a = EntityId::Node(NodeId::new(100));
1188 let entity_b = EntityId::Node(NodeId::new(200));
1189
1190 let result = executor.execute_batch(batch, |idx, _, result| {
1191 let entity = if idx % 2 == 0 { entity_a } else { entity_b };
1193 result.record_read(entity, EpochId::new(0));
1194 result.record_write(entity);
1195 });
1196
1197 assert!(result.all_succeeded());
1198 if result.conflict_cluster_count > 1 {
1200 assert!(
1201 result.largest_cluster_size < 8,
1202 "with disjoint conflicts, no single cluster should contain all transactions"
1203 );
1204 }
1205 }
1206
1207 #[test]
1208 fn test_batch_result_metrics_fields() {
1209 let executor = ParallelExecutor::new(4);
1212 let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
1213 let batch = BatchRequest::new(ops);
1214
1215 let result = executor.execute_batch(batch, |idx, _, result| {
1216 let entity = EntityId::Node(NodeId::new(idx as u64 + 1000));
1218 result.record_write(entity);
1219 });
1220
1221 assert!(result.all_succeeded(), "conflict-free batch should succeed");
1222 assert_eq!(result.success_count, 10);
1223 assert_eq!(result.failure_count, 0);
1224 assert_eq!(result.reexecution_count, 0);
1225 assert!(result.parallel_executed);
1226 assert_eq!(result.conflict_cluster_count, 0);
1227 assert_eq!(result.largest_cluster_size, 0);
1228 }
1229
1230 #[test]
1231 fn test_no_conflicts_no_reexecution() {
1232 let executor = ParallelExecutor::new(4);
1234 let ops: Vec<String> = (0..8).map(|i| format!("op{i}")).collect();
1235 let batch = BatchRequest::new(ops);
1236
1237 let result = executor.execute_batch(batch, |idx, _, result| {
1238 let entity = EntityId::Node(NodeId::new(idx as u64));
1240 result.record_write(entity);
1241 });
1242
1243 assert!(result.all_succeeded());
1244 assert_eq!(
1245 result.reexecution_count, 0,
1246 "no conflicts means no re-execution"
1247 );
1248 assert_eq!(result.conflict_cluster_count, 0);
1249 }
1250
1251 #[test]
1252 fn test_max_reexecution_rounds_reached() {
1253 let executor = ParallelExecutor::new(2);
1258 let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
1259 let batch = BatchRequest::new(ops);
1260
1261 let shared = EntityId::Node(NodeId::new(999));
1262 let call_count = AtomicUsize::new(0);
1263
1264 let result = executor.execute_batch(batch, |_idx, _, result| {
1265 call_count.fetch_add(1, Ordering::Relaxed);
1266 result.record_read(shared, EpochId::new(0));
1270 result.record_write(shared);
1271 });
1272
1273 assert!(result.all_succeeded());
1278
1279 let total_calls = call_count.load(Ordering::Relaxed);
1282 assert!(
1283 total_calls >= 10,
1284 "expected at least 10 calls (one per op), got {total_calls}"
1285 );
1286 }
1287
1288 #[test]
1289 fn test_small_batch_uses_sequential() {
1290 let executor = ParallelExecutor::new(4);
1292 let batch = BatchRequest::new(vec!["a", "b", "c"]);
1293
1294 let result = executor.execute_batch(batch, |idx, _, result| {
1295 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
1296 });
1297
1298 assert!(result.all_succeeded());
1299 assert!(
1300 !result.parallel_executed,
1301 "batch of 3 should use sequential"
1302 );
1303 assert_eq!(result.reexecution_count, 0);
1304 }
1305
1306 #[test]
1307 fn test_conflict_partitioner_preserves_dependency_order() {
1308 let entity_a = EntityId::Node(NodeId::new(1));
1311
1312 let read_sets = vec![
1313 HashSet::from([(entity_a, EpochId::new(0))]),
1314 HashSet::from([(entity_a, EpochId::new(0))]),
1315 HashSet::from([(entity_a, EpochId::new(0))]),
1316 ];
1317 let write_sets = vec![
1318 HashSet::from([entity_a]),
1319 HashSet::from([entity_a]),
1320 HashSet::from([entity_a]),
1321 ];
1322
1323 let invalid = vec![2, 0, 1];
1325 let (clusters, _) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
1326
1327 assert_eq!(clusters.len(), 1);
1328 assert_eq!(clusters[0], vec![0, 1, 2]);
1330 }
1331
1332 #[test]
1333 fn test_write_tracker_no_earlier_writer_for_unwritten_entity() {
1334 let tracker = WriteTracker::default();
1335 tracker.record_write(EntityId::Node(NodeId::new(1)), 5);
1336
1337 assert_eq!(
1339 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(99)), 10),
1340 None
1341 );
1342 }
1343
1344 #[test]
1345 fn test_execution_result_mark_failed() {
1346 let mut result = ExecutionResult::new(0);
1347 assert_eq!(result.status, ExecutionStatus::Success);
1348 assert!(result.error.is_none());
1349
1350 result.mark_failed("test error".to_string());
1351 assert_eq!(result.status, ExecutionStatus::Failed);
1352 assert_eq!(result.error.as_deref(), Some("test error"));
1353 }
1354
1355 #[test]
1356 fn test_parallel_executor_num_workers() {
1357 let executor = ParallelExecutor::new(8);
1358 assert_eq!(executor.num_workers(), 8);
1359 }
1360
1361 #[test]
1362 fn test_default_workers() {
1363 let executor = ParallelExecutor::default_workers();
1364 assert!(executor.num_workers() >= 1);
1365 }
1366
1367 #[test]
1368 fn test_batch_result_failed_indices_empty_when_all_succeed() {
1369 let executor = ParallelExecutor::new(2);
1370 let batch = BatchRequest::new(vec!["a", "b", "c", "d"]);
1371
1372 let result = executor.execute_batch(batch, |idx, _, result| {
1373 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
1374 });
1375
1376 let failed: Vec<usize> = result.failed_indices().collect();
1377 assert!(failed.is_empty());
1378 }
1379
1380 #[test]
1381 fn test_batch_result_multiple_failures() {
1382 let executor = ParallelExecutor::new(2);
1383 let batch = BatchRequest::new(vec!["ok1", "fail1", "ok2", "fail2", "ok3"]);
1384
1385 let result = executor.execute_batch(batch, |idx, op, result| {
1386 if op.starts_with("fail") {
1387 result.mark_failed(format!("error at {idx}"));
1388 } else {
1389 result.record_write(EntityId::Node(NodeId::new(idx as u64 + 500)));
1390 }
1391 });
1392
1393 assert!(!result.all_succeeded());
1394 assert_eq!(result.failure_count, 2);
1395 assert_eq!(result.success_count, 3);
1396 let failed: Vec<usize> = result.failed_indices().collect();
1397 assert_eq!(failed, vec![1, 3]);
1398 }
1399}