nt_memory/coordination/
locks.rs1use parking_lot::RwLock;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use uuid::Uuid;
8
9pub type LockToken = String;
11
12#[derive(Debug, Clone)]
14struct LockState {
15 token: LockToken,
17
18 acquired_at: Instant,
20
21 ttl: Duration,
23}
24
25impl LockState {
26 fn is_expired(&self) -> bool {
27 self.acquired_at.elapsed() > self.ttl
28 }
29}
30
31pub struct DistributedLock {
33 locks: Arc<RwLock<HashMap<String, LockState>>>,
35
36 default_ttl: Duration,
38}
39
40impl DistributedLock {
41 pub fn new() -> Self {
43 Self {
44 locks: Arc::new(RwLock::new(HashMap::new())),
45 default_ttl: Duration::from_secs(30),
46 }
47 }
48
49 pub fn with_ttl(mut self, ttl: Duration) -> Self {
51 self.default_ttl = ttl;
52 self
53 }
54
55 pub async fn acquire(&self, resource: &str, timeout: Duration) -> anyhow::Result<LockToken> {
57 let start = Instant::now();
58
59 loop {
60 if let Some(token) = self.try_acquire(resource) {
62 return Ok(token);
63 }
64
65 if start.elapsed() >= timeout {
67 return Err(anyhow::anyhow!("Lock acquisition timeout"));
68 }
69
70 tokio::time::sleep(Duration::from_millis(10)).await;
72 }
73 }
74
75 pub fn try_acquire(&self, resource: &str) -> Option<LockToken> {
77 let mut locks = self.locks.write();
78
79 if let Some(state) = locks.get(resource) {
81 if !state.is_expired() {
82 return None; }
84 }
85
86 let token = Uuid::new_v4().to_string();
88
89 locks.insert(
90 resource.to_string(),
91 LockState {
92 token: token.clone(),
93 acquired_at: Instant::now(),
94 ttl: self.default_ttl,
95 },
96 );
97
98 tracing::debug!("Lock acquired: {} -> {}", resource, token);
99
100 Some(token)
101 }
102
103 pub async fn release(&self, token: &str) -> anyhow::Result<()> {
105 let mut locks = self.locks.write();
106
107 locks.retain(|_, state| state.token != token);
109
110 tracing::debug!("Lock released: {}", token);
111
112 Ok(())
113 }
114
115 pub fn is_locked(&self, resource: &str) -> bool {
117 let locks = self.locks.read();
118
119 if let Some(state) = locks.get(resource) {
120 !state.is_expired()
121 } else {
122 false
123 }
124 }
125
126 pub fn extend(&self, token: &str, additional: Duration) -> anyhow::Result<()> {
128 let mut locks = self.locks.write();
129
130 for state in locks.values_mut() {
131 if state.token == token {
132 state.ttl += additional;
133 return Ok(());
134 }
135 }
136
137 Err(anyhow::anyhow!("Lock token not found"))
138 }
139
140 pub fn cleanup_expired(&self) {
142 let mut locks = self.locks.write();
143 locks.retain(|_, state| !state.is_expired());
144 }
145
146 pub fn count(&self) -> usize {
148 self.locks.read().len()
149 }
150}
151
152impl Default for DistributedLock {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161
162 #[tokio::test]
163 async fn test_lock_acquire_release() {
164 let locks = DistributedLock::new();
165
166 let token = locks
168 .acquire("resource1", Duration::from_secs(1))
169 .await
170 .unwrap();
171
172 assert!(locks.is_locked("resource1"));
173 assert_eq!(locks.count(), 1);
174
175 locks.release(&token).await.unwrap();
177
178 assert!(!locks.is_locked("resource1"));
179 assert_eq!(locks.count(), 0);
180 }
181
182 #[tokio::test]
183 async fn test_lock_timeout() {
184 let locks = DistributedLock::new();
185
186 let _token = locks
188 .acquire("resource1", Duration::from_secs(1))
189 .await
190 .unwrap();
191
192 let result = locks
194 .acquire("resource1", Duration::from_millis(100))
195 .await;
196
197 assert!(result.is_err());
198 }
199
200 #[tokio::test]
201 async fn test_lock_expiration() {
202 let locks = DistributedLock::new().with_ttl(Duration::from_millis(100));
203
204 let _token = locks
206 .acquire("resource1", Duration::from_secs(1))
207 .await
208 .unwrap();
209
210 tokio::time::sleep(Duration::from_millis(150)).await;
212
213 let token2 = locks
215 .acquire("resource1", Duration::from_millis(100))
216 .await
217 .unwrap();
218
219 assert!(token2.len() > 0);
220 }
221
222 #[tokio::test]
223 async fn test_multiple_resources() {
224 let locks = DistributedLock::new();
225
226 let token1 = locks
228 .acquire("resource1", Duration::from_secs(1))
229 .await
230 .unwrap();
231 let token2 = locks
232 .acquire("resource2", Duration::from_secs(1))
233 .await
234 .unwrap();
235
236 assert_eq!(locks.count(), 2);
237
238 locks.release(&token1).await.unwrap();
240 assert_eq!(locks.count(), 1);
241
242 locks.release(&token2).await.unwrap();
244 assert_eq!(locks.count(), 0);
245 }
246
247 #[tokio::test]
248 async fn test_lock_extension() {
249 let locks = DistributedLock::new().with_ttl(Duration::from_millis(100));
250
251 let token = locks
252 .acquire("resource1", Duration::from_secs(1))
253 .await
254 .unwrap();
255
256 locks.extend(&token, Duration::from_secs(10)).unwrap();
258
259 tokio::time::sleep(Duration::from_millis(150)).await;
261
262 assert!(locks.is_locked("resource1"));
264 }
265}