reifydb_sub_flow/worker/
parallel.rs1use crossbeam_channel::bounded;
5use reifydb_core::{interface::FlowId, log_trace};
6use reifydb_engine::StandardCommandTransaction;
7use reifydb_sub_api::{SchedulerService, TaskContext, task_once};
8
9use super::{UnitOfWork, UnitsOfWork, WorkerPool};
10use crate::{engine::FlowEngine, transaction::FlowTransaction};
11
12pub struct ParallelWorkerPool {
18 scheduler: SchedulerService,
19}
20
21impl ParallelWorkerPool {
22 pub fn new(scheduler: SchedulerService) -> Self {
24 Self {
25 scheduler,
26 }
27 }
28}
29
30impl WorkerPool for ParallelWorkerPool {
31 fn process(
32 &self,
33 txn: &mut StandardCommandTransaction,
34 units: UnitsOfWork,
35 engine: &FlowEngine,
36 ) -> crate::Result<()> {
37 if units.is_empty() {
38 return Ok(());
39 }
40
41 let units_of_work = units.into_inner();
42 let mut txns: Vec<(Vec<UnitOfWork>, FlowTransaction)> = Vec::with_capacity(units_of_work.len());
43
44 for flow_units in units_of_work {
45 if !flow_units.is_empty() {
46 let flow_id = flow_units[0].flow_id;
48 for unit in &flow_units {
49 assert_eq!(
50 unit.flow_id, flow_id,
51 "INVARIANT VIOLATED: Flow units contain mixed flow_ids - expected {:?}, got {:?}. \
52 Each Vec should contain units for exactly one flow.",
53 flow_id, unit.flow_id
54 );
55 }
56
57 let first_version = flow_units[0].version;
58 let flow_txn = FlowTransaction::new(txn, first_version);
59 txns.push((flow_units, flow_txn));
60 }
61 }
62
63 {
66 use std::collections::HashSet;
67 let mut flow_ids_in_tasks = HashSet::new();
68
69 for (flow_units, _) in &txns {
70 let flow_id = flow_units[0].flow_id;
71 assert!(
72 !flow_ids_in_tasks.contains(&flow_id),
73 "INVARIANT VIOLATED: flow_id {:?} will be processed by multiple parallel tasks. \
74 This will cause keyspace overlap as multiple FlowTransactions write to the same keys.",
75 flow_id
76 );
77 flow_ids_in_tasks.insert(flow_id);
78 }
79 }
80
81 let (result_tx, result_rx) = bounded(txns.len());
82
83 for (seq, (flow_units, mut flow_txn)) in txns.into_iter().enumerate() {
84 let result_tx = result_tx.clone();
85 let engine = engine.clone();
86 let flow_id = flow_units[0].flow_id;
87 let versions: Vec<_> = flow_units.iter().map(|u| u.version.0).collect();
88
89 log_trace!("[PARALLEL] SUBMIT seq={} flow={:?} versions={:?}", seq, flow_id, versions);
90
91 let task = task_once!(
92 "flow-processing",
93 High,
94 move |_ctx: &TaskContext| -> reifydb_core::Result<()> {
95 process(&mut flow_txn, flow_units, &engine)?;
96 let _ = result_tx.send(Ok((flow_id, flow_txn)));
97 Ok(())
98 }
99 );
100
101 self.scheduler.once(task)?;
102 }
103
104 drop(result_tx);
106
107 let mut completed: Vec<(FlowId, FlowTransaction)> = Vec::new();
108 let mut recv_seq = 0;
109 while let Ok(result) = result_rx.recv() {
110 match result {
111 Ok((flow_id, flow_txn)) => {
112 log_trace!("[PARALLEL] RECV seq={} flow={:?}", recv_seq, flow_id);
113 recv_seq += 1;
114 completed.push((flow_id, flow_txn));
115 }
116 Err(e) => return e,
117 }
118 }
119
120 completed.sort_by_key(|(flow_id, _)| *flow_id);
122
123 for (seq, (flow_id, mut flow)) in completed.into_iter().enumerate() {
125 log_trace!("[PARALLEL] COMMIT seq={} flow={:?}", seq, flow_id);
126 flow.commit(txn)?;
127 }
128
129 Ok(())
130 }
131
132 fn name(&self) -> &str {
133 "parallel-worker-pool"
134 }
135}
136
137fn process(flow_txn: &mut FlowTransaction, flow_units: Vec<UnitOfWork>, engine: &FlowEngine) -> crate::Result<()> {
139 for unit in flow_units {
141 if flow_txn.version() != unit.version {
143 flow_txn.update_version(unit.version)?;
144 }
145
146 for change in unit.source_changes {
148 engine.process(flow_txn, change)?;
149 }
150 }
151
152 Ok(())
153}