1use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ExecutionStatus {
72 Success,
74 NeedsRevalidation,
76 Reexecuted,
78 Failed,
80}
81
82#[derive(Debug)]
84pub struct ExecutionResult {
85 pub batch_index: usize,
87 pub status: ExecutionStatus,
89 pub read_set: HashSet<(EntityId, EpochId)>,
91 pub write_set: HashSet<EntityId>,
93 pub dependencies: Vec<usize>,
95 pub reexecution_count: usize,
97 pub error: Option<String>,
99}
100
101impl ExecutionResult {
102 fn new(batch_index: usize) -> Self {
104 Self {
105 batch_index,
106 status: ExecutionStatus::Success,
107 read_set: HashSet::new(),
108 write_set: HashSet::new(),
109 dependencies: Vec::new(),
110 reexecution_count: 0,
111 error: None,
112 }
113 }
114
115 pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
117 self.read_set.insert((entity, epoch));
118 }
119
120 pub fn record_write(&mut self, entity: EntityId) {
122 self.write_set.insert(entity);
123 }
124
125 pub fn mark_needs_revalidation(&mut self) {
127 self.status = ExecutionStatus::NeedsRevalidation;
128 }
129
130 pub fn mark_reexecuted(&mut self) {
132 self.status = ExecutionStatus::Reexecuted;
133 self.reexecution_count += 1;
134 }
135
136 pub fn mark_failed(&mut self, error: String) {
138 self.status = ExecutionStatus::Failed;
139 self.error = Some(error);
140 }
141}
142
143#[derive(Debug, Clone)]
145pub struct BatchRequest {
146 pub operations: Vec<String>,
148}
149
150impl BatchRequest {
151 pub fn new(operations: Vec<impl Into<String>>) -> Self {
153 Self {
154 operations: operations.into_iter().map(Into::into).collect(),
155 }
156 }
157
158 #[must_use]
160 pub fn len(&self) -> usize {
161 self.operations.len()
162 }
163
164 #[must_use]
166 pub fn is_empty(&self) -> bool {
167 self.operations.is_empty()
168 }
169}
170
171#[derive(Debug)]
173pub struct BatchResult {
174 pub results: Vec<ExecutionResult>,
176 pub success_count: usize,
178 pub failure_count: usize,
180 pub reexecution_count: usize,
182 pub parallel_executed: bool,
184}
185
186impl BatchResult {
187 #[must_use]
189 pub fn all_succeeded(&self) -> bool {
190 self.failure_count == 0
191 }
192
193 pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
195 self.results
196 .iter()
197 .filter(|r| r.status == ExecutionStatus::Failed)
198 .map(|r| r.batch_index)
199 }
200}
201
202#[derive(Debug, Default)]
204struct WriteTracker {
205 writes: RwLock<FxHashMap<EntityId, usize>>,
207}
208
209impl WriteTracker {
210 fn record_write(&self, entity: EntityId, batch_index: usize) {
213 let mut writes = self.writes.write();
214 writes
215 .entry(entity)
216 .and_modify(|existing| *existing = (*existing).min(batch_index))
217 .or_insert(batch_index);
218 }
219
220 fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
222 let writes = self.writes.read();
223 if let Some(&writer) = writes.get(entity)
224 && writer < batch_index
225 {
226 return Some(writer);
227 }
228 None
229 }
230}
231
232pub struct ParallelExecutor {
236 num_workers: usize,
238 pool: rayon::ThreadPool,
240}
241
242impl ParallelExecutor {
243 #[must_use]
249 pub fn new(num_workers: usize) -> Self {
250 assert!(num_workers > 0, "num_workers must be positive");
251
252 let pool = rayon::ThreadPoolBuilder::new()
253 .num_threads(num_workers)
254 .build()
255 .expect("failed to build thread pool");
256
257 Self { num_workers, pool }
258 }
259
260 #[must_use]
262 pub fn default_workers() -> Self {
263 Self::new(rayon::current_num_threads().max(1))
265 }
266
267 #[must_use]
269 pub fn num_workers(&self) -> usize {
270 self.num_workers
271 }
272
273 pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
278 where
279 F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
280 {
281 let n = batch.len();
282
283 if n == 0 {
285 return BatchResult {
286 results: Vec::new(),
287 success_count: 0,
288 failure_count: 0,
289 reexecution_count: 0,
290 parallel_executed: false,
291 };
292 }
293
294 if n < MIN_BATCH_SIZE_FOR_PARALLEL {
295 return self.execute_sequential(batch, execute_fn);
296 }
297
298 let write_tracker = Arc::new(WriteTracker::default());
300 let results: Vec<Mutex<ExecutionResult>> = (0..n)
301 .map(|i| Mutex::new(ExecutionResult::new(i)))
302 .collect();
303
304 self.pool.install(|| {
305 batch
306 .operations
307 .par_iter()
308 .enumerate()
309 .for_each(|(idx, op)| {
310 let mut result = results[idx].lock();
311 execute_fn(idx, op, &mut result);
312
313 for entity in &result.write_set {
315 write_tracker.record_write(*entity, idx);
316 }
317 });
318 });
319
320 let mut invalid_indices = Vec::new();
322
323 for (idx, result_mutex) in results.iter().enumerate() {
324 let mut result = result_mutex.lock();
325
326 let read_entities: Vec<EntityId> =
328 result.read_set.iter().map(|(entity, _)| *entity).collect();
329
330 for entity in read_entities {
332 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
333 result.mark_needs_revalidation();
334 result.dependencies.push(writer);
335 }
336 }
337
338 if result.status == ExecutionStatus::NeedsRevalidation {
339 invalid_indices.push(idx);
340 }
341 }
342
343 let conflict_rate = invalid_indices.len() as f64 / n as f64;
345 if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
346 return self.execute_sequential(batch, execute_fn);
348 }
349
350 let total_reexecutions = AtomicUsize::new(0);
352
353 for round in 0..MAX_REEXECUTION_ROUNDS {
354 if invalid_indices.is_empty() {
355 break;
356 }
357
358 let still_invalid: Vec<usize> = self.pool.install(|| {
360 invalid_indices
361 .par_iter()
362 .filter_map(|&idx| {
363 let mut result = results[idx].lock();
364
365 result.read_set.clear();
367 result.write_set.clear();
368 result.dependencies.clear();
369
370 execute_fn(idx, &batch.operations[idx], &mut result);
372 result.mark_reexecuted();
373 total_reexecutions.fetch_add(1, Ordering::Relaxed);
374
375 let read_entities: Vec<EntityId> =
377 result.read_set.iter().map(|(entity, _)| *entity).collect();
378
379 for entity in read_entities {
381 if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
382 {
383 result.mark_needs_revalidation();
384 result.dependencies.push(writer);
385 return Some(idx);
386 }
387 }
388
389 result.status = ExecutionStatus::Success;
390 None
391 })
392 .collect()
393 });
394
395 invalid_indices = still_invalid;
396
397 if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
398 for idx in &invalid_indices {
400 let mut result = results[*idx].lock();
401 result.mark_failed("Max re-execution rounds reached".to_string());
402 }
403 }
404 }
405
406 let mut final_results: Vec<ExecutionResult> =
408 results.into_iter().map(|m| m.into_inner()).collect();
409
410 final_results.sort_by_key(|r| r.batch_index);
412
413 let success_count = final_results
414 .iter()
415 .filter(|r| r.status != ExecutionStatus::Failed)
416 .count();
417
418 BatchResult {
419 failure_count: n - success_count,
420 success_count,
421 reexecution_count: total_reexecutions.load(Ordering::Relaxed),
422 parallel_executed: true,
423 results: final_results,
424 }
425 }
426
427 fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
429 where
430 F: Fn(usize, &str, &mut ExecutionResult),
431 {
432 let mut results = Vec::with_capacity(batch.len());
433
434 for (idx, op) in batch.operations.iter().enumerate() {
435 let mut result = ExecutionResult::new(idx);
436 execute_fn(idx, op, &mut result);
437 results.push(result);
438 }
439
440 let success_count = results
441 .iter()
442 .filter(|r| r.status != ExecutionStatus::Failed)
443 .count();
444
445 BatchResult {
446 failure_count: results.len() - success_count,
447 success_count,
448 reexecution_count: 0,
449 parallel_executed: false,
450 results,
451 }
452 }
453}
454
455impl Default for ParallelExecutor {
456 fn default() -> Self {
457 Self::default_workers()
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use grafeo_common::types::NodeId;
465 use std::sync::atomic::AtomicU64;
466 use std::thread;
467 use std::time::Duration;
468
469 #[test]
470 fn test_empty_batch() {
471 let executor = ParallelExecutor::new(4);
472 let batch = BatchRequest::new(Vec::<String>::new());
473
474 let result = executor.execute_batch(batch, |_, _, _| {});
475
476 assert!(result.all_succeeded());
477 assert_eq!(result.results.len(), 0);
478 }
479
480 #[test]
481 fn test_single_operation() {
482 let executor = ParallelExecutor::new(4);
483 let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
484
485 let result = executor.execute_batch(batch, |_, _, result| {
486 result.record_write(EntityId::Node(NodeId::new(1)));
487 });
488
489 assert!(result.all_succeeded());
490 assert_eq!(result.results.len(), 1);
491 assert!(!result.parallel_executed);
493 }
494
495 #[test]
496 fn test_independent_operations() {
497 let executor = ParallelExecutor::new(4);
498 let batch = BatchRequest::new(vec![
499 "CREATE (n1:Test {id: 1})",
500 "CREATE (n2:Test {id: 2})",
501 "CREATE (n3:Test {id: 3})",
502 "CREATE (n4:Test {id: 4})",
503 "CREATE (n5:Test {id: 5})",
504 ]);
505
506 let counter = AtomicU64::new(0);
507
508 let result = executor.execute_batch(batch, |idx, _, result| {
509 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
511 counter.fetch_add(1, Ordering::Relaxed);
512 });
513
514 assert!(result.all_succeeded());
515 assert_eq!(result.results.len(), 5);
516 assert_eq!(result.reexecution_count, 0); assert!(result.parallel_executed);
518 assert_eq!(counter.load(Ordering::Relaxed), 5);
519 }
520
521 #[test]
522 fn test_conflicting_operations() {
523 let executor = ParallelExecutor::new(4);
524 let batch = BatchRequest::new(vec![
525 "UPDATE (n:Test) SET n.value = 1",
526 "UPDATE (n:Test) SET n.value = 2",
527 "UPDATE (n:Test) SET n.value = 3",
528 "UPDATE (n:Test) SET n.value = 4",
529 "UPDATE (n:Test) SET n.value = 5",
530 ]);
531
532 let shared_entity = EntityId::Node(NodeId::new(100));
533
534 let result = executor.execute_batch(batch, |_idx, _, result| {
535 result.record_read(shared_entity, EpochId::new(0));
537 result.record_write(shared_entity);
538
539 thread::sleep(Duration::from_micros(10));
541 });
542
543 assert!(result.all_succeeded());
544 assert_eq!(result.results.len(), 5);
545 assert!(result.reexecution_count > 0 || !result.parallel_executed);
547 }
548
549 #[test]
550 fn test_partial_conflicts() {
551 let executor = ParallelExecutor::new(4);
552 let batch = BatchRequest::new(vec![
553 "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
554 ]);
555
556 let result = executor.execute_batch(batch, |idx, _, result| {
560 let entity = EntityId::Node(NodeId::new(idx as u64));
562 result.record_write(entity);
563 });
564
565 assert!(result.all_succeeded());
566 assert_eq!(result.results.len(), 10);
567 assert!(result.parallel_executed);
569 assert_eq!(result.reexecution_count, 0);
570 }
571
572 #[test]
573 fn test_execution_order_preserved() {
574 let executor = ParallelExecutor::new(4);
575 let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
576
577 let result = executor.execute_batch(batch, |idx, _, result| {
578 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
579 });
580
581 for (i, r) in result.results.iter().enumerate() {
583 assert_eq!(
584 r.batch_index, i,
585 "Result at position {} has wrong batch_index",
586 i
587 );
588 }
589 }
590
591 #[test]
592 fn test_failure_handling() {
593 let executor = ParallelExecutor::new(4);
594 let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
595
596 let result = executor.execute_batch(batch, |idx, op, result| {
597 if op == "fail" {
598 result.mark_failed("Intentional failure".to_string());
599 } else {
600 result.record_write(EntityId::Node(NodeId::new(idx as u64)));
601 }
602 });
603
604 assert!(!result.all_succeeded());
605 assert_eq!(result.failure_count, 1);
606 assert_eq!(result.success_count, 4);
607
608 let failed: Vec<usize> = result.failed_indices().collect();
609 assert_eq!(failed, vec![1]);
610 }
611
612 #[test]
613 fn test_write_tracker() {
614 let tracker = WriteTracker::default();
615
616 tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
617 tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
618 tracker.record_write(EntityId::Node(NodeId::new(1)), 2); assert_eq!(
622 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
623 Some(0)
624 );
625
626 assert_eq!(
628 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
629 Some(1)
630 );
631
632 assert_eq!(
634 tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
635 None
636 );
637 }
638
639 #[test]
640 fn test_batch_request() {
641 let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
642 assert_eq!(batch.len(), 3);
643 assert!(!batch.is_empty());
644
645 let empty_batch = BatchRequest::new(Vec::<String>::new());
646 assert!(empty_batch.is_empty());
647 }
648
649 #[test]
650 fn test_execution_result() {
651 let mut result = ExecutionResult::new(5);
652
653 assert_eq!(result.batch_index, 5);
654 assert_eq!(result.status, ExecutionStatus::Success);
655 assert!(result.read_set.is_empty());
656 assert!(result.write_set.is_empty());
657
658 result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
659 result.record_write(EntityId::Node(NodeId::new(2)));
660
661 assert_eq!(result.read_set.len(), 1);
662 assert_eq!(result.write_set.len(), 1);
663
664 result.mark_needs_revalidation();
665 assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
666
667 result.mark_reexecuted();
668 assert_eq!(result.status, ExecutionStatus::Reexecuted);
669 assert_eq!(result.reexecution_count, 1);
670 }
671}