persy 1.5.2

Transactional Persistence Engine
Documentation
use crate::error::TimeoutError;
use std::{
    borrow::Borrow,
    cmp::Eq,
    collections::{hash_map::Entry, HashMap},
    hash::Hash,
    sync::{Arc, Condvar, Mutex, MutexGuard},
    time::Duration,
};

struct RwLockVar {
    write: bool,
    read_count: u32,
    cond: Arc<Condvar>,
}
impl RwLockVar {
    fn new_write() -> RwLockVar {
        RwLockVar {
            write: true,
            read_count: 0,
            cond: Arc::new(Condvar::new()),
        }
    }

    fn new_read() -> RwLockVar {
        RwLockVar {
            write: false,
            read_count: 1,
            cond: Arc::new(Condvar::new()),
        }
    }

    fn inc_read(&mut self) {
        self.read_count += 1;
    }
    fn dec_read(&mut self) -> bool {
        self.read_count -= 1;
        self.read_count == 0
    }
}

pub struct RwLockManager<T> {
    locks: Mutex<HashMap<T, RwLockVar>>,
}

impl<T> Default for RwLockManager<T> {
    fn default() -> Self {
        RwLockManager {
            locks: Mutex::new(HashMap::<T, RwLockVar>::new()),
        }
    }
}

impl<T> RwLockManager<T>
where
    T: Eq + Hash + Clone,
{
    pub fn lock_all_write(&self, to_lock: &[T], timeout: Duration) -> Result<(), TimeoutError> {
        let mut locked = Vec::with_capacity(to_lock.len());
        for single in to_lock {
            let mut lock_manager = self.locks.lock().expect("lock not poisoned");
            loop {
                let cond = match lock_manager.entry(single.clone()) {
                    Entry::Occupied(o) => o.get().cond.clone(),
                    Entry::Vacant(v) => {
                        let lock = RwLockVar::new_write();
                        v.insert(lock);
                        locked.push(single.clone());
                        break;
                    }
                };
                let (guard, timedout) = cond.wait_timeout(lock_manager, timeout).expect("lock not poisoned");
                lock_manager = guard;
                if timedout.timed_out() {
                    RwLockManager::unlock_all_write_with_guard(&mut lock_manager, &locked);
                    return Err(TimeoutError::LockTimeout);
                }
            }
        }
        Ok(())
    }
    pub fn lock_all_read(&self, to_lock: &[T], timeout: Duration) -> Result<(), TimeoutError> {
        let mut locked = Vec::with_capacity(to_lock.len());
        for single in to_lock {
            let mut lock_manager = self.locks.lock().expect("lock not poisoned");
            loop {
                let cond;
                match lock_manager.entry(single.clone()) {
                    Entry::Occupied(mut o) => {
                        if o.get().write {
                            cond = o.get().cond.clone();
                        } else {
                            o.get_mut().inc_read();
                            locked.push(single.clone());
                            break;
                        }
                    }
                    Entry::Vacant(v) => {
                        v.insert(RwLockVar::new_read());
                        locked.push(single.clone());
                        break;
                    }
                };
                let (guard, timedout) = cond.wait_timeout(lock_manager, timeout).expect("lock not poisoned");
                lock_manager = guard;
                if timedout.timed_out() {
                    RwLockManager::unlock_all_read_with_guard(&mut lock_manager, &locked);
                    return Err(TimeoutError::LockTimeout);
                }
            }
        }
        Ok(())
    }

    fn unlock_all_read_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
        for single in to_unlock {
            if let Entry::Occupied(mut lock) = lock_manager.entry(single.clone()) {
                if lock.get_mut().dec_read() {
                    let cond = lock.get().cond.clone();
                    lock.remove();
                    cond.notify_all();
                }
            }
        }
    }
    pub fn unlock_all_read(&self, to_unlock: &[T]) {
        let mut lock_manager = self.locks.lock().expect("lock not poisoned");
        RwLockManager::unlock_all_read_with_guard(&mut lock_manager, to_unlock);
    }

    fn unlock_all_write_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
        for single in to_unlock {
            if let Some(lock) = lock_manager.remove(single) {
                lock.cond.notify_all();
            }
        }
    }

    pub fn unlock_all_write(&self, to_unlock: &[T]) {
        let mut lock_manager = self.locks.lock().expect("lock not poisoned");
        RwLockManager::unlock_all_write_with_guard(&mut lock_manager, to_unlock);
    }
}

