1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! Implementations of [`KeyMutex`] and [`KeyRwLock`] using `std::sync::{Mutex, RwLock}`.

use crate::{
    inner::{KeyRef, LockMap},
    Empty,
};
use guardian::{ArcMutexGuardian, ArcRwLockReadGuardian, ArcRwLockWriteGuardian};
use std::{
    hash::Hash,
    mem::ManuallyDrop,
    ops::{Deref, DerefMut},
    sync::{self, LockResult, PoisonError},
};

fn map_guard<G, R>(result: LockResult<G>, f: impl FnOnce(G) -> R) -> LockResult<R> {
    match result {
        Ok(guard) => Ok(f(guard)),
        Err(poisoned) => Err(PoisonError::new(f(poisoned.into_inner()))),
    }
}

decl_mutex_guard!(ArcMutexGuardian);
#[derive(Clone, Default)]
pub struct KeyMutex<K: Eq + Hash, V>(LockMap<K, sync::Mutex<V>>);
impl<K: Eq + Hash + Clone, V: Empty + Default> KeyMutex<K, V> {
    pub fn new() -> Self {
        Self(LockMap::new())
    }

    pub fn lock(&self, key: K) -> LockResult<OwnedMutexGuard<K, V>> {
        let lock = self.0.obtain(key.clone());
        map_guard(ArcMutexGuardian::take(lock), |guard| OwnedMutexGuard {
            key_ref: KeyRef::new(&self.0, key),
            guard: ManuallyDrop::new(guard),
        })
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

decl_rwlock_guard!(ArcRwLockReadGuardian, ArcRwLockWriteGuardian);
#[derive(Clone, Default)]
pub struct KeyRwLock<K: Eq + Hash, V>(LockMap<K, sync::RwLock<V>>);
impl<K: Eq + Hash + Clone, V: Empty + Default> KeyRwLock<K, V> {
    pub fn new() -> Self {
        Self(LockMap::new())
    }

    pub fn read(&self, key: K) -> LockResult<OwnedReadGuard<K, V>> {
        let lock = self.0.obtain(key.clone());
        map_guard(ArcRwLockReadGuardian::take(lock), |guard| OwnedReadGuard {
            key: KeyRef::new(&self.0, key),
            guard: ManuallyDrop::new(guard),
        })
    }

    pub fn write(&self, key: K) -> LockResult<OwnedWriteGuard<K, V>> {
        let lock = self.0.obtain(key.clone());
        map_guard(ArcRwLockWriteGuardian::take(lock), |guard| {
            OwnedWriteGuard {
                key: KeyRef::new(&self.0, key),
                guard: ManuallyDrop::new(guard),
            }
        })
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

impl<K: Eq + Hash, V> Empty for KeyMutex<K, V> {
    fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

impl<K: Eq + Hash, V> Empty for KeyRwLock<K, V> {
    fn is_empty(&self) -> bool {
        self.0.is_empty()
    }
}

#[cfg(test)]
mod test {
    use crate::KeyMutex;
    use std::collections::BTreeSet;

    #[test]
    fn drop_only_if_empty() {
        let locks = KeyMutex::<u32, BTreeSet<String>>::new();

        let mut lock = locks.lock(1).unwrap();
        lock.insert("Hello".to_owned());
        lock.insert("World".to_owned());
        drop(lock);

        // Value is not empty and thus is not dropped
        assert_eq!(locks.len(), 1);

        let mut lock = locks.lock(1).unwrap();
        assert_eq!(lock.len(), 2);
        lock.clear();
        drop(lock);

        // Should be dropped now
        assert_eq!(locks.len(), 0);
    }
}