#[cfg(feature = "std")]
use std::sync::{Mutex, MutexGuard, TryLockError};
#[cfg(feature = "std")]
use std::time::{Duration, Instant};
use torsh_core::error::TorshError;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
pub type LockResult<T> = Result<T, TorshError>;
#[cfg(feature = "std")]
#[derive(Debug, Clone)]
pub struct LockTimeout {
pub max_wait: Duration,
pub retry_interval: Duration,
pub max_retries: usize,
}
impl Default for LockTimeout {
fn default() -> Self {
Self {
max_wait: Duration::from_millis(5000), retry_interval: Duration::from_millis(10), max_retries: 500, }
}
}
pub struct SafeLock;
impl SafeLock {
pub fn acquire_with_timeout<'a, T>(
mutex: &'a Mutex<T>,
timeout: Option<LockTimeout>,
operation_name: &str,
) -> LockResult<MutexGuard<'a, T>> {
let timeout = timeout.unwrap_or_default();
let start_time = Instant::now();
let mut attempts = 0;
loop {
match mutex.try_lock() {
Ok(guard) => return Ok(guard),
Err(TryLockError::WouldBlock) => {
attempts += 1;
let elapsed = start_time.elapsed();
if elapsed >= timeout.max_wait || attempts >= timeout.max_retries {
return Err(TorshError::BackendError(format!(
"Failed to acquire lock for '{}' after {} attempts ({:?}). Possible deadlock detected.",
operation_name, attempts, elapsed
)));
}
let backoff = timeout.retry_interval * (1 << (attempts.min(6))); let jitter = Duration::from_millis((attempts % 5) as u64); let sleep_duration = (backoff + jitter).min(Duration::from_millis(100));
std::thread::sleep(sleep_duration);
}
Err(TryLockError::Poisoned(err)) => {
return Err(TorshError::BackendError(format!(
"Lock for '{}' is poisoned: {}",
operation_name, err
)));
}
}
}
}
pub fn acquire_blocking<'a, T>(
mutex: &'a Mutex<T>,
timeout: Option<LockTimeout>,
operation_name: &str,
) -> LockResult<MutexGuard<'a, T>> {
let timeout = timeout.unwrap_or_default();
let start_time = Instant::now();
let mut attempts = 0;
loop {
match mutex.try_lock() {
Ok(guard) => return Ok(guard),
Err(TryLockError::WouldBlock) => {
attempts += 1;
let elapsed = start_time.elapsed();
if elapsed >= timeout.max_wait || attempts >= timeout.max_retries {
return Err(TorshError::BackendError(format!(
"Failed to acquire lock for '{}' after {} attempts ({:?}). Possible deadlock detected.",
operation_name, attempts, elapsed
)));
}
let sleep_duration = timeout.retry_interval.min(Duration::from_millis(50));
std::thread::sleep(sleep_duration);
}
Err(TryLockError::Poisoned(err)) => {
return Err(TorshError::BackendError(format!(
"Lock for '{}' is poisoned: {}",
operation_name, err
)));
}
}
}
}
pub fn try_acquire<'a, T>(
mutex: &'a Mutex<T>,
operation_name: &str,
) -> LockResult<Option<MutexGuard<'a, T>>> {
match mutex.try_lock() {
Ok(guard) => Ok(Some(guard)),
Err(TryLockError::WouldBlock) => Ok(None),
Err(TryLockError::Poisoned(err)) => Err(TorshError::BackendError(format!(
"Lock for '{}' is poisoned: {}",
operation_name, err
))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum LockOrder {
GlobalConfig = 0,
MemoryManager = 1,
TransferCache = 2,
Statistics = 3,
TempBuffers = 4,
DebugTracking = 5,
}
pub struct LockOrderValidator {
current_locks: Vec<LockOrder>,
}
impl LockOrderValidator {
pub fn new() -> Self {
Self {
current_locks: Vec::new(),
}
}
pub fn can_acquire(&self, order: LockOrder) -> bool {
self.current_locks.iter().all(|&existing| existing <= order)
}
pub fn acquired(&mut self, order: LockOrder) -> Result<(), TorshError> {
if !self.can_acquire(order) {
return Err(TorshError::BackendError(format!(
"Lock order violation: trying to acquire {:?} while holding {:?}",
order, self.current_locks
)));
}
self.current_locks.push(order);
self.current_locks.sort();
Ok(())
}
pub fn released(&mut self, order: LockOrder) {
if let Some(pos) = self.current_locks.iter().position(|&x| x == order) {
self.current_locks.remove(pos);
}
}
pub fn current_locks(&self) -> &[LockOrder] {
&self.current_locks
}
}
impl Default for LockOrderValidator {
fn default() -> Self {
Self::new()
}
}
pub struct ScopedLockGuard<'a, T> {
_guard: MutexGuard<'a, T>,
order: LockOrder,
validator: Option<&'a mut LockOrderValidator>,
}
impl<'a, T> ScopedLockGuard<'a, T> {
pub fn new(
guard: MutexGuard<'a, T>,
order: LockOrder,
validator: Option<&'a mut LockOrderValidator>,
) -> Self {
Self {
_guard: guard,
order,
validator,
}
}
}
impl<'a, T> Drop for ScopedLockGuard<'a, T> {
fn drop(&mut self) {
if let Some(ref mut validator) = self.validator {
validator.released(self.order);
}
}
}
impl<'a, T> std::ops::Deref for ScopedLockGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self._guard
}
}
impl<'a, T> std::ops::DerefMut for ScopedLockGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self._guard
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
#[test]
fn test_lock_timeout_configuration() {
let timeout = LockTimeout::default();
assert_eq!(timeout.max_wait, Duration::from_millis(5000));
assert_eq!(timeout.retry_interval, Duration::from_millis(10));
assert_eq!(timeout.max_retries, 500);
}
#[test]
fn test_safe_lock_try_acquire() {
let mutex = Mutex::new(42);
let result = SafeLock::try_acquire(&mutex, "test");
assert!(result.is_ok());
let guard = result.unwrap();
assert!(guard.is_some());
assert_eq!(*guard.unwrap(), 42);
}
#[test]
fn test_safe_lock_contention() {
let mutex = Arc::new(Mutex::new(0));
let mutex_clone = Arc::clone(&mutex);
let handle = thread::spawn(move || {
let _guard = mutex_clone.lock().expect("lock should not be poisoned");
thread::sleep(Duration::from_millis(100));
});
thread::sleep(Duration::from_millis(10));
let timeout = LockTimeout {
max_wait: Duration::from_millis(50),
retry_interval: Duration::from_millis(5),
max_retries: 10,
};
let result = SafeLock::acquire_blocking(&mutex, Some(timeout), "test_contention");
assert!(result.is_err());
handle.join().unwrap();
}
#[test]
fn test_lock_order_validator() {
let mut validator = LockOrderValidator::new();
assert!(validator.can_acquire(LockOrder::GlobalConfig));
validator.acquired(LockOrder::GlobalConfig).unwrap();
assert!(validator.can_acquire(LockOrder::MemoryManager));
validator.acquired(LockOrder::MemoryManager).unwrap();
assert!(validator.can_acquire(LockOrder::Statistics));
validator.acquired(LockOrder::Statistics).unwrap();
assert!(!validator.can_acquire(LockOrder::TransferCache));
assert!(validator.acquired(LockOrder::TransferCache).is_err());
validator.released(LockOrder::Statistics);
validator.released(LockOrder::MemoryManager);
validator.released(LockOrder::GlobalConfig);
assert_eq!(validator.current_locks().len(), 0);
}
#[test]
fn test_lock_order_enum() {
assert!(LockOrder::GlobalConfig < LockOrder::MemoryManager);
assert!(LockOrder::MemoryManager < LockOrder::TransferCache);
assert!(LockOrder::TransferCache < LockOrder::Statistics);
assert!(LockOrder::Statistics < LockOrder::TempBuffers);
assert!(LockOrder::TempBuffers < LockOrder::DebugTracking);
}
}