nt_memory/coordination/
locks.rs

1//! Distributed lock implementation for critical sections
2
3use parking_lot::RwLock;
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7use uuid::Uuid;
8
9/// Lock token
10pub type LockToken = String;
11
12/// Lock state
13#[derive(Debug, Clone)]
14struct LockState {
15    /// Lock holder token
16    token: LockToken,
17
18    /// Acquisition time
19    acquired_at: Instant,
20
21    /// TTL
22    ttl: Duration,
23}
24
25impl LockState {
26    fn is_expired(&self) -> bool {
27        self.acquired_at.elapsed() > self.ttl
28    }
29}
30
31/// Distributed lock manager
32pub struct DistributedLock {
33    /// Active locks
34    locks: Arc<RwLock<HashMap<String, LockState>>>,
35
36    /// Default TTL
37    default_ttl: Duration,
38}
39
40impl DistributedLock {
41    /// Create new lock manager
42    pub fn new() -> Self {
43        Self {
44            locks: Arc::new(RwLock::new(HashMap::new())),
45            default_ttl: Duration::from_secs(30),
46        }
47    }
48
49    /// Configure default TTL
50    pub fn with_ttl(mut self, ttl: Duration) -> Self {
51        self.default_ttl = ttl;
52        self
53    }
54
55    /// Acquire lock on resource
56    pub async fn acquire(&self, resource: &str, timeout: Duration) -> anyhow::Result<LockToken> {
57        let start = Instant::now();
58
59        loop {
60            // Try to acquire
61            if let Some(token) = self.try_acquire(resource) {
62                return Ok(token);
63            }
64
65            // Check timeout
66            if start.elapsed() >= timeout {
67                return Err(anyhow::anyhow!("Lock acquisition timeout"));
68            }
69
70            // Wait and retry
71            tokio::time::sleep(Duration::from_millis(10)).await;
72        }
73    }
74
75    /// Try to acquire lock (non-blocking)
76    pub fn try_acquire(&self, resource: &str) -> Option<LockToken> {
77        let mut locks = self.locks.write();
78
79        // Check if lock exists and is valid
80        if let Some(state) = locks.get(resource) {
81            if !state.is_expired() {
82                return None; // Lock held by someone else
83            }
84        }
85
86        // Acquire lock
87        let token = Uuid::new_v4().to_string();
88
89        locks.insert(
90            resource.to_string(),
91            LockState {
92                token: token.clone(),
93                acquired_at: Instant::now(),
94                ttl: self.default_ttl,
95            },
96        );
97
98        tracing::debug!("Lock acquired: {} -> {}", resource, token);
99
100        Some(token)
101    }
102
103    /// Release lock
104    pub async fn release(&self, token: &str) -> anyhow::Result<()> {
105        let mut locks = self.locks.write();
106
107        // Find and remove lock with matching token
108        locks.retain(|_, state| state.token != token);
109
110        tracing::debug!("Lock released: {}", token);
111
112        Ok(())
113    }
114
115    /// Check if resource is locked
116    pub fn is_locked(&self, resource: &str) -> bool {
117        let locks = self.locks.read();
118
119        if let Some(state) = locks.get(resource) {
120            !state.is_expired()
121        } else {
122            false
123        }
124    }
125
126    /// Extend lock TTL
127    pub fn extend(&self, token: &str, additional: Duration) -> anyhow::Result<()> {
128        let mut locks = self.locks.write();
129
130        for state in locks.values_mut() {
131            if state.token == token {
132                state.ttl += additional;
133                return Ok(());
134            }
135        }
136
137        Err(anyhow::anyhow!("Lock token not found"))
138    }
139
140    /// Cleanup expired locks
141    pub fn cleanup_expired(&self) {
142        let mut locks = self.locks.write();
143        locks.retain(|_, state| !state.is_expired());
144    }
145
146    /// Get lock count
147    pub fn count(&self) -> usize {
148        self.locks.read().len()
149    }
150}
151
152impl Default for DistributedLock {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161
162    #[tokio::test]
163    async fn test_lock_acquire_release() {
164        let locks = DistributedLock::new();
165
166        // Acquire
167        let token = locks
168            .acquire("resource1", Duration::from_secs(1))
169            .await
170            .unwrap();
171
172        assert!(locks.is_locked("resource1"));
173        assert_eq!(locks.count(), 1);
174
175        // Release
176        locks.release(&token).await.unwrap();
177
178        assert!(!locks.is_locked("resource1"));
179        assert_eq!(locks.count(), 0);
180    }
181
182    #[tokio::test]
183    async fn test_lock_timeout() {
184        let locks = DistributedLock::new();
185
186        // Acquire lock
187        let _token = locks
188            .acquire("resource1", Duration::from_secs(1))
189            .await
190            .unwrap();
191
192        // Try to acquire again - should timeout
193        let result = locks
194            .acquire("resource1", Duration::from_millis(100))
195            .await;
196
197        assert!(result.is_err());
198    }
199
200    #[tokio::test]
201    async fn test_lock_expiration() {
202        let locks = DistributedLock::new().with_ttl(Duration::from_millis(100));
203
204        // Acquire lock
205        let _token = locks
206            .acquire("resource1", Duration::from_secs(1))
207            .await
208            .unwrap();
209
210        // Wait for expiration
211        tokio::time::sleep(Duration::from_millis(150)).await;
212
213        // Should be able to acquire again
214        let token2 = locks
215            .acquire("resource1", Duration::from_millis(100))
216            .await
217            .unwrap();
218
219        assert!(token2.len() > 0);
220    }
221
222    #[tokio::test]
223    async fn test_multiple_resources() {
224        let locks = DistributedLock::new();
225
226        // Acquire multiple locks
227        let token1 = locks
228            .acquire("resource1", Duration::from_secs(1))
229            .await
230            .unwrap();
231        let token2 = locks
232            .acquire("resource2", Duration::from_secs(1))
233            .await
234            .unwrap();
235
236        assert_eq!(locks.count(), 2);
237
238        // Release one
239        locks.release(&token1).await.unwrap();
240        assert_eq!(locks.count(), 1);
241
242        // Release other
243        locks.release(&token2).await.unwrap();
244        assert_eq!(locks.count(), 0);
245    }
246
247    #[tokio::test]
248    async fn test_lock_extension() {
249        let locks = DistributedLock::new().with_ttl(Duration::from_millis(100));
250
251        let token = locks
252            .acquire("resource1", Duration::from_secs(1))
253            .await
254            .unwrap();
255
256        // Extend
257        locks.extend(&token, Duration::from_secs(10)).unwrap();
258
259        // Wait beyond original TTL
260        tokio::time::sleep(Duration::from_millis(150)).await;
261
262        // Should still be locked
263        assert!(locks.is_locked("resource1"));
264    }
265}