use crate::{PRes, PersyError};
use std::{
collections::{hash_map::Entry, HashMap},
sync::{Arc, Condvar, Mutex, MutexGuard},
time::Duration,
};
struct RwLockVar {
write: bool,
read_count: u32,
cond: Arc<Condvar>,
}
impl RwLockVar {
fn new_write() -> RwLockVar {
RwLockVar {
write: true,
read_count: 0,
cond: Arc::new(Condvar::new()),
}
}
fn new_read() -> RwLockVar {
RwLockVar {
write: false,
read_count: 1,
cond: Arc::new(Condvar::new()),
}
}
fn inc_read(&mut self) {
self.read_count += 1;
}
fn dec_read(&mut self) -> bool {
self.read_count -= 1;
self.read_count == 0
}
}
pub struct RwLockManager<T>
where
T: std::cmp::Eq,
T: std::hash::Hash,
T: Clone,
{
locks: Mutex<HashMap<T, RwLockVar>>,
}
impl<T> Default for RwLockManager<T>
where
T: std::cmp::Eq,
T: std::hash::Hash,
T: Clone,
{
fn default() -> Self {
RwLockManager {
locks: Mutex::new(HashMap::<T, RwLockVar>::new()),
}
}
}
impl<T> RwLockManager<T>
where
T: std::cmp::Eq,
T: std::hash::Hash,
T: Clone,
{
pub fn lock_all_write(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
let mut locked = Vec::new();
for single in to_lock {
let mut lock_manager = self.locks.lock()?;
loop {
let cond = match lock_manager.entry(single.clone()) {
Entry::Occupied(o) => o.get().cond.clone(),
Entry::Vacant(v) => {
let lock = RwLockVar::new_write();
v.insert(lock);
locked.push(single.clone());
break;
}
};
match cond.wait_timeout(lock_manager, timeout) {
Ok((guard, timedout)) => {
lock_manager = guard;
if timedout.timed_out() {
RwLockManager::unlock_all_write_with_guard(&mut lock_manager, &locked);
return Err(PersyError::TransactionTimeout);
}
}
Err(x) => {
self.unlock_all_write(&locked)?;
return Err(PersyError::from(x));
}
}
}
}
Ok(())
}
pub fn lock_all_read(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
let mut locked = Vec::new();
for single in to_lock {
let mut lock_manager = self.locks.lock()?;
loop {
let cond;
match lock_manager.entry(single.clone()) {
Entry::Occupied(mut o) => {
if o.get().write {
cond = o.get().cond.clone();
} else {
o.get_mut().inc_read();
locked.push(single.clone());
break;
}
}
Entry::Vacant(v) => {
v.insert(RwLockVar::new_read());
locked.push(single.clone());
break;
}
};
match cond.wait_timeout(lock_manager, timeout) {
Ok((guard, timedout)) => {
lock_manager = guard;
if timedout.timed_out() {
RwLockManager::unlock_all_read_with_guard(&mut lock_manager, &locked);
return Err(PersyError::TransactionTimeout);
}
}
Err(x) => {
self.unlock_all_read(&locked)?;
return Err(PersyError::from(x));
}
}
}
}
Ok(())
}
fn unlock_all_read_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
for single in to_unlock {
if let Entry::Occupied(mut lock) = lock_manager.entry(single.clone()) {
if lock.get_mut().dec_read() {
let cond = lock.get().cond.clone();
lock.remove();
cond.notify_all();
}
}
}
}
pub fn unlock_all_read(&self, to_unlock: &[T]) -> PRes<()> {
let mut lock_manager = self.locks.lock()?;
RwLockManager::unlock_all_read_with_guard(&mut lock_manager, to_unlock);
Ok(())
}
fn unlock_all_write_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
for single in to_unlock {
if let Some(lock) = lock_manager.remove(single) {
lock.cond.notify_all();
}
}
}
pub fn unlock_all_write(&self, to_unlock: &[T]) -> PRes<()> {
let mut lock_manager = self.locks.lock()?;
RwLockManager::unlock_all_write_with_guard(&mut lock_manager, to_unlock);
Ok(())
}
}
pub struct LockManager<T>
where
T: std::cmp::Eq,
T: std::hash::Hash,
T: Clone,
{
locks: Mutex<HashMap<T, Arc<Condvar>>>,
}
impl<T> Default for LockManager<T>
where
T: std::cmp::Eq,
T: std::hash::Hash,
T: Clone,
{
fn default() -> Self {
LockManager {
locks: Mutex::new(HashMap::<T, Arc<Condvar>>::new()),
}
}
}
impl<T> LockManager<T>
where
T: std::cmp::Eq + std::hash::Hash + Clone,
{
pub fn lock_all(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
let mut locked = Vec::new();
for single in to_lock {
let cond = Arc::new(Condvar::new());
let mut lock_manager = self.locks.lock()?;
loop {
let cond = match lock_manager.entry(single.clone()) {
Entry::Occupied(o) => o.get().clone(),
Entry::Vacant(v) => {
v.insert(cond);
locked.push(single.clone());
break;
}
};
match cond.wait_timeout(lock_manager, timeout) {
Ok((guard, timedout)) => {
lock_manager = guard;
if timedout.timed_out() {
LockManager::unlock_all_with_guard(&mut lock_manager, locked.iter());
return Err(PersyError::TransactionTimeout);
}
}
Err(x) => {
self.unlock_all(&locked)?;
return Err(PersyError::from(x));
}
}
}
}
Ok(())
}
fn unlock_all_with_guard<'a, Q: 'a>(
lock_manager: &mut MutexGuard<HashMap<T, Arc<Condvar>>>,
to_unlock: impl Iterator<Item = &'a Q>,
) where
T: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq,
{
for single in to_unlock {
if let Some(cond) = lock_manager.remove(single) {
cond.notify_all();
}
}
}
#[inline]
pub fn unlock_all<Q>(&self, to_unlock: &[Q]) -> PRes<()>
where
T: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq,
{
self.unlock_all_iter(to_unlock.iter())
}
#[inline]
pub fn unlock_all_iter<'a, Q: 'a>(&self, to_unlock: impl Iterator<Item = &'a Q>) -> PRes<()>
where
T: std::borrow::Borrow<Q>,
Q: std::hash::Hash + Eq,
{
let mut lock_manager = self.locks.lock()?;
LockManager::unlock_all_with_guard(&mut lock_manager, to_unlock);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{LockManager, RwLockManager};
use std::time::Duration;
#[test]
fn test_lock_manager_unlock_if_lock_fail() {
let manager: LockManager<_> = Default::default();
manager.lock_all(&[5], Duration::new(1, 0)).expect("no issue here");
assert!(manager.lock_all(&[1, 5], Duration::new(0, 1)).is_err());
manager.lock_all(&[1], Duration::new(1, 0)).expect("no issue here");
manager.unlock_all(&[1, 5]).expect("no issue here");
}
#[test]
fn test_rw_lock_manager_unlock_if_lock_fail() {
let manager: RwLockManager<_> = Default::default();
manager
.lock_all_write(&[5], Duration::new(1, 0))
.expect("no issue here");
assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
manager
.lock_all_write(&[1], Duration::new(1, 0))
.expect("no issue here");
manager.unlock_all_write(&[1, 5]).expect("no issue here");
manager
.lock_all_write(&[5], Duration::new(1, 0))
.expect("no issue here");
assert!(manager.lock_all_read(&[1, 5], Duration::new(0, 1)).is_err());
manager
.lock_all_write(&[1], Duration::new(1, 0))
.expect("no issue here");
manager.unlock_all_write(&[1, 5]).expect("no issue here");
manager.lock_all_read(&[5], Duration::new(1, 0)).expect("no issue here");
assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
manager.lock_all_read(&[1], Duration::new(1, 0)).expect("no issue here");
manager.unlock_all_read(&[1, 5]).expect("no issue here");
}
}