1use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ExecutionStatus {
72 Success,
74 NeedsRevalidation,
76 Reexecuted,
78 Failed,
80}
81
82#[derive(Debug)]
84pub struct ExecutionResult {
85 pub batch_index: usize,
87 pub status: ExecutionStatus,
89 pub read_set: HashSet<(EntityId, EpochId)>,
91 pub write_set: HashSet<EntityId>,
93 pub dependencies: Vec<usize>,
95 pub reexecution_count: usize,
97 pub error: Option<String>,
99}
100
101impl ExecutionResult {
102 fn new(batch_index: usize) -> Self {
104 Self {
105 batch_index,
106 status: ExecutionStatus::Success,
107 read_set: HashSet::new(),
108 write_set: HashSet::new(),
109 dependencies: Vec::new(),
110 reexecution_count: 0,
111 error: None,
112 }
113 }
114
115 pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
117 self.read_set.insert((entity, epoch));
118 }
119
120 pub fn record_write(&mut self, entity: EntityId) {
122 self.write_set.insert(entity);
123 }
124
125 pub fn mark_needs_revalidation(&mut self) {
127 self.status = ExecutionStatus::NeedsRevalidation;
128 }
129
130 pub fn mark_reexecuted(&mut self) {
132 self.status = ExecutionStatus::Reexecuted;
133 self.reexecution_count += 1;
134 }
135
136 pub fn mark_failed(&mut self, error: String) {
138 self.status = ExecutionStatus::Failed;
139 self.error = Some(error);
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct BatchRequest {
146 pub operations: Vec<String>,
148}
149
150impl BatchRequest {
151 pub fn new(operations: Vec<impl Into<String>>) -> Self {
153 Self {
154 operations: operations.into_iter().map(Into::into).collect(),
155 }
156 }
157
158 #[must_use]
160 pub fn len(&self) -> usize {
161 self.operations.len()
162 }
163
164 #[must_use]
166 pub fn is_empty(&self) -> bool {
167 self.operations.is_empty()
168 }
169}
170
171#[derive(Debug)]
173pub struct BatchResult {
174 pub results: Vec<ExecutionResult>,
176 pub success_count: usize,
178 pub failure_count: usize,
180 pub reexecution_count: usize,
182 pub parallel_executed: bool,
184}
185
186impl BatchResult {
187 #[must_use]
189 pub fn all_succeeded(&self) -> bool {
190 self.failure_count == 0
191 }
192
193 pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
195 self.results
196 .iter()
197 .filter(|r| r.status == ExecutionStatus::Failed)
198 .map(|r| r.batch_index)
199 }
200}
201
202#[derive(Debug, Default)]
204struct WriteTracker {
205 writes: RwLock<FxHashMap<EntityId, usize>>,
207}
208
209impl WriteTracker {
210 fn record_write(&self, entity: EntityId, batch_index: usize) {
213 let mut writes = self.writes.write();
214 writes
215 .entry(entity)
216 .and_modify(|existing| *existing = (*existing).min(batch_index))
217 .or_insert(batch_index);
218 }
219
220 fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
222 let writes = self.writes.read();
223 if let Some(&writer) = writes.get(entity)
224 && writer < batch_index
225 {
226 return Some(writer);
227 }
228 None
229 }
230}
231
232pub struct ParallelExecutor {
236 num_workers: usize,
238 pool: rayon::ThreadPool,
240}
241
242impl ParallelExecutor {
243 pub fn new(num_workers: usize) -> Self {
249 assert!(num_workers > 0, "num_workers must be positive");
250
251 let pool = rayon::ThreadPoolBuilder::new()
252 .num_threads(num_workers)
253 .build()
254 .expect("Failed to build thread pool");
255
256 Self { num_workers, pool }
257 }
258
259 #[must_use]
261 pub fn default_workers() -> Self {
262 Self::new(rayon::current_num_threads().max(1))
264 }
265
266 #[must_use]
268 pub fn num_workers(&self) -> usize {
269 self.num_workers
270 }
271
272 pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
277 where
278 F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
279 {
280 let n = batch.len();
281
282 if n == 0 {
284 return BatchResult {
285 results: Vec::new(),
286 success_count: 0,
287 failure_count: 0,
288 reexecution_count: 0,
289 parallel_executed: false,
290 };
291 }
292
293 if n < MIN_BATCH_SIZE_FOR_PARALLEL {
294 return self.execute_sequential(batch, execute_fn);
295 }
296
297 let write_tracker = Arc::new(WriteTracker::default());
299 let results: Vec<Mutex<ExecutionResult>> = (0..n)
300 .map(|i| Mutex::new(ExecutionResult::new(i)))
301 .collect();
302
303 self.pool.install(|| {
304 batch
305 .operations
306 .par_iter()
307 .enumerate()
308 .for_each(|(idx, op)| {
309 let mut result = results[idx].lock();
310 execute_fn(idx, op, &mut result);
311
312 for entity in &result.write_set {
314 write_tracker.record_write(*entity, idx);
315 }
316 });
317 });
318
319 let mut invalid_indices = Vec::new();
321
322 for (idx, result_mutex) in results.iter().enumerate() {
323 let mut result = result_mutex.lock();
324
325 let read_entities: Vec<EntityId> =
327 result.read_set.iter().map(|(entity, _)| *entity).collect();
328
329 for entity in read_entities {
331 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
332 result.mark_needs_revalidation();
333 result.dependencies.push(writer);
334 }
335 }
336
337 if result.status == ExecutionStatus::NeedsRevalidation {
338 invalid_indices.push(idx);
339 }
340 }
341
342 let conflict_rate = invalid_indices.len() as f64 / n as f64;
344 if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
345 return self.execute_sequential(batch, execute_fn);
347 }
348
349 let total_reexecutions = AtomicUsize::new(0);
351
352 for round in 0..MAX_REEXECUTION_ROUNDS {
353 if invalid_indices.is_empty() {
354 break;
355 }
356
357 let still_invalid: Vec<usize> = self.pool.install(|| {
359 invalid_indices
360 .par_iter()
361 .filter_map(|&idx| {
362 let mut result = results[idx].lock();
363
364 result.read_set.clear();
366 result.write_set.clear();
367 result.dependencies.clear();
368
369 execute_fn(idx, &batch.operations[idx], &mut result);
371 result.mark_reexecuted();
372 total_reexecutions.fetch_add(1, Ordering::Relaxed);
373
374 let read_entities: Vec<EntityId> =
376 result.read_set.iter().map(|(entity, _)| *entity).collect();
377
378 for entity in read_entities {
380 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
381 {
382 result.mark_needs_revalidation();
383 result.dependencies.push(writer);
384 return Some(idx);
385 }
386 }
387
388 result.status = ExecutionStatus::Success;
389 None
390 })
391 .collect()
392 });
393
394 invalid_indices = still_invalid;
395
396 if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
397 for idx in &invalid_indices {
399 let mut result = results[*idx].lock();
400 result.mark_failed("Max re-execution rounds reached".to_string());
401 }
402 }
403 }
404
405 let mut final_results: Vec<ExecutionResult> =
407 results.into_iter().map(|m| m.into_inner()).collect();
408
409 final_results.sort_by_key(|r| r.batch_index);
411
412 let success_count = final_results
413 .iter()
414 .filter(|r| r.status != ExecutionStatus::Failed)
415 .count();
416
417 BatchResult {
418 failure_count: n - success_count,
419 success_count,
420 reexecution_count: total_reexecutions.load(Ordering::Relaxed),
421 parallel_executed: true,
422 results: final_results,
423 }
424 }
425
426 fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
428 where
429 F: Fn(usize, &str, &mut ExecutionResult),
430 {
431 let mut results = Vec::with_capacity(batch.len());
432
433 for (idx, op) in batch.operations.iter().enumerate() {
434 let mut result = ExecutionResult::new(idx);
435 execute_fn(idx, op, &mut result);
436 results.push(result);
437 }
438
439 let success_count = results
440 .iter()
441 .filter(|r| r.status != ExecutionStatus::Failed)
442 .count();
443
444 BatchResult {
445 failure_count: results.len() - success_count,
446 success_count,
447 reexecution_count: 0,
448 parallel_executed: false,
449 results,
450 }
451 }
452}
453
454impl Default for ParallelExecutor {
455 fn default() -> Self {
456 Self::default_workers()
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463 use grafeo_common::types::NodeId;
464 use std::sync::atomic::AtomicU64;
465 use std::thread;
466 use std::time::Duration;
467
468 #[test]
469 fn test_empty_batch() {
470 let executor = ParallelExecutor::new(4);
471 let batch = BatchRequest::new(Vec::<String>::new());
472
473 let result = executor.execute_batch(batch, |_, _, _| {});
474
475 assert!(result.all_succeeded());
476 assert_eq!(result.results.len(), 0);
477 }
478
479 #[test]
480 fn test_single_operation() {
481 let executor = ParallelExecutor::new(4);
482 let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
483
484 let result = executor.execute_batch(batch, |_, _, result| {
485 result.record_write(EntityId::Node(NodeId::new(1)));
486 });
487
488 assert!(result.all_succeeded());
489 assert_eq!(result.results.len(), 1);
490 assert!(!result.parallel_executed);
492 }
493
494 #[test]
495 fn test_independent_operations() {
496 let executor = ParallelExecutor::new(4);
497 let batch = BatchRequest::new(vec![
498 "CREATE (n1:Test {id: 1})",
499 "CREATE (n2:Test {id: 2})",
500 "CREATE (n3:Test {id: 3})",
501 "CREATE (n4:Test {id: 4})",
502 "CREATE (n5:Test {id: 5})",
503 ]);
504
505 let counter = AtomicU64::new(0);
506
507 let result = executor.execute_batch(batch, |idx, _, result| {
508 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
510 counter.fetch_add(1, Ordering::Relaxed);
511 });
512
513 assert!(result.all_succeeded());
514 assert_eq!(result.results.len(), 5);
515 assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
517 assert_eq!(counter.load(Ordering::Relaxed), 5);
518 }
519
520 #[test]
521 fn test_conflicting_operations() {
522 let executor = ParallelExecutor::new(4);
523 let batch = BatchRequest::new(vec![
524 "UPDATE (n:Test) SET n.value = 1",
525 "UPDATE (n:Test) SET n.value = 2",
526 "UPDATE (n:Test) SET n.value = 3",
527 "UPDATE (n:Test) SET n.value = 4",
528 "UPDATE (n:Test) SET n.value = 5",
529 ]);
530
531 let shared_entity = EntityId::Node(NodeId::new(100));
532
533 let result = executor.execute_batch(batch, |_idx, _, result| {
534 result.record_read(shared_entity, EpochId::new(0));
536 result.record_write(shared_entity);
537
538 thread::sleep(Duration::from_micros(10));
540 });
541
542 assert!(result.all_succeeded());
543 assert_eq!(result.results.len(), 5);
544 assert!(result.reexecution_count > 0 || !result.parallel_executed);
546 }
547
548 #[test]
549 fn test_partial_conflicts() {
550 let executor = ParallelExecutor::new(4);
551 let batch = BatchRequest::new(vec![
552 "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
553 ]);
554
555 let result = executor.execute_batch(batch, |idx, _, result| {
559 let entity = EntityId::Node(NodeId::new(idx as u64));
561 result.record_write(entity);
562 });
563
564 assert!(result.all_succeeded());
565 assert_eq!(result.results.len(), 10);
566 assert!(result.parallel_executed);
568 assert_eq!(result.reexecution_count, 0);
569 }
570
571 #[test]
572 fn test_execution_order_preserved() {
573 let executor = ParallelExecutor::new(4);
574 let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
575
576 let result = executor.execute_batch(batch, |idx, _, result| {
577 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
578 });
579
580 for (i, r) in result.results.iter().enumerate() {
582 assert_eq!(
583 r.batch_index, i,
584 "Result at position {} has wrong batch_index",
585 i
586 );
587 }
588 }
589
590 #[test]
591 fn test_failure_handling() {
592 let executor = ParallelExecutor::new(4);
593 let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
594
595 let result = executor.execute_batch(batch, |idx, op, result| {
596 if op == "fail" {
597 result.mark_failed("Intentional failure".to_string());
598 } else {
599 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
600 }
601 });
602
603 assert!(!result.all_succeeded());
604 assert_eq!(result.failure_count, 1);
605 assert_eq!(result.success_count, 4);
606
607 let failed: Vec<usize> = result.failed_indices().collect();
608 assert_eq!(failed, vec![1]);
609 }
610
611 #[test]
612 fn test_write_tracker() {
613 let tracker = WriteTracker::default();
614
615 tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
616 tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
617 tracker.record_write(EntityId::Node(NodeId::new(1)), 2); assert_eq!(
621 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
622 Some(0)
623 );
624
625 assert_eq!(
627 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
628 Some(1)
629 );
630
631 assert_eq!(
633 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
634 None
635 );
636 }
637
638 #[test]
639 fn test_batch_request() {
640 let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
641 assert_eq!(batch.len(), 3);
642 assert!(!batch.is_empty());
643
644 let empty_batch = BatchRequest::new(Vec::<String>::new());
645 assert!(empty_batch.is_empty());
646 }
647
648 #[test]
649 fn test_execution_result() {
650 let mut result = ExecutionResult::new(5);
651
652 assert_eq!(result.batch_index, 5);
653 assert_eq!(result.status, ExecutionStatus::Success);
654 assert!(result.read_set.is_empty());
655 assert!(result.write_set.is_empty());
656
657 result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
658 result.record_write(EntityId::Node(NodeId::new(2)));
659
660 assert_eq!(result.read_set.len(), 1);
661 assert_eq!(result.write_set.len(), 1);
662
663 result.mark_needs_revalidation();
664 assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
665
666 result.mark_reexecuted();
667 assert_eq!(result.status, ExecutionStatus::Reexecuted);
668 assert_eq!(result.reexecution_count, 1);
669 }
670}