Skip to main content

loro_internal/
sync.rs

1#[cfg(loom)]
2pub use loom::thread;
3#[cfg(not(loom))]
4pub use std::thread;
5
6#[cfg(loom)]
7mod raw {
8    pub use loom::sync::{
9        LockResult, Mutex as RawMutex, MutexGuard, RwLock as RawRwLock, RwLockReadGuard,
10        RwLockWriteGuard,
11    };
12}
13#[cfg(not(loom))]
14mod raw {
15    pub use std::sync::{
16        LockResult, Mutex as RawMutex, MutexGuard, RwLock as RawRwLock, RwLockReadGuard,
17        RwLockWriteGuard,
18    };
19}
20
21pub use raw::{MutexGuard, RwLockReadGuard, RwLockWriteGuard};
22
23#[cfg(loom)]
24pub use loom::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, AtomicU8, AtomicUsize};
25#[cfg(not(loom))]
26pub use std::sync::atomic::{AtomicBool, AtomicI64, AtomicU64, AtomicU8, AtomicUsize};
27
28#[cfg(loom)]
29pub(crate) use my_thread_local::ThreadLocal;
30#[cfg(not(loom))]
31pub(crate) use thread_local::ThreadLocal;
32
33fn expect_not_poisoned<T>(result: raw::LockResult<T>, lock_kind: &str) -> T {
34    result.unwrap_or_else(|_| panic!("poisoned {lock_kind}"))
35}
36
37#[derive(Debug)]
38pub struct Mutex<T: ?Sized> {
39    inner: raw::RawMutex<T>,
40}
41
42impl<T> Mutex<T> {
43    pub fn new(value: T) -> Self {
44        Self {
45            inner: raw::RawMutex::new(value),
46        }
47    }
48}
49
50impl<T: ?Sized> Mutex<T> {
51    pub fn lock(&self) -> MutexGuard<'_, T> {
52        self.lock_with_kind("mutex")
53    }
54
55    pub(crate) fn lock_with_kind(&self, lock_kind: &str) -> MutexGuard<'_, T> {
56        expect_not_poisoned(self.inner.lock(), lock_kind)
57    }
58
59    pub(crate) fn is_locked(&self) -> bool {
60        self.inner.try_lock().is_err()
61    }
62}
63
64impl<T: Default> Default for Mutex<T> {
65    fn default() -> Self {
66        Self::new(T::default())
67    }
68}
69
70#[derive(Debug)]
71pub struct RwLock<T> {
72    inner: raw::RawRwLock<T>,
73}
74
75impl<T> RwLock<T> {
76    pub fn new(value: T) -> Self {
77        Self {
78            inner: raw::RawRwLock::new(value),
79        }
80    }
81
82    pub fn into_inner(self) -> T {
83        expect_not_poisoned(self.inner.into_inner(), "rwlock")
84    }
85}
86
87impl<T> RwLock<T> {
88    pub fn read(&self) -> RwLockReadGuard<'_, T> {
89        expect_not_poisoned(self.inner.read(), "rwlock")
90    }
91
92    pub fn write(&self) -> RwLockWriteGuard<'_, T> {
93        expect_not_poisoned(self.inner.write(), "rwlock")
94    }
95}
96
97impl<T: Default> Default for RwLock<T> {
98    fn default() -> Self {
99        Self::new(T::default())
100    }
101}
102
103#[cfg(loom)]
104mod my_thread_local {
105    use std::sync::Arc;
106
107    use super::thread;
108    use super::Mutex;
109    use rustc_hash::FxHashMap;
110
111    #[derive(Debug)]
112    pub(crate) struct ThreadLocal<T> {
113        content: Arc<Mutex<FxHashMap<thread::ThreadId, Arc<T>>>>,
114    }
115
116    impl<T: Default> ThreadLocal<T> {
117        pub fn new() -> Self {
118            Self {
119                content: Arc::new(Mutex::new(FxHashMap::default())),
120            }
121        }
122
123        pub fn get_or_default(&self) -> Arc<T> {
124            let mut content = self.content.lock();
125            let v = content
126                .entry(thread::current().id())
127                .or_insert_with(|| Arc::new(T::default()));
128            v.clone()
129        }
130    }
131
132    impl<T> Clone for ThreadLocal<T> {
133        fn clone(&self) -> Self {
134            Self {
135                content: self.content.clone(),
136            }
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    #[should_panic(expected = "poisoned mutex")]
147    fn mutex_lock_panics_after_poison() {
148        let lock = Mutex::new(7);
149        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
150            let _guard = lock.lock();
151            panic!("poison mutex");
152        }));
153
154        drop(lock.lock());
155    }
156
157    #[test]
158    #[should_panic(expected = "poisoned rwlock")]
159    fn rwlock_read_panics_after_poison() {
160        let lock = RwLock::new(7);
161        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
162            let mut guard = lock.write();
163            *guard = 9;
164            panic!("poison rwlock");
165        }));
166
167        drop(lock.read());
168    }
169}