use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::storage::transaction::lock::{LockManager, LockMode, LockResult, TxnId};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum Resource {
Global,
Collection(String),
}
impl Resource {
pub fn key(&self) -> Vec<u8> {
match self {
Resource::Global => b"G/".to_vec(),
Resource::Collection(name) => {
let mut out = Vec::with_capacity(2 + name.len());
out.extend_from_slice(b"C/");
out.extend_from_slice(name.as_bytes());
out
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AcquireError {
Deadlock(Vec<TxnId>),
Timeout,
LockLimitExceeded,
IncompatibleEscalation {
resource: Resource,
held: LockMode,
requested: LockMode,
},
}
impl std::fmt::Display for AcquireError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Deadlock(cycle) => write!(f, "deadlock detected (cycle: {cycle:?})"),
Self::Timeout => f.write_str("lock acquire timed out"),
Self::LockLimitExceeded => f.write_str("per-txn lock limit exceeded"),
Self::IncompatibleEscalation {
resource,
held,
requested,
} => write!(
f,
"cannot escalate lock on {resource:?}: held={held:?} requested={requested:?}"
),
}
}
}
impl std::error::Error for AcquireError {}
static NEXT_TXN_ID: AtomicU64 = AtomicU64::new(1);
pub fn fresh_txn_id() -> TxnId {
NEXT_TXN_ID.fetch_add(1, Ordering::Relaxed)
}
#[derive(Debug, Clone)]
struct Held {
resource: Resource,
mode: LockMode,
}
pub struct LockerGuard {
manager: Arc<LockManager>,
txn_id: TxnId,
held: Vec<Held>,
}
impl LockerGuard {
pub fn new(manager: Arc<LockManager>) -> Self {
Self {
manager,
txn_id: fresh_txn_id(),
held: Vec::with_capacity(2),
}
}
pub fn acquire(&mut self, resource: Resource, mode: LockMode) -> Result<(), AcquireError> {
if let Some(existing) = self.held.iter().find(|h| h.resource == resource) {
let already = existing.mode;
if already == mode {
return Ok(());
}
if !already.can_upgrade_to(&mode) {
return Err(AcquireError::IncompatibleEscalation {
resource,
held: already,
requested: mode,
});
}
}
let key = resource.key();
match self.manager.acquire(self.txn_id, &key, mode) {
LockResult::Granted | LockResult::Upgraded | LockResult::Waiting => {
self.held.push(Held { resource, mode });
Ok(())
}
LockResult::Deadlock(cycle) => Err(AcquireError::Deadlock(cycle)),
LockResult::Timeout => Err(AcquireError::Timeout),
LockResult::LockLimitExceeded => Err(AcquireError::LockLimitExceeded),
LockResult::AlreadyHeld | LockResult::TxnNotFound => {
self.held.push(Held { resource, mode });
Ok(())
}
}
}
pub fn held_count(&self) -> usize {
self.held.len()
}
pub fn txn_id(&self) -> TxnId {
self.txn_id
}
}
impl Drop for LockerGuard {
fn drop(&mut self) {
while let Some(Held { resource, .. }) = self.held.pop() {
let key = resource.key();
self.manager.release(self.txn_id, &key);
}
self.manager.release_all(self.txn_id);
}
}