Skip to main content

grafeo_engine/transaction/
parallel.rs

1//! Block-STM inspired parallel transaction execution.
2//!
3//! Executes a batch of operations in parallel optimistically, validates for conflicts,
4//! and re-executes conflicting transactions. This is inspired by Aptos Block-STM and
5//! provides significant speedup for batch-heavy workloads like ETL imports.
6//!
7//! # Algorithm
8//!
9//! The execution follows four phases:
10//!
11//! 1. **Optimistic Execution**: Execute all operations in parallel without locking.
12//!    Each operation tracks its read and write sets.
13//!
14//! 2. **Validation**: Check if any read was invalidated by a concurrent write from
15//!    an earlier transaction in the batch.
16//!
17//! 3. **Re-execution**: Re-execute invalidated transactions with knowledge of
18//!    their dependencies.
19//!
20//! 4. **Commit**: Apply all writes in transaction order for determinism.
21//!
22//! # Performance
23//!
24//! | Conflict Rate | Expected Speedup |
25//! |---------------|------------------|
26//! | 0% | 3-4x on 4 cores |
27//! | <10% | 2-3x |
28//! | >30% | Falls back to sequential |
29//!
30//! # Example
31//!
32//! ```no_run
33//! use grafeo_engine::transaction::parallel::{ParallelExecutor, BatchRequest};
34//!
35//! let executor = ParallelExecutor::new(4); // 4 workers
36//!
37//! let batch = BatchRequest::new(vec![
38//!     "CREATE (n:Person {id: 1})",
39//!     "CREATE (n:Person {id: 2})",
40//!     "CREATE (n:Person {id: 3})",
41//! ]);
42//!
43//! let result = executor.execute_batch(batch, |_idx, _op, _result| {
44//!     // execute each operation against the store
45//! });
46//! assert!(result.all_succeeded());
47//! ```
48
49use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60/// Maximum number of re-execution attempts before giving up.
61const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63/// Minimum batch size to consider parallel execution (otherwise sequential is faster).
64const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66/// Maximum conflict rate before falling back to sequential execution.
67const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69/// Status of an operation execution.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ExecutionStatus {
72    /// Execution succeeded and is valid.
73    Success,
74    /// Execution needs re-validation due to potential conflicts.
75    NeedsRevalidation,
76    /// Execution was re-executed after conflict.
77    Reexecuted,
78    /// Execution failed with an error.
79    Failed,
80}
81
82/// Result of executing a single operation in the batch.
83#[derive(Debug)]
84pub struct ExecutionResult {
85    /// Index in the batch (for ordering).
86    pub batch_index: usize,
87    /// Execution status.
88    pub status: ExecutionStatus,
89    /// Entities read during execution (entity_id, epoch_read_at).
90    pub read_set: HashSet<(EntityId, EpochId)>,
91    /// Entities written during execution.
92    pub write_set: HashSet<EntityId>,
93    /// Dependencies on earlier transactions in the batch.
94    pub dependencies: Vec<usize>,
95    /// Number of times this operation was re-executed.
96    pub reexecution_count: usize,
97    /// Error message if failed.
98    pub error: Option<String>,
99}
100
101impl ExecutionResult {
102    /// Creates a new execution result.
103    fn new(batch_index: usize) -> Self {
104        Self {
105            batch_index,
106            status: ExecutionStatus::Success,
107            read_set: HashSet::new(),
108            write_set: HashSet::new(),
109            dependencies: Vec::new(),
110            reexecution_count: 0,
111            error: None,
112        }
113    }
114
115    /// Records a read operation.
116    pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
117        self.read_set.insert((entity, epoch));
118    }
119
120    /// Records a write operation.
121    pub fn record_write(&mut self, entity: EntityId) {
122        self.write_set.insert(entity);
123    }
124
125    /// Marks as needing revalidation.
126    pub fn mark_needs_revalidation(&mut self) {
127        self.status = ExecutionStatus::NeedsRevalidation;
128    }
129
130    /// Marks as reexecuted.
131    pub fn mark_reexecuted(&mut self) {
132        self.status = ExecutionStatus::Reexecuted;
133        self.reexecution_count += 1;
134    }
135
136    /// Marks as failed with an error.
137    pub fn mark_failed(&mut self, error: String) {
138        self.status = ExecutionStatus::Failed;
139        self.error = Some(error);
140    }
141}
142
143/// A batch of operations to execute in parallel.
144#[derive(Debug, Clone)]
145pub struct BatchRequest {
146    /// The operations to execute (as query strings).
147    pub operations: Vec<String>,
148}
149
150impl BatchRequest {
151    /// Creates a new batch request.
152    pub fn new(operations: Vec<impl Into<String>>) -> Self {
153        Self {
154            operations: operations.into_iter().map(Into::into).collect(),
155        }
156    }
157
158    /// Returns the number of operations.
159    #[must_use]
160    pub fn len(&self) -> usize {
161        self.operations.len()
162    }
163
164    /// Returns whether the batch is empty.
165    #[must_use]
166    pub fn is_empty(&self) -> bool {
167        self.operations.is_empty()
168    }
169}
170
171/// Result of executing a batch of operations.
172#[derive(Debug)]
173pub struct BatchResult {
174    /// Results for each operation (in order).
175    pub results: Vec<ExecutionResult>,
176    /// Total number of successful operations.
177    pub success_count: usize,
178    /// Total number of failed operations.
179    pub failure_count: usize,
180    /// Total number of re-executions performed.
181    pub reexecution_count: usize,
182    /// Whether parallel execution was used (vs fallback to sequential).
183    pub parallel_executed: bool,
184}
185
186impl BatchResult {
187    /// Returns true if all operations succeeded.
188    #[must_use]
189    pub fn all_succeeded(&self) -> bool {
190        self.failure_count == 0
191    }
192
193    /// Returns the indices of failed operations.
194    pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
195        self.results
196            .iter()
197            .filter(|r| r.status == ExecutionStatus::Failed)
198            .map(|r| r.batch_index)
199    }
200}
201
202/// Tracks which entities have been written by which batch index.
203#[derive(Debug, Default)]
204struct WriteTracker {
205    /// Entity -> batch index that wrote it.
206    writes: RwLock<FxHashMap<EntityId, usize>>,
207}
208
209impl WriteTracker {
210    /// Records a write by a batch index.
211    /// Keeps track of the earliest writer for conflict detection.
212    fn record_write(&self, entity: EntityId, batch_index: usize) {
213        let mut writes = self.writes.write();
214        writes
215            .entry(entity)
216            .and_modify(|existing| *existing = (*existing).min(batch_index))
217            .or_insert(batch_index);
218    }
219
220    /// Checks if an entity was written by an earlier transaction.
221    fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
222        let writes = self.writes.read();
223        if let Some(&writer) = writes.get(entity)
224            && writer < batch_index
225        {
226            return Some(writer);
227        }
228        None
229    }
230}
231
232/// Block-STM inspired parallel transaction executor.
233///
234/// Executes batches of operations in parallel with optimistic concurrency control.
235pub struct ParallelExecutor {
236    /// Number of worker threads.
237    num_workers: usize,
238    /// Thread pool for parallel execution.
239    pool: rayon::ThreadPool,
240}
241
242impl ParallelExecutor {
243    /// Creates a new parallel executor with the specified number of workers.
244    ///
245    /// # Panics
246    ///
247    /// Panics if `num_workers` is 0 or if the thread pool cannot be created.
248    #[must_use]
249    pub fn new(num_workers: usize) -> Self {
250        assert!(num_workers > 0, "num_workers must be positive");
251
252        let pool = rayon::ThreadPoolBuilder::new()
253            .num_threads(num_workers)
254            .build()
255            .expect("failed to build thread pool");
256
257        Self { num_workers, pool }
258    }
259
260    /// Creates a parallel executor with the default number of workers (number of CPUs).
261    #[must_use]
262    pub fn default_workers() -> Self {
263        // Use rayon's default parallelism which is based on num_cpus
264        Self::new(rayon::current_num_threads().max(1))
265    }
266
267    /// Returns the number of workers.
268    #[must_use]
269    pub fn num_workers(&self) -> usize {
270        self.num_workers
271    }
272
273    /// Executes a batch of operations in parallel.
274    ///
275    /// Operations are executed optimistically in parallel, validated for conflicts,
276    /// and re-executed as needed. The final result maintains deterministic ordering.
277    pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
278    where
279        F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
280    {
281        let n = batch.len();
282
283        // Handle empty or small batches
284        if n == 0 {
285            return BatchResult {
286                results: Vec::new(),
287                success_count: 0,
288                failure_count: 0,
289                reexecution_count: 0,
290                parallel_executed: false,
291            };
292        }
293
294        if n < MIN_BATCH_SIZE_FOR_PARALLEL {
295            return self.execute_sequential(batch, execute_fn);
296        }
297
298        // Phase 1: Optimistic parallel execution
299        let write_tracker = Arc::new(WriteTracker::default());
300        let results: Vec<Mutex<ExecutionResult>> = (0..n)
301            .map(|i| Mutex::new(ExecutionResult::new(i)))
302            .collect();
303
304        self.pool.install(|| {
305            batch
306                .operations
307                .par_iter()
308                .enumerate()
309                .for_each(|(idx, op)| {
310                    let mut result = results[idx].lock();
311                    execute_fn(idx, op, &mut result);
312
313                    // Record writes to tracker
314                    for entity in &result.write_set {
315                        write_tracker.record_write(*entity, idx);
316                    }
317                });
318        });
319
320        // Phase 2: Validation
321        let mut invalid_indices = Vec::new();
322
323        for (idx, result_mutex) in results.iter().enumerate() {
324            let mut result = result_mutex.lock();
325
326            // Collect entities to check (to avoid borrow issues)
327            let read_entities: Vec<EntityId> =
328                result.read_set.iter().map(|(entity, _)| *entity).collect();
329
330            // Check if any of our reads were invalidated by an earlier write
331            for entity in read_entities {
332                if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
333                    result.mark_needs_revalidation();
334                    result.dependencies.push(writer);
335                }
336            }
337
338            if result.status == ExecutionStatus::NeedsRevalidation {
339                invalid_indices.push(idx);
340            }
341        }
342
343        // Check conflict rate
344        let conflict_rate = invalid_indices.len() as f64 / n as f64;
345        if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
346            // Too many conflicts - fall back to sequential
347            return self.execute_sequential(batch, execute_fn);
348        }
349
350        // Phase 3: Re-execution of conflicting transactions
351        let total_reexecutions = AtomicUsize::new(0);
352
353        for round in 0..MAX_REEXECUTION_ROUNDS {
354            if invalid_indices.is_empty() {
355                break;
356            }
357
358            // Re-execute invalid transactions
359            let still_invalid: Vec<usize> = self.pool.install(|| {
360                invalid_indices
361                    .par_iter()
362                    .filter_map(|&idx| {
363                        let mut result = results[idx].lock();
364
365                        // Clear previous state
366                        result.read_set.clear();
367                        result.write_set.clear();
368                        result.dependencies.clear();
369
370                        // Re-execute
371                        execute_fn(idx, &batch.operations[idx], &mut result);
372                        result.mark_reexecuted();
373                        total_reexecutions.fetch_add(1, Ordering::Relaxed);
374
375                        // Collect entities for re-validation
376                        let read_entities: Vec<EntityId> =
377                            result.read_set.iter().map(|(entity, _)| *entity).collect();
378
379                        // Re-validate
380                        for entity in read_entities {
381                            if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
382                            {
383                                result.mark_needs_revalidation();
384                                result.dependencies.push(writer);
385                                return Some(idx);
386                            }
387                        }
388
389                        result.status = ExecutionStatus::Success;
390                        None
391                    })
392                    .collect()
393            });
394
395            invalid_indices = still_invalid;
396
397            if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
398                // Max rounds reached, mark remaining as failed
399                for idx in &invalid_indices {
400                    let mut result = results[*idx].lock();
401                    result.mark_failed("Max re-execution rounds reached".to_string());
402                }
403            }
404        }
405
406        // Phase 4: Collect results
407        let mut final_results: Vec<ExecutionResult> =
408            results.into_iter().map(|m| m.into_inner()).collect();
409
410        // Sort by batch index to maintain order
411        final_results.sort_by_key(|r| r.batch_index);
412
413        let success_count = final_results
414            .iter()
415            .filter(|r| r.status != ExecutionStatus::Failed)
416            .count();
417
418        BatchResult {
419            failure_count: n - success_count,
420            success_count,
421            reexecution_count: total_reexecutions.load(Ordering::Relaxed),
422            parallel_executed: true,
423            results: final_results,
424        }
425    }
426
427    /// Executes a batch sequentially (fallback for high conflict scenarios).
428    fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
429    where
430        F: Fn(usize, &str, &mut ExecutionResult),
431    {
432        let mut results = Vec::with_capacity(batch.len());
433
434        for (idx, op) in batch.operations.iter().enumerate() {
435            let mut result = ExecutionResult::new(idx);
436            execute_fn(idx, op, &mut result);
437            results.push(result);
438        }
439
440        let success_count = results
441            .iter()
442            .filter(|r| r.status != ExecutionStatus::Failed)
443            .count();
444
445        BatchResult {
446            failure_count: results.len() - success_count,
447            success_count,
448            reexecution_count: 0,
449            parallel_executed: false,
450            results,
451        }
452    }
453}
454
455impl Default for ParallelExecutor {
456    fn default() -> Self {
457        Self::default_workers()
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use grafeo_common::types::NodeId;
465    use std::sync::atomic::AtomicU64;
466    use std::thread;
467    use std::time::Duration;
468
469    #[test]
470    fn test_empty_batch() {
471        let executor = ParallelExecutor::new(4);
472        let batch = BatchRequest::new(Vec::<String>::new());
473
474        let result = executor.execute_batch(batch, |_, _, _| {});
475
476        assert!(result.all_succeeded());
477        assert_eq!(result.results.len(), 0);
478    }
479
480    #[test]
481    fn test_single_operation() {
482        let executor = ParallelExecutor::new(4);
483        let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
484
485        let result = executor.execute_batch(batch, |_, _, result| {
486            result.record_write(EntityId::Node(NodeId::new(1)));
487        });
488
489        assert!(result.all_succeeded());
490        assert_eq!(result.results.len(), 1);
491        // Small batch uses sequential execution
492        assert!(!result.parallel_executed);
493    }
494
495    #[test]
496    fn test_independent_operations() {
497        let executor = ParallelExecutor::new(4);
498        let batch = BatchRequest::new(vec![
499            "CREATE (n1:Test {id: 1})",
500            "CREATE (n2:Test {id: 2})",
501            "CREATE (n3:Test {id: 3})",
502            "CREATE (n4:Test {id: 4})",
503            "CREATE (n5:Test {id: 5})",
504        ]);
505
506        let counter = AtomicU64::new(0);
507
508        let result = executor.execute_batch(batch, |idx, _, result| {
509            // Each operation writes to a different entity
510            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
511            counter.fetch_add(1, Ordering::Relaxed);
512        });
513
514        assert!(result.all_succeeded());
515        assert_eq!(result.results.len(), 5);
516        assert_eq!(result.reexecution_count, 0); // No conflicts
517        assert!(result.parallel_executed);
518        assert_eq!(counter.load(Ordering::Relaxed), 5);
519    }
520
521    #[test]
522    fn test_conflicting_operations() {
523        let executor = ParallelExecutor::new(4);
524        let batch = BatchRequest::new(vec![
525            "UPDATE (n:Test) SET n.value = 1",
526            "UPDATE (n:Test) SET n.value = 2",
527            "UPDATE (n:Test) SET n.value = 3",
528            "UPDATE (n:Test) SET n.value = 4",
529            "UPDATE (n:Test) SET n.value = 5",
530        ]);
531
532        let shared_entity = EntityId::Node(NodeId::new(100));
533
534        let result = executor.execute_batch(batch, |_idx, _, result| {
535            // All operations read and write the same entity
536            result.record_read(shared_entity, EpochId::new(0));
537            result.record_write(shared_entity);
538
539            // Simulate some work
540            thread::sleep(Duration::from_micros(10));
541        });
542
543        assert!(result.all_succeeded());
544        assert_eq!(result.results.len(), 5);
545        // Some operations should have been re-executed due to conflicts
546        assert!(result.reexecution_count > 0 || !result.parallel_executed);
547    }
548
549    #[test]
550    fn test_partial_conflicts() {
551        let executor = ParallelExecutor::new(4);
552        let batch = BatchRequest::new(vec![
553            "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
554        ]);
555
556        // All operations write to independent entities (no conflicts)
557        // This tests parallel execution with no read-write conflicts
558
559        let result = executor.execute_batch(batch, |idx, _, result| {
560            // Each operation writes to its own entity (no conflicts)
561            let entity = EntityId::Node(NodeId::new(idx as u64));
562            result.record_write(entity);
563        });
564
565        assert!(result.all_succeeded());
566        assert_eq!(result.results.len(), 10);
567        // Should be parallel since no conflicts
568        assert!(result.parallel_executed);
569        assert_eq!(result.reexecution_count, 0);
570    }
571
572    #[test]
573    fn test_execution_order_preserved() {
574        let executor = ParallelExecutor::new(4);
575        let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
576
577        let result = executor.execute_batch(batch, |idx, _, result| {
578            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
579        });
580
581        // Verify results are in order
582        for (i, r) in result.results.iter().enumerate() {
583            assert_eq!(
584                r.batch_index, i,
585                "Result at position {} has wrong batch_index",
586                i
587            );
588        }
589    }
590
591    #[test]
592    fn test_failure_handling() {
593        let executor = ParallelExecutor::new(4);
594        let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
595
596        let result = executor.execute_batch(batch, |idx, op, result| {
597            if op == "fail" {
598                result.mark_failed("Intentional failure".to_string());
599            } else {
600                result.record_write(EntityId::Node(NodeId::new(idx as u64)));
601            }
602        });
603
604        assert!(!result.all_succeeded());
605        assert_eq!(result.failure_count, 1);
606        assert_eq!(result.success_count, 4);
607
608        let failed: Vec<usize> = result.failed_indices().collect();
609        assert_eq!(failed, vec![1]);
610    }
611
612    #[test]
613    fn test_write_tracker() {
614        let tracker = WriteTracker::default();
615
616        tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
617        tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
618        tracker.record_write(EntityId::Node(NodeId::new(1)), 2); // Keeps earliest (0)
619
620        // Entity 1 was first written by index 0 (earliest is kept)
621        assert_eq!(
622            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
623            Some(0)
624        );
625
626        // Entity 2 was written by index 1
627        assert_eq!(
628            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
629            Some(1)
630        );
631
632        // Index 0 has no earlier writers
633        assert_eq!(
634            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
635            None
636        );
637    }
638
639    #[test]
640    fn test_batch_request() {
641        let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
642        assert_eq!(batch.len(), 3);
643        assert!(!batch.is_empty());
644
645        let empty_batch = BatchRequest::new(Vec::<String>::new());
646        assert!(empty_batch.is_empty());
647    }
648
649    #[test]
650    fn test_execution_result() {
651        let mut result = ExecutionResult::new(5);
652
653        assert_eq!(result.batch_index, 5);
654        assert_eq!(result.status, ExecutionStatus::Success);
655        assert!(result.read_set.is_empty());
656        assert!(result.write_set.is_empty());
657
658        result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
659        result.record_write(EntityId::Node(NodeId::new(2)));
660
661        assert_eq!(result.read_set.len(), 1);
662        assert_eq!(result.write_set.len(), 1);
663
664        result.mark_needs_revalidation();
665        assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
666
667        result.mark_reexecuted();
668        assert_eq!(result.status, ExecutionStatus::Reexecuted);
669        assert_eq!(result.reexecution_count, 1);
670    }
671}