fast_async_mutex/
mutex.rs1use crate::inner::Inner;
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8#[derive(Debug)]
10pub struct Mutex<T: ?Sized> {
11 inner: Inner<T>,
12}
13
14impl<T> Mutex<T> {
15 #[inline]
17 pub const fn new(data: T) -> Mutex<T> {
18 Mutex {
19 inner: Inner::new(data),
20 }
21 }
22}
23
24impl<T: ?Sized> Mutex<T> {
25 #[inline]
42 pub const fn lock(&self) -> MutexGuardFuture<T> {
43 MutexGuardFuture {
44 mutex: &self,
45 is_realized: false,
46 }
47 }
48
49 #[inline]
67 pub fn lock_owned(self: &Arc<Self>) -> MutexOwnedGuardFuture<T> {
68 MutexOwnedGuardFuture {
69 mutex: self.clone(),
70 is_realized: false,
71 }
72 }
73}
74
75#[derive(Debug)]
79pub struct MutexGuard<'a, T: ?Sized> {
80 mutex: &'a Mutex<T>,
81}
82
83#[derive(Debug)]
84pub struct MutexGuardFuture<'a, T: ?Sized> {
85 mutex: &'a Mutex<T>,
86 is_realized: bool,
87}
88
89#[derive(Debug)]
94pub struct MutexOwnedGuard<T: ?Sized> {
95 mutex: Arc<Mutex<T>>,
96}
97
98#[derive(Debug)]
99pub struct MutexOwnedGuardFuture<T: ?Sized> {
100 mutex: Arc<Mutex<T>>,
101 is_realized: bool,
102}
103
104impl<'a, T: ?Sized> Future for MutexGuardFuture<'a, T> {
105 type Output = MutexGuard<'a, T>;
106
107 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
108 if self.mutex.inner.try_acquire() {
109 self.is_realized = true;
110 Poll::Ready(MutexGuard { mutex: self.mutex })
111 } else {
112 self.mutex.inner.store_waker(cx.waker());
113 Poll::Pending
114 }
115 }
116}
117
118impl<T: ?Sized> Future for MutexOwnedGuardFuture<T> {
119 type Output = MutexOwnedGuard<T>;
120
121 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
122 if self.mutex.inner.try_acquire() {
123 self.is_realized = true;
124 Poll::Ready(MutexOwnedGuard {
125 mutex: self.mutex.clone(),
126 })
127 } else {
128 self.mutex.inner.store_waker(cx.waker());
129 Poll::Pending
130 }
131 }
132}
133
134crate::impl_send_sync_mutex!(Mutex, MutexGuard, MutexOwnedGuard);
135
136crate::impl_deref_mut!(MutexGuard, 'a);
137crate::impl_deref_mut!(MutexOwnedGuard);
138
139crate::impl_drop_guard!(MutexGuard, 'a, unlock);
140crate::impl_drop_guard!(MutexOwnedGuard, unlock);
141crate::impl_drop_guard_future!(MutexGuardFuture, 'a, unlock);
142crate::impl_drop_guard_future!(MutexOwnedGuardFuture, unlock);
143
144#[cfg(test)]
145mod tests {
146 use crate::mutex::{Mutex, MutexGuard, MutexOwnedGuard};
147 use futures::executor::block_on;
148 use futures::{FutureExt, StreamExt, TryStreamExt};
149 use std::ops::AddAssign;
150 use std::sync::Arc;
151 use tokio::time::{sleep, Duration};
152
153 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
154 async fn test_mutex() {
155 let c = Mutex::new(0);
156
157 futures::stream::iter(0..10000)
158 .for_each_concurrent(None, |_| async {
159 let mut co: MutexGuard<i32> = c.lock().await;
160 *co += 1;
161 })
162 .await;
163
164 let co = c.lock().await;
165 assert_eq!(*co, 10000)
166 }
167
168 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
169 async fn test_mutex_delay() {
170 let expected_result = 100;
171 let c = Mutex::new(0);
172
173 futures::stream::iter(0..expected_result)
174 .then(|i| c.lock().map(move |co| (i, co)))
175 .for_each_concurrent(None, |(i, mut co)| async move {
176 sleep(Duration::from_millis(expected_result - i)).await;
177 *co += 1;
178 })
179 .await;
180
181 let co = c.lock().await;
182 assert_eq!(*co, expected_result)
183 }
184
185 #[tokio::test(flavor = "multi_thread", worker_threads = 12)]
186 async fn test_owned_mutex() {
187 let c = Arc::new(Mutex::new(0));
188
189 futures::stream::iter(0..10000)
190 .for_each_concurrent(None, |_| async {
191 let mut co: MutexOwnedGuard<i32> = c.lock_owned().await;
192 *co += 1;
193 })
194 .await;
195
196 let co = c.lock_owned().await;
197 assert_eq!(*co, 10000)
198 }
199
200 #[tokio::test]
201 async fn test_container() {
202 let c = Mutex::new(String::from("lol"));
203
204 let mut co: MutexGuard<String> = c.lock().await;
205 co.add_assign("lol");
206
207 assert_eq!(*co, "lollol");
208 }
209
210 #[tokio::test]
211 async fn test_timeout() {
212 let c = Mutex::new(String::from("lol"));
213
214 let co: MutexGuard<String> = c.lock().await;
215
216 futures::stream::iter(0..10000i32)
217 .then(|_| tokio::time::timeout(Duration::from_nanos(1), c.lock()))
218 .try_for_each_concurrent(None, |_c| futures::future::ok(()))
219 .await
220 .expect_err("timout must be");
221
222 drop(co);
223
224 let mut co: MutexGuard<String> = c.lock().await;
225 co.add_assign("lol");
226
227 assert_eq!(*co, "lollol");
228 }
229
230 #[test]
231 fn multithreading_test() {
232 let num = 100;
233 let mutex = Arc::new(Mutex::new(0));
234 let ths: Vec<_> = (0..num)
235 .map(|_| {
236 let mutex = mutex.clone();
237 std::thread::spawn(move || {
238 block_on(async {
239 let mut lock = mutex.lock().await;
240 *lock += 1;
241 })
242 })
243 })
244 .collect();
245
246 for thread in ths {
247 thread.join().unwrap();
248 }
249
250 block_on(async {
251 let lock = mutex.lock().await;
252 assert_eq!(num, *lock)
253 })
254 }
255}