forest/message_pool/
mpool_locker.rs1use crate::shim::address::Address;
5use ahash::HashMap;
6use parking_lot::Mutex;
7use std::sync::Arc;
8use tokio::sync::OwnedMutexGuard;
9
10pub struct MpoolLocker {
23 inner: Mutex<HashMap<Address, Arc<tokio::sync::Mutex<()>>>>,
24}
25
26impl MpoolLocker {
27 pub fn new() -> Self {
29 Self {
30 inner: Mutex::new(HashMap::default()),
31 }
32 }
33
34 pub async fn take_lock(&self, addr: Address) -> OwnedMutexGuard<()> {
37 let mutex = {
38 let mut map = self.inner.lock();
39 map.retain(|_, v| Arc::strong_count(v) > 1);
40 map.entry(addr)
41 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
42 .clone()
43 };
44 mutex.lock_owned().await
45 }
46}
47
48#[cfg(test)]
49mod tests {
50 use super::*;
51 use tokio::sync::{Barrier, oneshot};
52 use tokio::time::{Duration, timeout};
53
54 #[tokio::test]
55 async fn test_take_lock_serializes_same_address() {
56 let locker = Arc::new(MpoolLocker::new());
57 let addr = Address::new_id(1);
58
59 let (first_acquired_tx, first_acquired_rx) = oneshot::channel();
60 let (release_first_tx, release_first_rx) = oneshot::channel();
61 let (second_acquired_tx, second_acquired_rx) = oneshot::channel();
62
63 let locker2 = locker.clone();
64 let t1 = tokio::spawn(async move {
65 let _guard = locker2.take_lock(addr).await;
66 let _ = first_acquired_tx.send(());
67 let _ = release_first_rx.await;
68 });
69
70 first_acquired_rx.await.unwrap();
72
73 let locker3 = locker.clone();
74 let t2 = tokio::spawn(async move {
75 let _guard = locker3.take_lock(addr).await;
76 let _ = second_acquired_tx.send(());
77 });
78
79 assert!(
81 timeout(Duration::from_millis(50), second_acquired_rx)
82 .await
83 .is_err(),
84 "second task should not acquire the same address lock while first holds it"
85 );
86
87 let _ = release_first_tx.send(());
88 t1.await.unwrap();
89 t2.await.unwrap();
90 }
91
92 #[tokio::test]
93 async fn test_take_lock_allows_different_addresses() {
94 let locker = Arc::new(MpoolLocker::new());
95 let addr_a = Address::new_id(1);
96 let addr_b = Address::new_id(2);
97
98 let acquired_barrier = Arc::new(Barrier::new(2));
99
100 let locker2 = locker.clone();
101 let barrier_a = acquired_barrier.clone();
102 let t1 = tokio::spawn(async move {
103 let _guard = locker2.take_lock(addr_a).await;
104 barrier_a.wait().await;
105 });
106
107 let locker3 = locker.clone();
108 let barrier_b = acquired_barrier.clone();
109 let t2 = tokio::spawn(async move {
110 let _guard = locker3.take_lock(addr_b).await;
111 barrier_b.wait().await;
112 });
113
114 timeout(Duration::from_millis(200), async {
115 t1.await.unwrap();
116 t2.await.unwrap();
117 })
118 .await
119 .expect("different address locks should be acquired in parallel");
120 }
121
122 #[tokio::test]
123 async fn test_take_lock_prunes_idle_entries() {
124 let locker = MpoolLocker::new();
125 let addr_a = Address::new_id(1);
126 let addr_b = Address::new_id(2);
127
128 {
129 let _guard = locker.take_lock(addr_a).await;
130 assert_eq!(locker.inner.lock().len(), 1);
131 }
132 let _guard_b = locker.take_lock(addr_b).await;
134 assert_eq!(
135 locker.inner.lock().len(),
136 1,
137 "idle entry for addr_a should have been pruned"
138 );
139 }
140}