hala_sync/
spin.rs

1use std::{
2    cell::UnsafeCell,
3    ops,
4    sync::atomic::{AtomicBool, AtomicUsize, Ordering},
5};
6
7use crate::{maker::*, Lockable, LockableNew};
8
9/// A spin style mutex implementation without handle thread-specific data.
10pub struct SpinMutex<T> {
11    /// The lock status flag.
12    flag: AtomicBool,
13    /// Pointer to the Guard object that owns the lock, or 0 if no Guard object owns the lock.
14    guard: AtomicUsize,
15    /// unsafe cell to hold protected data.
16    data: UnsafeCell<T>,
17}
18
19impl<T> LockableNew for SpinMutex<T> {
20    type Value = T;
21
22    /// Creates a new mutex in an unlocked state ready for use.
23    fn new(t: T) -> Self {
24        Self {
25            flag: AtomicBool::new(false),
26            data: t.into(),
27            guard: AtomicUsize::new(0),
28        }
29    }
30}
31
32impl<T: Default> Default for SpinMutex<T> {
33    fn default() -> Self {
34        Self::new(Default::default())
35    }
36}
37
38impl<T> SpinMutex<T> {
39    pub const fn const_new(t: T) -> Self {
40        Self {
41            flag: AtomicBool::new(false),
42            data: UnsafeCell::new(t),
43            guard: AtomicUsize::new(0),
44        }
45    }
46    #[cold]
47    fn lockable(&self) {
48        while self.flag.load(Ordering::Relaxed) {}
49    }
50}
51
52impl<T> Lockable for SpinMutex<T> {
53    type GuardMut<'a> = SpinMutexGuard<'a, T>
54    where
55        Self: 'a;
56
57    #[inline]
58    fn lock(&self) -> Self::GuardMut<'_> {
59        while self
60            .flag
61            .compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
62            .is_err()
63        {
64            self.lockable();
65        }
66
67        let mut guard = SpinMutexGuard {
68            locker: self,
69            ptr: 0,
70        };
71
72        guard.ptr = &guard as *const _ as usize;
73
74        self.guard
75            .compare_exchange(
76                0,
77                &guard as *const _ as usize,
78                Ordering::Acquire,
79                Ordering::Relaxed,
80            )
81            .expect("Set guard ptr error");
82
83        guard
84    }
85
86    #[inline]
87    fn try_lock(&self) -> Option<Self::GuardMut<'_>> {
88        if self
89            .flag
90            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
91            .is_ok()
92        {
93            let mut guard = SpinMutexGuard {
94                locker: self,
95                ptr: 0,
96            };
97
98            guard.ptr = &guard as *const _ as usize;
99
100            self.guard
101                .compare_exchange(
102                    0,
103                    &guard as *const _ as usize,
104                    Ordering::Acquire,
105                    Ordering::Relaxed,
106                )
107                .expect("Set guard ptr error");
108
109            Some(guard)
110        } else {
111            None
112        }
113    }
114
115    #[inline]
116    fn unlock(guard: Self::GuardMut<'_>) -> &Self {
117        let locker = guard.locker;
118
119        drop(guard);
120
121        locker
122    }
123}
124
125/// RAII type that handle `scope lock` semantics
126pub struct SpinMutexGuard<'a, T> {
127    /// a reference to the associated [`SpinMutex`]
128    locker: &'a SpinMutex<T>,
129    ptr: usize,
130}
131
132impl<'a, T> Drop for SpinMutexGuard<'a, T> {
133    fn drop(&mut self) {
134        self.locker
135            .guard
136            .compare_exchange(self.ptr, 0, Ordering::Release, Ordering::Relaxed)
137            .expect("Unset guard ptr error");
138
139        self.locker.flag.store(false, Ordering::Release);
140    }
141}
142
143impl<'a, T> SpinMutexGuard<'a, T> {
144    #[inline]
145    fn deref_check(&self) {
146        // assert_eq!(
147        //     self.locker.guard.load(Ordering::Acquire),
148        //     self.ptr,
149        //     "fail to check constraint of deref/deref_mut ops"
150        // );
151    }
152}
153
154impl<'a, T> ops::Deref for SpinMutexGuard<'a, T> {
155    type Target = T;
156
157    #[inline]
158    fn deref(&self) -> &Self::Target {
159        self.deref_check();
160        unsafe { &*self.locker.data.get() }
161    }
162}
163
164impl<'a, T> ops::DerefMut for SpinMutexGuard<'a, T> {
165    #[inline]
166    fn deref_mut(&mut self) -> &mut Self::Target {
167        self.deref_check();
168        unsafe { &mut *self.locker.data.get() }
169    }
170}
171
172// these are the only places where `T: Send` matters; all other
173// functionality works fine on a single thread.
174unsafe impl<T: Send> Send for SpinMutex<T> {}
175unsafe impl<T: Send> Sync for SpinMutex<T> {}
176
177// Safe to send since we don't track any thread-specific details
178unsafe impl<'a, T: Send> Send for SpinMutexGuard<'a, T> {}
179unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
180
181/// Futures-aware [`SpinMutex`] type
182pub type AsyncSpinMutex<T> =
183    AsyncLockableMaker<SpinMutex<T>, SpinMutex<DefaultAsyncLockableMediator>>;
184
185#[cfg(test)]
186mod tests {
187    use std::{
188        sync::Arc,
189        time::{Duration, Instant},
190    };
191
192    use futures::{executor::ThreadPool, task::SpawnExt};
193
194    use crate::{AsyncLockable, AsyncSpinMutex};
195
196    #[futures_test::test]
197    async fn test_async_lock() {
198        let loops = 1000;
199
200        let pool = ThreadPool::builder().pool_size(10).create().unwrap();
201
202        let shared = Arc::new(AsyncSpinMutex::new(0));
203
204        let mut join_handles = vec![];
205
206        for _ in 0..loops {
207            let shared = shared.clone();
208
209            join_handles.push(
210                pool.spawn_with_handle(async move {
211                    let mut data = shared.lock().await;
212
213                    AsyncSpinMutex::unlock(data);
214
215                    for _ in 0..loops {
216                        data = shared.lock().await;
217
218                        *data += 1;
219
220                        AsyncSpinMutex::unlock(data);
221                    }
222                })
223                .unwrap(),
224            );
225        }
226
227        for join in join_handles {
228            join.await
229        }
230
231        assert_eq!(*shared.lock().await, loops * loops);
232    }
233
234    #[futures_test::test]
235    async fn bench_async_lock() {
236        let loops = 1000000;
237
238        let shared = AsyncSpinMutex::new(0);
239
240        let mut duration = Duration::from_secs(0);
241
242        for _ in 0..loops {
243            let start = Instant::now();
244            let mut shared = shared.lock().await;
245            duration += start.elapsed();
246
247            *shared += 1;
248        }
249
250        assert_eq!(*shared.lock().await, loops);
251
252        println!("bench_async_lock: {:?}", duration / loops);
253    }
254}