1use std::collections::HashSet;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use grafeo_common::types::EpochId;
52use grafeo_common::utils::hash::FxHashMap;
53use parking_lot::{Mutex, RwLock};
54use rayon::prelude::*;
55
56use super::EntityId;
57
58const MAX_REEXECUTION_ROUNDS: usize = 10;
60
61const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
63
64const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum ExecutionStatus {
70 Success,
72 NeedsRevalidation,
74 Reexecuted,
76 Failed,
78}
79
80#[derive(Debug)]
82pub struct ExecutionResult {
83 pub batch_index: usize,
85 pub status: ExecutionStatus,
87 pub read_set: HashSet<(EntityId, EpochId)>,
89 pub write_set: HashSet<EntityId>,
91 pub dependencies: Vec<usize>,
93 pub reexecution_count: usize,
95 pub error: Option<String>,
97}
98
99impl ExecutionResult {
100 fn new(batch_index: usize) -> Self {
102 Self {
103 batch_index,
104 status: ExecutionStatus::Success,
105 read_set: HashSet::new(),
106 write_set: HashSet::new(),
107 dependencies: Vec::new(),
108 reexecution_count: 0,
109 error: None,
110 }
111 }
112
113 pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
115 self.read_set.insert((entity, epoch));
116 }
117
118 pub fn record_write(&mut self, entity: EntityId) {
120 self.write_set.insert(entity);
121 }
122
123 pub fn mark_needs_revalidation(&mut self) {
125 self.status = ExecutionStatus::NeedsRevalidation;
126 }
127
128 pub fn mark_reexecuted(&mut self) {
130 self.status = ExecutionStatus::Reexecuted;
131 self.reexecution_count += 1;
132 }
133
134 pub fn mark_failed(&mut self, error: String) {
136 self.status = ExecutionStatus::Failed;
137 self.error = Some(error);
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct BatchRequest {
144 pub operations: Vec<String>,
146}
147
148impl BatchRequest {
149 pub fn new(operations: Vec<impl Into<String>>) -> Self {
151 Self {
152 operations: operations.into_iter().map(Into::into).collect(),
153 }
154 }
155
156 #[must_use]
158 pub fn len(&self) -> usize {
159 self.operations.len()
160 }
161
162 #[must_use]
164 pub fn is_empty(&self) -> bool {
165 self.operations.is_empty()
166 }
167}
168
169#[derive(Debug)]
171pub struct BatchResult {
172 pub results: Vec<ExecutionResult>,
174 pub success_count: usize,
176 pub failure_count: usize,
178 pub reexecution_count: usize,
180 pub parallel_executed: bool,
182}
183
184impl BatchResult {
185 #[must_use]
187 pub fn all_succeeded(&self) -> bool {
188 self.failure_count == 0
189 }
190
191 pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
193 self.results
194 .iter()
195 .filter(|r| r.status == ExecutionStatus::Failed)
196 .map(|r| r.batch_index)
197 }
198}
199
200#[derive(Debug, Default)]
202struct WriteTracker {
203 writes: RwLock<FxHashMap<EntityId, usize>>,
205}
206
207impl WriteTracker {
208 fn record_write(&self, entity: EntityId, batch_index: usize) {
211 let mut writes = self.writes.write();
212 writes
213 .entry(entity)
214 .and_modify(|existing| *existing = (*existing).min(batch_index))
215 .or_insert(batch_index);
216 }
217
218 fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
220 let writes = self.writes.read();
221 if let Some(&writer) = writes.get(entity)
222 && writer < batch_index
223 {
224 return Some(writer);
225 }
226 None
227 }
228
229 #[allow(dead_code)]
231 fn clear(&self) {
232 self.writes.write().clear();
233 }
234}
235
236pub struct ParallelExecutor {
240 num_workers: usize,
242 pool: rayon::ThreadPool,
244}
245
246impl ParallelExecutor {
247 pub fn new(num_workers: usize) -> Self {
253 assert!(num_workers > 0, "num_workers must be positive");
254
255 let pool = rayon::ThreadPoolBuilder::new()
256 .num_threads(num_workers)
257 .build()
258 .expect("Failed to build thread pool");
259
260 Self { num_workers, pool }
261 }
262
263 #[must_use]
265 pub fn default_workers() -> Self {
266 Self::new(rayon::current_num_threads().max(1))
268 }
269
270 #[must_use]
272 pub fn num_workers(&self) -> usize {
273 self.num_workers
274 }
275
276 pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
281 where
282 F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
283 {
284 let n = batch.len();
285
286 if n == 0 {
288 return BatchResult {
289 results: Vec::new(),
290 success_count: 0,
291 failure_count: 0,
292 reexecution_count: 0,
293 parallel_executed: false,
294 };
295 }
296
297 if n < MIN_BATCH_SIZE_FOR_PARALLEL {
298 return self.execute_sequential(batch, execute_fn);
299 }
300
301 let write_tracker = Arc::new(WriteTracker::default());
303 let results: Vec<Mutex<ExecutionResult>> = (0..n)
304 .map(|i| Mutex::new(ExecutionResult::new(i)))
305 .collect();
306
307 self.pool.install(|| {
308 batch
309 .operations
310 .par_iter()
311 .enumerate()
312 .for_each(|(idx, op)| {
313 let mut result = results[idx].lock();
314 execute_fn(idx, op, &mut result);
315
316 for entity in &result.write_set {
318 write_tracker.record_write(*entity, idx);
319 }
320 });
321 });
322
323 let mut invalid_indices = Vec::new();
325
326 for (idx, result_mutex) in results.iter().enumerate() {
327 let mut result = result_mutex.lock();
328
329 let read_entities: Vec<EntityId> =
331 result.read_set.iter().map(|(entity, _)| *entity).collect();
332
333 for entity in read_entities {
335 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
336 result.mark_needs_revalidation();
337 result.dependencies.push(writer);
338 }
339 }
340
341 if result.status == ExecutionStatus::NeedsRevalidation {
342 invalid_indices.push(idx);
343 }
344 }
345
346 let conflict_rate = invalid_indices.len() as f64 / n as f64;
348 if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
349 return self.execute_sequential(batch, execute_fn);
351 }
352
353 let total_reexecutions = AtomicUsize::new(0);
355
356 for round in 0..MAX_REEXECUTION_ROUNDS {
357 if invalid_indices.is_empty() {
358 break;
359 }
360
361 let still_invalid: Vec<usize> = self.pool.install(|| {
363 invalid_indices
364 .par_iter()
365 .filter_map(|&idx| {
366 let mut result = results[idx].lock();
367
368 result.read_set.clear();
370 result.write_set.clear();
371 result.dependencies.clear();
372
373 execute_fn(idx, &batch.operations[idx], &mut result);
375 result.mark_reexecuted();
376 total_reexecutions.fetch_add(1, Ordering::Relaxed);
377
378 let read_entities: Vec<EntityId> =
380 result.read_set.iter().map(|(entity, _)| *entity).collect();
381
382 for entity in read_entities {
384 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
385 {
386 result.mark_needs_revalidation();
387 result.dependencies.push(writer);
388 return Some(idx);
389 }
390 }
391
392 result.status = ExecutionStatus::Success;
393 None
394 })
395 .collect()
396 });
397
398 invalid_indices = still_invalid;
399
400 if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
401 for idx in &invalid_indices {
403 let mut result = results[*idx].lock();
404 result.mark_failed("Max re-execution rounds reached".to_string());
405 }
406 }
407 }
408
409 let mut final_results: Vec<ExecutionResult> =
411 results.into_iter().map(|m| m.into_inner()).collect();
412
413 final_results.sort_by_key(|r| r.batch_index);
415
416 let success_count = final_results
417 .iter()
418 .filter(|r| r.status != ExecutionStatus::Failed)
419 .count();
420
421 BatchResult {
422 failure_count: n - success_count,
423 success_count,
424 reexecution_count: total_reexecutions.load(Ordering::Relaxed),
425 parallel_executed: true,
426 results: final_results,
427 }
428 }
429
430 fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
432 where
433 F: Fn(usize, &str, &mut ExecutionResult),
434 {
435 let mut results = Vec::with_capacity(batch.len());
436
437 for (idx, op) in batch.operations.iter().enumerate() {
438 let mut result = ExecutionResult::new(idx);
439 execute_fn(idx, op, &mut result);
440 results.push(result);
441 }
442
443 let success_count = results
444 .iter()
445 .filter(|r| r.status != ExecutionStatus::Failed)
446 .count();
447
448 BatchResult {
449 failure_count: results.len() - success_count,
450 success_count,
451 reexecution_count: 0,
452 parallel_executed: false,
453 results,
454 }
455 }
456}
457
458impl Default for ParallelExecutor {
459 fn default() -> Self {
460 Self::default_workers()
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use grafeo_common::types::NodeId;
468 use std::sync::atomic::AtomicU64;
469 use std::thread;
470 use std::time::Duration;
471
472 #[test]
473 fn test_empty_batch() {
474 let executor = ParallelExecutor::new(4);
475 let batch = BatchRequest::new(Vec::<String>::new());
476
477 let result = executor.execute_batch(batch, |_, _, _| {});
478
479 assert!(result.all_succeeded());
480 assert_eq!(result.results.len(), 0);
481 }
482
483 #[test]
484 fn test_single_operation() {
485 let executor = ParallelExecutor::new(4);
486 let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
487
488 let result = executor.execute_batch(batch, |_, _, result| {
489 result.record_write(EntityId::Node(NodeId::new(1)));
490 });
491
492 assert!(result.all_succeeded());
493 assert_eq!(result.results.len(), 1);
494 assert!(!result.parallel_executed);
496 }
497
498 #[test]
499 fn test_independent_operations() {
500 let executor = ParallelExecutor::new(4);
501 let batch = BatchRequest::new(vec![
502 "CREATE (n1:Test {id: 1})",
503 "CREATE (n2:Test {id: 2})",
504 "CREATE (n3:Test {id: 3})",
505 "CREATE (n4:Test {id: 4})",
506 "CREATE (n5:Test {id: 5})",
507 ]);
508
509 let counter = AtomicU64::new(0);
510
511 let result = executor.execute_batch(batch, |idx, _, result| {
512 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
514 counter.fetch_add(1, Ordering::Relaxed);
515 });
516
517 assert!(result.all_succeeded());
518 assert_eq!(result.results.len(), 5);
519 assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
521 assert_eq!(counter.load(Ordering::Relaxed), 5);
522 }
523
524 #[test]
525 fn test_conflicting_operations() {
526 let executor = ParallelExecutor::new(4);
527 let batch = BatchRequest::new(vec![
528 "UPDATE (n:Test) SET n.value = 1",
529 "UPDATE (n:Test) SET n.value = 2",
530 "UPDATE (n:Test) SET n.value = 3",
531 "UPDATE (n:Test) SET n.value = 4",
532 "UPDATE (n:Test) SET n.value = 5",
533 ]);
534
535 let shared_entity = EntityId::Node(NodeId::new(100));
536
537 let result = executor.execute_batch(batch, |_idx, _, result| {
538 result.record_read(shared_entity, EpochId::new(0));
540 result.record_write(shared_entity);
541
542 thread::sleep(Duration::from_micros(10));
544 });
545
546 assert!(result.all_succeeded());
547 assert_eq!(result.results.len(), 5);
548 assert!(result.reexecution_count > 0 || !result.parallel_executed);
550 }
551
552 #[test]
553 fn test_partial_conflicts() {
554 let executor = ParallelExecutor::new(4);
555 let batch = BatchRequest::new(vec![
556 "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
557 ]);
558
559 let result = executor.execute_batch(batch, |idx, _, result| {
563 let entity = EntityId::Node(NodeId::new(idx as u64));
565 result.record_write(entity);
566 });
567
568 assert!(result.all_succeeded());
569 assert_eq!(result.results.len(), 10);
570 assert!(result.parallel_executed);
572 assert_eq!(result.reexecution_count, 0);
573 }
574
575 #[test]
576 fn test_execution_order_preserved() {
577 let executor = ParallelExecutor::new(4);
578 let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
579
580 let result = executor.execute_batch(batch, |idx, _, result| {
581 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
582 });
583
584 for (i, r) in result.results.iter().enumerate() {
586 assert_eq!(
587 r.batch_index, i,
588 "Result at position {} has wrong batch_index",
589 i
590 );
591 }
592 }
593
594 #[test]
595 fn test_failure_handling() {
596 let executor = ParallelExecutor::new(4);
597 let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
598
599 let result = executor.execute_batch(batch, |idx, op, result| {
600 if op == "fail" {
601 result.mark_failed("Intentional failure".to_string());
602 } else {
603 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
604 }
605 });
606
607 assert!(!result.all_succeeded());
608 assert_eq!(result.failure_count, 1);
609 assert_eq!(result.success_count, 4);
610
611 let failed: Vec<usize> = result.failed_indices().collect();
612 assert_eq!(failed, vec![1]);
613 }
614
615 #[test]
616 fn test_write_tracker() {
617 let tracker = WriteTracker::default();
618
619 tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
620 tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
621 tracker.record_write(EntityId::Node(NodeId::new(1)), 2); assert_eq!(
625 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
626 Some(0)
627 );
628
629 assert_eq!(
631 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
632 Some(1)
633 );
634
635 assert_eq!(
637 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
638 None
639 );
640 }
641
642 #[test]
643 fn test_batch_request() {
644 let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
645 assert_eq!(batch.len(), 3);
646 assert!(!batch.is_empty());
647
648 let empty_batch = BatchRequest::new(Vec::<String>::new());
649 assert!(empty_batch.is_empty());
650 }
651
652 #[test]
653 fn test_execution_result() {
654 let mut result = ExecutionResult::new(5);
655
656 assert_eq!(result.batch_index, 5);
657 assert_eq!(result.status, ExecutionStatus::Success);
658 assert!(result.read_set.is_empty());
659 assert!(result.write_set.is_empty());
660
661 result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
662 result.record_write(EntityId::Node(NodeId::new(2)));
663
664 assert_eq!(result.read_set.len(), 1);
665 assert_eq!(result.write_set.len(), 1);
666
667 result.mark_needs_revalidation();
668 assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
669
670 result.mark_reexecuted();
671 assert_eq!(result.status, ExecutionStatus::Reexecuted);
672 assert_eq!(result.reexecution_count, 1);
673 }
674}