Skip to main content

grafeo_engine/transaction/
parallel.rs

1//! Block-STM inspired parallel transaction execution.
2//!
3//! Executes a batch of operations in parallel optimistically, validates for conflicts,
4//! and re-executes conflicting transactions. This is inspired by Aptos Block-STM and
5//! provides significant speedup for batch-heavy workloads like ETL imports.
6//!
7//! # Algorithm
8//!
9//! The execution follows four phases:
10//!
11//! 1. **Optimistic Execution**: Execute all operations in parallel without locking.
12//!    Each operation tracks its read and write sets.
13//!
14//! 2. **Validation**: Check if any read was invalidated by a concurrent write from
15//!    an earlier transaction in the batch.
16//!
17//! 3. **Re-execution**: Re-execute invalidated transactions with knowledge of
18//!    their dependencies.
19//!
20//! 4. **Commit**: Apply all writes in transaction order for determinism.
21//!
22//! # Performance
23//!
24//! | Conflict Rate | Expected Speedup |
25//! |---------------|------------------|
26//! | 0% | 3-4x on 4 cores |
27//! | <10% | 2-3x |
28//! | >30% | Falls back to sequential |
29//!
30//! # Example
31//!
32//! ```no_run
33//! use grafeo_engine::transaction::parallel::{ParallelExecutor, BatchRequest};
34//!
35//! let executor = ParallelExecutor::new(4); // 4 workers
36//!
37//! let batch = BatchRequest::new(vec![
38//!     "CREATE (n:Person {id: 1})",
39//!     "CREATE (n:Person {id: 2})",
40//!     "CREATE (n:Person {id: 3})",
41//! ]);
42//!
43//! let result = executor.execute_batch(batch, |_idx, _op, _result| {
44//!     // execute each operation against the store
45//! });
46//! assert!(result.all_succeeded());
47//! ```
48
49use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60/// Maximum number of re-execution attempts before giving up.
61const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63/// Minimum batch size to consider parallel execution (otherwise sequential is faster).
64const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66/// Maximum conflict rate before falling back to sequential execution.
67const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69/// Status of an operation execution.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71#[non_exhaustive]
72pub enum ExecutionStatus {
73    /// Execution succeeded and is valid.
74    Success,
75    /// Execution needs re-validation due to potential conflicts.
76    NeedsRevalidation,
77    /// Execution was re-executed after conflict.
78    Reexecuted,
79    /// Execution failed with an error.
80    Failed,
81}
82
83/// Result of executing a single operation in the batch.
84#[derive(Debug)]
85pub struct ExecutionResult {
86    /// Index in the batch (for ordering).
87    pub batch_index: usize,
88    /// Execution status.
89    pub status: ExecutionStatus,
90    /// Entities read during execution (entity_id, epoch_read_at).
91    pub read_set: HashSet<(EntityId, EpochId)>,
92    /// Entities written during execution.
93    pub write_set: HashSet<EntityId>,
94    /// Dependencies on earlier transactions in the batch.
95    pub dependencies: Vec<usize>,
96    /// Number of times this operation was re-executed.
97    pub reexecution_count: usize,
98    /// Error message if failed.
99    pub error: Option<String>,
100}
101
102impl ExecutionResult {
103    /// Creates a new execution result.
104    fn new(batch_index: usize) -> Self {
105        Self {
106            batch_index,
107            status: ExecutionStatus::Success,
108            read_set: HashSet::new(),
109            write_set: HashSet::new(),
110            dependencies: Vec::new(),
111            reexecution_count: 0,
112            error: None,
113        }
114    }
115
116    /// Records a read operation.
117    pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
118        self.read_set.insert((entity, epoch));
119    }
120
121    /// Records a write operation.
122    pub fn record_write(&mut self, entity: EntityId) {
123        self.write_set.insert(entity);
124    }
125
126    /// Marks as needing revalidation.
127    pub fn mark_needs_revalidation(&mut self) {
128        self.status = ExecutionStatus::NeedsRevalidation;
129    }
130
131    /// Marks as reexecuted.
132    pub fn mark_reexecuted(&mut self) {
133        self.status = ExecutionStatus::Reexecuted;
134        self.reexecution_count += 1;
135    }
136
137    /// Marks as failed with an error.
138    pub fn mark_failed(&mut self, error: String) {
139        self.status = ExecutionStatus::Failed;
140        self.error = Some(error);
141    }
142}
143
144/// A batch of operations to execute in parallel.
145#[derive(Debug, Clone)]
146pub struct BatchRequest {
147    /// The operations to execute (as query strings).
148    pub operations: Vec<String>,
149}
150
151impl BatchRequest {
152    /// Creates a new batch request.
153    pub fn new(operations: Vec<impl Into<String>>) -> Self {
154        Self {
155            operations: operations.into_iter().map(Into::into).collect(),
156        }
157    }
158
159    /// Returns the number of operations.
160    #[must_use]
161    pub fn len(&self) -> usize {
162        self.operations.len()
163    }
164
165    /// Returns whether the batch is empty.
166    #[must_use]
167    pub fn is_empty(&self) -> bool {
168        self.operations.is_empty()
169    }
170}
171
172/// Result of executing a batch of operations.
173#[derive(Debug)]
174pub struct BatchResult {
175    /// Results for each operation (in order).
176    pub results: Vec<ExecutionResult>,
177    /// Total number of successful operations.
178    pub success_count: usize,
179    /// Total number of failed operations.
180    pub failure_count: usize,
181    /// Total number of re-executions performed.
182    pub reexecution_count: usize,
183    /// Whether parallel execution was used (vs fallback to sequential).
184    pub parallel_executed: bool,
185}
186
187impl BatchResult {
188    /// Returns true if all operations succeeded.
189    #[must_use]
190    pub fn all_succeeded(&self) -> bool {
191        self.failure_count == 0
192    }
193
194    /// Returns the indices of failed operations.
195    pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
196        self.results
197            .iter()
198            .filter(|r| r.status == ExecutionStatus::Failed)
199            .map(|r| r.batch_index)
200    }
201}
202
203/// Tracks which entities have been written by which batch index.
204#[derive(Debug, Default)]
205struct WriteTracker {
206    /// Entity -> batch index that wrote it.
207    writes: RwLock<FxHashMap<EntityId, usize>>,
208}
209
210impl WriteTracker {
211    /// Records a write by a batch index.
212    /// Keeps track of the earliest writer for conflict detection.
213    fn record_write(&self, entity: EntityId, batch_index: usize) {
214        let mut writes = self.writes.write();
215        writes
216            .entry(entity)
217            .and_modify(|existing| *existing = (*existing).min(batch_index))
218            .or_insert(batch_index);
219    }
220
221    /// Checks if an entity was written by an earlier transaction.
222    fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
223        let writes = self.writes.read();
224        if let Some(&writer) = writes.get(entity)
225            && writer < batch_index
226        {
227            return Some(writer);
228        }
229        None
230    }
231}
232
233/// Block-STM inspired parallel transaction executor.
234///
235/// Executes batches of operations in parallel with optimistic concurrency control.
236pub struct ParallelExecutor {
237    /// Number of worker threads.
238    num_workers: usize,
239    /// Thread pool for parallel execution.
240    pool: rayon::ThreadPool,
241}
242
243impl ParallelExecutor {
244    /// Creates a new parallel executor with the specified number of workers.
245    ///
246    /// # Panics
247    ///
248    /// Panics if `num_workers` is 0 or if the thread pool cannot be created.
249    #[must_use]
250    pub fn new(num_workers: usize) -> Self {
251        assert!(num_workers > 0, "num_workers must be positive");
252
253        let pool = rayon::ThreadPoolBuilder::new()
254            .num_threads(num_workers)
255            .build()
256            .expect("failed to build thread pool");
257
258        Self { num_workers, pool }
259    }
260
261    /// Creates a parallel executor with the default number of workers (number of CPUs).
262    #[must_use]
263    pub fn default_workers() -> Self {
264        // Use rayon's default parallelism which is based on num_cpus
265        Self::new(rayon::current_num_threads().max(1))
266    }
267
268    /// Returns the number of workers.
269    #[must_use]
270    pub fn num_workers(&self) -> usize {
271        self.num_workers
272    }
273
274    /// Executes a batch of operations in parallel.
275    ///
276    /// Operations are executed optimistically in parallel, validated for conflicts,
277    /// and re-executed as needed. The final result maintains deterministic ordering.
278    pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
279    where
280        F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
281    {
282        let n = batch.len();
283
284        // Handle empty or small batches
285        if n == 0 {
286            return BatchResult {
287                results: Vec::new(),
288                success_count: 0,
289                failure_count: 0,
290                reexecution_count: 0,
291                parallel_executed: false,
292            };
293        }
294
295        if n < MIN_BATCH_SIZE_FOR_PARALLEL {
296            return self.execute_sequential(batch, execute_fn);
297        }
298
299        // Phase 1: Optimistic parallel execution
300        let write_tracker = Arc::new(WriteTracker::default());
301        let results: Vec<Mutex<ExecutionResult>> = (0..n)
302            .map(|i| Mutex::new(ExecutionResult::new(i)))
303            .collect();
304
305        self.pool.install(|| {
306            batch
307                .operations
308                .par_iter()
309                .enumerate()
310                .for_each(|(idx, op)| {
311                    let mut result = results[idx].lock();
312                    execute_fn(idx, op, &mut result);
313
314                    // Record writes to tracker
315                    for entity in &result.write_set {
316                        write_tracker.record_write(*entity, idx);
317                    }
318                });
319        });
320
321        // Phase 2: Validation
322        let mut invalid_indices = Vec::new();
323
324        for (idx, result_mutex) in results.iter().enumerate() {
325            let mut result = result_mutex.lock();
326
327            // Collect entities to check (to avoid borrow issues)
328            let read_entities: Vec<EntityId> =
329                result.read_set.iter().map(|(entity, _)| *entity).collect();
330
331            // Check if any of our reads were invalidated by an earlier write
332            for entity in read_entities {
333                if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
334                    result.mark_needs_revalidation();
335                    result.dependencies.push(writer);
336                }
337            }
338
339            if result.status == ExecutionStatus::NeedsRevalidation {
340                invalid_indices.push(idx);
341            }
342        }
343
344        // Check conflict rate
345        let conflict_rate = invalid_indices.len() as f64 / n as f64;
346        if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
347            // Too many conflicts - fall back to sequential
348            return self.execute_sequential(batch, execute_fn);
349        }
350
351        // Phase 3: Re-execution of conflicting transactions
352        let total_reexecutions = AtomicUsize::new(0);
353
354        for round in 0..MAX_REEXECUTION_ROUNDS {
355            if invalid_indices.is_empty() {
356                break;
357            }
358
359            // Re-execute invalid transactions
360            let still_invalid: Vec<usize> = self.pool.install(|| {
361                invalid_indices
362                    .par_iter()
363                    .filter_map(|&idx| {
364                        let mut result = results[idx].lock();
365
366                        // Clear previous state
367                        result.read_set.clear();
368                        result.write_set.clear();
369                        result.dependencies.clear();
370
371                        // Re-execute
372                        execute_fn(idx, &batch.operations[idx], &mut result);
373                        result.mark_reexecuted();
374                        total_reexecutions.fetch_add(1, Ordering::Relaxed);
375
376                        // Collect entities for re-validation
377                        let read_entities: Vec<EntityId> =
378                            result.read_set.iter().map(|(entity, _)| *entity).collect();
379
380                        // Re-validate
381                        for entity in read_entities {
382                            if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
383                            {
384                                result.mark_needs_revalidation();
385                                result.dependencies.push(writer);
386                                return Some(idx);
387                            }
388                        }
389
390                        result.status = ExecutionStatus::Success;
391                        None
392                    })
393                    .collect()
394            });
395
396            invalid_indices = still_invalid;
397
398            if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
399                // Max rounds reached, mark remaining as failed
400                for idx in &invalid_indices {
401                    let mut result = results[*idx].lock();
402                    result.mark_failed("Max re-execution rounds reached".to_string());
403                }
404            }
405        }
406
407        // Phase 4: Collect results
408        let mut final_results: Vec<ExecutionResult> =
409            results.into_iter().map(|m| m.into_inner()).collect();
410
411        // Sort by batch index to maintain order
412        final_results.sort_by_key(|r| r.batch_index);
413
414        let success_count = final_results
415            .iter()
416            .filter(|r| r.status != ExecutionStatus::Failed)
417            .count();
418
419        BatchResult {
420            failure_count: n - success_count,
421            success_count,
422            reexecution_count: total_reexecutions.load(Ordering::Relaxed),
423            parallel_executed: true,
424            results: final_results,
425        }
426    }
427
428    /// Executes a batch sequentially (fallback for high conflict scenarios).
429    fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
430    where
431        F: Fn(usize, &str, &mut ExecutionResult),
432    {
433        let mut results = Vec::with_capacity(batch.len());
434
435        for (idx, op) in batch.operations.iter().enumerate() {
436            let mut result = ExecutionResult::new(idx);
437            execute_fn(idx, op, &mut result);
438            results.push(result);
439        }
440
441        let success_count = results
442            .iter()
443            .filter(|r| r.status != ExecutionStatus::Failed)
444            .count();
445
446        BatchResult {
447            failure_count: results.len() - success_count,
448            success_count,
449            reexecution_count: 0,
450            parallel_executed: false,
451            results,
452        }
453    }
454}
455
456impl Default for ParallelExecutor {
457    fn default() -> Self {
458        Self::default_workers()
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use grafeo_common::types::NodeId;
466    use std::sync::atomic::AtomicU64;
467    use std::thread;
468    use std::time::Duration;
469
470    #[test]
471    fn test_empty_batch() {
472        let executor = ParallelExecutor::new(4);
473        let batch = BatchRequest::new(Vec::<String>::new());
474
475        let result = executor.execute_batch(batch, |_, _, _| {});
476
477        assert!(result.all_succeeded());
478        assert_eq!(result.results.len(), 0);
479    }
480
481    #[test]
482    fn test_single_operation() {
483        let executor = ParallelExecutor::new(4);
484        let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
485
486        let result = executor.execute_batch(batch, |_, _, result| {
487            result.record_write(EntityId::Node(NodeId::new(1)));
488        });
489
490        assert!(result.all_succeeded());
491        assert_eq!(result.results.len(), 1);
492        // Small batch uses sequential execution
493        assert!(!result.parallel_executed);
494    }
495
496    #[test]
497    fn test_independent_operations() {
498        let executor = ParallelExecutor::new(4);
499        let batch = BatchRequest::new(vec![
500            "CREATE (n1:Test {id: 1})",
501            "CREATE (n2:Test {id: 2})",
502            "CREATE (n3:Test {id: 3})",
503            "CREATE (n4:Test {id: 4})",
504            "CREATE (n5:Test {id: 5})",
505        ]);
506
507        let counter = AtomicU64::new(0);
508
509        let result = executor.execute_batch(batch, |idx, _, result| {
510            // Each operation writes to a different entity
511            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
512            counter.fetch_add(1, Ordering::Relaxed);
513        });
514
515        assert!(result.all_succeeded());
516        assert_eq!(result.results.len(), 5);
517        assert_eq!(result.reexecution_count, 0); // No conflicts
518        assert!(result.parallel_executed);
519        assert_eq!(counter.load(Ordering::Relaxed), 5);
520    }
521
522    #[test]
523    fn test_conflicting_operations() {
524        let executor = ParallelExecutor::new(4);
525        let batch = BatchRequest::new(vec![
526            "UPDATE (n:Test) SET n.value = 1",
527            "UPDATE (n:Test) SET n.value = 2",
528            "UPDATE (n:Test) SET n.value = 3",
529            "UPDATE (n:Test) SET n.value = 4",
530            "UPDATE (n:Test) SET n.value = 5",
531        ]);
532
533        let shared_entity = EntityId::Node(NodeId::new(100));
534
535        let result = executor.execute_batch(batch, |_idx, _, result| {
536            // All operations read and write the same entity
537            result.record_read(shared_entity, EpochId::new(0));
538            result.record_write(shared_entity);
539
540            // Simulate some work
541            thread::sleep(Duration::from_micros(10));
542        });
543
544        assert!(result.all_succeeded());
545        assert_eq!(result.results.len(), 5);
546        // Some operations should have been re-executed due to conflicts
547        assert!(result.reexecution_count > 0 || !result.parallel_executed);
548    }
549
550    #[test]
551    fn test_partial_conflicts() {
552        let executor = ParallelExecutor::new(4);
553        let batch = BatchRequest::new(vec![
554            "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
555        ]);
556
557        // All operations write to independent entities (no conflicts)
558        // This tests parallel execution with no read-write conflicts
559
560        let result = executor.execute_batch(batch, |idx, _, result| {
561            // Each operation writes to its own entity (no conflicts)
562            let entity = EntityId::Node(NodeId::new(idx as u64));
563            result.record_write(entity);
564        });
565
566        assert!(result.all_succeeded());
567        assert_eq!(result.results.len(), 10);
568        // Should be parallel since no conflicts
569        assert!(result.parallel_executed);
570        assert_eq!(result.reexecution_count, 0);
571    }
572
573    #[test]
574    fn test_execution_order_preserved() {
575        let executor = ParallelExecutor::new(4);
576        let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
577
578        let result = executor.execute_batch(batch, |idx, _, result| {
579            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
580        });
581
582        // Verify results are in order
583        for (i, r) in result.results.iter().enumerate() {
584            assert_eq!(
585                r.batch_index, i,
586                "Result at position {} has wrong batch_index",
587                i
588            );
589        }
590    }
591
592    #[test]
593    fn test_failure_handling() {
594        let executor = ParallelExecutor::new(4);
595        let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
596
597        let result = executor.execute_batch(batch, |idx, op, result| {
598            if op == "fail" {
599                result.mark_failed("Intentional failure".to_string());
600            } else {
601                result.record_write(EntityId::Node(NodeId::new(idx as u64)));
602            }
603        });
604
605        assert!(!result.all_succeeded());
606        assert_eq!(result.failure_count, 1);
607        assert_eq!(result.success_count, 4);
608
609        let failed: Vec<usize> = result.failed_indices().collect();
610        assert_eq!(failed, vec![1]);
611    }
612
613    #[test]
614    fn test_write_tracker() {
615        let tracker = WriteTracker::default();
616
617        tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
618        tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
619        tracker.record_write(EntityId::Node(NodeId::new(1)), 2); // Keeps earliest (0)
620
621        // Entity 1 was first written by index 0 (earliest is kept)
622        assert_eq!(
623            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
624            Some(0)
625        );
626
627        // Entity 2 was written by index 1
628        assert_eq!(
629            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
630            Some(1)
631        );
632
633        // Index 0 has no earlier writers
634        assert_eq!(
635            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
636            None
637        );
638    }
639
640    #[test]
641    fn test_batch_request() {
642        let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
643        assert_eq!(batch.len(), 3);
644        assert!(!batch.is_empty());
645
646        let empty_batch = BatchRequest::new(Vec::<String>::new());
647        assert!(empty_batch.is_empty());
648    }
649
650    #[test]
651    fn test_execution_result() {
652        let mut result = ExecutionResult::new(5);
653
654        assert_eq!(result.batch_index, 5);
655        assert_eq!(result.status, ExecutionStatus::Success);
656        assert!(result.read_set.is_empty());
657        assert!(result.write_set.is_empty());
658
659        result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
660        result.record_write(EntityId::Node(NodeId::new(2)));
661
662        assert_eq!(result.read_set.len(), 1);
663        assert_eq!(result.write_set.len(), 1);
664
665        result.mark_needs_revalidation();
666        assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
667
668        result.mark_reexecuted();
669        assert_eq!(result.status, ExecutionStatus::Reexecuted);
670        assert_eq!(result.reexecution_count, 1);
671    }
672}