use crate::core::types::TrackingResult;
use std::sync::{Mutex, RwLock};
pub trait SafeLock<T> {
fn safe_lock(&self) -> TrackingResult<std::sync::MutexGuard<'_, T>>;
fn try_safe_lock(&self) -> TrackingResult<Option<std::sync::MutexGuard<'_, T>>>;
}
impl<T> SafeLock<T> for Mutex<T> {
fn safe_lock(&self) -> TrackingResult<std::sync::MutexGuard<'_, T>> {
self.lock().map_err(|e| {
crate::core::types::TrackingError::LockError(format!(
"Failed to acquire mutex lock: {e}",
))
})
}
fn try_safe_lock(&self) -> TrackingResult<Option<std::sync::MutexGuard<'_, T>>> {
match self.try_lock() {
Ok(guard) => Ok(Some(guard)),
Err(std::sync::TryLockError::WouldBlock) => Ok(None),
Err(std::sync::TryLockError::Poisoned(e)) => Err(
crate::core::types::TrackingError::LockError(format!("Mutex poisoned: {e}")),
),
}
}
}
pub trait SafeRwLock<T> {
fn safe_read(&self) -> TrackingResult<std::sync::RwLockReadGuard<'_, T>>;
fn safe_write(&self) -> TrackingResult<std::sync::RwLockWriteGuard<'_, T>>;
}
impl<T> SafeRwLock<T> for RwLock<T> {
fn safe_read(&self) -> TrackingResult<std::sync::RwLockReadGuard<'_, T>> {
self.read().map_err(|e| {
crate::core::types::TrackingError::LockError(format!(
"Failed to acquire read lock: {e}",
))
})
}
fn safe_write(&self) -> TrackingResult<std::sync::RwLockWriteGuard<'_, T>> {
self.write().map_err(|e| {
crate::core::types::TrackingError::LockError(format!(
"Failed to acquire write lock: {e}",
))
})
}
}
#[macro_export]
macro_rules! safe_lock {
($mutex:expr) => {
$mutex.safe_lock()?
};
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::Duration;
#[test]
fn test_safe_mutex_lock() {
let mutex = Mutex::new(42);
let guard = mutex.safe_lock().unwrap();
assert_eq!(*guard, 42);
}
#[test]
fn test_safe_mutex_try_lock() {
let mutex = Mutex::new(42);
let guard = mutex.try_safe_lock().unwrap();
assert!(guard.is_some());
assert_eq!(*guard.unwrap(), 42);
}
#[test]
fn test_safe_mutex_try_lock_would_block() {
let mutex = Arc::new(Mutex::new(42));
let mutex_clone = Arc::clone(&mutex);
let _guard = mutex.safe_lock().unwrap();
let handle = thread::spawn(move || {
let result = mutex_clone.try_safe_lock().unwrap();
result.is_none()
});
assert!(handle.join().unwrap());
}
#[test]
fn test_safe_rwlock_read() {
let rwlock = RwLock::new(42);
let guard = rwlock.safe_read().unwrap();
assert_eq!(*guard, 42);
}
#[test]
fn test_safe_rwlock_write() {
let rwlock = RwLock::new(42);
let mut guard = rwlock.safe_write().unwrap();
*guard = 100;
drop(guard);
let guard = rwlock.safe_read().unwrap();
assert_eq!(*guard, 100);
}
#[test]
fn test_safe_rwlock_multiple_readers() {
let rwlock = Arc::new(RwLock::new(42));
let mut handles = vec![];
for _ in 0..5 {
let rwlock_clone = Arc::clone(&rwlock);
let handle = thread::spawn(move || {
let guard = rwlock_clone.safe_read().unwrap();
assert_eq!(*guard, 42);
thread::sleep(Duration::from_millis(10));
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_safe_rwlock_writer_exclusivity() {
let rwlock = Arc::new(RwLock::new(0));
let rwlock_clone = Arc::clone(&rwlock);
let handle = thread::spawn(move || {
let mut guard = rwlock_clone.safe_write().unwrap();
*guard = 42;
thread::sleep(Duration::from_millis(50));
*guard = 100;
});
thread::sleep(Duration::from_millis(10));
let guard = rwlock.safe_read().unwrap();
assert_eq!(*guard, 100);
handle.join().unwrap();
}
#[test]
fn test_concurrent_safe_operations() {
let mutex = Arc::new(Mutex::new(0));
let mut handles = vec![];
for _ in 0..10 {
let mutex_clone = Arc::clone(&mutex);
let handle = thread::spawn(move || {
let mut guard = mutex_clone.safe_lock().unwrap();
*guard += 1;
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let guard = mutex.safe_lock().unwrap();
assert_eq!(*guard, 10);
}
#[test]
fn test_safe_lock_macro() {
let mutex = Mutex::new(42);
let result: Result<(), crate::core::types::TrackingError> = (|| {
let guard = crate::safe_lock!(mutex);
assert_eq!(*guard, 42);
Ok(())
})();
assert!(result.is_ok());
}
#[test]
fn test_error_handling() {
let mutex = Mutex::new(42);
let result = mutex.safe_lock();
assert!(result.is_ok());
let try_result = mutex.try_safe_lock();
assert!(try_result.is_ok());
}
}