use crate::backend::strategy::L2BackendStrategy;
use crate::error::Result;
use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::OwnedRwLockWriteGuard;
use tokio::sync::RwLock;
use tracing::{debug, error, warn};
#[derive(Clone, Debug)]
struct LockInfo {
lock_value: String,
expire_at: std::time::Instant,
ttl: u64,
}
impl LockInfo {
fn new(lock_value: String, ttl: u64) -> Self {
Self {
lock_value,
expire_at: std::time::Instant::now() + Duration::from_secs(ttl),
ttl,
}
}
fn is_expired(&self) -> bool {
std::time::Instant::now() >= self.expire_at
}
}
#[derive(Clone)]
pub struct ShardedLockManager {
num_shards: usize,
shards: Vec<Arc<DashMap<String, LockInfo>>>,
local_cache: Vec<Arc<DashMap<String, (String, std::time::Instant)>>>,
l2_backend: Arc<dyn L2BackendStrategy>,
lock_counter: Arc<AtomicU64>,
unlock_counter: Arc<AtomicU64>,
current_shard: Arc<AtomicUsize>,
}
impl ShardedLockManager {
pub fn new(l2_backend: Arc<dyn L2BackendStrategy>, num_shards: usize) -> Self {
let shards: Vec<_> = (0..num_shards)
.map(|_| Arc::new(DashMap::new()))
.collect();
let local_cache: Vec<_> = (0..num_shards)
.map(|_| Arc::new(DashMap::new()))
.collect();
Self {
num_shards,
shards,
local_cache,
l2_backend,
lock_counter: Arc::new(AtomicU64::new(0)),
unlock_counter: Arc::new(AtomicU64::new(0)),
current_shard: Arc::new(AtomicUsize::new(0)),
}
}
pub fn default_num_shards() -> usize {
std::thread::available_parallelism()
.map(|p| p.get() * 2)
.unwrap_or(16)
}
fn shard_for_key(&self, key: &str) -> usize {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
std::hash::Hash::hash(key, &mut hasher);
let hash = std::hash::Hasher::finish(&hasher);
(hash as usize) % self.num_shards
}
pub async fn try_lock(
&self,
key: &str,
ttl: u64,
retry_interval: u64,
max_retries: u32,
) -> Result<Option<String>> {
let shard_idx = self.shard_for_key(key);
let shard = &self.shards[shard_idx];
let cache = &self.local_cache[shard_idx];
let lock_value = format!("{}-{}", uuid::Uuid::new_v4(), std::process::id());
if let Some((cached_value, expire_at)) = cache.get(key) {
if std::time::Instant::now() < *expire_at {
return Ok(Some(cached_value.clone()));
}
}
for attempt in 0..=max_retries {
match self.l2_backend.lock(key, ttl).await {
Ok(Some(value)) => {
self.lock_counter.fetch_add(1, Ordering::Relaxed);
cache.insert(
key.to_string(),
(value.clone(), std::time::Instant::now() + Duration::from_secs(ttl / 2)),
);
shard.insert(key.to_string(), LockInfo::new(value.clone(), ttl));
debug!("Acquired lock for key {} on attempt {}", key, attempt + 1);
return Ok(Some(value));
}
Ok(None) => {
if attempt < max_retries {
tokio::time::sleep(Duration::from_millis(retry_interval)).await;
}
}
Err(e) => {
error!("Failed to acquire lock for key {}: {}", key, e);
return Err(e);
}
}
}
debug!("Failed to acquire lock for key {} after {} attempts", key, max_retries + 1);
Ok(None)
}
pub async fn lock_with_timeout(
&self,
key: &str,
ttl: u64,
timeout_ms: u64,
) -> Result<Option<String>> {
let start = std::time::Instant::now();
let timeout_duration = Duration::from_millis(timeout_ms);
let retry_interval = 100;
let max_retries = (timeout_ms / retry_interval) as u32;
let shard_idx = self.shard_for_key(key);
let shard = &self.shards[shard_idx];
let cache = &self.local_cache[shard_idx];
let lock_value = format!("{}-{}", uuid::Uuid::new_v4(), std::process::id());
if let Some((cached_value, expire_at)) = cache.get(key) {
if std::time::Instant::now() < *expire_at {
return Ok(Some(cached_value.clone()));
}
}
loop {
if start.elapsed() >= timeout_duration {
debug!("Lock acquisition timed out for key {} after {}ms", key, timeout_ms);
return Ok(None);
}
match self.l2_backend.lock(key, ttl).await {
Ok(Some(value)) => {
self.lock_counter.fetch_add(1, Ordering::Relaxed);
cache.insert(
key.to_string(),
(value.clone(), std::time::Instant::now() + Duration::from_secs(ttl / 2)),
);
shard.insert(key.to_string(), LockInfo::new(value.clone(), ttl));
debug!("Acquired lock for key {}", key);
return Ok(Some(value));
}
Ok(None) => {
let remaining = timeout_duration.saturating_sub(start.elapsed());
let sleep_duration = Duration::from_millis(retry_interval).min(remaining);
if sleep_duration.is_zero() {
debug!("Failed to acquire lock for key {} within timeout", key);
return Ok(None);
}
tokio::time::sleep(sleep_duration).await;
}
Err(e) => {
error!("Failed to acquire lock for key {}: {}", key, e);
return Err(e);
}
}
}
}
pub async fn unlock(&self, key: &str, lock_value: &str) -> Result<bool> {
let shard_idx = self.shard_for_key(key);
let shard = &self.shards[shard_idx];
let cache = &self.local_cache[shard_idx];
if let Some(lock_info) = shard.get(key) {
if lock_info.lock_value != lock_value {
debug!("Lock value mismatch for key {}", key);
return Ok(false);
}
}
match self.l2_backend.unlock(key, lock_value).await {
Ok(released) => {
if released {
self.unlock_counter.fetch_add(1, Ordering::Relaxed);
shard.remove(key);
cache.remove(key);
debug!("Released lock for key {}", key);
} else {
debug!("Failed to release lock for key {}: lock not found or value mismatch", key);
}
Ok(released)
}
Err(e) => {
error!("Error releasing lock for key {}: {}", key, e);
Err(e)
}
}
}
pub async fn is_locked(&self, key: &str) -> bool {
let shard_idx = self.shard_for_key(key);
let shard = &self.shards[shard_idx];
if let Some(lock_info) = shard.get(key) {
if lock_info.is_expired() {
shard.remove(key);
return false;
}
return true;
}
false
}
pub async fn extend_lock(&self, key: &str, lock_value: &str, ttl: u64) -> Result<bool> {
let shard_idx = self.shard_for_key(key);
let shard = &self.shards[shard_idx];
let cache = &self.local_cache[shard_idx];
if let Some(lock_info) = shard.get(key) {
if lock_info.lock_value != lock_value {
return Ok(false);
}
if lock_info.is_expired() {
shard.remove(key);
cache.remove(key);
return Ok(false);
}
match self.l2_backend.lock(key, ttl).await {
Ok(Some(new_value)) => {
if new_value == lock_value {
let new_info = LockInfo::new(lock_value.to_string(), ttl);
shard.insert(key.to_string(), new_info);
cache.insert(
key.to_string(),
(lock_value.to_string(), std::time::Instant::now() + Duration::from_secs(ttl / 2)),
);
debug!("Extended lock for key {} with TTL {}s", key, ttl);
return Ok(true);
} else {
shard.remove(key);
cache.remove(key);
return Ok(false);
}
}
Ok(None) => {
shard.remove(key);
cache.remove(key);
return Ok(false);
}
Err(e) => return Err(e),
}
}
Ok(false)
}
pub fn stats(&self) -> LockStats {
let mut total_locks = 0;
let mut expired_locks = 0;
let mut local_cache_size = 0;
for shard in &self.shards {
total_locks += shard.len();
for entry in shard.iter() {
if entry.value().is_expired() {
expired_locks += 1;
}
}
}
for cache in &self.local_cache {
local_cache_size += cache.len();
}
LockStats {
total_locks,
expired_locks,
local_cache_size,
num_shards: self.num_shards,
locks_acquired: self.lock_counter.load(Ordering::Relaxed),
locks_released: self.unlock_counter.load(Ordering::Relaxed),
}
}
pub async fn cleanup(&self) {
let now = std::time::Instant::now();
for shard in &self.shards {
let expired_keys: Vec<_> = shard
.iter()
.filter(|e| e.value().expire_at <= now)
.map(|e| e.key().clone())
.collect();
for key in expired_keys {
shard.remove(&key);
}
}
for cache in &self.local_cache {
let expired_keys: Vec<_> = cache
.iter()
.filter(|e| e.value().1 <= now)
.map(|e| e.key().clone())
.collect();
for key in expired_keys {
cache.remove(&key);
}
}
debug!(
"Cleaned up expired locks: {} total locks across {} shards",
self.shards.iter().map(|s| s.len()).sum::<usize>(),
self.num_shards
);
}
}
#[derive(Debug, Clone)]
pub struct LockStats {
pub total_locks: usize,
pub expired_locks: usize,
pub local_cache_size: usize,
pub num_shards: usize,
pub locks_acquired: u64,
pub locks_released: u64,
}
#[derive(Debug)]
pub struct LockGuard {
manager: ShardedLockManager,
key: String,
lock_value: String,
}
impl Drop for LockGuard {
fn drop(&mut self) {
tokio::spawn(async move {
if let Err(e) = self.manager.unlock(&self.key, &self.lock_value).await {
warn!("Failed to auto-release lock for key {}: {}", self.key, e);
}
});
}
}
impl LockGuard {
pub fn new(manager: ShardedLockManager, key: String, lock_value: String) -> Self {
Self {
manager,
key,
lock_value,
}
}
pub fn lock_value(&self) -> &str {
&self.lock_value
}
}