rasi_ext/utils/sync/
maker.rs

1use std::{
2    collections::HashMap,
3    ops::{self, DerefMut},
4    task::Waker,
5};
6
7#[cfg(feature = "trace_lock")]
8use std::{fmt::Debug, panic::Location};
9
10use super::*;
11
12/// Type factory for [`AsyncLockable`]
13pub struct AsyncLockableMaker<Locker, Wakers> {
14    inner_locker: Locker,
15    wakers: Wakers,
16    #[cfg(feature = "trace_lock")]
17    id: usize,
18}
19
20impl<Locker, Wakers> Default for AsyncLockableMaker<Locker, Wakers>
21where
22    Locker: Default,
23    Wakers: Default,
24{
25    fn default() -> Self {
26        #[cfg(feature = "trace_lock")]
27        let id = {
28            use rand::{thread_rng, RngCore};
29
30            let mut buf = [0u8; 8];
31            thread_rng().fill_bytes(&mut buf);
32
33            u64::from_be_bytes(buf) as usize
34        };
35
36        Self {
37            inner_locker: Default::default(),
38            wakers: Default::default(),
39            #[cfg(feature = "trace_lock")]
40            id,
41        }
42    }
43}
44
45impl<Locker, Wakers> AsyncLockableMaker<Locker, Wakers>
46where
47    Locker: LockableNew,
48    Wakers: Default,
49{
50    pub fn new(value: Locker::Value) -> Self {
51        #[cfg(feature = "trace_lock")]
52        let id = {
53            use rand::{thread_rng, RngCore};
54
55            let mut buf = [0u8; 8];
56            thread_rng().fill_bytes(&mut buf);
57
58            u64::from_be_bytes(buf) as usize
59        };
60
61        Self {
62            inner_locker: Locker::new(value),
63            wakers: Default::default(),
64            #[cfg(feature = "trace_lock")]
65            id,
66        }
67    }
68}
69
70impl<Locker, Wakers, Mediator> AsyncLockable for AsyncLockableMaker<Locker, Wakers>
71where
72    Locker: Lockable + Send + Sync,
73    for<'a> Locker::GuardMut<'a>: Send + Unpin,
74    Wakers: Lockable + Send + Sync,
75    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
76    Mediator: AsyncLockableMediator + 'static,
77{
78    type GuardMut<'a>= AsyncLockableMakerGuard<'a, Locker, Wakers,Mediator>
79    where
80        Self: 'a;
81
82    type GuardMutFuture<'a> = AsyncLockableMakerFuture<'a, Locker, Wakers,Mediator>
83    where
84        Self: 'a;
85
86    #[track_caller]
87    fn lock(&self) -> Self::GuardMutFuture<'_> {
88        AsyncLockableMakerFuture {
89            locker: self,
90            wait_key: None,
91            #[cfg(feature = "trace_lock")]
92            caller: Location::caller(),
93            #[cfg(feature = "trace_lock")]
94            id: self.id,
95        }
96    }
97
98    fn unlock<'a>(guard: Self::GuardMut<'a>) -> &'a Self {
99        let locker = guard.locker;
100
101        drop(guard);
102
103        locker
104    }
105}
106
107/// RAII `Guard` type for [`AsyncLockableMaker`]
108pub struct AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
109where
110    Locker: Lockable,
111    Wakers: Lockable,
112    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
113    Mediator: AsyncLockableMediator,
114{
115    locker: &'a AsyncLockableMaker<Locker, Wakers>,
116    inner_guard: Option<Locker::GuardMut<'a>>,
117    #[cfg(feature = "trace_lock")]
118    caller: &'static Location<'static>,
119    #[cfg(feature = "trace_lock")]
120    id: usize,
121}
122
123impl<'a, Locker, Wakers, Mediator> AsyncGuardMut<'a>
124    for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
125where
126    Locker: Lockable + Send + Sync,
127    for<'b> Locker::GuardMut<'b>: Send + Unpin,
128    Wakers: Lockable + Send + Sync,
129    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
130    Mediator: AsyncLockableMediator + 'static,
131{
132    type Locker = AsyncLockableMaker<Locker, Wakers>;
133}
134
135impl<'a, Locker, Wakers, Mediator, T> ops::Deref
136    for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
137where
138    Locker: Lockable,
139    for<'c> Locker::GuardMut<'c>: ops::Deref<Target = T>,
140    Wakers: Lockable,
141    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
142    Mediator: AsyncLockableMediator,
143{
144    type Target = T;
145
146    fn deref(&self) -> &Self::Target {
147        self.inner_guard.as_deref().unwrap()
148    }
149}
150
151impl<'a, Locker, Wakers, Mediator, T> ops::DerefMut
152    for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
153where
154    Locker: Lockable,
155    for<'c> Locker::GuardMut<'c>: ops::DerefMut<Target = T>,
156    Wakers: Lockable,
157    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
158    Mediator: AsyncLockableMediator,
159{
160    fn deref_mut(&mut self) -> &mut Self::Target {
161        self.inner_guard.as_deref_mut().unwrap()
162    }
163}
164
165impl<'a, Locker, Wakers, Mediator> Drop for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
166where
167    Locker: Lockable,
168    Wakers: Lockable,
169    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
170    Mediator: AsyncLockableMediator,
171{
172    fn drop(&mut self) {
173        if let Some(guard) = self.inner_guard.take() {
174            drop(guard);
175
176            let mut wakers = self.locker.wakers.lock();
177
178            #[cfg(feature = "trace_lock")]
179            wakers.notify_one(self.id, self.caller);
180
181            #[cfg(not(feature = "trace_lock"))]
182            wakers.notify_all();
183        }
184    }
185}
186
187/// Future created by [`lock`](AsyncLockableMaker::lock) function.
188pub struct AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
189where
190    Locker: Lockable,
191    Wakers: Lockable,
192    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
193    Mediator: AsyncLockableMediator,
194{
195    locker: &'a AsyncLockableMaker<Locker, Wakers>,
196    wait_key: Option<usize>,
197    #[cfg(feature = "trace_lock")]
198    caller: &'static Location<'static>,
199    #[cfg(feature = "trace_lock")]
200    id: usize,
201}
202
203#[cfg(feature = "trace_lock")]
204impl<'a, Locker, Wakers, Mediator> Debug for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
205where
206    Locker: Lockable,
207    Wakers: Lockable,
208    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
209    Mediator: AsyncLockableMediator,
210{
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        write!(
213            f,
214            "mutex({}) caller: {}({})",
215            self.id,
216            self.caller.file(),
217            self.caller.line()
218        )
219    }
220}
221
222impl<'a, Locker, Wakers, Mediator> std::future::Future
223    for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
224where
225    Locker: Lockable,
226    Wakers: Lockable,
227    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
228    Mediator: AsyncLockableMediator,
229{
230    type Output = AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>;
231
232    fn poll(
233        mut self: std::pin::Pin<&mut Self>,
234        cx: &mut std::task::Context<'_>,
235    ) -> std::task::Poll<Self::Output> {
236        if let Some(guard) = self.locker.inner_locker.try_lock() {
237            #[cfg(feature = "trace_lock")]
238            log::trace!("{:?}, locked", self);
239
240            return std::task::Poll::Ready(AsyncLockableMakerGuard {
241                locker: self.locker,
242                inner_guard: Some(guard),
243                #[cfg(feature = "trace_lock")]
244                caller: self.caller,
245                #[cfg(feature = "trace_lock")]
246                id: self.id,
247            });
248        }
249
250        let mut wakers = self.locker.wakers.lock();
251
252        // Ensure that we haven't raced `MutexGuard::drop`'s unlock path by
253        // attempting to acquire the lock again
254        if let Some(guard) = self.locker.inner_locker.try_lock() {
255            #[cfg(feature = "trace_lock")]
256            log::trace!("{:?}, locked", self);
257
258            return std::task::Poll::Ready(AsyncLockableMakerGuard {
259                locker: self.locker,
260                inner_guard: Some(guard),
261                #[cfg(feature = "trace_lock")]
262                caller: self.caller,
263                #[cfg(feature = "trace_lock")]
264                id: self.id,
265            });
266        }
267
268        #[cfg(feature = "trace_lock")]
269        {
270            self.wait_key = Some(wakers.wait_lockable(cx, self.caller));
271        }
272
273        #[cfg(not(feature = "trace_lock"))]
274        {
275            self.wait_key = Some(wakers.wait_lockable(cx));
276        }
277
278        std::task::Poll::Pending
279    }
280}
281
282pub struct DefaultAsyncLockableMediator {
283    key_next: usize,
284    #[cfg(feature = "trace_lock")]
285    wakers: HashMap<usize, (&'static Location<'static>, Waker)>,
286    #[cfg(not(feature = "trace_lock"))]
287    wakers: HashMap<usize, Waker>,
288}
289
290impl Default for DefaultAsyncLockableMediator {
291    fn default() -> Self {
292        Self {
293            key_next: 0,
294            wakers: HashMap::new(),
295        }
296    }
297}
298
299impl AsyncLockableMediator for DefaultAsyncLockableMediator {
300    #[cfg(not(feature = "trace_lock"))]
301    fn wait_lockable(&mut self, cx: &mut std::task::Context<'_>) -> usize {
302        let key = self.key_next;
303        self.key_next += 1;
304
305        self.wakers.insert(key, cx.waker().clone());
306
307        key
308    }
309
310    #[cfg(feature = "trace_lock")]
311    fn wait_lockable(
312        &mut self,
313        cx: &mut std::task::Context<'_>,
314        tracer: &'static Location<'static>,
315    ) -> usize {
316        let key = self.key_next;
317        self.key_next += 1;
318
319        log::trace!(
320            "async lock pending,caller: {}({}), ptr={:?}",
321            tracer.file(),
322            tracer.line(),
323            self as *mut Self,
324        );
325
326        self.wakers.insert(key, (tracer, cx.waker().clone()));
327
328        key
329    }
330
331    fn cancel(&mut self, key: usize) -> bool {
332        #[cfg(feature = "trace_lock")]
333        {
334            if let Some((tracer, _)) = self.wakers.remove(&key) {
335                log::trace!(
336                    "async locked remove pending,caller: {}({}), ptr={:?}",
337                    tracer.file(),
338                    tracer.line(),
339                    self as *mut Self,
340                );
341
342                true
343            } else {
344                false
345            }
346        }
347
348        #[cfg(not(feature = "trace_lock"))]
349        {
350            return self.wakers.remove(&key).is_some();
351        }
352    }
353    #[cfg(feature = "trace_lock")]
354    fn notify_one(&mut self, id: usize, tracer: &'static std::panic::Location<'static>) {
355        log::trace!(
356            "mutex({}) caller: {}({}), notify one pending({}) waker",
357            id,
358            tracer.file(),
359            tracer.line(),
360            self.wakers.len(),
361        );
362
363        let mut keys = self.wakers.keys().cloned().collect::<Vec<_>>();
364
365        if !keys.is_empty() {
366            keys.sort();
367
368            let (waker_tracer, waker) = self.wakers.remove(&keys[0]).unwrap();
369            log::trace!(
370                "mutex({}) caller: {}({}), notify waker: {}({})",
371                id,
372                tracer.file(),
373                tracer.line(),
374                waker_tracer.file(),
375                waker_tracer.line(),
376            );
377
378            waker.wake();
379
380            log::trace!(
381                "mutex({}) caller: {}({}), notify waker: {}({}) -- success",
382                id,
383                tracer.file(),
384                tracer.line(),
385                waker_tracer.file(),
386                waker_tracer.line(),
387            );
388        }
389    }
390
391    #[cfg(not(feature = "trace_lock"))]
392    fn notify_one(&mut self) {
393        let mut keys = self.wakers.keys().cloned().collect::<Vec<_>>();
394
395        if !keys.is_empty() {
396            keys.sort();
397
398            let waker = self.wakers.remove(&keys[0]).unwrap();
399
400            waker.wake();
401        }
402    }
403
404    fn notify_all(&mut self) {
405        #[cfg(feature = "trace_lock")]
406        for (_, (_, waker)) in self.wakers.drain() {
407            waker.wake();
408        }
409
410        #[cfg(not(feature = "trace_lock"))]
411        for (_, waker) in self.wakers.drain() {
412            waker.wake();
413        }
414    }
415}