1use std::{
2 collections::HashMap,
3 ops::{self, DerefMut},
4 task::Waker,
5};
6
7#[cfg(feature = "trace_lock")]
8use std::{fmt::Debug, panic::Location};
9
10use super::*;
11
12pub struct AsyncLockableMaker<Locker, Wakers> {
14 inner_locker: Locker,
15 wakers: Wakers,
16}
17
18impl<Locker, Wakers> Default for AsyncLockableMaker<Locker, Wakers>
19where
20 Locker: Default,
21 Wakers: Default,
22{
23 fn default() -> Self {
24 Self {
25 inner_locker: Default::default(),
26 wakers: Default::default(),
27 }
28 }
29}
30
31impl<Locker, Wakers> AsyncLockableMaker<Locker, Wakers>
32where
33 Locker: LockableNew,
34 Wakers: Default,
35{
36 pub fn new(value: Locker::Value) -> Self {
37 Self {
38 inner_locker: Locker::new(value),
39 wakers: Default::default(),
40 }
41 }
42}
43
44impl<Locker, Wakers, Mediator> AsyncLockable for AsyncLockableMaker<Locker, Wakers>
45where
46 Locker: Lockable + Send + Sync,
47 for<'a> Locker::GuardMut<'a>: Send + Unpin,
48 Wakers: Lockable + Send + Sync,
49 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
50 Mediator: AsyncLockableMediator + 'static,
51{
52 type GuardMut<'a>= AsyncLockableMakerGuard<'a, Locker, Wakers,Mediator>
53 where
54 Self: 'a;
55
56 type GuardMutFuture<'a> = AsyncLockableMakerFuture<'a, Locker, Wakers,Mediator>
57 where
58 Self: 'a;
59
60 #[track_caller]
61 fn lock(&self) -> Self::GuardMutFuture<'_> {
62 AsyncLockableMakerFuture {
63 locker: self,
64 wait_key: None,
65 #[cfg(feature = "trace_lock")]
66 caller: Location::caller(),
67 }
68 }
69
70 fn unlock<'a>(guard: Self::GuardMut<'a>) -> &'a Self {
71 let locker = guard.locker;
72
73 drop(guard);
74
75 locker
76 }
77}
78
79pub struct AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
81where
82 Locker: Lockable,
83 Wakers: Lockable,
84 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
85 Mediator: AsyncLockableMediator,
86{
87 locker: &'a AsyncLockableMaker<Locker, Wakers>,
88 inner_guard: Option<Locker::GuardMut<'a>>,
89}
90
91impl<'a, Locker, Wakers, Mediator> AsyncGuardMut<'a>
92 for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
93where
94 Locker: Lockable + Send + Sync,
95 for<'b> Locker::GuardMut<'b>: Send + Unpin,
96 Wakers: Lockable + Send + Sync,
97 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
98 Mediator: AsyncLockableMediator + 'static,
99{
100 type Locker = AsyncLockableMaker<Locker, Wakers>;
101}
102
103impl<'a, Locker, Wakers, Mediator, T> ops::Deref
104 for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
105where
106 Locker: Lockable,
107 for<'c> Locker::GuardMut<'c>: ops::Deref<Target = T>,
108 Wakers: Lockable,
109 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
110 Mediator: AsyncLockableMediator,
111{
112 type Target = T;
113
114 fn deref(&self) -> &Self::Target {
115 self.inner_guard.as_deref().unwrap()
116 }
117}
118
119impl<'a, Locker, Wakers, Mediator, T> ops::DerefMut
120 for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
121where
122 Locker: Lockable,
123 for<'c> Locker::GuardMut<'c>: ops::DerefMut<Target = T>,
124 Wakers: Lockable,
125 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
126 Mediator: AsyncLockableMediator,
127{
128 fn deref_mut(&mut self) -> &mut Self::Target {
129 self.inner_guard.as_deref_mut().unwrap()
130 }
131}
132
133impl<'a, Locker, Wakers, Mediator> Drop for AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>
134where
135 Locker: Lockable,
136 Wakers: Lockable,
137 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
138 Mediator: AsyncLockableMediator,
139{
140 fn drop(&mut self) {
141 if let Some(guard) = self.inner_guard.take() {
142 drop(guard);
143
144 let mut wakers = self.locker.wakers.lock();
145
146 wakers.notify_all();
147 }
148 }
149}
150
151pub struct AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
153where
154 Locker: Lockable,
155 Wakers: Lockable,
156 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
157 Mediator: AsyncLockableMediator,
158{
159 locker: &'a AsyncLockableMaker<Locker, Wakers>,
160 wait_key: Option<usize>,
161 #[cfg(feature = "trace_lock")]
162 caller: &'static Location<'static>,
163}
164
165#[cfg(feature = "trace_lock")]
166impl<'a, Locker, Wakers, Mediator> Debug for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
167where
168 Locker: Lockable,
169 Wakers: Lockable,
170 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
171 Mediator: AsyncLockableMediator,
172{
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 write!(f, "caller: {}({})", self.caller.file(), self.caller.line())
175 }
176}
177
178impl<'a, Locker, Wakers, Mediator> std::future::Future
179 for AsyncLockableMakerFuture<'a, Locker, Wakers, Mediator>
180where
181 Locker: Lockable,
182 Wakers: Lockable,
183 for<'b> Wakers::GuardMut<'b>: DerefMut<Target = Mediator>,
184 Mediator: AsyncLockableMediator,
185{
186 type Output = AsyncLockableMakerGuard<'a, Locker, Wakers, Mediator>;
187
188 fn poll(
189 mut self: std::pin::Pin<&mut Self>,
190 cx: &mut std::task::Context<'_>,
191 ) -> std::task::Poll<Self::Output> {
192 #[cfg(feature = "trace_lock")]
193 log::trace!("async try lock, {:?}", self);
194
195 if let Some(guard) = self.locker.inner_locker.try_lock() {
196 #[cfg(feature = "trace_lock")]
197 log::trace!("async locked, {:?}", self);
198
199 return std::task::Poll::Ready(AsyncLockableMakerGuard {
200 locker: self.locker,
201 inner_guard: Some(guard),
202 });
203 }
204
205 let mut wakers = self.locker.wakers.lock();
206 self.wait_key = Some(wakers.wait_lockable(cx));
207
208 if let Some(guard) = self.locker.inner_locker.try_lock() {
211 #[cfg(feature = "trace_lock")]
212 log::trace!("async locked, {:?}", self);
213
214 wakers.cancel(self.wait_key.take().unwrap());
215
216 return std::task::Poll::Ready(AsyncLockableMakerGuard {
217 locker: self.locker,
218 inner_guard: Some(guard),
219 });
220 }
221
222 #[cfg(feature = "trace_lock")]
223 log::trace!("async lock pending, {:?}", self);
224
225 std::task::Poll::Pending
226 }
227}
228
229pub struct DefaultAsyncLockableMediator {
230 key_next: usize,
231 wakers: HashMap<usize, Waker>,
232}
233
234impl Default for DefaultAsyncLockableMediator {
235 fn default() -> Self {
236 Self {
237 key_next: 0,
238 wakers: HashMap::new(),
239 }
240 }
241}
242
243impl AsyncLockableMediator for DefaultAsyncLockableMediator {
244 fn wait_lockable(&mut self, cx: &mut std::task::Context<'_>) -> usize {
245 let key = self.key_next;
246 self.key_next += 1;
247
248 self.wakers.insert(key, cx.waker().clone());
249
250 key
251 }
252
253 fn cancel(&mut self, key: usize) -> bool {
254 self.wakers.remove(&key).is_some()
255 }
256
257 fn notify_one(&mut self) {
258 if let Some(key) = self.wakers.keys().next().map(|key| *key) {
259 self.wakers.remove(&key).unwrap().wake();
260 }
261 }
262
263 fn notify_all(&mut self) {
264 for (_, waker) in self.wakers.drain() {
265 waker.wake();
266 }
267 }
268}