use std::collections::HashSet;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use grafeo_common::types::EpochId;
use grafeo_common::utils::hash::FxHashMap;
use parking_lot::{Mutex, RwLock};
use rayon::prelude::*;
use super::EntityId;
const MAX_REEXECUTION_ROUNDS: usize = 10;
const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExecutionStatus {
Success,
NeedsRevalidation,
Reexecuted,
Failed,
}
#[derive(Debug)]
pub struct ExecutionResult {
pub batch_index: usize,
pub status: ExecutionStatus,
pub read_set: HashSet<(EntityId, EpochId)>,
pub write_set: HashSet<EntityId>,
pub dependencies: Vec<usize>,
pub reexecution_count: usize,
pub error: Option<String>,
}
impl ExecutionResult {
fn new(batch_index: usize) -> Self {
Self {
batch_index,
status: ExecutionStatus::Success,
read_set: HashSet::new(),
write_set: HashSet::new(),
dependencies: Vec::new(),
reexecution_count: 0,
error: None,
}
}
pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
self.read_set.insert((entity, epoch));
}
pub fn record_write(&mut self, entity: EntityId) {
self.write_set.insert(entity);
}
pub fn mark_needs_revalidation(&mut self) {
self.status = ExecutionStatus::NeedsRevalidation;
}
pub fn mark_reexecuted(&mut self) {
self.status = ExecutionStatus::Reexecuted;
self.reexecution_count += 1;
}
pub fn mark_failed(&mut self, error: String) {
self.status = ExecutionStatus::Failed;
self.error = Some(error);
}
}
#[derive(Debug, Clone)]
pub struct BatchRequest {
pub operations: Vec<String>,
}
impl BatchRequest {
pub fn new(operations: Vec<impl Into<String>>) -> Self {
Self {
operations: operations.into_iter().map(Into::into).collect(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.operations.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
}
#[derive(Debug)]
pub struct BatchResult {
pub results: Vec<ExecutionResult>,
pub success_count: usize,
pub failure_count: usize,
pub reexecution_count: usize,
pub parallel_executed: bool,
}
impl BatchResult {
#[must_use]
pub fn all_succeeded(&self) -> bool {
self.failure_count == 0
}
pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
self.results
.iter()
.filter(|r| r.status == ExecutionStatus::Failed)
.map(|r| r.batch_index)
}
}
#[derive(Debug, Default)]
struct WriteTracker {
writes: RwLock<FxHashMap<EntityId, usize>>,
}
impl WriteTracker {
fn record_write(&self, entity: EntityId, batch_index: usize) {
let mut writes = self.writes.write();
writes
.entry(entity)
.and_modify(|existing| *existing = (*existing).min(batch_index))
.or_insert(batch_index);
}
fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
let writes = self.writes.read();
if let Some(&writer) = writes.get(entity)
&& writer < batch_index
{
return Some(writer);
}
None
}
}
pub struct ParallelExecutor {
num_workers: usize,
pool: rayon::ThreadPool,
}
impl ParallelExecutor {
#[must_use]
pub fn new(num_workers: usize) -> Self {
assert!(num_workers > 0, "num_workers must be positive");
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_workers)
.build()
.expect("failed to build thread pool");
Self { num_workers, pool }
}
#[must_use]
pub fn default_workers() -> Self {
Self::new(rayon::current_num_threads().max(1))
}
#[must_use]
pub fn num_workers(&self) -> usize {
self.num_workers
}
pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
where
F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
{
let n = batch.len();
if n == 0 {
return BatchResult {
results: Vec::new(),
success_count: 0,
failure_count: 0,
reexecution_count: 0,
parallel_executed: false,
};
}
if n < MIN_BATCH_SIZE_FOR_PARALLEL {
return self.execute_sequential(batch, execute_fn);
}
let write_tracker = Arc::new(WriteTracker::default());
let results: Vec<Mutex<ExecutionResult>> = (0..n)
.map(|i| Mutex::new(ExecutionResult::new(i)))
.collect();
self.pool.install(|| {
batch
.operations
.par_iter()
.enumerate()
.for_each(|(idx, op)| {
let mut result = results[idx].lock();
execute_fn(idx, op, &mut result);
for entity in &result.write_set {
write_tracker.record_write(*entity, idx);
}
});
});
let mut invalid_indices = Vec::new();
for (idx, result_mutex) in results.iter().enumerate() {
let mut result = result_mutex.lock();
let read_entities: Vec<EntityId> =
result.read_set.iter().map(|(entity, _)| *entity).collect();
for entity in read_entities {
if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
result.mark_needs_revalidation();
result.dependencies.push(writer);
}
}
if result.status == ExecutionStatus::NeedsRevalidation {
invalid_indices.push(idx);
}
}
let conflict_rate = invalid_indices.len() as f64 / n as f64;
if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
return self.execute_sequential(batch, execute_fn);
}
let total_reexecutions = AtomicUsize::new(0);
for round in 0..MAX_REEXECUTION_ROUNDS {
if invalid_indices.is_empty() {
break;
}
let still_invalid: Vec<usize> = self.pool.install(|| {
invalid_indices
.par_iter()
.filter_map(|&idx| {
let mut result = results[idx].lock();
result.read_set.clear();
result.write_set.clear();
result.dependencies.clear();
execute_fn(idx, &batch.operations[idx], &mut result);
result.mark_reexecuted();
total_reexecutions.fetch_add(1, Ordering::Relaxed);
let read_entities: Vec<EntityId> =
result.read_set.iter().map(|(entity, _)| *entity).collect();
for entity in read_entities {
if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
{
result.mark_needs_revalidation();
result.dependencies.push(writer);
return Some(idx);
}
}
result.status = ExecutionStatus::Success;
None
})
.collect()
});
invalid_indices = still_invalid;
if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
for idx in &invalid_indices {
let mut result = results[*idx].lock();
result.mark_failed("Max re-execution rounds reached".to_string());
}
}
}
let mut final_results: Vec<ExecutionResult> =
results.into_iter().map(|m| m.into_inner()).collect();
final_results.sort_by_key(|r| r.batch_index);
let success_count = final_results
.iter()
.filter(|r| r.status != ExecutionStatus::Failed)
.count();
BatchResult {
failure_count: n - success_count,
success_count,
reexecution_count: total_reexecutions.load(Ordering::Relaxed),
parallel_executed: true,
results: final_results,
}
}
fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
where
F: Fn(usize, &str, &mut ExecutionResult),
{
let mut results = Vec::with_capacity(batch.len());
for (idx, op) in batch.operations.iter().enumerate() {
let mut result = ExecutionResult::new(idx);
execute_fn(idx, op, &mut result);
results.push(result);
}
let success_count = results
.iter()
.filter(|r| r.status != ExecutionStatus::Failed)
.count();
BatchResult {
failure_count: results.len() - success_count,
success_count,
reexecution_count: 0,
parallel_executed: false,
results,
}
}
}
impl Default for ParallelExecutor {
fn default() -> Self {
Self::default_workers()
}
}
#[cfg(test)]
mod tests {
use super::*;
use grafeo_common::types::NodeId;
use std::sync::atomic::AtomicU64;
use std::thread;
use std::time::Duration;
#[test]
fn test_empty_batch() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(Vec::<String>::new());
let result = executor.execute_batch(batch, |_, _, _| {});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 0);
}
#[test]
fn test_single_operation() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
let result = executor.execute_batch(batch, |_, _, result| {
result.record_write(EntityId::Node(NodeId::new(1)));
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 1);
assert!(!result.parallel_executed);
}
#[test]
fn test_independent_operations() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"CREATE (n1:Test {id: 1})",
"CREATE (n2:Test {id: 2})",
"CREATE (n3:Test {id: 3})",
"CREATE (n4:Test {id: 4})",
"CREATE (n5:Test {id: 5})",
]);
let counter = AtomicU64::new(0);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
counter.fetch_add(1, Ordering::Relaxed);
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 5);
assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
assert_eq!(counter.load(Ordering::Relaxed), 5);
}
#[test]
fn test_conflicting_operations() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"UPDATE (n:Test) SET n.value = 1",
"UPDATE (n:Test) SET n.value = 2",
"UPDATE (n:Test) SET n.value = 3",
"UPDATE (n:Test) SET n.value = 4",
"UPDATE (n:Test) SET n.value = 5",
]);
let shared_entity = EntityId::Node(NodeId::new(100));
let result = executor.execute_batch(batch, |_idx, _, result| {
result.record_read(shared_entity, EpochId::new(0));
result.record_write(shared_entity);
thread::sleep(Duration::from_micros(10));
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 5);
assert!(result.reexecution_count > 0 || !result.parallel_executed);
}
#[test]
fn test_partial_conflicts() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec![
"op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
]);
let result = executor.execute_batch(batch, |idx, _, result| {
let entity = EntityId::Node(NodeId::new(idx as u64));
result.record_write(entity);
});
assert!(result.all_succeeded());
assert_eq!(result.results.len(), 10);
assert!(result.parallel_executed);
assert_eq!(result.reexecution_count, 0);
}
#[test]
fn test_execution_order_preserved() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
let result = executor.execute_batch(batch, |idx, _, result| {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
});
for (i, r) in result.results.iter().enumerate() {
assert_eq!(
r.batch_index, i,
"Result at position {} has wrong batch_index",
i
);
}
}
#[test]
fn test_failure_handling() {
let executor = ParallelExecutor::new(4);
let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
let result = executor.execute_batch(batch, |idx, op, result| {
if op == "fail" {
result.mark_failed("Intentional failure".to_string());
} else {
result.record_write(EntityId::Node(NodeId::new(idx as u64)));
}
});
assert!(!result.all_succeeded());
assert_eq!(result.failure_count, 1);
assert_eq!(result.success_count, 4);
let failed: Vec<usize> = result.failed_indices().collect();
assert_eq!(failed, vec![1]);
}
#[test]
fn test_write_tracker() {
let tracker = WriteTracker::default();
tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
tracker.record_write(EntityId::Node(NodeId::new(1)), 2);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
Some(0)
);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
Some(1)
);
assert_eq!(
tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
None
);
}
#[test]
fn test_batch_request() {
let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
assert_eq!(batch.len(), 3);
assert!(!batch.is_empty());
let empty_batch = BatchRequest::new(Vec::<String>::new());
assert!(empty_batch.is_empty());
}
#[test]
fn test_execution_result() {
let mut result = ExecutionResult::new(5);
assert_eq!(result.batch_index, 5);
assert_eq!(result.status, ExecutionStatus::Success);
assert!(result.read_set.is_empty());
assert!(result.write_set.is_empty());
result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
result.record_write(EntityId::Node(NodeId::new(2)));
assert_eq!(result.read_set.len(), 1);
assert_eq!(result.write_set.len(), 1);
result.mark_needs_revalidation();
assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
result.mark_reexecuted();
assert_eq!(result.status, ExecutionStatus::Reexecuted);
assert_eq!(result.reexecution_count, 1);
}
}