1use std::collections::HashSet;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::Wallet;
6use crate::persist::models::{RoundStateId, StoredRoundState};
7
8#[derive(Clone)]
9pub(crate) struct RoundStateLockIndex {
10 locked: Arc<parking_lot::Mutex<HashSet<RoundStateId>>>,
11}
12
13impl RoundStateLockIndex {
14 pub fn new() -> Self {
15 Self {
16 locked: Arc::new(parking_lot::Mutex::new(HashSet::new())),
17 }
18 }
19
20 pub(crate) fn try_lock(&self, round_state: RoundStateId) -> Option<RoundStateGuard> {
21 let mut index_lock = self.locked.lock();
22 if index_lock.insert(round_state) {
23 Some(RoundStateGuard { index: self.clone(), round_state })
24 } else {
25 None
26 }
27 }
28
29 pub(crate) async fn wait_lock(&self, round_state: RoundStateId) -> anyhow::Result<RoundStateGuard> {
31 let mut attempts = 0;
32 loop {
33 if let Some(guard) = self.try_lock(round_state) {
34 return Ok(guard);
35 }
36 attempts += 1;
37 if attempts > 100 {
39 bail!("Timed out waiting for lock on round state {}", round_state);
40 }
41 tokio::time::sleep(Duration::from_millis(100)).await;
42 }
43 }
44}
45
46pub struct RoundStateGuard {
47 index: RoundStateLockIndex,
48 round_state: RoundStateId,
49}
50
51impl std::ops::Drop for RoundStateGuard {
52 fn drop(&mut self) {
53 assert!(self.index.locked.lock().remove(&self.round_state),
54 "RoundStateGuard already unlocked",
55 );
56 }
57}
58
59impl Wallet {
60 pub async fn lock_wait_round_state(&self, id: RoundStateId) -> anyhow::Result<Option<StoredRoundState>> {
65 let guard = self.round_state_lock_index.wait_lock(id).await?;
66
67 if let Some(state) = self.db.get_round_state_by_id(id).await? {
68 return Ok(Some(state.lock(guard)));
69 }
70
71 Ok(None)
72 }
73}
74
75#[cfg(test)]
76mod test {
77 use super::*;
78
79 #[test]
80 fn round_state_lock() {
81 let index = RoundStateLockIndex::new();
82
83 let guard = index.try_lock(RoundStateId(1));
85 assert!(guard.is_some(), "first lock should succeed");
86
87 let guard2 = index.try_lock(RoundStateId(1));
89 assert!(guard2.is_none(), "second lock should fail");
90
91 drop(guard);
93 assert!(index.try_lock(RoundStateId(1)).is_some(), "lock should succeed after drop");
94
95 let guard3 = index.try_lock(RoundStateId(2));
97 assert!(guard3.is_some(), "second lock should succeed");
98
99 let cloned = index.clone();
101 let id = RoundStateId(1);
102 let guard4 = cloned.try_lock(id);
103 assert!(guard4.is_some(), "cloned index should share lock state");
104 assert!(index.try_lock(id).is_none(), "original should prevent lock");
105
106 drop(guard4);
108 let guard5 = index.try_lock(id);
109 assert!(guard5.is_some(), "lock should succeed on original index after drop");
110 assert!(cloned.try_lock(id).is_none(), "cloned index should prevent lock");
111 }
112
113 #[cfg(not(target_arch = "wasm32"))]
114 #[tokio::test]
115 async fn lock_wait_succeeds_after_guard_dropped() {
116 let index = RoundStateLockIndex::new();
117 let guard = index.try_lock(RoundStateId(1)).unwrap();
118
119 let cloned = index.clone();
120 let handle = tokio::spawn(async move {
121 cloned.wait_lock(RoundStateId(1)).await
122 });
123
124 tokio::time::sleep(Duration::from_millis(150)).await;
126 drop(guard);
127
128 let result = tokio::time::timeout(Duration::from_secs(2), handle).await;
129 assert!(result.is_ok(), "lock_wait should complete after guard is dropped");
130 }
131}