persy 0.7.0

Transactional Persistence Engine
Documentation
use crate::{PRes, PersyError};
use std::{
    collections::{hash_map::Entry, HashMap},
    sync::{Arc, Condvar, Mutex, MutexGuard},
    time::Duration,
};

struct RwLockVar {
    write: bool,
    read_count: u32,
    cond: Arc<Condvar>,
}
impl RwLockVar {
    fn new_write() -> RwLockVar {
        RwLockVar {
            write: true,
            read_count: 0,
            cond: Arc::new(Condvar::new()),
        }
    }

    fn new_read() -> RwLockVar {
        RwLockVar {
            write: false,
            read_count: 1,
            cond: Arc::new(Condvar::new()),
        }
    }

    fn inc_read(&mut self) {
        self.read_count += 1;
    }
    fn dec_read(&mut self) -> bool {
        self.read_count -= 1;
        self.read_count == 0
    }
}

pub struct RwLockManager<T>
where
    T: std::cmp::Eq,
    T: std::hash::Hash,
    T: Clone,
{
    locks: Mutex<HashMap<T, RwLockVar>>,
}

impl<T> Default for RwLockManager<T>
where
    T: std::cmp::Eq,
    T: std::hash::Hash,
    T: Clone,
{
    fn default() -> Self {
        RwLockManager {
            locks: Mutex::new(HashMap::<T, RwLockVar>::new()),
        }
    }
}

impl<T> RwLockManager<T>
where
    T: std::cmp::Eq,
    T: std::hash::Hash,
    T: Clone,
{
    pub fn lock_all_write(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
        let mut locked = Vec::new();
        for single in to_lock {
            let mut lock_manager = self.locks.lock()?;
            loop {
                let cond = match lock_manager.entry(single.clone()) {
                    Entry::Occupied(o) => o.get().cond.clone(),
                    Entry::Vacant(v) => {
                        let lock = RwLockVar::new_write();
                        v.insert(lock);
                        locked.push(single.clone());
                        break;
                    }
                };
                match cond.wait_timeout(lock_manager, timeout) {
                    Ok((guard, timedout)) => {
                        lock_manager = guard;
                        if timedout.timed_out() {
                            RwLockManager::unlock_all_write_with_guard(&mut lock_manager, &locked);
                            return Err(PersyError::TransactionTimeout);
                        }
                    }
                    Err(x) => {
                        // TODO: Check this, it may not be possible to unlock, but may be safe
                        // anyway because no-one can actually lock anything.
                        self.unlock_all_write(&locked)?;
                        return Err(PersyError::from(x));
                    }
                }
            }
        }
        Ok(())
    }
    pub fn lock_all_read(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
        let mut locked = Vec::new();
        for single in to_lock {
            let mut lock_manager = self.locks.lock()?;
            loop {
                let cond;
                match lock_manager.entry(single.clone()) {
                    Entry::Occupied(mut o) => {
                        if o.get().write {
                            cond = o.get().cond.clone();
                        } else {
                            o.get_mut().inc_read();
                            locked.push(single.clone());
                            break;
                        }
                    }
                    Entry::Vacant(v) => {
                        v.insert(RwLockVar::new_read());
                        locked.push(single.clone());
                        break;
                    }
                };
                match cond.wait_timeout(lock_manager, timeout) {
                    Ok((guard, timedout)) => {
                        lock_manager = guard;
                        if timedout.timed_out() {
                            RwLockManager::unlock_all_read_with_guard(&mut lock_manager, &locked);
                            return Err(PersyError::TransactionTimeout);
                        }
                    }
                    Err(x) => {
                        // TODO: Check this, it may not be possible to unlock, but may be safe
                        // anyway because no-one can actually lock anything.
                        self.unlock_all_read(&locked)?;
                        return Err(PersyError::from(x));
                    }
                }
            }
        }
        Ok(())
    }

    fn unlock_all_read_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
        for single in to_unlock {
            if let Entry::Occupied(mut lock) = lock_manager.entry(single.clone()) {
                if lock.get_mut().dec_read() {
                    let cond = lock.get().cond.clone();
                    lock.remove();
                    cond.notify_all();
                }
            }
        }
    }
    pub fn unlock_all_read(&self, to_unlock: &[T]) -> PRes<()> {
        let mut lock_manager = self.locks.lock()?;
        RwLockManager::unlock_all_read_with_guard(&mut lock_manager, to_unlock);
        Ok(())
    }

    fn unlock_all_write_with_guard(lock_manager: &mut MutexGuard<HashMap<T, RwLockVar>>, to_unlock: &[T]) {
        for single in to_unlock {
            if let Some(lock) = lock_manager.remove(single) {
                lock.cond.notify_all();
            }
        }
    }

    pub fn unlock_all_write(&self, to_unlock: &[T]) -> PRes<()> {
        let mut lock_manager = self.locks.lock()?;
        RwLockManager::unlock_all_write_with_guard(&mut lock_manager, to_unlock);
        Ok(())
    }
}
pub struct LockManager<T>
where
    T: std::cmp::Eq,
    T: std::hash::Hash,
    T: Clone,
{
    locks: Mutex<HashMap<T, Arc<Condvar>>>,
}

