1use crate::grid::protocol::LockMessage;
7use dashmap::DashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::{Duration, Instant};
11use tokio::sync::{broadcast, oneshot};
12
13#[derive(Debug)]
15pub enum DlmError {
16 Timeout,
17 Denied,
18 NetworkError,
19}
20
21pub struct DistributedLockManager {
22 pub node_id: u32,
24 pub coordinator_id: u32,
26
27 lock_in_rx: tokio::sync::Mutex<broadcast::Receiver<LockMessage>>,
29
30 lock_out_tx: broadcast::Sender<LockMessage>,
32
33 pending_reqs: DashMap<u64, oneshot::Sender<(bool, u64)>>,
35
36 granted_locks: DashMap<(String, Vec<u8>), (u32, u64, Instant)>,
38
39 next_fencing_token: AtomicU64,
41 next_req_id: AtomicU64,
43}
44
45impl DistributedLockManager {
46 pub fn new(
47 node_id: u32,
48 coordinator_id: u32,
49 lock_in_rx: broadcast::Receiver<LockMessage>,
50 lock_out_tx: broadcast::Sender<LockMessage>,
51 ) -> Arc<Self> {
52 Arc::new(Self {
53 node_id,
54 coordinator_id,
55 lock_in_rx: tokio::sync::Mutex::new(lock_in_rx),
56 lock_out_tx,
57 pending_reqs: DashMap::new(),
58 granted_locks: DashMap::new(),
59 next_fencing_token: AtomicU64::new(1),
60 next_req_id: AtomicU64::new(1),
61 })
62 }
63
64 pub async fn acquire(
66 &self,
67 table: &str,
68 key: &[u8],
69 lease_ms: u64,
70 timeout: Duration,
71 ) -> Result<u64, DlmError> {
72 let req_id = self.next_req_id.fetch_add(1, Ordering::SeqCst);
73 let (tx, rx) = oneshot::channel();
74 self.pending_reqs.insert(req_id, tx);
75
76 let msg = LockMessage::Acquire {
77 table: table.to_string(),
78 key: key.to_vec(),
79 lease_ms,
80 node_id: self.node_id,
81 req_id,
82 };
83
84 if self.lock_out_tx.send(msg).is_err() {
85 self.pending_reqs.remove(&req_id);
86 return Err(DlmError::NetworkError);
87 }
88
89 match tokio::time::timeout(timeout, rx).await {
90 Ok(Ok((true, fencing_token))) => Ok(fencing_token),
91 Ok(Ok((false, _))) => Err(DlmError::Denied),
92 _ => {
93 self.pending_reqs.remove(&req_id);
94 Err(DlmError::Timeout)
95 }
96 }
97 }
98
99 pub async fn release(&self, table: &str, key: &[u8], fencing_token: u64) {
101 let msg = LockMessage::Release {
102 table: table.to_string(),
103 key: key.to_vec(),
104 fencing_token,
105 node_id: self.node_id,
106 };
107 let _ = self.lock_out_tx.send(msg);
108 }
109
110 fn try_grant_lock(
112 &self,
113 table: String,
114 key: Vec<u8>,
115 node_id: u32,
116 lease_ms: u64,
117 ) -> (bool, u64) {
118 let lock_key = (table, key);
119 let now = Instant::now();
120 let expiration = now + Duration::from_millis(lease_ms);
121
122 let mut granted = false;
123 let mut f_token = 0;
124
125 if let Some(mut existing) = self.granted_locks.get_mut(&lock_key) {
127 if now > existing.2 {
129 f_token = self.next_fencing_token.fetch_add(1, Ordering::SeqCst);
131 existing.0 = node_id;
132 existing.1 = f_token;
133 existing.2 = expiration;
134 granted = true;
135 }
136 } else {
137 f_token = self.next_fencing_token.fetch_add(1, Ordering::SeqCst);
139 self.granted_locks
140 .insert(lock_key, (node_id, f_token, expiration));
141 granted = true;
142 }
143
144 (granted, f_token)
145 }
146
147 pub async fn run_receiver_loop(self: Arc<Self>) {
149 let mut rx = self.lock_in_rx.lock().await;
150
151 while let Ok(msg) = rx.recv().await {
152 match msg {
153 LockMessage::Acquire {
154 table,
155 key,
156 lease_ms,
157 node_id,
158 req_id,
159 } => {
160 if self.node_id == self.coordinator_id {
161 let (granted, fencing_token) =
162 self.try_grant_lock(table, key, node_id, lease_ms);
163 let _ = self.lock_out_tx.send(LockMessage::AcquireAck {
164 req_id,
165 granted,
166 fencing_token,
167 });
168 }
169 }
170
171 LockMessage::AcquireAck {
172 req_id,
173 granted,
174 fencing_token,
175 } => {
176 if let Some((_, sender)) = self.pending_reqs.remove(&req_id) {
177 let _ = sender.send((granted, fencing_token));
178 }
179 }
180
181 LockMessage::Release {
182 table,
183 key,
184 fencing_token,
185 node_id,
186 } => {
187 if self.node_id == self.coordinator_id {
188 let lock_key = (table.clone(), key.clone());
189 if let Some(entry) = self.granted_locks.get(&lock_key)
190 && entry.0 == node_id
191 && entry.1 == fencing_token
192 {
193 drop(entry);
194 self.granted_locks.remove(&lock_key);
195 }
196 }
197 }
198
199 LockMessage::Heartbeat {
200 node_id,
201 fencing_tokens,
202 } => {
203 if self.node_id == self.coordinator_id {
204 let now = Instant::now();
205 let extension = Duration::from_millis(5000);
207
208 for mut entry in self.granted_locks.iter_mut() {
210 let (owner, token, exp) = entry.value_mut();
211 if *owner == node_id && fencing_tokens.contains(token) {
212 *exp = now + extension;
214 }
215 }
216 }
217 }
218 }
219 }
220 }
221}