1use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71#[non_exhaustive]
72pub enum ExecutionStatus {
73 Success,
75 NeedsRevalidation,
77 Reexecuted,
79 Failed,
81}
82
83#[derive(Debug)]
85pub struct ExecutionResult {
86 pub batch_index: usize,
88 pub status: ExecutionStatus,
90 pub read_set: HashSet<(EntityId, EpochId)>,
92 pub write_set: HashSet<EntityId>,
94 pub dependencies: Vec<usize>,
96 pub reexecution_count: usize,
98 pub error: Option<String>,
100}
101
102impl ExecutionResult {
103 fn new(batch_index: usize) -> Self {
105 Self {
106 batch_index,
107 status: ExecutionStatus::Success,
108 read_set: HashSet::new(),
109 write_set: HashSet::new(),
110 dependencies: Vec::new(),
111 reexecution_count: 0,
112 error: None,
113 }
114 }
115
116 pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
118 self.read_set.insert((entity, epoch));
119 }
120
121 pub fn record_write(&mut self, entity: EntityId) {
123 self.write_set.insert(entity);
124 }
125
126 pub fn mark_needs_revalidation(&mut self) {
128 self.status = ExecutionStatus::NeedsRevalidation;
129 }
130
131 pub fn mark_reexecuted(&mut self) {
133 self.status = ExecutionStatus::Reexecuted;
134 self.reexecution_count += 1;
135 }
136
137 pub fn mark_failed(&mut self, error: String) {
139 self.status = ExecutionStatus::Failed;
140 self.error = Some(error);
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct BatchRequest {
147 pub operations: Vec<String>,
149}
150
151impl BatchRequest {
152 pub fn new(operations: Vec<impl Into<String>>) -> Self {
154 Self {
155 operations: operations.into_iter().map(Into::into).collect(),
156 }
157 }
158
159 #[must_use]
161 pub fn len(&self) -> usize {
162 self.operations.len()
163 }
164
165 #[must_use]
167 pub fn is_empty(&self) -> bool {
168 self.operations.is_empty()
169 }
170}
171
172#[derive(Debug)]
174pub struct BatchResult {
175 pub results: Vec<ExecutionResult>,
177 pub success_count: usize,
179 pub failure_count: usize,
181 pub reexecution_count: usize,
183 pub parallel_executed: bool,
185}
186
187impl BatchResult {
188 #[must_use]
190 pub fn all_succeeded(&self) -> bool {
191 self.failure_count == 0
192 }
193
194 pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
196 self.results
197 .iter()
198 .filter(|r| r.status == ExecutionStatus::Failed)
199 .map(|r| r.batch_index)
200 }
201}
202
203#[derive(Debug, Default)]
205struct WriteTracker {
206 writes: RwLock<FxHashMap<EntityId, usize>>,
208}
209
210impl WriteTracker {
211 fn record_write(&self, entity: EntityId, batch_index: usize) {
214 let mut writes = self.writes.write();
215 writes
216 .entry(entity)
217 .and_modify(|existing| *existing = (*existing).min(batch_index))
218 .or_insert(batch_index);
219 }
220
221 fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
223 let writes = self.writes.read();
224 if let Some(&writer) = writes.get(entity)
225 && writer < batch_index
226 {
227 return Some(writer);
228 }
229 None
230 }
231}
232
233pub struct ParallelExecutor {
237 num_workers: usize,
239 pool: rayon::ThreadPool,
241}
242
243impl ParallelExecutor {
244 #[must_use]
250 pub fn new(num_workers: usize) -> Self {
251 assert!(num_workers > 0, "num_workers must be positive");
252
253 let pool = rayon::ThreadPoolBuilder::new()
254 .num_threads(num_workers)
255 .build()
256 .expect("failed to build thread pool");
257
258 Self { num_workers, pool }
259 }
260
261 #[must_use]
263 pub fn default_workers() -> Self {
264 Self::new(rayon::current_num_threads().max(1))
266 }
267
268 #[must_use]
270 pub fn num_workers(&self) -> usize {
271 self.num_workers
272 }
273
274 pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
279 where
280 F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
281 {
282 let n = batch.len();
283
284 if n == 0 {
286 return BatchResult {
287 results: Vec::new(),
288 success_count: 0,
289 failure_count: 0,
290 reexecution_count: 0,
291 parallel_executed: false,
292 };
293 }
294
295 if n < MIN_BATCH_SIZE_FOR_PARALLEL {
296 return self.execute_sequential(batch, execute_fn);
297 }
298
299 let write_tracker = Arc::new(WriteTracker::default());
301 let results: Vec<Mutex<ExecutionResult>> = (0..n)
302 .map(|i| Mutex::new(ExecutionResult::new(i)))
303 .collect();
304
305 self.pool.install(|| {
306 batch
307 .operations
308 .par_iter()
309 .enumerate()
310 .for_each(|(idx, op)| {
311 let mut result = results[idx].lock();
312 execute_fn(idx, op, &mut result);
313
314 for entity in &result.write_set {
316 write_tracker.record_write(*entity, idx);
317 }
318 });
319 });
320
321 let mut invalid_indices = Vec::new();
323
324 for (idx, result_mutex) in results.iter().enumerate() {
325 let mut result = result_mutex.lock();
326
327 let read_entities: Vec<EntityId> =
329 result.read_set.iter().map(|(entity, _)| *entity).collect();
330
331 for entity in read_entities {
333 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
334 result.mark_needs_revalidation();
335 result.dependencies.push(writer);
336 }
337 }
338
339 if result.status == ExecutionStatus::NeedsRevalidation {
340 invalid_indices.push(idx);
341 }
342 }
343
344 let conflict_rate = invalid_indices.len() as f64 / n as f64;
346 if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
347 return self.execute_sequential(batch, execute_fn);
349 }
350
351 let total_reexecutions = AtomicUsize::new(0);
353
354 for round in 0..MAX_REEXECUTION_ROUNDS {
355 if invalid_indices.is_empty() {
356 break;
357 }
358
359 let still_invalid: Vec<usize> = self.pool.install(|| {
361 invalid_indices
362 .par_iter()
363 .filter_map(|&idx| {
364 let mut result = results[idx].lock();
365
366 result.read_set.clear();
368 result.write_set.clear();
369 result.dependencies.clear();
370
371 execute_fn(idx, &batch.operations[idx], &mut result);
373 result.mark_reexecuted();
374 total_reexecutions.fetch_add(1, Ordering::Relaxed);
375
376 let read_entities: Vec<EntityId> =
378 result.read_set.iter().map(|(entity, _)| *entity).collect();
379
380 for entity in read_entities {
382 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
383 {
384 result.mark_needs_revalidation();
385 result.dependencies.push(writer);
386 return Some(idx);
387 }
388 }
389
390 result.status = ExecutionStatus::Success;
391 None
392 })
393 .collect()
394 });
395
396 invalid_indices = still_invalid;
397
398 if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
399 for idx in &invalid_indices {
401 let mut result = results[*idx].lock();
402 result.mark_failed("Max re-execution rounds reached".to_string());
403 }
404 }
405 }
406
407 let mut final_results: Vec<ExecutionResult> =
409 results.into_iter().map(|m| m.into_inner()).collect();
410
411 final_results.sort_by_key(|r| r.batch_index);
413
414 let success_count = final_results
415 .iter()
416 .filter(|r| r.status != ExecutionStatus::Failed)
417 .count();
418
419 BatchResult {
420 failure_count: n - success_count,
421 success_count,
422 reexecution_count: total_reexecutions.load(Ordering::Relaxed),
423 parallel_executed: true,
424 results: final_results,
425 }
426 }
427
428 fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
430 where
431 F: Fn(usize, &str, &mut ExecutionResult),
432 {
433 let mut results = Vec::with_capacity(batch.len());
434
435 for (idx, op) in batch.operations.iter().enumerate() {
436 let mut result = ExecutionResult::new(idx);
437 execute_fn(idx, op, &mut result);
438 results.push(result);
439 }
440
441 let success_count = results
442 .iter()
443 .filter(|r| r.status != ExecutionStatus::Failed)
444 .count();
445
446 BatchResult {
447 failure_count: results.len() - success_count,
448 success_count,
449 reexecution_count: 0,
450 parallel_executed: false,
451 results,
452 }
453 }
454}
455
456impl Default for ParallelExecutor {
457 fn default() -> Self {
458 Self::default_workers()
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use grafeo_common::types::NodeId;
466 use std::sync::atomic::AtomicU64;
467 use std::thread;
468 use std::time::Duration;
469
470 #[test]
471 fn test_empty_batch() {
472 let executor = ParallelExecutor::new(4);
473 let batch = BatchRequest::new(Vec::<String>::new());
474
475 let result = executor.execute_batch(batch, |_, _, _| {});
476
477 assert!(result.all_succeeded());
478 assert_eq!(result.results.len(), 0);
479 }
480
481 #[test]
482 fn test_single_operation() {
483 let executor = ParallelExecutor::new(4);
484 let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
485
486 let result = executor.execute_batch(batch, |_, _, result| {
487 result.record_write(EntityId::Node(NodeId::new(1)));
488 });
489
490 assert!(result.all_succeeded());
491 assert_eq!(result.results.len(), 1);
492 assert!(!result.parallel_executed);
494 }
495
496 #[test]
497 fn test_independent_operations() {
498 let executor = ParallelExecutor::new(4);
499 let batch = BatchRequest::new(vec![
500 "CREATE (n1:Test {id: 1})",
501 "CREATE (n2:Test {id: 2})",
502 "CREATE (n3:Test {id: 3})",
503 "CREATE (n4:Test {id: 4})",
504 "CREATE (n5:Test {id: 5})",
505 ]);
506
507 let counter = AtomicU64::new(0);
508
509 let result = executor.execute_batch(batch, |idx, _, result| {
510 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
512 counter.fetch_add(1, Ordering::Relaxed);
513 });
514
515 assert!(result.all_succeeded());
516 assert_eq!(result.results.len(), 5);
517 assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
519 assert_eq!(counter.load(Ordering::Relaxed), 5);
520 }
521
522 #[test]
523 fn test_conflicting_operations() {
524 let executor = ParallelExecutor::new(4);
525 let batch = BatchRequest::new(vec![
526 "UPDATE (n:Test) SET n.value = 1",
527 "UPDATE (n:Test) SET n.value = 2",
528 "UPDATE (n:Test) SET n.value = 3",
529 "UPDATE (n:Test) SET n.value = 4",
530 "UPDATE (n:Test) SET n.value = 5",
531 ]);
532
533 let shared_entity = EntityId::Node(NodeId::new(100));
534
535 let result = executor.execute_batch(batch, |_idx, _, result| {
536 result.record_read(shared_entity, EpochId::new(0));
538 result.record_write(shared_entity);
539
540 thread::sleep(Duration::from_micros(10));
542 });
543
544 assert!(result.all_succeeded());
545 assert_eq!(result.results.len(), 5);
546 assert!(result.reexecution_count > 0 || !result.parallel_executed);
548 }
549
550 #[test]
551 fn test_partial_conflicts() {
552 let executor = ParallelExecutor::new(4);
553 let batch = BatchRequest::new(vec![
554 "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
555 ]);
556
557 let result = executor.execute_batch(batch, |idx, _, result| {
561 let entity = EntityId::Node(NodeId::new(idx as u64));
563 result.record_write(entity);
564 });
565
566 assert!(result.all_succeeded());
567 assert_eq!(result.results.len(), 10);
568 assert!(result.parallel_executed);
570 assert_eq!(result.reexecution_count, 0);
571 }
572
573 #[test]
574 fn test_execution_order_preserved() {
575 let executor = ParallelExecutor::new(4);
576 let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
577
578 let result = executor.execute_batch(batch, |idx, _, result| {
579 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
580 });
581
582 for (i, r) in result.results.iter().enumerate() {
584 assert_eq!(
585 r.batch_index, i,
586 "Result at position {} has wrong batch_index",
587 i
588 );
589 }
590 }
591
592 #[test]
593 fn test_failure_handling() {
594 let executor = ParallelExecutor::new(4);
595 let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
596
597 let result = executor.execute_batch(batch, |idx, op, result| {
598 if op == "fail" {
599 result.mark_failed("Intentional failure".to_string());
600 } else {
601 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
602 }
603 });
604
605 assert!(!result.all_succeeded());
606 assert_eq!(result.failure_count, 1);
607 assert_eq!(result.success_count, 4);
608
609 let failed: Vec<usize> = result.failed_indices().collect();
610 assert_eq!(failed, vec![1]);
611 }
612
613 #[test]
614 fn test_write_tracker() {
615 let tracker = WriteTracker::default();
616
617 tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
618 tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
619 tracker.record_write(EntityId::Node(NodeId::new(1)), 2); assert_eq!(
623 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
624 Some(0)
625 );
626
627 assert_eq!(
629 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
630 Some(1)
631 );
632
633 assert_eq!(
635 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
636 None
637 );
638 }
639
640 #[test]
641 fn test_batch_request() {
642 let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
643 assert_eq!(batch.len(), 3);
644 assert!(!batch.is_empty());
645
646 let empty_batch = BatchRequest::new(Vec::<String>::new());
647 assert!(empty_batch.is_empty());
648 }
649
650 #[test]
651 fn test_execution_result() {
652 let mut result = ExecutionResult::new(5);
653
654 assert_eq!(result.batch_index, 5);
655 assert_eq!(result.status, ExecutionStatus::Success);
656 assert!(result.read_set.is_empty());
657 assert!(result.write_set.is_empty());
658
659 result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
660 result.record_write(EntityId::Node(NodeId::new(2)));
661
662 assert_eq!(result.read_set.len(), 1);
663 assert_eq!(result.write_set.len(), 1);
664
665 result.mark_needs_revalidation();
666 assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
667
668 result.mark_reexecuted();
669 assert_eq!(result.status, ExecutionStatus::Reexecuted);
670 assert_eq!(result.reexecution_count, 1);
671 }
672}