use crate::grid::protocol::LockMessage;
use dashmap::DashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio::sync::{broadcast, oneshot};
#[derive(Debug)]
pub enum DlmError {
Timeout,
Denied,
NetworkError,
}
pub struct DistributedLockManager {
pub node_id: u32,
pub coordinator_id: u32,
lock_in_rx: tokio::sync::Mutex<broadcast::Receiver<LockMessage>>,
lock_out_tx: broadcast::Sender<LockMessage>,
pending_reqs: DashMap<u64, oneshot::Sender<(bool, u64)>>,
granted_locks: DashMap<(String, Vec<u8>), (u32, u64, Instant)>,
next_fencing_token: AtomicU64,
next_req_id: AtomicU64,
}
impl DistributedLockManager {
pub fn new(
node_id: u32,
coordinator_id: u32,
lock_in_rx: broadcast::Receiver<LockMessage>,
lock_out_tx: broadcast::Sender<LockMessage>,
) -> Arc<Self> {
Arc::new(Self {
node_id,
coordinator_id,
lock_in_rx: tokio::sync::Mutex::new(lock_in_rx),
lock_out_tx,
pending_reqs: DashMap::new(),
granted_locks: DashMap::new(),
next_fencing_token: AtomicU64::new(1),
next_req_id: AtomicU64::new(1),
})
}
pub async fn acquire(
&self,
table: &str,
key: &[u8],
lease_ms: u64,
timeout: Duration,
) -> Result<u64, DlmError> {
let req_id = self.next_req_id.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = oneshot::channel();
self.pending_reqs.insert(req_id, tx);
let msg = LockMessage::Acquire {
table: table.to_string(),
key: key.to_vec(),
lease_ms,
node_id: self.node_id,
req_id,
};
if self.lock_out_tx.send(msg).is_err() {
self.pending_reqs.remove(&req_id);
return Err(DlmError::NetworkError);
}
match tokio::time::timeout(timeout, rx).await {
Ok(Ok((true, fencing_token))) => Ok(fencing_token),
Ok(Ok((false, _))) => Err(DlmError::Denied),
_ => {
self.pending_reqs.remove(&req_id);
Err(DlmError::Timeout)
}
}
}
pub async fn release(&self, table: &str, key: &[u8], fencing_token: u64) {
let msg = LockMessage::Release {
table: table.to_string(),
key: key.to_vec(),
fencing_token,
node_id: self.node_id,
};
let _ = self.lock_out_tx.send(msg);
}
fn try_grant_lock(
&self,
table: String,
key: Vec<u8>,
node_id: u32,
lease_ms: u64,
) -> (bool, u64) {
let lock_key = (table, key);
let now = Instant::now();
let expiration = now + Duration::from_millis(lease_ms);
let mut granted = false;
let mut f_token = 0;
if let Some(mut existing) = self.granted_locks.get_mut(&lock_key) {
if now > existing.2 {
f_token = self.next_fencing_token.fetch_add(1, Ordering::SeqCst);
existing.0 = node_id;
existing.1 = f_token;
existing.2 = expiration;
granted = true;
}
} else {
f_token = self.next_fencing_token.fetch_add(1, Ordering::SeqCst);
self.granted_locks
.insert(lock_key, (node_id, f_token, expiration));
granted = true;
}
(granted, f_token)
}
pub async fn run_receiver_loop(self: Arc<Self>) {
let mut rx = self.lock_in_rx.lock().await;
while let Ok(msg) = rx.recv().await {
match msg {
LockMessage::Acquire {
table,
key,
lease_ms,
node_id,
req_id,
} => {
if self.node_id == self.coordinator_id {
let (granted, fencing_token) =
self.try_grant_lock(table, key, node_id, lease_ms);
let _ = self.lock_out_tx.send(LockMessage::AcquireAck {
req_id,
granted,
fencing_token,
});
}
}
LockMessage::AcquireAck {
req_id,
granted,
fencing_token,
} => {
if let Some((_, sender)) = self.pending_reqs.remove(&req_id) {
let _ = sender.send((granted, fencing_token));
}
}
LockMessage::Release {
table,
key,
fencing_token,
node_id,
} => {
if self.node_id == self.coordinator_id {
let lock_key = (table.clone(), key.clone());
if let Some(entry) = self.granted_locks.get(&lock_key)
&& entry.0 == node_id
&& entry.1 == fencing_token
{
drop(entry);
self.granted_locks.remove(&lock_key);
}
}
}
LockMessage::Heartbeat {
node_id,
fencing_tokens,
} => {
if self.node_id == self.coordinator_id {
let now = Instant::now();
let extension = Duration::from_millis(5000);
for mut entry in self.granted_locks.iter_mut() {
let (owner, token, exp) = entry.value_mut();
if *owner == node_id && fencing_tokens.contains(token) {
*exp = now + extension;
}
}
}
}
}
}
}
}