prio_graph/
prio_graph.rs

1use {
2    crate::{
3        lock::Lock, top_level_id::TopLevelId, AccessKind, GraphNode, ResourceKey, TransactionId,
4    },
5    ahash::AHashMap,
6    std::collections::{hash_map::Entry, BinaryHeap},
7};
8
9/// A directed acyclic graph where edges are only present between nodes if
10/// that node is the next-highest priority node for a particular resource.
11/// Resources can be either read or write locked with write locks being
12/// exclusive.
13/// `Transaction`s are inserted into the graph and then popped in time-priority order.
14/// Between conflicting transactions, the first to be inserted will always have higher priority.
15pub struct PrioGraph<
16    Id: TransactionId,
17    Rk: ResourceKey,
18    Tl: TopLevelId<Id>,
19    Pfn: Fn(&Id, &GraphNode<Id>) -> Tl,
20> {
21    /// Locked resources and which transaction holds them.
22    locks: AHashMap<Rk, Lock<Id>>,
23    /// Graph edges and count of edges into each node. The count is used
24    /// to detect joins.
25    nodes: AHashMap<Id, GraphNode<Id>>,
26    /// Main queue - currently unblocked transactions.
27    main_queue: BinaryHeap<Tl>,
28    /// Priority modification for top-level transactions.
29    top_level_prioritization_fn: Pfn,
30}
31
32impl<
33        Id: TransactionId,
34        Rk: ResourceKey,
35        Tl: TopLevelId<Id>,
36        Pfn: Fn(&Id, &GraphNode<Id>) -> Tl,
37    > PrioGraph<Id, Rk, Tl, Pfn>
38{
39    /// Drains all transactions from the primary queue into a batch.
40    /// Then, for each transaction in the batch, unblock transactions it was blocking.
41    /// If any of those transactions are now unblocked, add them to the main queue.
42    /// Repeat until the main queue is empty.
43    pub fn natural_batches(
44        iter: impl IntoIterator<Item = (Id, impl IntoIterator<Item = (Rk, AccessKind)>)>,
45        top_level_prioritization_fn: Pfn,
46    ) -> Vec<Vec<Id>> {
47        // Insert all transactions into the graph.
48        let mut graph = PrioGraph::new(top_level_prioritization_fn);
49        for (id, tx) in iter.into_iter() {
50            graph.insert_transaction(id, tx);
51        }
52
53        graph.make_natural_batches()
54    }
55
56    /// Create a new priority graph.
57    pub fn new(top_level_prioritization_fn: Pfn) -> Self {
58        Self {
59            locks: AHashMap::new(),
60            nodes: AHashMap::new(),
61            main_queue: BinaryHeap::new(),
62            top_level_prioritization_fn,
63        }
64    }
65
66    /// Clear the graph.
67    pub fn clear(&mut self) {
68        self.main_queue.clear();
69        self.locks.clear();
70        self.nodes.clear();
71    }
72
73    /// Make natural batches from the transactions already inserted into the graph.
74    /// Drains all transactions from the primary queue into a batch.
75    /// Then, for each transaction in the batch, unblock transactions it was blocking.
76    /// If any of those transactions are now unblocked, add them to the main queue.
77    /// Repeat until the main queue is empty.
78    pub fn make_natural_batches(&mut self) -> Vec<Vec<Id>> {
79        // Create natural batches by manually popping without unblocking at each level.
80        let mut batches = vec![];
81
82        while !self.main_queue.is_empty() {
83            let mut batch = Vec::new();
84            while let Some(id) = self.pop() {
85                batch.push(id);
86            }
87
88            for id in &batch {
89                self.unblock(id);
90            }
91
92            batches.push(batch);
93        }
94
95        batches
96    }
97
98    /// Insert a transaction into the graph with the given `Id`.
99    /// `Transaction`s should be inserted in priority order.
100    pub fn insert_transaction(&mut self, id: Id, tx: impl IntoIterator<Item = (Rk, AccessKind)>) {
101        let mut node = GraphNode {
102            active: true,
103            edges: Vec::new(),
104            blocked_by_count: 0,
105        };
106
107        let mut block_tx = |blocking_id: Id| {
108            // If the blocking transaction is the same as the current transaction, do nothing.
109            // This indicates the transaction has multiple accesses to the same resource.
110            if blocking_id == id {
111                return;
112            }
113
114            let Some(blocking_tx_node) = self.nodes.get_mut(&blocking_id) else {
115                panic!("blocking node must exist");
116            };
117
118            // If the node isn't active then we only do chain tracking.
119            if blocking_tx_node.active {
120                // Add edges to the current node.
121                // If it is a unique edge, increment the blocked_by_count for the current node.
122                if blocking_tx_node.try_add_edge(id) {
123                    node.blocked_by_count += 1;
124                }
125            }
126        };
127
128        for (resource_key, access_kind) in tx.into_iter() {
129            match self.locks.entry(resource_key) {
130                Entry::Vacant(entry) => {
131                    entry.insert(match access_kind {
132                        AccessKind::Read => Lock::Read(vec![id], None),
133                        AccessKind::Write => Lock::Write(id),
134                    });
135                }
136                Entry::Occupied(mut entry) => match access_kind {
137                    AccessKind::Read => {
138                        if let Some(blocking_tx) = entry.get_mut().add_read(id) {
139                            block_tx(blocking_tx);
140                        }
141                    }
142                    AccessKind::Write => {
143                        if let Some(blocking_txs) = entry.get_mut().add_write(id) {
144                            for blocking_tx in blocking_txs {
145                                block_tx(blocking_tx);
146                            }
147                        }
148                    }
149                },
150            }
151        }
152
153        self.nodes.insert(id, node);
154
155        // If the node is not blocked, add it to the main queue.
156        if self.nodes.get(&id).unwrap().blocked_by_count == 0 {
157            self.main_queue.push(self.create_top_level_id(id));
158        }
159    }
160
161    /// Returns true if the main queue is empty.
162    pub fn is_empty(&self) -> bool {
163        self.main_queue.is_empty()
164    }
165
166    /// Combination of `pop` and `unblock`.
167    /// Returns None if the queue is empty.
168    /// Returns the `Id` of the popped node, and the set of unblocked `Id`s.
169    pub fn pop_and_unblock(&mut self) -> Option<(Id, Vec<Id>)> {
170        let id = self.pop()?;
171        Some((id, self.unblock(&id)))
172    }
173
174    /// Pop the highest priority node id from the main queue.
175    /// Returns None if the queue is empty.
176    pub fn pop(&mut self) -> Option<Id> {
177        self.main_queue.pop().map(|top_level_id| top_level_id.id())
178    }
179
180    /// This will unblock transactions that were blocked by this transaction.
181    /// Returns the set of `Id`s that were unblocked.
182    ///
183    /// Panics:
184    ///     - Node does not exist.
185    ///     - If the node.blocked_by_count != 0
186    pub fn unblock(&mut self, id: &Id) -> Vec<Id> {
187        // If the node is already removed, do nothing.
188        let Some(node) = self.nodes.get_mut(id) else {
189            panic!("node must exist");
190        };
191        assert_eq!(node.blocked_by_count, 0, "node must be unblocked");
192
193        node.active = false;
194        let edges = core::mem::take(&mut node.edges);
195
196        // Unblock transactions that were blocked by this node.
197        for blocked_tx in edges.iter() {
198            let blocked_tx_node = self
199                .nodes
200                .get_mut(blocked_tx)
201                .expect("blocked_tx must exist");
202            blocked_tx_node.blocked_by_count -= 1;
203
204            if blocked_tx_node.blocked_by_count == 0 {
205                self.main_queue.push(self.create_top_level_id(*blocked_tx));
206            }
207        }
208
209        edges
210    }
211
212    /// Returns whether the given `Id` is at the top level of the graph, i.e. not blocked.
213    /// If the node does not exist, returns false.
214    pub fn is_blocked(&self, id: Id) -> bool {
215        self.nodes
216            .get(&id)
217            .map(|node| node.active && node.blocked_by_count != 0)
218            .unwrap_or_default()
219    }
220
221    fn create_top_level_id(&self, id: Id) -> Tl {
222        (self.top_level_prioritization_fn)(&id, self.nodes.get(&id).unwrap())
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    pub type TxId = u64;
231
232    pub type Account = u64;
233
234    pub struct Tx {
235        read_locked_resources: Vec<Account>,
236        write_locked_resources: Vec<Account>,
237    }
238
239    impl Tx {
240        fn resources(&self) -> impl Iterator<Item = (Account, AccessKind)> + '_ {
241            let write_locked_resources = self
242                .write_locked_resources
243                .iter()
244                .cloned()
245                .map(|rk| (rk, AccessKind::Write));
246            let read_locked_resources = self
247                .read_locked_resources
248                .iter()
249                .cloned()
250                .map(|rk| (rk, AccessKind::Read));
251
252            write_locked_resources.chain(read_locked_resources)
253        }
254    }
255
256    // Take in groups of transactions, where each group is a set of transaction ids,
257    // and the read and write locked resources for each transaction.
258    fn setup_test(
259        transaction_groups: impl IntoIterator<Item = (Vec<TxId>, Vec<Account>, Vec<Account>)>,
260    ) -> (AHashMap<TxId, Tx>, Vec<TxId>) {
261        let mut transaction_lookup_table = AHashMap::new();
262        let mut priority_ordered_ids = vec![];
263        for (ids, read_accounts, write_accounts) in transaction_groups {
264            for id in &ids {
265                priority_ordered_ids.push(*id);
266                transaction_lookup_table.insert(
267                    *id,
268                    Tx {
269                        read_locked_resources: read_accounts.clone(),
270                        write_locked_resources: write_accounts.clone(),
271                    },
272                );
273            }
274        }
275
276        // Sort in reverse priority order - highest priority first.
277        priority_ordered_ids.sort_by(|a, b| b.cmp(a));
278
279        (transaction_lookup_table, priority_ordered_ids)
280    }
281
282    fn create_lookup_iterator<'a>(
283        transaction_lookup_table: &'a AHashMap<TxId, Tx>,
284        reverse_priority_order_ids: &'a [TxId],
285    ) -> impl Iterator<Item = (TxId, impl IntoIterator<Item = (Account, AccessKind)> + 'a)> + 'a
286    {
287        reverse_priority_order_ids.iter().map(|id| {
288            (
289                *id,
290                transaction_lookup_table
291                    .get(id)
292                    .expect("id must exist")
293                    .resources(),
294            )
295        })
296    }
297
298    impl TopLevelId<TxId> for TxId {
299        fn id(&self) -> TxId {
300            *self
301        }
302    }
303
304    fn test_top_level_priority_fn(id: &TxId, _node: &GraphNode<TxId>) -> TxId {
305        *id
306    }
307
308    #[test]
309    fn test_simple_queue() {
310        // Setup:
311        // 3 -> 2 -> 1
312        // batches: [3], [2], [1]
313        let (transaction_lookup_table, transaction_queue) =
314            setup_test([(vec![3, 2, 1], vec![], vec![0])]);
315        let batches = PrioGraph::natural_batches(
316            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
317            test_top_level_priority_fn,
318        );
319        assert_eq!(batches, [[3], [2], [1]]);
320    }
321
322    #[test]
323    fn test_multiple_separate_queues() {
324        // Setup:
325        // 8 -> 4 -> 2 -> 1
326        // 7 -> 5 -> 3
327        // 6
328        // batches: [8, 7, 6], [4, 5], [2, 3], [1]
329        let (transaction_lookup_table, transaction_queue) = setup_test([
330            (vec![8, 4, 2, 1], vec![], vec![0]),
331            (vec![7, 5, 3], vec![], vec![1]),
332            (vec![6], vec![], vec![2]),
333        ]);
334        let batches = PrioGraph::natural_batches(
335            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
336            test_top_level_priority_fn,
337        );
338        assert_eq!(batches, [vec![8, 7, 6], vec![5, 4], vec![3, 2], vec![1]]);
339    }
340
341    #[test]
342    fn test_joining_queues() {
343        // Setup:
344        // 6 -> 3
345        //        \
346        //          -> 2 -> 1
347        //        /
348        // 5 -> 4
349        // batches: [6, 5], [3, 4], [2], [1]
350        let (transaction_lookup_table, transaction_queue) = setup_test([
351            (vec![6, 3], vec![], vec![0]),
352            (vec![5, 4], vec![], vec![1]),
353            (vec![2, 1], vec![], vec![0, 1]),
354        ]);
355        let batches = PrioGraph::natural_batches(
356            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
357            test_top_level_priority_fn,
358        );
359        assert_eq!(batches, [vec![6, 5], vec![4, 3], vec![2], vec![1]]);
360    }
361
362    #[test]
363    fn test_forking_queues() {
364        // Setup:
365        //         -> 2 -> 1
366        //        /
367        // 6 -> 5
368        //        \
369        //         -> 4 -> 3
370        // batches: [6], [5], [4, 2], [3, 1]
371        let (transaction_lookup_table, transaction_queue) = setup_test([
372            (vec![6, 5], vec![], vec![0, 1]),
373            (vec![2, 1], vec![], vec![0]),
374            (vec![4, 3], vec![], vec![1]),
375        ]);
376        let batches = PrioGraph::natural_batches(
377            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
378            test_top_level_priority_fn,
379        );
380        assert_eq!(batches, [vec![6], vec![5], vec![4, 2], vec![3, 1]]);
381    }
382
383    #[test]
384    fn test_forking_and_joining() {
385        // Setup:
386        //         -> 5 ----          -> 2 -> 1
387        //        /          \      /
388        // 9 -> 8              -> 4
389        //        \          /      \
390        //         -> 7 -> 6          -> 3
391        // batches: [9], [8], [7, 5], [6], [4], [3, 2], [1]
392        let (transaction_lookup_table, transaction_queue) = setup_test([
393            (vec![5, 2, 1], vec![], vec![0]),
394            (vec![9, 8, 4], vec![], vec![0, 1]),
395            (vec![7, 6, 3], vec![], vec![1]),
396        ]);
397        let batches = PrioGraph::natural_batches(
398            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
399            test_top_level_priority_fn,
400        );
401        assert_eq!(
402            batches,
403            [
404                vec![9],
405                vec![8],
406                vec![7, 5],
407                vec![6],
408                vec![4],
409                vec![3, 2],
410                vec![1]
411            ]
412        );
413    }
414
415    #[test]
416    fn test_shared_read_account_no_conflicts() {
417        // Setup:
418        //   - all transactions read-lock account 0.
419        // 8 -> 6 -> 4 -> 2
420        // 7 -> 5 -> 3 -> 1
421        // Batches: [8, 7], [6, 5], [4, 3], [2, 1]
422        let (transaction_lookup_table, transaction_queue) = setup_test([
423            (vec![8, 6, 4, 2], vec![0], vec![1]),
424            (vec![7, 5, 3, 1], vec![0], vec![2]),
425        ]);
426        let batches = PrioGraph::natural_batches(
427            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
428            test_top_level_priority_fn,
429        );
430        assert_eq!(batches, [vec![8, 7], vec![6, 5], vec![4, 3], vec![2, 1]]);
431    }
432
433    #[test]
434    fn test_self_conflicting() {
435        // Setup:
436        //   - transaction read and write locks account 0.
437        // 1
438        // Batches: [1]
439        let (transaction_lookup_table, transaction_queue) =
440            setup_test([(vec![1], vec![0], vec![0])]);
441        let batches = PrioGraph::natural_batches(
442            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
443            test_top_level_priority_fn,
444        );
445        assert_eq!(batches, [vec![1]]);
446    }
447
448    #[test]
449    fn test_self_conflicting_write_priority() {
450        // Setup:
451        //   - transaction 2 read and write locks account 0.
452        // 2 --> 1
453        // Batches: [2, 1]
454        let (transaction_lookup_table, transaction_queue) =
455            setup_test([(vec![2], vec![0], vec![0]), (vec![1], vec![0], vec![])]);
456        let batches = PrioGraph::natural_batches(
457            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
458            test_top_level_priority_fn,
459        );
460        assert_eq!(batches, [vec![2], vec![1]]);
461    }
462
463    #[test]
464    fn test_write_read_read_conflict() {
465        // Setup:
466        //  - W --> R
467        //      \
468        //       -> R
469        // - all transactions using same account 0.
470        // Batches: [3], [2, 1]
471        let (transaction_lookup_table, transaction_queue) =
472            setup_test([(vec![3], vec![], vec![0]), (vec![2, 1], vec![0], vec![])]);
473        let batches = PrioGraph::natural_batches(
474            create_lookup_iterator(&transaction_lookup_table, &transaction_queue),
475            test_top_level_priority_fn,
476        );
477        assert_eq!(batches, [vec![3], vec![2, 1]]);
478    }
479}