hala_sync/
maker.rs

1use std::{
2    collections::HashMap,
3    ops::{self, DerefMut},
4    task::Waker,
5};
6
7#[cfg(feature = "trace_lock")]
8use std::{fmt::Debug, panic::Location};
9
10use super::*;
11
12/// Type factory for [`AsyncLockable`]
13pub struct AsyncLockableMaker<Locker, Wakers> {
14    inner_locker: Locker,
15    wakers: Wakers,
16}
17
18impl<Locker, Wakers> Default for AsyncLockableMaker<Locker, Wakers>
19where
20    Locker: Default,
21    Wakers: Default,
22{
23    fn default() -> Self {
24        Self {
25            inner_locker: Default::default(),
26            wakers: Default::default(),
27        }
28    }
29}
30
31impl<Locker, Wakers> AsyncLockableMaker<Locker, Wakers>
32where
33    Locker: LockableNew,
34    Wakers: Default,
35{
36    pub fn new(value: Locker::Value) -> Self {
37        Self {
38            inner_locker: Locker::new(value),
39            wakers: Default::default(),
40        }
41    }
42}
43
44impl<Locker, Wakers, Mediator> AsyncLockable for AsyncLockableMaker<Locker, Wakers>
45where
46    Locker: Lockable + Send + Sync,
47    for<'a> Locker::GuardMut<'a>: Send + Unpin,
48    Wakers: Lockable + Send + Sync,
49    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
50    Mediator: AsyncLockableMediator + 'static,
51{
52    type GuardMut<'a>= AsyncLockableMakerGuard<'a, Locker, Wakers,Mediator>
53    where
54        Self: 'a;
55
56    type GuardMutFuture<'a> = AsyncLockableMakerFuture<'a, Locker, Wakers,Mediator>
57    where
58        Self: 'a;
59
60    #[track_caller]
61    fn lock(&self) -> Self::GuardMutFuture<'_> {
62        AsyncLockableMakerFuture {
63            locker: self,
64            wait_key: None,
65            #[cfg(feature = "trace_lock")]
66            caller: Location::caller(),
67        }
68    }
69
70    fn unlock<'a>(guard: Self::GuardMut<'a>) -> &'a Self {
71        let locker = guard.locker;
72
73        drop(guard);
74
75        locker
76    }
77}
78
79/// RAII `Guard` type for [`AsyncLockableMaker`]
80pub struct AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
81where
82    Locker: Lockable,
83    Wakers: Lockable,
84    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
85    Mediator: AsyncLockableMediator,
86{
87    locker: &'a AsyncLockableMaker<Locker, Wakers>,
88    inner_guard: Option<Locker::GuardMut<'a>>,
89}
90
91impl<'a, Locker, Wakers, Mediator> AsyncGuardMut<'a>
92    for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
93where
94    Locker: Lockable + Send + Sync,
95    for<'b> Locker::GuardMut<'b>: Send + Unpin,
96    Wakers: Lockable + Send + Sync,
97    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
98    Mediator: AsyncLockableMediator + 'static,
99{
100    type Locker = AsyncLockableMaker<Locker, Wakers>;
101}
102
103impl<'a, Locker, Wakers, Mediator, T> ops::Deref
104    for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
105where
106    Locker: Lockable,
107    for<'c> Locker::GuardMut<'c>: ops::Deref<Target = T>,
108    Wakers: Lockable,
109    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
110    Mediator: AsyncLockableMediator,
111{
112    type Target = T;
113
114    fn deref(&self) -> &Self::Target {
115        self.inner_guard.as_deref().unwrap()
116    }
117}
118
119impl<'a, Locker, Wakers, Mediator, T> ops::DerefMut
120    for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
121where
122    Locker: Lockable,
123    for<'c> Locker::GuardMut<'c>: ops::DerefMut<Target = T>,
124    Wakers: Lockable,
125    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
126    Mediator: AsyncLockableMediator,
127{
128    fn deref_mut(&mut self) -> &mut Self::Target {
129        self.inner_guard.as_deref_mut().unwrap()
130    }
131}
132
133impl<'a, Locker, Wakers, Mediator> Drop for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
134where
135    Locker: Lockable,
136    Wakers: Lockable,
137    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
138    Mediator: AsyncLockableMediator,
139{
140    fn drop(&mut self) {
141        if let Some(guard) = self.inner_guard.take() {
142            drop(guard);
143
144            let mut wakers = self.locker.wakers.lock();
145
146            wakers.notify_all();
147        }
148    }
149}
150
151/// Future created by [`lock`](AsyncLockableMaker::lock) function.
152pub struct AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
153where
154    Locker: Lockable,
155    Wakers: Lockable,
156    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
157    Mediator: AsyncLockableMediator,
158{
159    locker: &'a AsyncLockableMaker<Locker, Wakers>,
160    wait_key: Option<usize>,
161    #[cfg(feature = "trace_lock")]
162    caller: &'static Location<'static>,
163}
164
165#[cfg(feature = "trace_lock")]
166impl<'a, Locker, Wakers, Mediator> Debug for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
167where
168    Locker: Lockable,
169    Wakers: Lockable,
170    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
171    Mediator: AsyncLockableMediator,
172{
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        write!(f, "caller: {}({})", self.caller.file(), self.caller.line())
175    }
176}
177
178impl<'a, Locker, Wakers, Mediator> std::future::Future
179    for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
180where
181    Locker: Lockable,
182    Wakers: Lockable,
183    for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
184    Mediator: AsyncLockableMediator,
185{
186    type Output = AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>;
187
188    fn poll(
189        mut self: std::pin::Pin<&mut Self>,
190        cx: &mut std::task::Context<'_>,
191    ) -> std::task::Poll<Self::Output> {
192        #[cfg(feature = "trace_lock")]
193        log::trace!("async try lock, {:?}", self);
194
195        if let Some(guard) = self.locker.inner_locker.try_lock() {
196            #[cfg(feature = "trace_lock")]
197            log::trace!("async locked, {:?}", self);
198
199            return std::task::Poll::Ready(AsyncLockableMakerGuard {
200                locker: self.locker,
201                inner_guard: Some(guard),
202            });
203        }
204
205        let mut wakers = self.locker.wakers.lock();
206        self.wait_key = Some(wakers.wait_lockable(cx));
207
208        // Ensure that we haven't raced `MutexGuard::drop`'s unlock path by
209        // attempting to acquire the lock again
210        if let Some(guard) = self.locker.inner_locker.try_lock() {
211            #[cfg(feature = "trace_lock")]
212            log::trace!("async locked, {:?}", self);
213
214            wakers.cancel(self.wait_key.take().unwrap());
215
216            return std::task::Poll::Ready(AsyncLockableMakerGuard {
217                locker: self.locker,
218                inner_guard: Some(guard),
219            });
220        }
221
222        #[cfg(feature = "trace_lock")]
223        log::trace!("async lock pending, {:?}", self);
224
225        std::task::Poll::Pending
226    }
227}
228
229pub struct DefaultAsyncLockableMediator {
230    key_next: usize,
231    wakers: HashMap<usize, Waker>,
232}
233
234impl Default for DefaultAsyncLockableMediator {
235    fn default() -> Self {
236        Self {
237            key_next: 0,
238            wakers: HashMap::new(),
239        }
240    }
241}
242
243impl AsyncLockableMediator for DefaultAsyncLockableMediator {
244    fn wait_lockable(&mut self, cx: &mut std::task::Context<'_>) -> usize {
245        let key = self.key_next;
246        self.key_next += 1;
247
248        self.wakers.insert(key, cx.waker().clone());
249
250        key
251    }
252
253    fn cancel(&mut self, key: usize) -> bool {
254        self.wakers.remove(&key).is_some()
255    }
256
257    fn notify_one(&mut self) {
258        if let Some(key) = self.wakers.keys().next().map(|key| *key) {
259            self.wakers.remove(&key).unwrap().wake();
260        }
261    }
262
263    fn notify_all(&mut self) {
264        for (_, waker) in self.wakers.drain() {
265            waker.wake();
266        }
267    }
268}