struct LockInfo {
    var: Arc<Condvar>,
}
impl LockInfo {
    fn new() -> Self {
        Self {
            var: Arc::new(Condvar::new()),
        }
    }
}

pub struct LockManager<T> {
    locks: Mutex<HashMap<T, LockInfo>>,
}

impl<T> Default for LockManager<T> {
    fn default() -> Self {
        LockManager {
            locks: Mutex::new(HashMap::<T, LockInfo>::new()),
        }
    }
}

impl<T> LockManager<T>
where
    T: Eq + Hash + Clone,
{
    pub fn lock_all(&self, to_lock: &[T], timeout: Duration) -> Result<(), TimeoutError> {
        let mut locked = Vec::with_capacity(to_lock.len());
        for single in to_lock {
            let info = LockInfo::new();
            let mut lock_manager = self.locks.lock().expect("lock not poisoned");
            loop {
                let cond_var = match lock_manager.entry(single.clone()) {
                    Entry::Occupied(o) => o.get().var.clone(),
                    Entry::Vacant(v) => {
                        v.insert(info);
                        locked.push(single.clone());
                        break;
                    }
                };
                let (guard, timedout) = cond_var.wait_timeout(lock_manager, timeout).expect("lock not poisoned");
                lock_manager = guard;
                if timedout.timed_out() {
                    LockManager::unlock_all_with_guard(&mut lock_manager, locked.iter());
                    return Err(TimeoutError::LockTimeout);
                }
            }
        }
        Ok(())
    }

    fn unlock_all_with_guard<'a, Q: 'a>(
        lock_manager: &mut MutexGuard<HashMap<T, LockInfo>>,
        to_unlock: impl Iterator<Item = &'a Q>,
    ) where
        T: Borrow<Q>,
        Q: Hash + Eq,
    {
        for single in to_unlock {
            if let Some(cond) = lock_manager.remove(single) {
                cond.var.notify_all();
            }
        }
    }

    #[inline]
    pub fn unlock_all<Q>(&self, to_unlock: &[Q])
    where
        T: Borrow<Q>,
        Q: Hash + Eq,
    {
        self.unlock_all_iter(to_unlock.iter())
    }

    #[inline]
    pub fn unlock_all_iter<'a, Q: 'a>(&self, to_unlock: impl Iterator<Item = &'a Q>)
    where
        T: Borrow<Q>,
        Q: Hash + Eq,
    {
        let mut lock_manager = self.locks.lock().expect("lock not poisoned");
        LockManager::unlock_all_with_guard(&mut lock_manager, to_unlock);
    }
}

#[cfg(test)]
mod tests {
    use super::{LockManager, RwLockManager};
    use std::time::Duration;

    #[test]
    fn test_lock_manager_unlock_if_lock_fail() {
        let manager: LockManager<_> = Default::default();
        manager.lock_all(&[5], Duration::new(1, 0)).expect("no issue here");
        assert!(manager.lock_all(&[1, 5], Duration::new(0, 1)).is_err());
        manager.lock_all(&[1], Duration::new(1, 0)).expect("no issue here");
        manager.unlock_all(&[1, 5]);
    }

    #[test]
    fn test_rw_lock_manager_unlock_if_lock_fail() {
        let manager: RwLockManager<_> = Default::default();
        manager
            .lock_all_write(&[5], Duration::new(1, 0))
            .expect("no issue here");
        assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
        manager
            .lock_all_write(&[1], Duration::new(1, 0))
            .expect("no issue here");
        manager.unlock_all_write(&[1, 5]);

        manager
            .lock_all_write(&[5], Duration::new(1, 0))
            .expect("no issue here");
        assert!(manager.lock_all_read(&[1, 5], Duration::new(0, 1)).is_err());
        manager
            .lock_all_write(&[1], Duration::new(1, 0))
            .expect("no issue here");
        manager.unlock_all_write(&[1, 5]);

        manager.lock_all_read(&[5], Duration::new(1, 0)).expect("no issue here");
        assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
        manager.lock_all_read(&[1], Duration::new(1, 0)).expect("no issue here");
        manager.unlock_all_read(&[1, 5]);
    }
}