use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use grafeo_common::types::EpochId;
use grafeo_common::utils::hash::FxHashMap;
use parking_lot::{Mutex, RwLock};
use rayon::prelude::*;
use super::EntityId;
const MAX_REEXECUTION_ROUNDS: usize = 10;
const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
const CLUSTER_SKIP_THRESHOLD: f64 = 0.8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExecutionStatus {
Success,
NeedsRevalidation,
Reexecuted,
Failed,
}
#[derive(Debug)]
pub struct ExecutionResult {
pub batch_index: usize,
pub status: ExecutionStatus,
pub read_set: HashSet<(EntityId, EpochId)>,
pub write_set: HashSet<EntityId>,
pub dependencies: Vec<usize>,
pub reexecution_count: usize,
pub error: Option<String>,
}
impl ExecutionResult {
fn new(batch_index: usize) -> Self {
Self {
batch_index,
status: ExecutionStatus::Success,
read_set: HashSet::new(),
write_set: HashSet::new(),
dependencies: Vec::new(),
reexecution_count: 0,
error: None,
}
}
pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
self.read_set.insert((entity, epoch));
}
pub fn record_write(&mut self, entity: EntityId) {
self.write_set.insert(entity);
}
pub fn mark_needs_revalidation(&mut self) {
self.status = ExecutionStatus::NeedsRevalidation;
}
pub fn mark_reexecuted(&mut self) {
self.status = ExecutionStatus::Reexecuted;
self.reexecution_count += 1;
}
pub fn mark_failed(&mut self, error: String) {
self.status = ExecutionStatus::Failed;
self.error = Some(error);
}
}
#[derive(Debug, Clone)]
pub struct BatchRequest {
pub operations: Vec<String>,
}
impl BatchRequest {
pub fn new(operations: Vec<impl Into<String>>) -> Self {
Self {
operations: operations.into_iter().map(Into::into).collect(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.operations.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
}
#[derive(Debug)]
pub struct BatchResult {
pub results: Vec<ExecutionResult>,
pub success_count: usize,
pub failure_count: usize,
pub reexecution_count: usize,
pub parallel_executed: bool,
pub conflict_cluster_count: usize,
pub largest_cluster_size: usize,
}
impl BatchResult {
#[must_use]
pub fn all_succeeded(&self) -> bool {
self.failure_count == 0
}
pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
self.results
.iter()
.filter(|r| r.status == ExecutionStatus::Failed)
.map(|r| r.batch_index)
}
}
#[derive(Debug, Default)]
struct WriteTracker {
writes: RwLock<FxHashMap<EntityId, usize>>,
}
impl WriteTracker {
fn record_write(&self, entity: EntityId, batch_index: usize) {
let mut writes = self.writes.write();
writes
.entry(entity)
.and_modify(|existing| *existing = (*existing).min(batch_index))
.or_insert(batch_index);
}
fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
let writes = self.writes.read();
if let Some(&writer) = writes.get(entity)
&& writer < batch_index
{
return Some(writer);
}
None
}
}
struct ConflictPartitioner {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl ConflictPartitioner {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]);
}
self.parent[x]
}
fn union(&mut self, a: usize, b: usize) {
let ra = self.find(a);
let rb = self.find(b);
if ra == rb {
return;
}
match self.rank[ra].cmp(&self.rank[rb]) {
std::cmp::Ordering::Less => self.parent[ra] = rb,
std::cmp::Ordering::Greater => self.parent[rb] = ra,
std::cmp::Ordering::Equal => {
self.parent[rb] = ra;
self.rank[ra] += 1;
}
}
}
fn partition(
read_sets: &[HashSet<(EntityId, EpochId)>],
write_sets: &[HashSet<EntityId>],
invalid_indices: &[usize],
) -> (Vec<Vec<usize>>, usize) {
if invalid_indices.is_empty() {
return (Vec::new(), 0);
}
let index_to_compact: FxHashMap<usize, usize> = invalid_indices
.iter()
.enumerate()
.map(|(compact, &orig)| (orig, compact))
.collect();
let n = invalid_indices.len();
let mut uf = ConflictPartitioner::new(n);
let mut entity_writers: FxHashMap<EntityId, Vec<usize>> = FxHashMap::default();
for &orig_idx in invalid_indices {
let compact = index_to_compact[&orig_idx];
for entity in &write_sets[orig_idx] {
entity_writers.entry(*entity).or_default().push(compact);
}
}
for &orig_idx in invalid_indices {
let compact = index_to_compact[&orig_idx];
for (entity, _epoch) in &read_sets[orig_idx] {
if let Some(writers) = entity_writers.get(entity) {
for &writer_compact in writers {
if writer_compact != compact {
uf.union(compact, writer_compact);
}
}
}
}
for entity in &write_sets[orig_idx] {
if let Some(writers) = entity_writers.get(entity) {
for &writer_compact in writers {
if writer_compact != compact {
uf.union(compact, writer_compact);
}
}
}
}
}
let mut cluster_map: FxHashMap<usize, Vec<usize>> = FxHashMap::default();
for (compact, &orig_idx) in invalid_indices.iter().enumerate() {
let root = uf.find(compact);
cluster_map.entry(root).or_default().push(orig_idx);
}
let mut clusters: Vec<Vec<usize>> = cluster_map.into_values().collect();
for cluster in &mut clusters {
cluster.sort_unstable();
}
let largest = clusters.iter().map(Vec::len).max().unwrap_or(0);
(clusters, largest)
}
}
pub struct ParallelExecutor {
num_workers: usize,
pool: rayon::ThreadPool,
}
impl ParallelExecutor {
#[must_use]
pub fn new(num_workers: usize) -> Self {
assert!(num_workers > 0, "num_workers must be positive");
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_workers)
.build()
.expect("failed to build thread pool");
Self { num_workers, pool }
}
#[must_use]
pub fn default_workers() -> Self {
Self::new(rayon::current_num_threads().max(1))
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
where
F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
{
let n = batch.len();
if n == 0 {
return BatchResult {
results: Vec::new(),
success_count: 0,
failure_count: 0,
reexecution_count: 0,
parallel_executed: false,
conflict_cluster_count: 0,
largest_cluster_size: 0,
};
}
if n < MIN_BATCH_SIZE_FOR_PARALLEL {
return self.execute_sequential(batch, execute_fn);
}
let write_tracker = Arc::new(WriteTracker::default());
let results: Vec<Mutex<ExecutionResult>> = (0..n)
.map(|i| Mutex::new(ExecutionResult::new(i)))
.collect();
self.pool.install(|| {
batch
.operations
.par_iter()
.enumerate()
.for_each(|(idx, op)| {
let mut result = results[idx].lock();
execute_fn(idx, op, &mut result);
for entity in &result.write_set {
write_tracker.record_write(*entity, idx);
}
});
});
let mut invalid_indices = Vec::new();
for (idx, result_mutex) in results.iter().enumerate() {
let mut result = result_mutex.lock();
let read_entities: Vec<EntityId> =
result.read_set.iter().map(|(entity, _)| *entity).collect();
for entity in read_entities {
if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
result.mark_needs_revalidation();
result.dependencies.push(writer);
}
}
if result.status == ExecutionStatus::NeedsRevalidation {
invalid_indices.push(idx);
}
}
let conflict_rate = invalid_indices.len() as f64 / n as f64;
if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
return self.execute_sequential(batch, execute_fn);
}
let total_reexecutions = AtomicUsize::new(0);
let all_read_sets: Vec<HashSet<(EntityId, EpochId)>> =
results.iter().map(|r| r.lock().read_set.clone()).collect();
let all_write_sets: Vec<HashSet<EntityId>> =
results.iter().map(|r| r.lock().write_set.clone()).collect();
let (clusters, largest_cluster) =
ConflictPartitioner::partition(&all_read_sets, &all_write_sets, &invalid_indices);
let use_clusters = !clusters.is_empty()
&& (largest_cluster as f64 / invalid_indices.len().max(1) as f64)
<= CLUSTER_SKIP_THRESHOLD;
if use_clusters {
self.pool.install(|| {
clusters.par_iter().for_each(|cluster| {
for &idx in cluster {
let mut result = results[idx].lock();
result.read_set.clear();
result.write_set.clear();
result.dependencies.clear();
execute_fn(idx, &batch.operations[idx], &mut result);
result.mark_reexecuted();
total_reexecutions.fetch_add(1, Ordering::Relaxed);
for entity in &result.write_set {
write_tracker.record_write(*entity, idx);
}
result.status = ExecutionStatus::Success;
}
});
});
} else {
for round in 0..MAX_REEXECUTION_ROUNDS {
if invalid_indices.is_empty() {
break;
}
let still_invalid: Vec<usize> = self.pool.install(|| {
invalid_indices
.par_iter()
.filter_map(|&idx| {
let mut result = results[idx].lock();
result.read_set.clear();
result.write_set.clear();
result.dependencies.clear();
execute_fn(idx, &batch.operations[idx], &mut result);
result.mark_reexecuted();
total_reexecutions.fetch_add(1, Ordering::Relaxed);
let read_entities: Vec<EntityId> =
result.read_set.iter().map(|(entity, _)| *entity).collect();
for entity in read_entities {
if let Some(writer) =
write_tracker.was_written_by_earlier(&entity, idx)
{
result.mark_needs_revalidation();
result.dependencies.push(writer);
return Some(idx);
}
}
result.status = ExecutionStatus::Success;
None
})
.collect()
});
invalid_indices = still_invalid;
if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
for idx in &invalid_indices {
let mut result = results[*idx].lock();
result.mark_failed("Max re-execution rounds reached".to_string());
}
}
}
}
let mut final_results: Vec<ExecutionResult> =
results.into_iter().map(|m| m.into_inner()).collect();
final_results.sort_by_key(|r| r.batch_index);
let success_count = final_results
.iter()
.filter(|r| r.status != ExecutionStatus::Failed)
.count();
BatchResult {
failure_count: n - success_count,
success_count,
reexecution_count: total_reexecutions.load(Ordering::Relaxed),
parallel_executed: true,
conflict_cluster_count: clusters.len(),
largest_cluster_size: largest_cluster,
results: final_results,
}
}
fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
where
F: Fn(usize, &str, &mut ExecutionResult),
{
let mut results = Vec::with_capacity(batch.len());
for (idx, op) in batch.operations.iter().enumerate() {
let mut result = ExecutionResult::new(idx);
execute_fn(idx, op, &mut result);
results.push(result);
}
let success_count = results
.iter()
.filter(|r| r.status != ExecutionStatus::Failed)
.count();
BatchResult {
failure_count: results.len() - success_count,
success_count,
reexecution_count: 0,
parallel_executed: false,
conflict_cluster_count: 0,
largest_cluster_size: 0,
results,
}
}
}
impl Default for ParallelExecutor {
fn default() -> Self {
Self::default_workers()
}
}
#[cfg(test)]
mod tests {
use super::*;
use grafeo_common::types::NodeId;
use std::sync::atomic::AtomicU64;
use std::thread;
use std::time::Duration;
#[test]
fn test_empty_batch() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(Vec::<String>::new());
let result = executor.execute_batch(batch, |_, _, _| {});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 0);
}
#[test]
fn test_single_operation() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
let result = executor.execute_batch(batch, |_, _, result| {
result.record_write(EntityId::Node(NodeId::new(1)));
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 1);
assert!(!result.parallel_executed);
}
#[test]
fn test_independent_operations() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"CREATE (n1:Test {id: 1})",
"CREATE (n2:Test {id: 2})",
"CREATE (n3:Test {id: 3})",
"CREATE (n4:Test {id: 4})",
"CREATE (n5:Test {id: 5})",
]);
let counter = AtomicU64::new(0);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
counter.fetch_add(1, Ordering::Relaxed);
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 5);
assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
assert_eq!(counter.load(Ordering::Relaxed), 5);
}
#[test]
fn test_conflicting_operations() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"UPDATE (n:Test) SET n.value = 1",
"UPDATE (n:Test) SET n.value = 2",
"UPDATE (n:Test) SET n.value = 3",
"UPDATE (n:Test) SET n.value = 4",
"UPDATE (n:Test) SET n.value = 5",
]);
let shared_entity = EntityId::Node(NodeId::new(100));
let result = executor.execute_batch(batch, |_idx, _, result| {
result.record_read(shared_entity, EpochId::new(0));
result.record_write(shared_entity);
thread::sleep(Duration::from_micros(10));
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 5);
assert!(result.reexecution_count > 0 || !result.parallel_executed);
}
#[test]
fn test_partial_conflicts() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
]);
let result = executor.execute_batch(batch, |idx, _, result| {
let entity = EntityId::Node(NodeId::new(idx as u64));
result.record_write(entity);
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 10);
assert!(result.parallel_executed);
assert_eq!(result.reexecution_count, 0);
}
#[test]
fn test_execution_order_preserved() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
});
for (i, r) in result.results.iter().enumerate() {
assert_eq!(
r.batch_index, i,
"Result at position {} has wrong batch_index",
i
);
}
}
#[test]
fn test_failure_handling() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
let result = executor.execute_batch(batch, |idx, op, result| {
if op == "fail" {
result.mark_failed("Intentional failure".to_string());
} else {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
}
});
assert!(!result.all_succeeded());
assert_eq!(result.failure_count, 1);
assert_eq!(result.success_count, 4);
let failed: Vec<usize> = result.failed_indices().collect();
assert_eq!(failed, vec![1]);
}
#[test]
fn test_write_tracker() {
let tracker = WriteTracker::default();
tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
tracker.record_write(EntityId::Node(NodeId::new(1)), 2);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
Some(0)
);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
Some(1)
);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
None
);
}
#[test]
fn test_batch_request() {
let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
assert_eq!(batch.len(), 3);
assert!(!batch.is_empty());
let empty_batch = BatchRequest::new(Vec::<String>::new());
assert!(empty_batch.is_empty());
}
#[test]
fn test_execution_result() {
let mut result = ExecutionResult::new(5);
assert_eq!(result.batch_index, 5);
assert_eq!(result.status, ExecutionStatus::Success);
assert!(result.read_set.is_empty());
assert!(result.write_set.is_empty());
result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
result.record_write(EntityId::Node(NodeId::new(2)));
assert_eq!(result.read_set.len(), 1);
assert_eq!(result.write_set.len(), 1);
result.mark_needs_revalidation();
assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
result.mark_reexecuted();
assert_eq!(result.status, ExecutionStatus::Reexecuted);
assert_eq!(result.reexecution_count, 1);
}
#[test]
fn test_partitioner_empty() {
let (clusters, largest) = ConflictPartitioner::partition(&[], &[], &[]);
assert!(clusters.is_empty());
assert_eq!(largest, 0);
}
#[test]
fn test_partitioner_disjoint_clusters() {
let entity_a = EntityId::Node(NodeId::new(100));
let entity_b = EntityId::Node(NodeId::new(200));
let read_sets = vec![
HashSet::from([(entity_a, EpochId::new(0))]),
HashSet::new(),
HashSet::from([(entity_b, EpochId::new(0))]),
HashSet::new(),
];
let write_sets = vec![
HashSet::from([entity_a]),
HashSet::from([entity_a]),
HashSet::from([entity_b]),
HashSet::from([entity_b]),
];
let invalid = vec![0, 1, 2, 3];
let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
assert_eq!(clusters.len(), 2, "should produce 2 disjoint clusters");
assert_eq!(largest, 2, "each cluster has 2 transactions");
let all: HashSet<usize> = clusters.iter().flat_map(|c| c.iter().copied()).collect();
assert_eq!(all, HashSet::from([0, 1, 2, 3]));
}
#[test]
fn test_partitioner_single_cluster() {
let entity_a = EntityId::Node(NodeId::new(42));
let read_sets = vec![
HashSet::from([(entity_a, EpochId::new(0))]),
HashSet::from([(entity_a, EpochId::new(0))]),
HashSet::from([(entity_a, EpochId::new(0))]),
];
let write_sets = vec![
HashSet::from([entity_a]),
HashSet::from([entity_a]),
HashSet::from([entity_a]),
];
let invalid = vec![0, 1, 2];
let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
assert_eq!(clusters.len(), 1, "all share the same entity");
assert_eq!(largest, 3);
assert_eq!(clusters[0], vec![0, 1, 2]);
}
#[test]
fn test_partitioner_chain_merges() {
let entity_a = EntityId::Node(NodeId::new(10));
let entity_b = EntityId::Node(NodeId::new(20));
let read_sets = vec![HashSet::new(), HashSet::new(), HashSet::new()];
let write_sets = vec![
HashSet::from([entity_a]),
HashSet::from([entity_a, entity_b]),
HashSet::from([entity_b]),
];
let invalid = vec![0, 1, 2];
let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
assert_eq!(clusters.len(), 1, "chain should merge into one cluster");
assert_eq!(largest, 3);
}
#[test]
fn test_partitioner_read_write_conflict() {
let entity_a = EntityId::Node(NodeId::new(50));
let read_sets = vec![HashSet::new(), HashSet::from([(entity_a, EpochId::new(0))])];
let write_sets = vec![HashSet::from([entity_a]), HashSet::new()];
let invalid = vec![0, 1];
let (clusters, largest) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
assert_eq!(clusters.len(), 1, "read-write overlap merges clusters");
assert_eq!(largest, 2);
}
#[test]
fn test_partitioner_subset_of_transactions() {
let entity_a = EntityId::Node(NodeId::new(1));
let entity_b = EntityId::Node(NodeId::new(2));
let read_sets = vec![
HashSet::new(),
HashSet::new(),
HashSet::from([(entity_a, EpochId::new(0))]),
HashSet::new(),
HashSet::new(),
HashSet::from([(entity_b, EpochId::new(0))]),
];
let write_sets = vec![
HashSet::new(),
HashSet::new(),
HashSet::from([entity_a]),
HashSet::new(),
HashSet::new(),
HashSet::from([entity_b]),
];
let invalid = vec![2, 5];
let (clusters, _) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
assert_eq!(
clusters.len(),
2,
"non-overlapping invalid txns form separate clusters"
);
}
#[test]
fn test_cluster_based_reexecution() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"g1_op1", "g1_op2", "g2_op1", "g2_op2", "ind1", "ind2", "ind3", "ind4",
]);
let entity_a = EntityId::Node(NodeId::new(100));
let entity_b = EntityId::Node(NodeId::new(200));
let result = executor.execute_batch(batch, |idx, _, result| {
match idx {
0 | 1 => {
result.record_read(entity_a, EpochId::new(0));
result.record_write(entity_a);
}
2 | 3 => {
result.record_read(entity_b, EpochId::new(0));
result.record_write(entity_b);
}
_ => {
result.record_write(EntityId::Node(NodeId::new(idx as u64 + 1000)));
}
}
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 8);
assert!(result.parallel_executed);
}
#[test]
fn test_cluster_metrics_reported() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["a", "b", "c", "d", "e", "f", "g", "h"]);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
});
assert_eq!(result.conflict_cluster_count, 0);
assert_eq!(result.largest_cluster_size, 0);
assert_eq!(result.reexecution_count, 0);
}
#[test]
fn test_union_find_correctness() {
let mut uf = ConflictPartitioner::new(6);
uf.union(0, 1);
uf.union(2, 3);
uf.union(4, 5);
assert_eq!(uf.find(0), uf.find(1));
assert_eq!(uf.find(2), uf.find(3));
assert_eq!(uf.find(4), uf.find(5));
assert_ne!(uf.find(0), uf.find(2));
assert_ne!(uf.find(0), uf.find(4));
uf.union(1, 3);
assert_eq!(uf.find(0), uf.find(2));
assert_eq!(uf.find(0), uf.find(3));
assert_ne!(uf.find(0), uf.find(4));
}
#[test]
fn test_cluster_reexecution_resolves_conflicts() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..8).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let entity_a = EntityId::Node(NodeId::new(100));
let entity_b = EntityId::Node(NodeId::new(200));
let result = executor.execute_batch(batch, |idx, _, result| match idx {
0 | 1 => {
result.record_read(entity_a, EpochId::new(0));
result.record_write(entity_a);
}
2 | 3 => {
result.record_read(entity_b, EpochId::new(0));
result.record_write(entity_b);
}
_ => {
result.record_write(EntityId::Node(NodeId::new(idx as u64 + 1000)));
}
});
assert!(result.all_succeeded(), "all operations should succeed");
assert!(result.parallel_executed, "should use parallel execution");
assert!(
result.conflict_cluster_count > 0,
"should detect conflict clusters for shared entities"
);
assert!(
result.reexecution_count > 0,
"should re-execute conflicting operations"
);
}
#[test]
fn test_large_single_cluster_falls_back() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let shared = EntityId::Node(NodeId::new(999));
let result = executor.execute_batch(batch, |_idx, _, result| {
result.record_read(shared, EpochId::new(0));
result.record_write(shared);
});
assert!(result.all_succeeded(), "all operations should succeed");
if result.parallel_executed {
assert!(
result.largest_cluster_size >= 8,
"largest cluster should exceed 80% threshold, got {}",
result.largest_cluster_size
);
}
}
#[test]
fn test_sequential_fallback_high_conflict_rate() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..5).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let shared = EntityId::Node(NodeId::new(42));
let result = executor.execute_batch(batch, |_idx, _, result| {
result.record_read(shared, EpochId::new(0));
result.record_write(shared);
});
assert!(result.all_succeeded(), "all operations should succeed");
if result.parallel_executed {
assert!(
result.reexecution_count > 0,
"parallel path with 100% conflicts must trigger re-execution"
);
}
}
#[test]
fn test_cluster_skip_threshold_large_cluster() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let shared = EntityId::Node(NodeId::new(1));
let result = executor.execute_batch(batch, |_idx, _, result| {
result.record_read(shared, EpochId::new(0));
result.record_write(shared);
});
assert!(result.all_succeeded());
assert!(
result.largest_cluster_size >= result.conflict_cluster_count,
"largest cluster should dominate"
);
}
#[test]
fn test_multiple_disjoint_clusters() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..8).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let entity_a = EntityId::Node(NodeId::new(100));
let entity_b = EntityId::Node(NodeId::new(200));
let result = executor.execute_batch(batch, |idx, _, result| {
let entity = if idx % 2 == 0 { entity_a } else { entity_b };
result.record_read(entity, EpochId::new(0));
result.record_write(entity);
});
assert!(result.all_succeeded());
if result.conflict_cluster_count > 1 {
assert!(
result.largest_cluster_size < 8,
"with disjoint conflicts, no single cluster should contain all transactions"
);
}
}
#[test]
fn test_batch_result_metrics_fields() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let result = executor.execute_batch(batch, |idx, _, result| {
let entity = EntityId::Node(NodeId::new(idx as u64 + 1000));
result.record_write(entity);
});
assert!(result.all_succeeded(), "conflict-free batch should succeed");
assert_eq!(result.success_count, 10);
assert_eq!(result.failure_count, 0);
assert_eq!(result.reexecution_count, 0);
assert!(result.parallel_executed);
assert_eq!(result.conflict_cluster_count, 0);
assert_eq!(result.largest_cluster_size, 0);
}
#[test]
fn test_no_conflicts_no_reexecution() {
let executor = ParallelExecutor::new(4);
let ops: Vec<String> = (0..8).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let result = executor.execute_batch(batch, |idx, _, result| {
let entity = EntityId::Node(NodeId::new(idx as u64));
result.record_write(entity);
});
assert!(result.all_succeeded());
assert_eq!(
result.reexecution_count, 0,
"no conflicts means no re-execution"
);
assert_eq!(result.conflict_cluster_count, 0);
}
#[test]
fn test_max_reexecution_rounds_reached() {
let executor = ParallelExecutor::new(2);
let ops: Vec<String> = (0..10).map(|i| format!("op{i}")).collect();
let batch = BatchRequest::new(ops);
let shared = EntityId::Node(NodeId::new(999));
let call_count = AtomicUsize::new(0);
let result = executor.execute_batch(batch, |_idx, _, result| {
call_count.fetch_add(1, Ordering::Relaxed);
result.record_read(shared, EpochId::new(0));
result.record_write(shared);
});
assert!(result.all_succeeded());
let total_calls = call_count.load(Ordering::Relaxed);
assert!(
total_calls >= 10,
"expected at least 10 calls (one per op), got {total_calls}"
);
}
#[test]
fn test_small_batch_uses_sequential() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["a", "b", "c"]);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
});
assert!(result.all_succeeded());
assert!(
!result.parallel_executed,
"batch of 3 should use sequential"
);
assert_eq!(result.reexecution_count, 0);
}
#[test]
fn test_conflict_partitioner_preserves_dependency_order() {
let entity_a = EntityId::Node(NodeId::new(1));
let read_sets = vec![
HashSet::from([(entity_a, EpochId::new(0))]),
HashSet::from([(entity_a, EpochId::new(0))]),
HashSet::from([(entity_a, EpochId::new(0))]),
];
let write_sets = vec![
HashSet::from([entity_a]),
HashSet::from([entity_a]),
HashSet::from([entity_a]),
];
let invalid = vec![2, 0, 1];
let (clusters, _) = ConflictPartitioner::partition(&read_sets, &write_sets, &invalid);
assert_eq!(clusters.len(), 1);
assert_eq!(clusters[0], vec![0, 1, 2]);
}
#[test]
fn test_write_tracker_no_earlier_writer_for_unwritten_entity() {
let tracker = WriteTracker::default();
tracker.record_write(EntityId::Node(NodeId::new(1)), 5);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(99)), 10),
None
);
}
#[test]
fn test_execution_result_mark_failed() {
let mut result = ExecutionResult::new(0);
assert_eq!(result.status, ExecutionStatus::Success);
assert!(result.error.is_none());
result.mark_failed("test error".to_string());
assert_eq!(result.status, ExecutionStatus::Failed);
assert_eq!(result.error.as_deref(), Some("test error"));
}
#[test]
fn test_parallel_executor_num_workers() {
let executor = ParallelExecutor::new(8);
assert_eq!(executor.num_workers(), 8);
}
#[test]
fn test_default_workers() {
let executor = ParallelExecutor::default_workers();
assert!(executor.num_workers() >= 1);
}
#[test]
fn test_batch_result_failed_indices_empty_when_all_succeed() {
let executor = ParallelExecutor::new(2);
let batch = BatchRequest::new(vec!["a", "b", "c", "d"]);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
});
let failed: Vec<usize> = result.failed_indices().collect();
assert!(failed.is_empty());
}
#[test]
fn test_batch_result_multiple_failures() {
let executor = ParallelExecutor::new(2);
let batch = BatchRequest::new(vec!["ok1", "fail1", "ok2", "fail2", "ok3"]);
let result = executor.execute_batch(batch, |idx, op, result| {
if op.starts_with("fail") {
result.mark_failed(format!("error at {idx}"));
} else {
result.record_write(EntityId::Node(NodeId::new(idx as u64 + 500)));
}
});
assert!(!result.all_succeeded());
assert_eq!(result.failure_count, 2);
assert_eq!(result.success_count, 3);
let failed: Vec<usize> = result.failed_indices().collect();
assert_eq!(failed, vec![1, 3]);
}
}