Skip to main content

grafeo_engine/transaction/
parallel.rs

1//! Block-STM inspired parallel transaction execution.
2//!
3//! Executes a batch of operations in parallel optimistically, validates for conflicts,
4//! and re-executes conflicting transactions. This is inspired by Aptos Block-STM and
5//! provides significant speedup for batch-heavy workloads like ETL imports.
6//!
7//! # Algorithm
8//!
9//! The execution follows four phases:
10//!
11//! 1. **Optimistic Execution**: Execute all operations in parallel without locking.
12//!    Each operation tracks its read and write sets.
13//!
14//! 2. **Validation**: Check if any read was invalidated by a concurrent write from
15//!    an earlier transaction in the batch.
16//!
17//! 3. **Re-execution**: Re-execute invalidated transactions with knowledge of
18//!    their dependencies.
19//!
20//! 4. **Commit**: Apply all writes in transaction order for determinism.
21//!
22//! # Performance
23//!
24//! | Conflict Rate | Expected Speedup |
25//! |---------------|------------------|
26//! | 0% | 3-4x on 4 cores |
27//! | <10% | 2-3x |
28//! | >30% | Falls back to sequential |
29//!
30//! # Example
31//!
32//! ```no_run
33//! use grafeo_engine::transaction::parallel::{ParallelExecutor, BatchRequest};
34//!
35//! let executor = ParallelExecutor::new(4); // 4 workers
36//!
37//! let batch = BatchRequest::new(vec![
38//!     "CREATE (n:Person {id: 1})",
39//!     "CREATE (n:Person {id: 2})",
40//!     "CREATE (n:Person {id: 3})",
41//! ]);
42//!
43//! let result = executor.execute_batch(batch, |_idx, _op, _result| {
44//!     // execute each operation against the store
45//! });
46//! assert!(result.all_succeeded());
47//! ```
48
49use std::collections::HashSet;
50use std::sync::Arc;
51use std::sync::atomic::{AtomicUsize, Ordering};
52
53use grafeo_common::types::EpochId;
54use grafeo_common::utils::hash::FxHashMap;
55use parking_lot::{Mutex, RwLock};
56use rayon::prelude::*;
57
58use super::EntityId;
59
60/// Maximum number of re-execution attempts before giving up.
61const MAX_REEXECUTION_ROUNDS: usize = 10;
62
63/// Minimum batch size to consider parallel execution (otherwise sequential is faster).
64const MIN_BATCH_SIZE_FOR_PARALLEL: usize = 4;
65
66/// Maximum conflict rate before falling back to sequential execution.
67const MAX_CONFLICT_RATE_FOR_PARALLEL: f64 = 0.3;
68
69/// Status of an operation execution.
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum ExecutionStatus {
72    /// Execution succeeded and is valid.
73    Success,
74    /// Execution needs re-validation due to potential conflicts.
75    NeedsRevalidation,
76    /// Execution was re-executed after conflict.
77    Reexecuted,
78    /// Execution failed with an error.
79    Failed,
80}
81
82/// Result of executing a single operation in the batch.
83#[derive(Debug)]
84pub struct ExecutionResult {
85    /// Index in the batch (for ordering).
86    pub batch_index: usize,
87    /// Execution status.
88    pub status: ExecutionStatus,
89    /// Entities read during execution (entity_id, epoch_read_at).
90    pub read_set: HashSet<(EntityId, EpochId)>,
91    /// Entities written during execution.
92    pub write_set: HashSet<EntityId>,
93    /// Dependencies on earlier transactions in the batch.
94    pub dependencies: Vec<usize>,
95    /// Number of times this operation was re-executed.
96    pub reexecution_count: usize,
97    /// Error message if failed.
98    pub error: Option<String>,
99}
100
101impl ExecutionResult {
102    /// Creates a new execution result.
103    fn new(batch_index: usize) -> Self {
104        Self {
105            batch_index,
106            status: ExecutionStatus::Success,
107            read_set: HashSet::new(),
108            write_set: HashSet::new(),
109            dependencies: Vec::new(),
110            reexecution_count: 0,
111            error: None,
112        }
113    }
114
115    /// Records a read operation.
116    pub fn record_read(&mut self, entity: EntityId, epoch: EpochId) {
117        self.read_set.insert((entity, epoch));
118    }
119
120    /// Records a write operation.
121    pub fn record_write(&mut self, entity: EntityId) {
122        self.write_set.insert(entity);
123    }
124
125    /// Marks as needing revalidation.
126    pub fn mark_needs_revalidation(&mut self) {
127        self.status = ExecutionStatus::NeedsRevalidation;
128    }
129
130    /// Marks as reexecuted.
131    pub fn mark_reexecuted(&mut self) {
132        self.status = ExecutionStatus::Reexecuted;
133        self.reexecution_count += 1;
134    }
135
136    /// Marks as failed with an error.
137    pub fn mark_failed(&mut self, error: String) {
138        self.status = ExecutionStatus::Failed;
139        self.error = Some(error);
140    }
141}
142
143/// A batch of operations to execute in parallel.
144#[derive(Debug, Clone)]
145pub struct BatchRequest {
146    /// The operations to execute (as query strings).
147    pub operations: Vec<String>,
148}
149
150impl BatchRequest {
151    /// Creates a new batch request.
152    pub fn new(operations: Vec<impl Into<String>>) -> Self {
153        Self {
154            operations: operations.into_iter().map(Into::into).collect(),
155        }
156    }
157
158    /// Returns the number of operations.
159    #[must_use]
160    pub fn len(&self) -> usize {
161        self.operations.len()
162    }
163
164    /// Returns whether the batch is empty.
165    #[must_use]
166    pub fn is_empty(&self) -> bool {
167        self.operations.is_empty()
168    }
169}
170
171/// Result of executing a batch of operations.
172#[derive(Debug)]
173pub struct BatchResult {
174    /// Results for each operation (in order).
175    pub results: Vec<ExecutionResult>,
176    /// Total number of successful operations.
177    pub success_count: usize,
178    /// Total number of failed operations.
179    pub failure_count: usize,
180    /// Total number of re-executions performed.
181    pub reexecution_count: usize,
182    /// Whether parallel execution was used (vs fallback to sequential).
183    pub parallel_executed: bool,
184}
185
186impl BatchResult {
187    /// Returns true if all operations succeeded.
188    #[must_use]
189    pub fn all_succeeded(&self) -> bool {
190        self.failure_count == 0
191    }
192
193    /// Returns the indices of failed operations.
194    pub fn failed_indices(&self) -> impl Iterator<Item = usize> + '_ {
195        self.results
196            .iter()
197            .filter(|r| r.status == ExecutionStatus::Failed)
198            .map(|r| r.batch_index)
199    }
200}
201
202/// Tracks which entities have been written by which batch index.
203#[derive(Debug, Default)]
204struct WriteTracker {
205    /// Entity -> batch index that wrote it.
206    writes: RwLock<FxHashMap<EntityId, usize>>,
207}
208
209impl WriteTracker {
210    /// Records a write by a batch index.
211    /// Keeps track of the earliest writer for conflict detection.
212    fn record_write(&self, entity: EntityId, batch_index: usize) {
213        let mut writes = self.writes.write();
214        writes
215            .entry(entity)
216            .and_modify(|existing| *existing = (*existing).min(batch_index))
217            .or_insert(batch_index);
218    }
219
220    /// Checks if an entity was written by an earlier transaction.
221    fn was_written_by_earlier(&self, entity: &EntityId, batch_index: usize) -> Option<usize> {
222        let writes = self.writes.read();
223        if let Some(&writer) = writes.get(entity)
224            && writer < batch_index
225        {
226            return Some(writer);
227        }
228        None
229    }
230}
231
232/// Block-STM inspired parallel transaction executor.
233///
234/// Executes batches of operations in parallel with optimistic concurrency control.
235pub struct ParallelExecutor {
236    /// Number of worker threads.
237    num_workers: usize,
238    /// Thread pool for parallel execution.
239    pool: rayon::ThreadPool,
240}
241
242impl ParallelExecutor {
243    /// Creates a new parallel executor with the specified number of workers.
244    ///
245    /// # Panics
246    ///
247    /// Panics if num_workers is 0.
248    pub fn new(num_workers: usize) -> Self {
249        assert!(num_workers > 0, "num_workers must be positive");
250
251        let pool = rayon::ThreadPoolBuilder::new()
252            .num_threads(num_workers)
253            .build()
254            .expect("Failed to build thread pool");
255
256        Self { num_workers, pool }
257    }
258
259    /// Creates a parallel executor with the default number of workers (number of CPUs).
260    #[must_use]
261    pub fn default_workers() -> Self {
262        // Use rayon's default parallelism which is based on num_cpus
263        Self::new(rayon::current_num_threads().max(1))
264    }
265
266    /// Returns the number of workers.
267    #[must_use]
268    pub fn num_workers(&self) -> usize {
269        self.num_workers
270    }
271
272    /// Executes a batch of operations in parallel.
273    ///
274    /// Operations are executed optimistically in parallel, validated for conflicts,
275    /// and re-executed as needed. The final result maintains deterministic ordering.
276    pub fn execute_batch<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
277    where
278        F: Fn(usize, &str, &mut ExecutionResult) + Sync + Send,
279    {
280        let n = batch.len();
281
282        // Handle empty or small batches
283        if n == 0 {
284            return BatchResult {
285                results: Vec::new(),
286                success_count: 0,
287                failure_count: 0,
288                reexecution_count: 0,
289                parallel_executed: false,
290            };
291        }
292
293        if n < MIN_BATCH_SIZE_FOR_PARALLEL {
294            return self.execute_sequential(batch, execute_fn);
295        }
296
297        // Phase 1: Optimistic parallel execution
298        let write_tracker = Arc::new(WriteTracker::default());
299        let results: Vec<Mutex<ExecutionResult>> = (0..n)
300            .map(|i| Mutex::new(ExecutionResult::new(i)))
301            .collect();
302
303        self.pool.install(|| {
304            batch
305                .operations
306                .par_iter()
307                .enumerate()
308                .for_each(|(idx, op)| {
309                    let mut result = results[idx].lock();
310                    execute_fn(idx, op, &mut result);
311
312                    // Record writes to tracker
313                    for entity in &result.write_set {
314                        write_tracker.record_write(*entity, idx);
315                    }
316                });
317        });
318
319        // Phase 2: Validation
320        let mut invalid_indices = Vec::new();
321
322        for (idx, result_mutex) in results.iter().enumerate() {
323            let mut result = result_mutex.lock();
324
325            // Collect entities to check (to avoid borrow issues)
326            let read_entities: Vec<EntityId> =
327                result.read_set.iter().map(|(entity, _)| *entity).collect();
328
329            // Check if any of our reads were invalidated by an earlier write
330            for entity in read_entities {
331                if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx) {
332                    result.mark_needs_revalidation();
333                    result.dependencies.push(writer);
334                }
335            }
336
337            if result.status == ExecutionStatus::NeedsRevalidation {
338                invalid_indices.push(idx);
339            }
340        }
341
342        // Check conflict rate
343        let conflict_rate = invalid_indices.len() as f64 / n as f64;
344        if conflict_rate > MAX_CONFLICT_RATE_FOR_PARALLEL {
345            // Too many conflicts - fall back to sequential
346            return self.execute_sequential(batch, execute_fn);
347        }
348
349        // Phase 3: Re-execution of conflicting transactions
350        let total_reexecutions = AtomicUsize::new(0);
351
352        for round in 0..MAX_REEXECUTION_ROUNDS {
353            if invalid_indices.is_empty() {
354                break;
355            }
356
357            // Re-execute invalid transactions
358            let still_invalid: Vec<usize> = self.pool.install(|| {
359                invalid_indices
360                    .par_iter()
361                    .filter_map(|&idx| {
362                        let mut result = results[idx].lock();
363
364                        // Clear previous state
365                        result.read_set.clear();
366                        result.write_set.clear();
367                        result.dependencies.clear();
368
369                        // Re-execute
370                        execute_fn(idx, &batch.operations[idx], &mut result);
371                        result.mark_reexecuted();
372                        total_reexecutions.fetch_add(1, Ordering::Relaxed);
373
374                        // Collect entities for re-validation
375                        let read_entities: Vec<EntityId> =
376                            result.read_set.iter().map(|(entity, _)| *entity).collect();
377
378                        // Re-validate
379                        for entity in read_entities {
380                            if let Some(writer) = write_tracker.was_written_by_earlier(&entity, idx)
381                            {
382                                result.mark_needs_revalidation();
383                                result.dependencies.push(writer);
384                                return Some(idx);
385                            }
386                        }
387
388                        result.status = ExecutionStatus::Success;
389                        None
390                    })
391                    .collect()
392            });
393
394            invalid_indices = still_invalid;
395
396            if round == MAX_REEXECUTION_ROUNDS - 1 && !invalid_indices.is_empty() {
397                // Max rounds reached, mark remaining as failed
398                for idx in &invalid_indices {
399                    let mut result = results[*idx].lock();
400                    result.mark_failed("Max re-execution rounds reached".to_string());
401                }
402            }
403        }
404
405        // Phase 4: Collect results
406        let mut final_results: Vec<ExecutionResult> =
407            results.into_iter().map(|m| m.into_inner()).collect();
408
409        // Sort by batch index to maintain order
410        final_results.sort_by_key(|r| r.batch_index);
411
412        let success_count = final_results
413            .iter()
414            .filter(|r| r.status != ExecutionStatus::Failed)
415            .count();
416
417        BatchResult {
418            failure_count: n - success_count,
419            success_count,
420            reexecution_count: total_reexecutions.load(Ordering::Relaxed),
421            parallel_executed: true,
422            results: final_results,
423        }
424    }
425
426    /// Executes a batch sequentially (fallback for high conflict scenarios).
427    fn execute_sequential<F>(&self, batch: BatchRequest, execute_fn: F) -> BatchResult
428    where
429        F: Fn(usize, &str, &mut ExecutionResult),
430    {
431        let mut results = Vec::with_capacity(batch.len());
432
433        for (idx, op) in batch.operations.iter().enumerate() {
434            let mut result = ExecutionResult::new(idx);
435            execute_fn(idx, op, &mut result);
436            results.push(result);
437        }
438
439        let success_count = results
440            .iter()
441            .filter(|r| r.status != ExecutionStatus::Failed)
442            .count();
443
444        BatchResult {
445            failure_count: results.len() - success_count,
446            success_count,
447            reexecution_count: 0,
448            parallel_executed: false,
449            results,
450        }
451    }
452}
453
454impl Default for ParallelExecutor {
455    fn default() -> Self {
456        Self::default_workers()
457    }
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463    use grafeo_common::types::NodeId;
464    use std::sync::atomic::AtomicU64;
465    use std::thread;
466    use std::time::Duration;
467
468    #[test]
469    fn test_empty_batch() {
470        let executor = ParallelExecutor::new(4);
471        let batch = BatchRequest::new(Vec::<String>::new());
472
473        let result = executor.execute_batch(batch, |_, _, _| {});
474
475        assert!(result.all_succeeded());
476        assert_eq!(result.results.len(), 0);
477    }
478
479    #[test]
480    fn test_single_operation() {
481        let executor = ParallelExecutor::new(4);
482        let batch = BatchRequest::new(vec!["CREATE (n:Test)"]);
483
484        let result = executor.execute_batch(batch, |_, _, result| {
485            result.record_write(EntityId::Node(NodeId::new(1)));
486        });
487
488        assert!(result.all_succeeded());
489        assert_eq!(result.results.len(), 1);
490        // Small batch uses sequential execution
491        assert!(!result.parallel_executed);
492    }
493
494    #[test]
495    fn test_independent_operations() {
496        let executor = ParallelExecutor::new(4);
497        let batch = BatchRequest::new(vec![
498            "CREATE (n1:Test {id: 1})",
499            "CREATE (n2:Test {id: 2})",
500            "CREATE (n3:Test {id: 3})",
501            "CREATE (n4:Test {id: 4})",
502            "CREATE (n5:Test {id: 5})",
503        ]);
504
505        let counter = AtomicU64::new(0);
506
507        let result = executor.execute_batch(batch, |idx, _, result| {
508            // Each operation writes to a different entity
509            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
510            counter.fetch_add(1, Ordering::Relaxed);
511        });
512
513        assert!(result.all_succeeded());
514        assert_eq!(result.results.len(), 5);
515        assert_eq!(result.reexecution_count, 0); // No conflicts
516        assert!(result.parallel_executed);
517        assert_eq!(counter.load(Ordering::Relaxed), 5);
518    }
519
520    #[test]
521    fn test_conflicting_operations() {
522        let executor = ParallelExecutor::new(4);
523        let batch = BatchRequest::new(vec![
524            "UPDATE (n:Test) SET n.value = 1",
525            "UPDATE (n:Test) SET n.value = 2",
526            "UPDATE (n:Test) SET n.value = 3",
527            "UPDATE (n:Test) SET n.value = 4",
528            "UPDATE (n:Test) SET n.value = 5",
529        ]);
530
531        let shared_entity = EntityId::Node(NodeId::new(100));
532
533        let result = executor.execute_batch(batch, |_idx, _, result| {
534            // All operations read and write the same entity
535            result.record_read(shared_entity, EpochId::new(0));
536            result.record_write(shared_entity);
537
538            // Simulate some work
539            thread::sleep(Duration::from_micros(10));
540        });
541
542        assert!(result.all_succeeded());
543        assert_eq!(result.results.len(), 5);
544        // Some operations should have been re-executed due to conflicts
545        assert!(result.reexecution_count > 0 || !result.parallel_executed);
546    }
547
548    #[test]
549    fn test_partial_conflicts() {
550        let executor = ParallelExecutor::new(4);
551        let batch = BatchRequest::new(vec![
552            "op1", "op2", "op3", "op4", "op5", "op6", "op7", "op8", "op9", "op10",
553        ]);
554
555        // All operations write to independent entities (no conflicts)
556        // This tests parallel execution with no read-write conflicts
557
558        let result = executor.execute_batch(batch, |idx, _, result| {
559            // Each operation writes to its own entity (no conflicts)
560            let entity = EntityId::Node(NodeId::new(idx as u64));
561            result.record_write(entity);
562        });
563
564        assert!(result.all_succeeded());
565        assert_eq!(result.results.len(), 10);
566        // Should be parallel since no conflicts
567        assert!(result.parallel_executed);
568        assert_eq!(result.reexecution_count, 0);
569    }
570
571    #[test]
572    fn test_execution_order_preserved() {
573        let executor = ParallelExecutor::new(4);
574        let batch = BatchRequest::new(vec!["op0", "op1", "op2", "op3", "op4", "op5", "op6", "op7"]);
575
576        let result = executor.execute_batch(batch, |idx, _, result| {
577            result.record_write(EntityId::Node(NodeId::new(idx as u64)));
578        });
579
580        // Verify results are in order
581        for (i, r) in result.results.iter().enumerate() {
582            assert_eq!(
583                r.batch_index, i,
584                "Result at position {} has wrong batch_index",
585                i
586            );
587        }
588    }
589
590    #[test]
591    fn test_failure_handling() {
592        let executor = ParallelExecutor::new(4);
593        let batch = BatchRequest::new(vec!["success1", "fail", "success2", "success3", "success4"]);
594
595        let result = executor.execute_batch(batch, |idx, op, result| {
596            if op == "fail" {
597                result.mark_failed("Intentional failure".to_string());
598            } else {
599                result.record_write(EntityId::Node(NodeId::new(idx as u64)));
600            }
601        });
602
603        assert!(!result.all_succeeded());
604        assert_eq!(result.failure_count, 1);
605        assert_eq!(result.success_count, 4);
606
607        let failed: Vec<usize> = result.failed_indices().collect();
608        assert_eq!(failed, vec![1]);
609    }
610
611    #[test]
612    fn test_write_tracker() {
613        let tracker = WriteTracker::default();
614
615        tracker.record_write(EntityId::Node(NodeId::new(1)), 0);
616        tracker.record_write(EntityId::Node(NodeId::new(2)), 1);
617        tracker.record_write(EntityId::Node(NodeId::new(1)), 2); // Keeps earliest (0)
618
619        // Entity 1 was first written by index 0 (earliest is kept)
620        assert_eq!(
621            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 3),
622            Some(0)
623        );
624
625        // Entity 2 was written by index 1
626        assert_eq!(
627            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(2)), 2),
628            Some(1)
629        );
630
631        // Index 0 has no earlier writers
632        assert_eq!(
633            tracker.was_written_by_earlier(&EntityId::Node(NodeId::new(1)), 0),
634            None
635        );
636    }
637
638    #[test]
639    fn test_batch_request() {
640        let batch = BatchRequest::new(vec!["op1", "op2", "op3"]);
641        assert_eq!(batch.len(), 3);
642        assert!(!batch.is_empty());
643
644        let empty_batch = BatchRequest::new(Vec::<String>::new());
645        assert!(empty_batch.is_empty());
646    }
647
648    #[test]
649    fn test_execution_result() {
650        let mut result = ExecutionResult::new(5);
651
652        assert_eq!(result.batch_index, 5);
653        assert_eq!(result.status, ExecutionStatus::Success);
654        assert!(result.read_set.is_empty());
655        assert!(result.write_set.is_empty());
656
657        result.record_read(EntityId::Node(NodeId::new(1)), EpochId::new(10));
658        result.record_write(EntityId::Node(NodeId::new(2)));
659
660        assert_eq!(result.read_set.len(), 1);
661        assert_eq!(result.write_set.len(), 1);
662
663        result.mark_needs_revalidation();
664        assert_eq!(result.status, ExecutionStatus::NeedsRevalidation);
665
666        result.mark_reexecuted();
667        assert_eq!(result.status, ExecutionStatus::Reexecuted);
668        assert_eq!(result.reexecution_count, 1);
669    }
670}