1use std::{
2 cell::UnsafeCell,
3 ops,
4 sync::atomic::{AtomicBool, AtomicUsize, Ordering},
5};
6
7use crate::{maker::*, Lockable, LockableNew};
8
9pub struct SpinMutex<T> {
11 flag: AtomicBool,
13 guard: AtomicUsize,
15 data: UnsafeCell<T>,
17}
18
19impl<T> LockableNew for SpinMutex<T> {
20 type Value = T;
21
22 fn new(t: T) -> Self {
24 Self {
25 flag: AtomicBool::new(false),
26 data: t.into(),
27 guard: AtomicUsize::new(0),
28 }
29 }
30}
31
32impl<T: Default> Default for SpinMutex<T> {
33 fn default() -> Self {
34 Self::new(Default::default())
35 }
36}
37
38impl<T> SpinMutex<T> {
39 pub const fn const_new(t: T) -> Self {
40 Self {
41 flag: AtomicBool::new(false),
42 data: UnsafeCell::new(t),
43 guard: AtomicUsize::new(0),
44 }
45 }
46 #[cold]
47 fn lockable(&self) {
48 while self.flag.load(Ordering::Relaxed) {}
49 }
50}
51
52impl<T> Lockable for SpinMutex<T> {
53 type GuardMut<'a> = SpinMutexGuard<'a, T>
54 where
55 Self: 'a;
56
57 #[inline]
58 fn lock(&self) -> Self::GuardMut<'_> {
59 while self
60 .flag
61 .compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
62 .is_err()
63 {
64 self.lockable();
65 }
66
67 let mut guard = SpinMutexGuard {
68 locker: self,
69 ptr: 0,
70 };
71
72 guard.ptr = &guard as *const _ as usize;
73
74 self.guard
75 .compare_exchange(
76 0,
77 &guard as *const _ as usize,
78 Ordering::Acquire,
79 Ordering::Relaxed,
80 )
81 .expect("Set guard ptr error");
82
83 guard
84 }
85
86 #[inline]
87 fn try_lock(&self) -> Option<Self::GuardMut<'_>> {
88 if self
89 .flag
90 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
91 .is_ok()
92 {
93 let mut guard = SpinMutexGuard {
94 locker: self,
95 ptr: 0,
96 };
97
98 guard.ptr = &guard as *const _ as usize;
99
100 self.guard
101 .compare_exchange(
102 0,
103 &guard as *const _ as usize,
104 Ordering::Acquire,
105 Ordering::Relaxed,
106 )
107 .expect("Set guard ptr error");
108
109 Some(guard)
110 } else {
111 None
112 }
113 }
114
115 #[inline]
116 fn unlock(guard: Self::GuardMut<'_>) -> &Self {
117 let locker = guard.locker;
118
119 drop(guard);
120
121 locker
122 }
123}
124
125pub struct SpinMutexGuard<'a, T> {
127 locker: &'a SpinMutex<T>,
129 ptr: usize,
130}
131
132impl<'a, T> Drop for SpinMutexGuard<'a, T> {
133 fn drop(&mut self) {
134 self.locker
135 .guard
136 .compare_exchange(self.ptr, 0, Ordering::Release, Ordering::Relaxed)
137 .expect("Unset guard ptr error");
138
139 self.locker.flag.store(false, Ordering::Release);
140 }
141}
142
143impl<'a, T> SpinMutexGuard<'a, T> {
144 #[inline]
145 fn deref_check(&self) {
146 }
152}
153
154impl<'a, T> ops::Deref for SpinMutexGuard<'a, T> {
155 type Target = T;
156
157 #[inline]
158 fn deref(&self) -> &Self::Target {
159 self.deref_check();
160 unsafe { &*self.locker.data.get() }
161 }
162}
163
164impl<'a, T> ops::DerefMut for SpinMutexGuard<'a, T> {
165 #[inline]
166 fn deref_mut(&mut self) -> &mut Self::Target {
167 self.deref_check();
168 unsafe { &mut *self.locker.data.get() }
169 }
170}
171
172unsafe impl<T: Send> Send for SpinMutex<T> {}
175unsafe impl<T: Send> Sync for SpinMutex<T> {}
176
177unsafe impl<'a, T: Send> Send for SpinMutexGuard<'a, T> {}
179unsafe impl<'a, T: Sync> Sync for SpinMutexGuard<'a, T> {}
180
181pub type AsyncSpinMutex<T> =
183 AsyncLockableMaker<SpinMutex<T>, SpinMutex<DefaultAsyncLockableMediator>>;
184
185#[cfg(test)]
186mod tests {
187 use std::{
188 sync::Arc,
189 time::{Duration, Instant},
190 };
191
192 use futures::{executor::ThreadPool, task::SpawnExt};
193
194 use crate::{AsyncLockable, AsyncSpinMutex};
195
196 #[futures_test::test]
197 async fn test_async_lock() {
198 let loops = 1000;
199
200 let pool = ThreadPool::builder().pool_size(10).create().unwrap();
201
202 let shared = Arc::new(AsyncSpinMutex::new(0));
203
204 let mut join_handles = vec![];
205
206 for _ in 0..loops {
207 let shared = shared.clone();
208
209 join_handles.push(
210 pool.spawn_with_handle(async move {
211 let mut data = shared.lock().await;
212
213 AsyncSpinMutex::unlock(data);
214
215 for _ in 0..loops {
216 data = shared.lock().await;
217
218 *data += 1;
219
220 AsyncSpinMutex::unlock(data);
221 }
222 })
223 .unwrap(),
224 );
225 }
226
227 for join in join_handles {
228 join.await
229 }
230
231 assert_eq!(*shared.lock().await, loops * loops);
232 }
233
234 #[futures_test::test]
235 async fn bench_async_lock() {
236 let loops = 1000000;
237
238 let shared = AsyncSpinMutex::new(0);
239
240 let mut duration = Duration::from_secs(0);
241
242 for _ in 0..loops {
243 let start = Instant::now();
244 let mut shared = shared.lock().await;
245 duration += start.elapsed();
246
247 *shared += 1;
248 }
249
250 assert_eq!(*shared.lock().await, loops);
251
252 println!("bench_async_lock: {:?}", duration / loops);
253 }
254}