1use std::{
2 collections::HashMap,
3 ops::{self, DerefMut},
4 task::Waker,
5};
6
7#[cfg(feature = "trace_lock")]
8use std::{fmt::Debug, panic::Location};
9
10use super::*;
11
12pub struct AsyncLockableMaker<Locker, Wakers> {
14 inner_locker: Locker,
15 wakers: Wakers,
16 #[cfg(feature = "trace_lock")]
17 id: usize,
18}
19
20impl<Locker, Wakers> Default for AsyncLockableMaker<Locker, Wakers>
21where
22 Locker: Default,
23 Wakers: Default,
24{
25 fn default() -> Self {
26 #[cfg(feature = "trace_lock")]
27 let id = {
28 use rand::{thread_rng, RngCore};
29
30 let mut buf = [0u8; 8];
31 thread_rng().fill_bytes(&mut buf);
32
33 u64::from_be_bytes(buf) as usize
34 };
35
36 Self {
37 inner_locker: Default::default(),
38 wakers: Default::default(),
39 #[cfg(feature = "trace_lock")]
40 id,
41 }
42 }
43}
44
45impl<Locker, Wakers> AsyncLockableMaker<Locker, Wakers>
46where
47 Locker: LockableNew,
48 Wakers: Default,
49{
50 pub fn new(value: Locker::Value) -> Self {
51 #[cfg(feature = "trace_lock")]
52 let id = {
53 use rand::{thread_rng, RngCore};
54
55 let mut buf = [0u8; 8];
56 thread_rng().fill_bytes(&mut buf);
57
58 u64::from_be_bytes(buf) as usize
59 };
60
61 Self {
62 inner_locker: Locker::new(value),
63 wakers: Default::default(),
64 #[cfg(feature = "trace_lock")]
65 id,
66 }
67 }
68}
69
70impl<Locker, Wakers, Mediator> AsyncLockable for AsyncLockableMaker<Locker, Wakers>
71where
72 Locker: Lockable + Send + Sync,
73 for<'a> Locker::GuardMut<'a>: Send + Unpin,
74 Wakers: Lockable + Send + Sync,
75 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
76 Mediator: AsyncLockableMediator + 'static,
77{
78 type GuardMut<'a>= AsyncLockableMakerGuard<'a, Locker, Wakers,Mediator>
79 where
80 Self: 'a;
81
82 type GuardMutFuture<'a> = AsyncLockableMakerFuture<'a, Locker, Wakers,Mediator>
83 where
84 Self: 'a;
85
86 #[track_caller]
87 fn lock(&self) -> Self::GuardMutFuture<'_> {
88 AsyncLockableMakerFuture {
89 locker: self,
90 wait_key: None,
91 #[cfg(feature = "trace_lock")]
92 caller: Location::caller(),
93 #[cfg(feature = "trace_lock")]
94 id: self.id,
95 }
96 }
97
98 fn unlock<'a>(guard: Self::GuardMut<'a>) -> &'a Self {
99 let locker = guard.locker;
100
101 drop(guard);
102
103 locker
104 }
105}
106
107pub struct AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
109where
110 Locker: Lockable,
111 Wakers: Lockable,
112 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
113 Mediator: AsyncLockableMediator,
114{
115 locker: &'a AsyncLockableMaker<Locker, Wakers>,
116 inner_guard: Option<Locker::GuardMut<'a>>,
117 #[cfg(feature = "trace_lock")]
118 caller: &'static Location<'static>,
119 #[cfg(feature = "trace_lock")]
120 id: usize,
121}
122
123impl<'a, Locker, Wakers, Mediator> AsyncGuardMut<'a>
124 for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
125where
126 Locker: Lockable + Send + Sync,
127 for<'b> Locker::GuardMut<'b>: Send + Unpin,
128 Wakers: Lockable + Send + Sync,
129 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
130 Mediator: AsyncLockableMediator + 'static,
131{
132 type Locker = AsyncLockableMaker<Locker, Wakers>;
133}
134
135impl<'a, Locker, Wakers, Mediator, T> ops::Deref
136 for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
137where
138 Locker: Lockable,
139 for<'c> Locker::GuardMut<'c>: ops::Deref<Target = T>,
140 Wakers: Lockable,
141 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
142 Mediator: AsyncLockableMediator,
143{
144 type Target = T;
145
146 fn deref(&self) -> &Self::Target {
147 self.inner_guard.as_deref().unwrap()
148 }
149}
150
151impl<'a, Locker, Wakers, Mediator, T> ops::DerefMut
152 for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
153where
154 Locker: Lockable,
155 for<'c> Locker::GuardMut<'c>: ops::DerefMut<Target = T>,
156 Wakers: Lockable,
157 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
158 Mediator: AsyncLockableMediator,
159{
160 fn deref_mut(&mut self) -> &mut Self::Target {
161 self.inner_guard.as_deref_mut().unwrap()
162 }
163}
164
165impl<'a, Locker, Wakers, Mediator> Drop for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
166where
167 Locker: Lockable,
168 Wakers: Lockable,
169 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
170 Mediator: AsyncLockableMediator,
171{
172 fn drop(&mut self) {
173 if let Some(guard) = self.inner_guard.take() {
174 drop(guard);
175
176 let mut wakers = self.locker.wakers.lock();
177
178 #[cfg(feature = "trace_lock")]
179 wakers.notify_one(self.id, self.caller);
180
181 #[cfg(not(feature = "trace_lock"))]
182 wakers.notify_all();
183 }
184 }
185}
186
187pub struct AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
189where
190 Locker: Lockable,
191 Wakers: Lockable,
192 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
193 Mediator: AsyncLockableMediator,
194{
195 locker: &'a AsyncLockableMaker<Locker, Wakers>,
196 wait_key: Option<usize>,
197 #[cfg(feature = "trace_lock")]
198 caller: &'static Location<'static>,
199 #[cfg(feature = "trace_lock")]
200 id: usize,
201}
202
203#[cfg(feature = "trace_lock")]
204impl<'a, Locker, Wakers, Mediator> Debug for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
205where
206 Locker: Lockable,
207 Wakers: Lockable,
208 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
209 Mediator: AsyncLockableMediator,
210{
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 write!(
213 f,
214 "mutex({}) caller: {}({})",
215 self.id,
216 self.caller.file(),
217 self.caller.line()
218 )
219 }
220}
221
222impl<'a, Locker, Wakers, Mediator> std::future::Future
223 for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
224where
225 Locker: Lockable,
226 Wakers: Lockable,
227 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
228 Mediator: AsyncLockableMediator,
229{
230 type Output = AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>;
231
232 fn poll(
233 mut self: std::pin::Pin<&mut Self>,
234 cx: &mut std::task::Context<'_>,
235 ) -> std::task::Poll<Self::Output> {
236 if let Some(guard) = self.locker.inner_locker.try_lock() {
237 #[cfg(feature = "trace_lock")]
238 log::trace!("{:?}, locked", self);
239
240 return std::task::Poll::Ready(AsyncLockableMakerGuard {
241 locker: self.locker,
242 inner_guard: Some(guard),
243 #[cfg(feature = "trace_lock")]
244 caller: self.caller,
245 #[cfg(feature = "trace_lock")]
246 id: self.id,
247 });
248 }
249
250 let mut wakers = self.locker.wakers.lock();
251
252 if let Some(guard) = self.locker.inner_locker.try_lock() {
255 #[cfg(feature = "trace_lock")]
256 log::trace!("{:?}, locked", self);
257
258 return std::task::Poll::Ready(AsyncLockableMakerGuard {
259 locker: self.locker,
260 inner_guard: Some(guard),
261 #[cfg(feature = "trace_lock")]
262 caller: self.caller,
263 #[cfg(feature = "trace_lock")]
264 id: self.id,
265 });
266 }
267
268 #[cfg(feature = "trace_lock")]
269 {
270 self.wait_key = Some(wakers.wait_lockable(cx, self.caller));
271 }
272
273 #[cfg(not(feature = "trace_lock"))]
274 {
275 self.wait_key = Some(wakers.wait_lockable(cx));
276 }
277
278 std::task::Poll::Pending
279 }
280}
281
282pub struct DefaultAsyncLockableMediator {
283 key_next: usize,
284 #[cfg(feature = "trace_lock")]
285 wakers: HashMap<usize, (&'static Location<'static>, Waker)>,
286 #[cfg(not(feature = "trace_lock"))]
287 wakers: HashMap<usize, Waker>,
288}
289
290impl Default for DefaultAsyncLockableMediator {
291 fn default() -> Self {
292 Self {
293 key_next: 0,
294 wakers: HashMap::new(),
295 }
296 }
297}
298
299impl AsyncLockableMediator for DefaultAsyncLockableMediator {
300 #[cfg(not(feature = "trace_lock"))]
301 fn wait_lockable(&mut self, cx: &mut std::task::Context<'_>) -> usize {
302 let key = self.key_next;
303 self.key_next += 1;
304
305 self.wakers.insert(key, cx.waker().clone());
306
307 key
308 }
309
310 #[cfg(feature = "trace_lock")]
311 fn wait_lockable(
312 &mut self,
313 cx: &mut std::task::Context<'_>,
314 tracer: &'static Location<'static>,
315 ) -> usize {
316 let key = self.key_next;
317 self.key_next += 1;
318
319 log::trace!(
320 "async lock pending,caller: {}({}), ptr={:?}",
321 tracer.file(),
322 tracer.line(),
323 self as *mut Self,
324 );
325
326 self.wakers.insert(key, (tracer, cx.waker().clone()));
327
328 key
329 }
330
331 fn cancel(&mut self, key: usize) -> bool {
332 #[cfg(feature = "trace_lock")]
333 {
334 if let Some((tracer, _)) = self.wakers.remove(&key) {
335 log::trace!(
336 "async locked remove pending,caller: {}({}), ptr={:?}",
337 tracer.file(),
338 tracer.line(),
339 self as *mut Self,
340 );
341
342 true
343 } else {
344 false
345 }
346 }
347
348 #[cfg(not(feature = "trace_lock"))]
349 {
350 return self.wakers.remove(&key).is_some();
351 }
352 }
353 #[cfg(feature = "trace_lock")]
354 fn notify_one(&mut self, id: usize, tracer: &'static std::panic::Location<'static>) {
355 log::trace!(
356 "mutex({}) caller: {}({}), notify one pending({}) waker",
357 id,
358 tracer.file(),
359 tracer.line(),
360 self.wakers.len(),
361 );
362
363 let mut keys = self.wakers.keys().cloned().collect::<Vec<_>>();
364
365 if !keys.is_empty() {
366 keys.sort();
367
368 let (waker_tracer, waker) = self.wakers.remove(&keys[0]).unwrap();
369 log::trace!(
370 "mutex({}) caller: {}({}), notify waker: {}({})",
371 id,
372 tracer.file(),
373 tracer.line(),
374 waker_tracer.file(),
375 waker_tracer.line(),
376 );
377
378 waker.wake();
379
380 log::trace!(
381 "mutex({}) caller: {}({}), notify waker: {}({}) -- success",
382 id,
383 tracer.file(),
384 tracer.line(),
385 waker_tracer.file(),
386 waker_tracer.line(),
387 );
388 }
389 }
390
391 #[cfg(not(feature = "trace_lock"))]
392 fn notify_one(&mut self) {
393 let mut keys = self.wakers.keys().cloned().collect::<Vec<_>>();
394
395 if !keys.is_empty() {
396 keys.sort();
397
398 let waker = self.wakers.remove(&keys[0]).unwrap();
399
400 waker.wake();
401 }
402 }
403
404 fn notify_all(&mut self) {
405 #[cfg(feature = "trace_lock")]
406 for (_, (_, waker)) in self.wakers.drain() {
407 waker.wake();
408 }
409
410 #[cfg(not(feature = "trace_lock"))]
411 for (_, waker) in self.wakers.drain() {
412 waker.wake();
413 }
414 }
415}