impl<T> Default for LockManager<T>
where
    T: std::cmp::Eq,
    T: std::hash::Hash,
    T: Clone,
{
    fn default() -> Self {
        LockManager {
            locks: Mutex::new(HashMap::<T, Arc<Condvar>>::new()),
        }
    }
}

impl<T> LockManager<T>
where
    T: std::cmp::Eq + std::hash::Hash + Clone,
{
    pub fn lock_all(&self, to_lock: &[T], timeout: Duration) -> PRes<()> {
        let mut locked = Vec::new();
        for single in to_lock {
            let cond = Arc::new(Condvar::new());
            let mut lock_manager = self.locks.lock()?;
            loop {
                let cond = match lock_manager.entry(single.clone()) {
                    Entry::Occupied(o) => o.get().clone(),
                    Entry::Vacant(v) => {
                        v.insert(cond);
                        locked.push(single.clone());
                        break;
                    }
                };
                match cond.wait_timeout(lock_manager, timeout) {
                    Ok((guard, timedout)) => {
                        lock_manager = guard;
                        if timedout.timed_out() {
                            LockManager::unlock_all_with_guard(&mut lock_manager, locked.iter());
                            return Err(PersyError::TransactionTimeout);
                        }
                    }
                    Err(x) => {
                        // TODO: Check this, it may not be possible to unlock, but may be safe
                        // anyway because no-one can actually lock anything.
                        self.unlock_all(&locked)?;
                        return Err(PersyError::from(x));
                    }
                }
            }
        }
        Ok(())
    }

    fn unlock_all_with_guard<'a, Q: 'a>(
        lock_manager: &mut MutexGuard<HashMap<T, Arc<Condvar>>>,
        to_unlock: impl Iterator<Item = &'a Q>,
    ) where
        T: std::borrow::Borrow<Q>,
        Q: std::hash::Hash + Eq,
    {
        for single in to_unlock {
            if let Some(cond) = lock_manager.remove(single) {
                cond.notify_all();
            }
        }
    }

    #[inline]
    pub fn unlock_all<Q>(&self, to_unlock: &[Q]) -> PRes<()>
    where
        T: std::borrow::Borrow<Q>,
        Q: std::hash::Hash + Eq,
    {
        self.unlock_all_iter(to_unlock.iter())
    }

    #[inline]
    pub fn unlock_all_iter<'a, Q: 'a>(&self, to_unlock: impl Iterator<Item = &'a Q>) -> PRes<()>
    where
        T: std::borrow::Borrow<Q>,
        Q: std::hash::Hash + Eq,
    {
        let mut lock_manager = self.locks.lock()?;
        LockManager::unlock_all_with_guard(&mut lock_manager, to_unlock);
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::{LockManager, RwLockManager};
    use std::time::Duration;

    #[test]
    fn test_lock_manager_unlock_if_lock_fail() {
        let manager: LockManager<_> = Default::default();
        manager.lock_all(&[5], Duration::new(1, 0)).expect("no issue here");
        assert!(manager.lock_all(&[1, 5], Duration::new(0, 1)).is_err());
        manager.lock_all(&[1], Duration::new(1, 0)).expect("no issue here");
        manager.unlock_all(&[1, 5]).expect("no issue here");
    }

    #[test]
    fn test_rw_lock_manager_unlock_if_lock_fail() {
        let manager: RwLockManager<_> = Default::default();
        manager
            .lock_all_write(&[5], Duration::new(1, 0))
            .expect("no issue here");
        assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
        manager
            .lock_all_write(&[1], Duration::new(1, 0))
            .expect("no issue here");
        manager.unlock_all_write(&[1, 5]).expect("no issue here");

        manager
            .lock_all_write(&[5], Duration::new(1, 0))
            .expect("no issue here");
        assert!(manager.lock_all_read(&[1, 5], Duration::new(0, 1)).is_err());
        manager
            .lock_all_write(&[1], Duration::new(1, 0))
            .expect("no issue here");
        manager.unlock_all_write(&[1, 5]).expect("no issue here");

        manager.lock_all_read(&[5], Duration::new(1, 0)).expect("no issue here");
        assert!(manager.lock_all_write(&[1, 5], Duration::new(0, 1)).is_err());
        manager.lock_all_read(&[1], Duration::new(1, 0)).expect("no issue here");
        manager.unlock_all_read(&[1, 5]).expect("no issue here");
    }
}