Skip to main content

grafeo_engine/transaction/
parallel.rs

1//! Block-STM inspired parallel transaction execution.
2//!
3//! Executes a batch of operations in parallel optimistically, validates for conflicts,
4//! and re-executes conflicting transactions. This is inspired by Aptos Block-STM and
5//! provides significant speedup for batch-heavy workloads like ETL imports.
6//!
7//! # Algorithm
8//!
9//! The execution follows four phases:
10//!
11//! 1. **Optimistic Execution**: Execute all operations in parallel without locking.
12//!    Each operation tracks its read and write sets.
13//!
14//! 2. **Validation**: Check if any read was invalidated by a concurrent write from
15//!    an earlier transaction in the batch.
16//!
17//! 3. **Re-execution**: Re-execute invalidated transactions with knowledge of
18//!    their dependencies.
19//!
20//! 4. **Commit**: Apply all writes in transaction order for determinism.
21//!
22//! # Performance
23//!
24//! | Conflict Rate | Expected Speedup |
25//! |---------------|------------------|
26//! | 0% | 3-4x on 4 cores |
27//! | <10% | 2-3x |
28//! | >30% | Falls back to sequential |
29//!
30//! # Example
31//!
32//! ```ignore
33//! use grafeo_engine::transaction::parallel::{ParallelExecutor, BatchRequest};
34//!
35//! let executor = ParallelExecutor::new(4); // 4 workers
36//!
37//! let batch = BatchRequest::new(vec![
38//!     "CREATE (n:Person {id: 1})",
39//!     "CREATE (n:Person {id: 2})",
40//!     "CREATE (n:Person {id: 3})",
41//! ]);
42//!
43//! let result = executor.execute_batch(batch)?;
44//! assert!(result.all_succeeded());
45//! ```
46
47use std::collections::HashSet;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicUsize, Ordering};
50
51use grafeo_common::types::EpochId;
52use grafeo_common::utils::hash::FxHashMap;
53use parking_lot::{Mutex, RwLock};
54use rayon::prelude::*;
55
56use super::EntityId;
57
58/// Maximum number of re-execution attempts before giving up.
59const MAX_REEXECUTION_ROUNDS: usize = 10;
60
61/// Minimum batch size to consider parallel execution (otherwise sequential is faster).
62const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
63
64/// Maximum conflict rate before falling back to sequential execution.
65const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
66
67/// Status of an operation execution.
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
69pub enum ExecutionStatus {
70    /// Execution succeeded and is valid.
71    Success,
72    /// Execution needs re-validation due to potential conflicts.
73    NeedsRevalidation,
74    /// Execution was re-executed after conflict.
75    Reexecuted,
76    /// Execution failed with an error.
77    Failed,
78}
79
80/// Result of executing a single operation in the batch.
81#[derive(Debug)]
82pub struct ExecutionResult {
83    /// Index in the batch (for ordering).
84    pub batch_index: usize,
85    /// Execution status.
86    pub status: ExecutionStatus,
87    /// Entities read during execution (entity_id, epoch_read_at).
88    pub read_set: HashSet<(EntityId, EpochId)>,
89    /// Entities written during execution.
90    pub write_set: HashSet<EntityId>,
91    /// Dependencies on earlier transactions in the batch.
92    pub dependencies: Vec<usize>,
93    /// Number of times this operation was re-executed.
94    pub reexecution_count: usize,
95    /// Error message if failed.
96    pub error: Option<String>,
97}
98
99impl ExecutionResult {
100    /// Creates a new execution result.
101    fn new(batch_index: usize) -> Self {
102        Self {
103            batch_index,
104            status: ExecutionStatus::Success,
105            read_set: HashSet::new(),
106            write_set: HashSet::new(),
107            dependencies: Vec::new(),
108            reexecution_count: 0,
109            error: None,
110        }
111    }
112
113    /// Records a read operation.
114    pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
115        self.read_set.insert((entity, epoch));
116    }
117
118    /// Records a write operation.
119    pub fn record_write(&mut self, entity: EntityId) {
120        self.write_set.insert(entity);
121    }
122
123    /// Marks as needing revalidation.
124    pub fn mark_needs_revalidation(&mut self) {
125        self.status = ExecutionStatus::NeedsRevalidation;
126    }
127
128    /// Marks as reexecuted.
129    pub fn mark_reexecuted(&mut self) {
130        self.status = ExecutionStatus::Reexecuted;
131        self.reexecution_count += 1;
132    }
133
134    /// Marks as failed with an error.
135    pub fn mark_failed(&mut self, error: String) {
136        self.status = ExecutionStatus::Failed;
137        self.error = Some(error);
138    }
139}
140
141/// A batch of operations to execute in parallel.
142#[derive(Debug, Clone)]
143pub struct BatchRequest {
144    /// The operations to execute (as query strings).
145    pub operations: Vec<String>,
146}
147
148impl BatchRequest {
149    /// Creates a new batch request.
150    pub fn new(operations: Vec<impl Into<String>>) -> Self {
151        Self {
152            operations: operations.into_iter().map(Into::into).collect(),
153        }
154    }
155
156    /// Returns the number of operations.
157    #[must_use]
158    pub fn len(&self) -> usize {
159        self.operations.len()
160    }
161
162    /// Returns whether the batch is empty.
163    #[must_use]
164    pub fn is_empty(&self) -> bool {
165        self.operations.is_empty()
166    }
167}
168
169/// Result of executing a batch of operations.
170#[derive(Debug)]
171pub struct BatchResult {
172    /// Results for each operation (in order).
173    pub results: Vec<ExecutionResult>,
174    /// Total number of successful operations.
175    pub success_count: usize,
176    /// Total number of failed operations.
177    pub failure_count: usize,
178    /// Total number of re-executions performed.
179    pub reexecution_count: usize,
180    /// Whether parallel execution was used (vs fallback to sequential).
181    pub parallel_executed: bool,
182}
183
184impl BatchResult {
185    /// Returns true if all operations succeeded.
186    #[must_use]
187    pub fn all_succeeded(&self) -> bool {
188        self.failure_count == 0
189    }
190
191    /// Returns the indices of failed operations.
192    pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
193        self.results
194            .iter()
195            .filter(|r| r.status == ExecutionStatus::Failed)
196            .map(|r| r.batch_index)
197    }
198}
199
200/// Tracks which entities have been written by which batch index.
201#[derive(Debug, Default)]
202struct WriteTracker {
203    /// Entity -> batch index that wrote it.
204    writes: RwLock<FxHashMap<EntityId, usize>>,
205}
206
207impl WriteTracker {
208    /// Records a write by a batch index.
209    /// Keeps track of the earliest writer for conflict detection.
210    fn record_write(&self, entity: EntityId, batch_index: usize) {
211        let mut writes = self.writes.write();
212        writes
213            .entry(entity)
214            .and_modify(|existing| *existing = (*existing).min(batch_index))
215            .or_insert(batch_index);
216    }
217
218    /// Checks if an entity was written by an earlier transaction.
219    fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
220        let writes = self.writes.read();
221        if let Some(&writer) = writes.get(entity)
222            && writer < batch_index
223        {
224            return Some(writer);
225        }
226        None
227    }
228}
229
230/// Block-STM inspired parallel transaction executor.
231///
232/// Executes batches of operations in parallel with optimistic concurrency control.
233pub struct ParallelExecutor {
234    /// Number of worker threads.
235    num_workers: usize,
236    /// Thread pool for parallel execution.
237    pool: rayon::ThreadPool,
238}
239
240impl ParallelExecutor {
241    /// Creates a new parallel executor with the specified number of workers.
242    ///
243    /// # Panics
244    ///
245    /// Panics if num_workers is 0.
246    pub fn new(num_workers: usize) -> Self {
247        assert!(num_workers > 0, "num_workers must be positive");
248
249        let pool = rayon::ThreadPoolBuilder::new()
250            .num_threads(num_workers)
251            .build()
252            .expect("Failed to build thread pool");
253
254        Self { num_workers, pool }
255    }
256
257    /// Creates a parallel executor with the default number of workers (number of CPUs).
258    #[must_use]
259    pub fn default_workers() -> Self {
260        // Use rayon's default parallelism which is based on num_cpus
261        Self::new(rayon::current_num_threads().max(1))
262    }
263
264    /// Returns the number of workers.
265    #[must_use]
266    pub fn num_workers(&self) -> usize {
267        self.num_workers
268    }
269
270    /// Executes a batch of operations in parallel.
271    ///
272    /// Operations are executed optimistically in parallel, validated for conflicts,
273    /// and re-executed as needed. The final result maintains deterministic ordering.
274    pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
275    where
276        F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
277    {
278        let n = batch.len();
279
280        // Handle empty or small batches
281        if n == 0 {
282            return BatchResult {
283                results: Vec::new(),
284                success_count: 0,
285                failure_count: 0,
286                reexecution_count: 0,
287                parallel_executed: false,
288            };
289        }
290
291        if n < MIN_BATCH_SIZE_FOR_PARALLEL {
292            return self.execute_sequential(batch, execute_fn);
293        }
294
295        // Phase 1: Optimistic parallel execution
296        let write_tracker = Arc::new(WriteTracker::default());
297        let results: Vec<Mutex<ExecutionResult>> = (0..n)
298            .map(|i| Mutex::new(ExecutionResult::new(i)))
299            .collect();
300
301        self.pool.install(|| {
302            batch
303                .operations
304                .par_iter()
305                .enumerate()
306                .for_each(|(idx, op)| {
307                    let mut result = results[idx].lock();
308                    execute_fn(idx, op, &mut result);
309
310                    // Record writes to tracker
311                    for entity in &result.write_set {
312                        write_tracker.record_write(*entity, idx);
313                    }
314                });
315        });
316
317        // Phase 2: Validation
318        let mut invalid_indices = Vec::new();
319
320        for (idx, result_mutex) in results.iter().enumerate() {
321            let mut result = result_mutex.lock();
322
323            // Collect entities to check (to avoid borrow issues)
324            let read_entities: Vec<EntityId> =
325                result.read_set.iter().map(|(entity, _)| *entity).collect();
326
327            // Check if any of our reads were invalidated by an earlier write
328            for entity in read_entities {
329                if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
330                    result.mark_needs_revalidation();
331                    result.dependencies.push(writer);
332                }
333            }
334
335            if result.status == ExecutionStatus::NeedsRevalidation {
336                invalid_indices.push(idx);
337            }
338        }
339
340        // Check conflict rate
341        let conflict_rate = invalid_indices.len() as f64 / n as f64;
342        if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
343            // Too many conflicts - fall back to sequential
344            return self.execute_sequential(batch, execute_fn);
345        }
346
347        // Phase 3: Re-execution of conflicting transactions
348        let total_reexecutions = AtomicUsize::new(0);
349
350        for round in 0..MAX_REEXECUTION_ROUNDS {
351            if invalid_indices.is_empty() {
352                break;
353            }
354
355            // Re-execute invalid transactions
356            let still_invalid: Vec<usize> = self.pool.install(|| {
357                invalid_indices
358                    .par_iter()
359                    .filter_map(|&idx| {
360                        let mut result = results[idx].lock();
361
362                        // Clear previous state
363                        result.read_set.clear();
364                        result.write_set.clear();
365                        result.dependencies.clear();
366
367                        // Re-execute
368                        execute_fn(idx, &batch.operations[idx], &mut result);
369                        result.mark_reexecuted();
370                        total_reexecutions.fetch_add(1, Ordering::Relaxed);
371
372                        // Collect entities for re-validation
373                        let read_entities: Vec<EntityId> =
374                            result.read_set.iter().map(|(entity, _)| *entity).collect();
375
376                        // Re-validate
377                        for entity in read_entities {
378                            if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
379                            {
380                                result.mark_needs_revalidation();
381                                result.dependencies.push(writer);
382                                return Some(idx);
383                            }
384                        }
385
386                        result.status = ExecutionStatus::Success;
387                        None
388                    })
389                    .collect()
390            });
391
392            invalid_indices = still_invalid;
393
394            if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
395                // Max rounds reached, mark remaining as failed
396                for idx in &invalid_indices {
397                    let mut result = results[*idx].lock();
398                    result.mark_failed("Max re-execution rounds reached".to_string());
399                }
400            }
401        }
402
403        // Phase 4: Collect results
404        let mut final_results: Vec<ExecutionResult> =
405            results.into_iter().map(|m| m.into_inner()).collect();
406
407        // Sort by batch index to maintain order
408        final_results.sort_by_key(|r| r.batch_index);
409
410        let success_count = final_results
411            .iter()
412            .filter(|r| r.status != ExecutionStatus::Failed)
413            .count();
414
415        BatchResult {
416            failure_count: n - success_count,
417            success_count,
418            reexecution_count: total_reexecutions.load(Ordering::Relaxed),
419            parallel_executed: true,
420            results: final_results,
421        }
422    }
423
424    /// Executes a batch sequentially (fallback for high conflict scenarios).
425    fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
426    where
427        F: Fn(usize, &str, &mut ExecutionResult),
428    {
429        let mut results = Vec::with_capacity(batch.len());
430
431        for (idx, op) in batch.operations.iter().enumerate() {
432            let mut result = ExecutionResult::new(idx);
433            execute_fn(idx, op, &mut result);
434            results.push(result);
435        }
436
437        let success_count = results
438            .iter()
439            .filter(|r| r.status != ExecutionStatus::Failed)
440            .count();
441
442        BatchResult {
443            failure_count: results.len() - success_count,
444            success_count,
445            reexecution_count: 0,
446            parallel_executed: false,
447            results,
448        }
449    }
450}
451
452impl Default for ParallelExecutor {
453    fn default() -> Self {
454        Self::default_workers()
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use grafeo_common::types::NodeId;
462    use std::sync::atomic::AtomicU64;
463    use std::thread;
464    use std::time::Duration;
465
466    #[test]
467    fn test_empty_batch() {
468        let executor = ParallelExecutor::new(4);
469        let batch = BatchRequest::new(Vec::<String>::new());
470
471        let result = executor.execute_batch(batch, |_, _, _| {});
472
473        assert!(result.all_succeeded());
474        assert_eq!(result.results.len(), 0);
475    }
476
477    #[test]
478    fn test_single_operation() {
479        let executor = ParallelExecutor::new(4);
480        let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
481
482        let result = executor.execute_batch(batch, |_, _, result| {
483            result.record_write(EntityId::Node(NodeId::new(1)));
484        });
485
486        assert!(result.all_succeeded());
487        assert_eq!(result.results.len(), 1);
488        // Small batch uses sequential execution
489        assert!(!result.parallel_executed);
490    }
491
492    #[test]
493    fn test_independent_operations() {
494        let executor = ParallelExecutor::new(4);
495        let batch = BatchRequest::new(vec![
496            "CREATE (n1:Test {id: 1})",
497            "CREATE (n2:Test {id: 2})",
498            "CREATE (n3:Test {id: 3})",
499            "CREATE (n4:Test {id: 4})",
500            "CREATE (n5:Test {id: 5})",
501        ]);
502
503        let counter = AtomicU64::new(0);
504
505        let result = executor.execute_batch(batch, |idx, _, result| {
506            // Each operation writes to a different entity
507            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
508            counter.fetch_add(1, Ordering::Relaxed);
509        });
510
511        assert!(result.all_succeeded());
512        assert_eq!(result.results.len(), 5);
513        assert_eq!(result.reexecution_count, 0); // No conflicts
514        assert!(result.parallel_executed);
515        assert_eq!(counter.load(Ordering::Relaxed), 5);
516    }
517
518    #[test]
519    fn test_conflicting_operations() {
520        let executor = ParallelExecutor::new(4);
521        let batch = BatchRequest::new(vec![
522            "UPDATE (n:Test) SET n.value = 1",
523            "UPDATE (n:Test) SET n.value = 2",
524            "UPDATE (n:Test) SET n.value = 3",
525            "UPDATE (n:Test) SET n.value = 4",
526            "UPDATE (n:Test) SET n.value = 5",
527        ]);
528
529        let shared_entity = EntityId::Node(NodeId::new(100));
530
531        let result = executor.execute_batch(batch, |_idx, _, result| {
532            // All operations read and write the same entity
533            result.record_read(shared_entity, EpochId::new(0));
534            result.record_write(shared_entity);
535
536            // Simulate some work
537            thread::sleep(Duration::from_micros(10));
538        });
539
540        assert!(result.all_succeeded());
541        assert_eq!(result.results.len(), 5);
542        // Some operations should have been re-executed due to conflicts
543        assert!(result.reexecution_count > 0 || !result.parallel_executed);
544    }
545
546    #[test]
547    fn test_partial_conflicts() {
548        let executor = ParallelExecutor::new(4);
549        let batch = BatchRequest::new(vec![
550            "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
551        ]);
552
553        // All operations write to independent entities (no conflicts)
554        // This tests parallel execution with no read-write conflicts
555
556        let result = executor.execute_batch(batch, |idx, _, result| {
557            // Each operation writes to its own entity (no conflicts)
558            let entity = EntityId::Node(NodeId::new(idx as u64));
559            result.record_write(entity);
560        });
561
562        assert!(result.all_succeeded());
563        assert_eq!(result.results.len(), 10);
564        // Should be parallel since no conflicts
565        assert!(result.parallel_executed);
566        assert_eq!(result.reexecution_count, 0);
567    }
568
569    #[test]
570    fn test_execution_order_preserved() {
571        let executor = ParallelExecutor::new(4);
572        let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
573
574        let result = executor.execute_batch(batch, |idx, _, result| {
575            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
576        });
577
578        // Verify results are in order
579        for (i, r) in result.results.iter().enumerate() {
580            assert_eq!(
581                r.batch_index, i,
582                "Result at position {} has wrong batch_index",
583                i
584            );
585        }
586    }
587
588    #[test]
589    fn test_failure_handling() {
590        let executor = ParallelExecutor::new(4);
591        let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
592
593        let result = executor.execute_batch(batch, |idx, op, result| {
594            if op == "fail" {
595                result.mark_failed("Intentional failure".to_string());
596            } else {
597                result.record_write(EntityId::Node(NodeId::new(idx as u64)));
598            }
599        });
600
601        assert!(!result.all_succeeded());
602        assert_eq!(result.failure_count, 1);
603        assert_eq!(result.success_count, 4);
604
605        let failed: Vec<usize> = result.failed_indices().collect();
606        assert_eq!(failed, vec![1]);
607    }
608
609    #[test]
610    fn test_write_tracker() {
611        let tracker = WriteTracker::default();
612
613        tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
614        tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
615        tracker.record_write(EntityId::Node(NodeId::new(1)), 2); // Keeps earliest (0)
616
617        // Entity 1 was first written by index 0 (earliest is kept)
618        assert_eq!(
619            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
620            Some(0)
621        );
622
623        // Entity 2 was written by index 1
624        assert_eq!(
625            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
626            Some(1)
627        );
628
629        // Index 0 has no earlier writers
630        assert_eq!(
631            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
632            None
633        );
634    }
635
636    #[test]
637    fn test_batch_request() {
638        let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
639        assert_eq!(batch.len(), 3);
640        assert!(!batch.is_empty());
641
642        let empty_batch = BatchRequest::new(Vec::<String>::new());
643        assert!(empty_batch.is_empty());
644    }
645
646    #[test]
647    fn test_execution_result() {
648        let mut result = ExecutionResult::new(5);
649
650        assert_eq!(result.batch_index, 5);
651        assert_eq!(result.status, ExecutionStatus::Success);
652        assert!(result.read_set.is_empty());
653        assert!(result.write_set.is_empty());
654
655        result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
656        result.record_write(EntityId::Node(NodeId::new(2)));
657
658        assert_eq!(result.read_set.len(), 1);
659        assert_eq!(result.write_set.len(), 1);
660
661        result.mark_needs_revalidation();
662        assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
663
664        result.mark_reexecuted();
665        assert_eq!(result.status, ExecutionStatus::Reexecuted);
666        assert_eq!(result.reexecution_count, 1);
667    }
668}