Skip to main content

grafeo_engine/transaction/
parallel.rs

1//! Block-STM inspired parallel transaction execution.
2//!
3//! Executes a batch of operations in parallel optimistically, validates for conflicts,
4//! and re-executes conflicting transactions. This is inspired by Aptos Block-STM and
5//! provides significant speedup for batch-heavy workloads like ETL imports.
6//!
7//! # Algorithm
8//!
9//! The execution follows four phases:
10//!
11//! 1. **Optimistic Execution**: Execute all operations in parallel without locking.
12//!    Each operation tracks its read and write sets.
13//!
14//! 2. **Validation**: Check if any read was invalidated by a concurrent write from
15//!    an earlier transaction in the batch.
16//!
17//! 3. **Re-execution**: Re-execute invalidated transactions with knowledge of
18//!    their dependencies.
19//!
20//! 4. **Commit**: Apply all writes in transaction order for determinism.
21//!
22//! # Performance
23//!
24//! | Conflict Rate | Expected Speedup |
25//! |---------------|------------------|
26//! | 0% | 3-4x on 4 cores |
27//! | <10% | 2-3x |
28//! | >30% | Falls back to sequential |
29//!
30//! # Example
31//!
32//! ```ignore
33//! use grafeo_engine::transaction::parallel::{ParallelExecutor, BatchRequest};
34//!
35//! let executor = ParallelExecutor::new(4); // 4 workers
36//!
37//! let batch = BatchRequest::new(vec![
38//!     "CREATE (n:Person {id: 1})",
39//!     "CREATE (n:Person {id: 2})",
40//!     "CREATE (n:Person {id: 3})",
41//! ]);
42//!
43//! let result = executor.execute_batch(batch)?;
44//! assert!(result.all_succeeded());
45//! ```
46
47use std::collections::HashSet;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use grafeo_common::types::EpochId;
52use grafeo_common::utils::hash::FxHashMap;
53use parking_lot::{Mutex, RwLock};
54use rayon::prelude::*;
55
56use super::EntityId;
57
58/// Maximum number of re-execution attempts before giving up.
59const MAX_REEXECUTION_ROUNDS: usize = 10;
60
61/// Minimum batch size to consider parallel execution (otherwise sequential is faster).
62const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
63
64/// Maximum conflict rate before falling back to sequential execution.
65const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
66
67/// Status of an operation execution.
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum ExecutionStatus {
70    /// Execution succeeded and is valid.
71    Success,
72    /// Execution needs re-validation due to potential conflicts.
73    NeedsRevalidation,
74    /// Execution was re-executed after conflict.
75    Reexecuted,
76    /// Execution failed with an error.
77    Failed,
78}
79
80/// Result of executing a single operation in the batch.
81#[derive(Debug)]
82pub struct ExecutionResult {
83    /// Index in the batch (for ordering).
84    pub batch_index: usize,
85    /// Execution status.
86    pub status: ExecutionStatus,
87    /// Entities read during execution (entity_id, epoch_read_at).
88    pub read_set: HashSet<(EntityId, EpochId)>,
89    /// Entities written during execution.
90    pub write_set: HashSet<EntityId>,
91    /// Dependencies on earlier transactions in the batch.
92    pub dependencies: Vec<usize>,
93    /// Number of times this operation was re-executed.
94    pub reexecution_count: usize,
95    /// Error message if failed.
96    pub error: Option<String>,
97}
98
99impl ExecutionResult {
100    /// Creates a new execution result.
101    fn new(batch_index: usize) -> Self {
102        Self {
103            batch_index,
104            status: ExecutionStatus::Success,
105            read_set: HashSet::new(),
106            write_set: HashSet::new(),
107            dependencies: Vec::new(),
108            reexecution_count: 0,
109            error: None,
110        }
111    }
112
113    /// Records a read operation.
114    pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
115        self.read_set.insert((entity, epoch));
116    }
117
118    /// Records a write operation.
119    pub fn record_write(&mut self, entity: EntityId) {
120        self.write_set.insert(entity);
121    }
122
123    /// Marks as needing revalidation.
124    pub fn mark_needs_revalidation(&mut self) {
125        self.status = ExecutionStatus::NeedsRevalidation;
126    }
127
128    /// Marks as reexecuted.
129    pub fn mark_reexecuted(&mut self) {
130        self.status = ExecutionStatus::Reexecuted;
131        self.reexecution_count += 1;
132    }
133
134    /// Marks as failed with an error.
135    pub fn mark_failed(&mut self, error: String) {
136        self.status = ExecutionStatus::Failed;
137        self.error = Some(error);
138    }
139}
140
141/// A batch of operations to execute in parallel.
142#[derive(Debug, Clone)]
143pub struct BatchRequest {
144    /// The operations to execute (as query strings).
145    pub operations: Vec<String>,
146}
147
148impl BatchRequest {
149    /// Creates a new batch request.
150    pub fn new(operations: Vec<impl Into<String>>) -> Self {
151        Self {
152            operations: operations.into_iter().map(Into::into).collect(),
153        }
154    }
155
156    /// Returns the number of operations.
157    #[must_use]
158    pub fn len(&self) -> usize {
159        self.operations.len()
160    }
161
162    /// Returns whether the batch is empty.
163    #[must_use]
164    pub fn is_empty(&self) -> bool {
165        self.operations.is_empty()
166    }
167}
168
169/// Result of executing a batch of operations.
170#[derive(Debug)]
171pub struct BatchResult {
172    /// Results for each operation (in order).
173    pub results: Vec<ExecutionResult>,
174    /// Total number of successful operations.
175    pub success_count: usize,
176    /// Total number of failed operations.
177    pub failure_count: usize,
178    /// Total number of re-executions performed.
179    pub reexecution_count: usize,
180    /// Whether parallel execution was used (vs fallback to sequential).
181    pub parallel_executed: bool,
182}
183
184impl BatchResult {
185    /// Returns true if all operations succeeded.
186    #[must_use]
187    pub fn all_succeeded(&self) -> bool {
188        self.failure_count == 0
189    }
190
191    /// Returns the indices of failed operations.
192    pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
193        self.results
194            .iter()
195            .filter(|r| r.status == ExecutionStatus::Failed)
196            .map(|r| r.batch_index)
197    }
198}
199
200/// Tracks which entities have been written by which batch index.
201#[derive(Debug, Default)]
202struct WriteTracker {
203    /// Entity -> batch index that wrote it.
204    writes: RwLock<FxHashMap<EntityId, usize>>,
205}
206
207impl WriteTracker {
208    /// Records a write by a batch index.
209    /// Keeps track of the earliest writer for conflict detection.
210    fn record_write(&self, entity: EntityId, batch_index: usize) {
211        let mut writes = self.writes.write();
212        writes
213            .entry(entity)
214            .and_modify(|existing| *existing = (*existing).min(batch_index))
215            .or_insert(batch_index);
216    }
217
218    /// Checks if an entity was written by an earlier transaction.
219    fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
220        let writes = self.writes.read();
221        if let Some(&writer) = writes.get(entity)
222            && writer < batch_index
223        {
224            return Some(writer);
225        }
226        None
227    }
228
229    /// Clears all writes.
230    #[allow(dead_code)]
231    fn clear(&self) {
232        self.writes.write().clear();
233    }
234}
235
236/// Block-STM inspired parallel transaction executor.
237///
238/// Executes batches of operations in parallel with optimistic concurrency control.
239pub struct ParallelExecutor {
240    /// Number of worker threads.
241    num_workers: usize,
242    /// Thread pool for parallel execution.
243    pool: rayon::ThreadPool,
244}
245
246impl ParallelExecutor {
247    /// Creates a new parallel executor with the specified number of workers.
248    ///
249    /// # Panics
250    ///
251    /// Panics if num_workers is 0.
252    pub fn new(num_workers: usize) -> Self {
253        assert!(num_workers > 0, "num_workers must be positive");
254
255        let pool = rayon::ThreadPoolBuilder::new()
256            .num_threads(num_workers)
257            .build()
258            .expect("Failed to build thread pool");
259
260        Self { num_workers, pool }
261    }
262
263    /// Creates a parallel executor with the default number of workers (number of CPUs).
264    #[must_use]
265    pub fn default_workers() -> Self {
266        // Use rayon's default parallelism which is based on num_cpus
267        Self::new(rayon::current_num_threads().max(1))
268    }
269
270    /// Returns the number of workers.
271    #[must_use]
272    pub fn num_workers(&self) -> usize {
273        self.num_workers
274    }
275
276    /// Executes a batch of operations in parallel.
277    ///
278    /// Operations are executed optimistically in parallel, validated for conflicts,
279    /// and re-executed as needed. The final result maintains deterministic ordering.
280    pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
281    where
282        F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
283    {
284        let n = batch.len();
285
286        // Handle empty or small batches
287        if n == 0 {
288            return BatchResult {
289                results: Vec::new(),
290                success_count: 0,
291                failure_count: 0,
292                reexecution_count: 0,
293                parallel_executed: false,
294            };
295        }
296
297        if n < MIN_BATCH_SIZE_FOR_PARALLEL {
298            return self.execute_sequential(batch, execute_fn);
299        }
300
301        // Phase 1: Optimistic parallel execution
302        let write_tracker = Arc::new(WriteTracker::default());
303        let results: Vec<Mutex<ExecutionResult>> = (0..n)
304            .map(|i| Mutex::new(ExecutionResult::new(i)))
305            .collect();
306
307        self.pool.install(|| {
308            batch
309                .operations
310                .par_iter()
311                .enumerate()
312                .for_each(|(idx, op)| {
313                    let mut result = results[idx].lock();
314                    execute_fn(idx, op, &mut result);
315
316                    // Record writes to tracker
317                    for entity in &result.write_set {
318                        write_tracker.record_write(*entity, idx);
319                    }
320                });
321        });
322
323        // Phase 2: Validation
324        let mut invalid_indices = Vec::new();
325
326        for (idx, result_mutex) in results.iter().enumerate() {
327            let mut result = result_mutex.lock();
328
329            // Collect entities to check (to avoid borrow issues)
330            let read_entities: Vec<EntityId> =
331                result.read_set.iter().map(|(entity, _)| *entity).collect();
332
333            // Check if any of our reads were invalidated by an earlier write
334            for entity in read_entities {
335                if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
336                    result.mark_needs_revalidation();
337                    result.dependencies.push(writer);
338                }
339            }
340
341            if result.status == ExecutionStatus::NeedsRevalidation {
342                invalid_indices.push(idx);
343            }
344        }
345
346        // Check conflict rate
347        let conflict_rate = invalid_indices.len() as f64 / n as f64;
348        if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
349            // Too many conflicts - fall back to sequential
350            return self.execute_sequential(batch, execute_fn);
351        }
352
353        // Phase 3: Re-execution of conflicting transactions
354        let total_reexecutions = AtomicUsize::new(0);
355
356        for round in 0..MAX_REEXECUTION_ROUNDS {
357            if invalid_indices.is_empty() {
358                break;
359            }
360
361            // Re-execute invalid transactions
362            let still_invalid: Vec<usize> = self.pool.install(|| {
363                invalid_indices
364                    .par_iter()
365                    .filter_map(|&idx| {
366                        let mut result = results[idx].lock();
367
368                        // Clear previous state
369                        result.read_set.clear();
370                        result.write_set.clear();
371                        result.dependencies.clear();
372
373                        // Re-execute
374                        execute_fn(idx, &batch.operations[idx], &mut result);
375                        result.mark_reexecuted();
376                        total_reexecutions.fetch_add(1, Ordering::Relaxed);
377
378                        // Collect entities for re-validation
379                        let read_entities: Vec<EntityId> =
380                            result.read_set.iter().map(|(entity, _)| *entity).collect();
381
382                        // Re-validate
383                        for entity in read_entities {
384                            if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
385                            {
386                                result.mark_needs_revalidation();
387                                result.dependencies.push(writer);
388                                return Some(idx);
389                            }
390                        }
391
392                        result.status = ExecutionStatus::Success;
393                        None
394                    })
395                    .collect()
396            });
397
398            invalid_indices = still_invalid;
399
400            if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
401                // Max rounds reached, mark remaining as failed
402                for idx in &invalid_indices {
403                    let mut result = results[*idx].lock();
404                    result.mark_failed("Max re-execution rounds reached".to_string());
405                }
406            }
407        }
408
409        // Phase 4: Collect results
410        let mut final_results: Vec<ExecutionResult> =
411            results.into_iter().map(|m| m.into_inner()).collect();
412
413        // Sort by batch index to maintain order
414        final_results.sort_by_key(|r| r.batch_index);
415
416        let success_count = final_results
417            .iter()
418            .filter(|r| r.status != ExecutionStatus::Failed)
419            .count();
420
421        BatchResult {
422            failure_count: n - success_count,
423            success_count,
424            reexecution_count: total_reexecutions.load(Ordering::Relaxed),
425            parallel_executed: true,
426            results: final_results,
427        }
428    }
429
430    /// Executes a batch sequentially (fallback for high conflict scenarios).
431    fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
432    where
433        F: Fn(usize, &str, &mut ExecutionResult),
434    {
435        let mut results = Vec::with_capacity(batch.len());
436
437        for (idx, op) in batch.operations.iter().enumerate() {
438            let mut result = ExecutionResult::new(idx);
439            execute_fn(idx, op, &mut result);
440            results.push(result);
441        }
442
443        let success_count = results
444            .iter()
445            .filter(|r| r.status != ExecutionStatus::Failed)
446            .count();
447
448        BatchResult {
449            failure_count: results.len() - success_count,
450            success_count,
451            reexecution_count: 0,
452            parallel_executed: false,
453            results,
454        }
455    }
456}
457
458impl Default for ParallelExecutor {
459    fn default() -> Self {
460        Self::default_workers()
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use grafeo_common::types::NodeId;
468    use std::sync::atomic::AtomicU64;
469    use std::thread;
470    use std::time::Duration;
471
472    #[test]
473    fn test_empty_batch() {
474        let executor = ParallelExecutor::new(4);
475        let batch = BatchRequest::new(Vec::<String>::new());
476
477        let result = executor.execute_batch(batch, |_, _, _| {});
478
479        assert!(result.all_succeeded());
480        assert_eq!(result.results.len(), 0);
481    }
482
483    #[test]
484    fn test_single_operation() {
485        let executor = ParallelExecutor::new(4);
486        let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
487
488        let result = executor.execute_batch(batch, |_, _, result| {
489            result.record_write(EntityId::Node(NodeId::new(1)));
490        });
491
492        assert!(result.all_succeeded());
493        assert_eq!(result.results.len(), 1);
494        // Small batch uses sequential execution
495        assert!(!result.parallel_executed);
496    }
497
498    #[test]
499    fn test_independent_operations() {
500        let executor = ParallelExecutor::new(4);
501        let batch = BatchRequest::new(vec![
502            "CREATE (n1:Test {id: 1})",
503            "CREATE (n2:Test {id: 2})",
504            "CREATE (n3:Test {id: 3})",
505            "CREATE (n4:Test {id: 4})",
506            "CREATE (n5:Test {id: 5})",
507        ]);
508
509        let counter = AtomicU64::new(0);
510
511        let result = executor.execute_batch(batch, |idx, _, result| {
512            // Each operation writes to a different entity
513            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
514            counter.fetch_add(1, Ordering::Relaxed);
515        });
516
517        assert!(result.all_succeeded());
518        assert_eq!(result.results.len(), 5);
519        assert_eq!(result.reexecution_count, 0); // No conflicts
520        assert!(result.parallel_executed);
521        assert_eq!(counter.load(Ordering::Relaxed), 5);
522    }
523
524    #[test]
525    fn test_conflicting_operations() {
526        let executor = ParallelExecutor::new(4);
527        let batch = BatchRequest::new(vec![
528            "UPDATE (n:Test) SET n.value = 1",
529            "UPDATE (n:Test) SET n.value = 2",
530            "UPDATE (n:Test) SET n.value = 3",
531            "UPDATE (n:Test) SET n.value = 4",
532            "UPDATE (n:Test) SET n.value = 5",
533        ]);
534
535        let shared_entity = EntityId::Node(NodeId::new(100));
536
537        let result = executor.execute_batch(batch, |_idx, _, result| {
538            // All operations read and write the same entity
539            result.record_read(shared_entity, EpochId::new(0));
540            result.record_write(shared_entity);
541
542            // Simulate some work
543            thread::sleep(Duration::from_micros(10));
544        });
545
546        assert!(result.all_succeeded());
547        assert_eq!(result.results.len(), 5);
548        // Some operations should have been re-executed due to conflicts
549        assert!(result.reexecution_count > 0 || !result.parallel_executed);
550    }
551
552    #[test]
553    fn test_partial_conflicts() {
554        let executor = ParallelExecutor::new(4);
555        let batch = BatchRequest::new(vec![
556            "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
557        ]);
558
559        // All operations write to independent entities (no conflicts)
560        // This tests parallel execution with no read-write conflicts
561
562        let result = executor.execute_batch(batch, |idx, _, result| {
563            // Each operation writes to its own entity (no conflicts)
564            let entity = EntityId::Node(NodeId::new(idx as u64));
565            result.record_write(entity);
566        });
567
568        assert!(result.all_succeeded());
569        assert_eq!(result.results.len(), 10);
570        // Should be parallel since no conflicts
571        assert!(result.parallel_executed);
572        assert_eq!(result.reexecution_count, 0);
573    }
574
575    #[test]
576    fn test_execution_order_preserved() {
577        let executor = ParallelExecutor::new(4);
578        let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
579
580        let result = executor.execute_batch(batch, |idx, _, result| {
581            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
582        });
583
584        // Verify results are in order
585        for (i, r) in result.results.iter().enumerate() {
586            assert_eq!(
587                r.batch_index, i,
588                "Result at position {} has wrong batch_index",
589                i
590            );
591        }
592    }
593
594    #[test]
595    fn test_failure_handling() {
596        let executor = ParallelExecutor::new(4);
597        let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
598
599        let result = executor.execute_batch(batch, |idx, op, result| {
600            if op == "fail" {
601                result.mark_failed("Intentional failure".to_string());
602            } else {
603                result.record_write(EntityId::Node(NodeId::new(idx as u64)));
604            }
605        });
606
607        assert!(!result.all_succeeded());
608        assert_eq!(result.failure_count, 1);
609        assert_eq!(result.success_count, 4);
610
611        let failed: Vec<usize> = result.failed_indices().collect();
612        assert_eq!(failed, vec![1]);
613    }
614
615    #[test]
616    fn test_write_tracker() {
617        let tracker = WriteTracker::default();
618
619        tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
620        tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
621        tracker.record_write(EntityId::Node(NodeId::new(1)), 2); // Keeps earliest (0)
622
623        // Entity 1 was first written by index 0 (earliest is kept)
624        assert_eq!(
625            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
626            Some(0)
627        );
628
629        // Entity 2 was written by index 1
630        assert_eq!(
631            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
632            Some(1)
633        );
634
635        // Index 0 has no earlier writers
636        assert_eq!(
637            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
638            None
639        );
640    }
641
642    #[test]
643    fn test_batch_request() {
644        let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
645        assert_eq!(batch.len(), 3);
646        assert!(!batch.is_empty());
647
648        let empty_batch = BatchRequest::new(Vec::<String>::new());
649        assert!(empty_batch.is_empty());
650    }
651
652    #[test]
653    fn test_execution_result() {
654        let mut result = ExecutionResult::new(5);
655
656        assert_eq!(result.batch_index, 5);
657        assert_eq!(result.status, ExecutionStatus::Success);
658        assert!(result.read_set.is_empty());
659        assert!(result.write_set.is_empty());
660
661        result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
662        result.record_write(EntityId::Node(NodeId::new(2)));
663
664        assert_eq!(result.read_set.len(), 1);
665        assert_eq!(result.write_set.len(), 1);
666
667        result.mark_needs_revalidation();
668        assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
669
670        result.mark_reexecuted();
671        assert_eq!(result.status, ExecutionStatus::Reexecuted);
672        assert_eq!(result.reexecution_count, 1);
673    }
674}