1use {
2 crate::{
3 lock::Lock, top_level_id::TopLevelId, AccessKind, GraphNode, ResourceKey, TransactionId,
4 },
5 ahash::AHashMap,
6 std::collections::{hash_map::Entry, BinaryHeap},
7};
8
9pub struct PrioGraph<
16 Id: TransactionId,
17 Rk: ResourceKey,
18 Tl: TopLevelId<Id>,
19 Pfn: Fn(&Id, &GraphNode<Id>) -> Tl,
20> {
21 locks: AHashMap<Rk, Lock<Id>>,
23 nodes: AHashMap<Id, GraphNode<Id>>,
26 main_queue: BinaryHeap<Tl>,
28 top_level_prioritization_fn: Pfn,
30}
31
32impl<
33 Id: TransactionId,
34 Rk: ResourceKey,
35 Tl: TopLevelId<Id>,
36 Pfn: Fn(&Id, &GraphNode<Id>) -> Tl,
37 > PrioGraph<Id, Rk, Tl, Pfn>
38{
39 pub fn natural_batches(
44 iter: impl IntoIterator<Item = (Id, impl IntoIterator<Item = (Rk, AccessKind)>)>,
45 top_level_prioritization_fn: Pfn,
46 ) -> Vec<Vec<Id>> {
47 let mut graph = PrioGraph::new(top_level_prioritization_fn);
49 for (id, tx) in iter.into_iter() {
50 graph.insert_transaction(id, tx);
51 }
52
53 graph.make_natural_batches()
54 }
55
56 pub fn new(top_level_prioritization_fn: Pfn) -> Self {
58 Self {
59 locks: AHashMap::new(),
60 nodes: AHashMap::new(),
61 main_queue: BinaryHeap::new(),
62 top_level_prioritization_fn,
63 }
64 }
65
66 pub fn clear(&mut self) {
68 self.main_queue.clear();
69 self.locks.clear();
70 self.nodes.clear();
71 }
72
73 pub fn make_natural_batches(&mut self) -> Vec<Vec<Id>> {
79 let mut batches = vec![];
81
82 while !self.main_queue.is_empty() {
83 let mut batch = Vec::new();
84 while let Some(id) = self.pop() {
85 batch.push(id);
86 }
87
88 for id in &batch {
89 self.unblock(id);
90 }
91
92 batches.push(batch);
93 }
94
95 batches
96 }
97
98 pub fn insert_transaction(&mut self, id: Id, tx: impl IntoIterator<Item = (Rk, AccessKind)>) {
101 let mut node = GraphNode {
102 active: true,
103 edges: Vec::new(),
104 blocked_by_count: 0,
105 };
106
107 let mut block_tx = |blocking_id: Id| {
108 if blocking_id == id {
111 return;
112 }
113
114 let Some(blocking_tx_node) = self.nodes.get_mut(&blocking_id) else {
115 panic!("blocking node must exist");
116 };
117
118 if blocking_tx_node.active {
120 if blocking_tx_node.try_add_edge(id) {
123 node.blocked_by_count += 1;
124 }
125 }
126 };
127
128 for (resource_key, access_kind) in tx.into_iter() {
129 match self.locks.entry(resource_key) {
130 Entry::Vacant(entry) => {
131 entry.insert(match access_kind {
132 AccessKind::Read => Lock::Read(vec![id], None),
133 AccessKind::Write => Lock::Write(id),
134 });
135 }
136 Entry::Occupied(mut entry) => match access_kind {
137 AccessKind::Read => {
138 if let Some(blocking_tx) = entry.get_mut().add_read(id) {
139 block_tx(blocking_tx);
140 }
141 }
142 AccessKind::Write => {
143 if let Some(blocking_txs) = entry.get_mut().add_write(id) {
144 for blocking_tx in blocking_txs {
145 block_tx(blocking_tx);
146 }
147 }
148 }
149 },
150 }
151 }
152
153 self.nodes.insert(id, node);
154
155 if self.nodes.get(&id).unwrap().blocked_by_count == 0 {
157 self.main_queue.push(self.create_top_level_id(id));
158 }
159 }
160
161 pub fn is_empty(&self) -> bool {
163 self.main_queue.is_empty()
164 }
165
166 pub fn pop_and_unblock(&mut self) -> Option<(Id, Vec<Id>)> {
170 let id = self.pop()?;
171 Some((id, self.unblock(&id)))
172 }
173
174 pub fn pop(&mut self) -> Option<Id> {
177 self.main_queue.pop().map(|top_level_id| top_level_id.id())
178 }
179
180 pub fn unblock(&mut self, id: &Id) -> Vec<Id> {
187 let Some(node) = self.nodes.get_mut(id) else {
189 panic!("node must exist");
190 };
191 assert_eq!(node.blocked_by_count, 0, "node must be unblocked");
192
193 node.active = false;
194 let edges = core::mem::take(&mut node.edges);
195
196 for blocked_tx in edges.iter() {
198 let blocked_tx_node = self
199 .nodes
200 .get_mut(blocked_tx)
201 .expect("blocked_tx must exist");
202 blocked_tx_node.blocked_by_count -= 1;
203
204 if blocked_tx_node.blocked_by_count == 0 {
205 self.main_queue.push(self.create_top_level_id(*blocked_tx));
206 }
207 }
208
209 edges
210 }
211
212 pub fn is_blocked(&self, id: Id) -> bool {
215 self.nodes
216 .get(&id)
217 .map(|node| node.active && node.blocked_by_count != 0)
218 .unwrap_or_default()
219 }
220
221 fn create_top_level_id(&self, id: Id) -> Tl {
222 (self.top_level_prioritization_fn)(&id, self.nodes.get(&id).unwrap())
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 pub type TxId = u64;
231
232 pub type Account = u64;
233
234 pub struct Tx {
235 read_locked_resources: Vec<Account>,
236 write_locked_resources: Vec<Account>,
237 }
238
239 impl Tx {
240 fn resources(&self) -> impl Iterator<Item = (Account, AccessKind)> + '_ {
241 let write_locked_resources = self
242 .write_locked_resources
243 .iter()
244 .cloned()
245 .map(|rk| (rk, AccessKind::Write));
246 let read_locked_resources = self
247 .read_locked_resources
248 .iter()
249 .cloned()
250 .map(|rk| (rk, AccessKind::Read));
251
252 write_locked_resources.chain(read_locked_resources)
253 }
254 }
255
256 fn setup_test(
259 transaction_groups: impl IntoIterator<Item = (Vec<TxId>, Vec<Account>, Vec<Account>)>,
260 ) -> (AHashMap<TxId, Tx>, Vec<TxId>) {
261 let mut transaction_lookup_table = AHashMap::new();
262 let mut priority_ordered_ids = vec![];
263 for (ids, read_accounts, write_accounts) in transaction_groups {
264 for id in &ids {
265 priority_ordered_ids.push(*id);
266 transaction_lookup_table.insert(
267 *id,
268 Tx {
269 read_locked_resources: read_accounts.clone(),
270 write_locked_resources: write_accounts.clone(),
271 },
272 );
273 }
274 }
275
276 priority_ordered_ids.sort_by(|a, b| b.cmp(a));
278
279 (transaction_lookup_table, priority_ordered_ids)
280 }
281
282 fn create_lookup_iterator<'a>(
283 transaction_lookup_table: &'a AHashMap<TxId, Tx>,
284 reverse_priority_order_ids: &'a [TxId],
285 ) -> impl Iterator<Item = (TxId, impl IntoIterator<Item = (Account, AccessKind)> + 'a)> + 'a
286 {
287 reverse_priority_order_ids.iter().map(|id| {
288 (
289 *id,
290 transaction_lookup_table
291 .get(id)
292 .expect("id must exist")
293 .resources(),
294 )
295 })
296 }
297
298 impl TopLevelId<TxId> for TxId {
299 fn id(&self) -> TxId {
300 *self
301 }
302 }
303
304 fn test_top_level_priority_fn(id: &TxId, _node: &GraphNode<TxId>) -> TxId {
305 *id
306 }
307
308 #[test]
309 fn test_simple_queue() {
310 let (transaction_lookup_table, transaction_queue) =
314 setup_test([(vec![3, 2, 1], vec![], vec![0])]);
315 let batches = PrioGraph::natural_batches(
316 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
317 test_top_level_priority_fn,
318 );
319 assert_eq!(batches, [[3], [2], [1]]);
320 }
321
322 #[test]
323 fn test_multiple_separate_queues() {
324 let (transaction_lookup_table, transaction_queue) = setup_test([
330 (vec![8, 4, 2, 1], vec![], vec![0]),
331 (vec![7, 5, 3], vec![], vec![1]),
332 (vec![6], vec![], vec![2]),
333 ]);
334 let batches = PrioGraph::natural_batches(
335 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
336 test_top_level_priority_fn,
337 );
338 assert_eq!(batches, [vec![8, 7, 6], vec![5, 4], vec![3, 2], vec![1]]);
339 }
340
341 #[test]
342 fn test_joining_queues() {
343 let (transaction_lookup_table, transaction_queue) = setup_test([
351 (vec![6, 3], vec![], vec![0]),
352 (vec![5, 4], vec![], vec![1]),
353 (vec![2, 1], vec![], vec![0, 1]),
354 ]);
355 let batches = PrioGraph::natural_batches(
356 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
357 test_top_level_priority_fn,
358 );
359 assert_eq!(batches, [vec![6, 5], vec![4, 3], vec![2], vec![1]]);
360 }
361
362 #[test]
363 fn test_forking_queues() {
364 let (transaction_lookup_table, transaction_queue) = setup_test([
372 (vec![6, 5], vec![], vec![0, 1]),
373 (vec![2, 1], vec![], vec![0]),
374 (vec![4, 3], vec![], vec![1]),
375 ]);
376 let batches = PrioGraph::natural_batches(
377 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
378 test_top_level_priority_fn,
379 );
380 assert_eq!(batches, [vec![6], vec![5], vec![4, 2], vec![3, 1]]);
381 }
382
383 #[test]
384 fn test_forking_and_joining() {
385 let (transaction_lookup_table, transaction_queue) = setup_test([
393 (vec![5, 2, 1], vec![], vec![0]),
394 (vec![9, 8, 4], vec![], vec![0, 1]),
395 (vec![7, 6, 3], vec![], vec![1]),
396 ]);
397 let batches = PrioGraph::natural_batches(
398 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
399 test_top_level_priority_fn,
400 );
401 assert_eq!(
402 batches,
403 [
404 vec![9],
405 vec![8],
406 vec![7, 5],
407 vec![6],
408 vec![4],
409 vec![3, 2],
410 vec![1]
411 ]
412 );
413 }
414
415 #[test]
416 fn test_shared_read_account_no_conflicts() {
417 let (transaction_lookup_table, transaction_queue) = setup_test([
423 (vec![8, 6, 4, 2], vec![0], vec![1]),
424 (vec![7, 5, 3, 1], vec![0], vec![2]),
425 ]);
426 let batches = PrioGraph::natural_batches(
427 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
428 test_top_level_priority_fn,
429 );
430 assert_eq!(batches, [vec![8, 7], vec![6, 5], vec![4, 3], vec![2, 1]]);
431 }
432
433 #[test]
434 fn test_self_conflicting() {
435 let (transaction_lookup_table, transaction_queue) =
440 setup_test([(vec![1], vec![0], vec![0])]);
441 let batches = PrioGraph::natural_batches(
442 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
443 test_top_level_priority_fn,
444 );
445 assert_eq!(batches, [vec![1]]);
446 }
447
448 #[test]
449 fn test_self_conflicting_write_priority() {
450 let (transaction_lookup_table, transaction_queue) =
455 setup_test([(vec![2], vec![0], vec![0]), (vec![1], vec![0], vec![])]);
456 let batches = PrioGraph::natural_batches(
457 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
458 test_top_level_priority_fn,
459 );
460 assert_eq!(batches, [vec![2], vec![1]]);
461 }
462
463 #[test]
464 fn test_write_read_read_conflict() {
465 let (transaction_lookup_table, transaction_queue) =
472 setup_test([(vec![3], vec![], vec![0]), (vec![2, 1], vec![0], vec![])]);
473 let batches = PrioGraph::natural_batches(
474 create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
475 test_top_level_priority_fn,
476 );
477 assert_eq!(batches, [vec![3], vec![2, 1]]);
478 }
479}