1use crate::bytecode::{TaskHandle, Value};
2use hashbrown::HashMap;
3use std::collections::VecDeque;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, Condvar, Mutex};
8use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
9
10pub(crate) type AsyncValueFuture =
11 Pin<Box<dyn Future<Output = std::result::Result<Value, String>>>>;
12
13pub(crate) struct AsyncRegistry {
14 pub(crate) next_id: u64,
15 pub(crate) pending: HashMap<u64, AsyncTaskEntry>,
16}
17
18impl AsyncRegistry {
19 pub(crate) fn new() -> Self {
20 Self {
21 next_id: 1,
22 pending: HashMap::new(),
23 }
24 }
25
26 pub(crate) fn register(&mut self, entry: AsyncTaskEntry) -> u64 {
27 let id = self.next_id;
28 self.next_id += 1;
29 self.pending.insert(id, entry);
30 id
31 }
32
33 pub(crate) fn has_pending_for(&self, handle: TaskHandle) -> bool {
34 self.pending.values().any(|entry| match entry.target {
35 AsyncTaskTarget::ScriptTask(existing) | AsyncTaskTarget::NativeTask(existing) => {
36 existing == handle
37 }
38 })
39 }
40
41 pub(crate) fn is_empty(&self) -> bool {
42 self.pending.is_empty()
43 }
44}
45
46pub(crate) struct AsyncTaskEntry {
47 pub(crate) target: AsyncTaskTarget,
48 pub(crate) future: AsyncValueFuture,
49 wake_flag: Arc<WakeFlag>,
50 immediate_poll: bool,
51}
52
53#[derive(Clone, Copy)]
54pub(crate) enum AsyncTaskTarget {
55 ScriptTask(TaskHandle),
56 NativeTask(TaskHandle),
57}
58
59impl AsyncTaskEntry {
60 pub(crate) fn new(target: AsyncTaskTarget, future: AsyncValueFuture) -> Self {
61 Self {
62 target,
63 future,
64 wake_flag: Arc::new(WakeFlag::new()),
65 immediate_poll: true,
66 }
67 }
68
69 pub(crate) fn take_should_poll(&mut self) -> bool {
70 if self.immediate_poll {
71 self.immediate_poll = false;
72 true
73 } else {
74 self.wake_flag.take()
75 }
76 }
77
78 pub(crate) fn make_waker(&self) -> Waker {
79 make_async_waker(&self.wake_flag)
80 }
81}
82
83struct WakeFlag {
84 pending: AtomicBool,
85}
86
87impl WakeFlag {
88 fn new() -> Self {
89 Self {
90 pending: AtomicBool::new(true),
91 }
92 }
93
94 fn take(&self) -> bool {
95 self.pending.swap(false, Ordering::SeqCst)
96 }
97
98 fn wake(&self) {
99 self.pending.store(true, Ordering::SeqCst);
100 }
101}
102
103fn make_async_waker(flag: &Arc<WakeFlag>) -> Waker {
104 unsafe {
105 Waker::from_raw(RawWaker::new(
106 Arc::into_raw(flag.clone()) as *const (),
107 &ASYNC_WAKER_VTABLE,
108 ))
109 }
110}
111
112unsafe fn async_waker_clone(ptr: *const ()) -> RawWaker {
113 let arc = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
114 let cloned = arc.clone();
115 std::mem::forget(arc);
116 RawWaker::new(Arc::into_raw(cloned) as *const (), &ASYNC_WAKER_VTABLE)
117}
118
119unsafe fn async_waker_wake(ptr: *const ()) {
120 let arc = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
121 arc.wake();
122}
123
124unsafe fn async_waker_wake_by_ref(ptr: *const ()) {
125 let arc = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
126 arc.wake();
127 std::mem::forget(arc);
128}
129
130unsafe fn async_waker_drop(ptr: *const ()) {
131 let _ = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
132}
133
134static ASYNC_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
135 async_waker_clone,
136 async_waker_wake,
137 async_waker_wake_by_ref,
138 async_waker_drop,
139);
140
141struct AsyncTaskQueueInner<Args, R> {
142 queue: Mutex<VecDeque<PendingAsyncTask<Args, R>>>,
143 condvar: Condvar,
144}
145
146pub struct AsyncTaskQueue<Args, R> {
147 inner: Arc<AsyncTaskQueueInner<Args, R>>,
148}
149
150impl<Args, R> Clone for AsyncTaskQueue<Args, R> {
151 fn clone(&self) -> Self {
152 Self {
153 inner: Arc::clone(&self.inner),
154 }
155 }
156}
157
158impl<Args, R> AsyncTaskQueue<Args, R> {
159 pub fn new() -> Self {
160 Self {
161 inner: Arc::new(AsyncTaskQueueInner {
162 queue: Mutex::new(VecDeque::new()),
163 condvar: Condvar::new(),
164 }),
165 }
166 }
167
168 pub fn push(&self, task: PendingAsyncTask<Args, R>) {
169 let mut guard = self.inner.queue.lock().unwrap();
170 guard.push_back(task);
171 self.inner.condvar.notify_one();
172 }
173
174 pub fn pop(&self) -> Option<PendingAsyncTask<Args, R>> {
175 let mut guard = self.inner.queue.lock().unwrap();
176 guard.pop_front()
177 }
178
179 pub fn pop_blocking(&self) -> PendingAsyncTask<Args, R> {
180 let mut guard = self.inner.queue.lock().unwrap();
181 loop {
182 if let Some(task) = guard.pop_front() {
183 return task;
184 }
185 guard = self.inner.condvar.wait(guard).unwrap();
186 }
187 }
188
189 pub fn len(&self) -> usize {
190 let guard = self.inner.queue.lock().unwrap();
191 guard.len()
192 }
193
194 pub fn is_empty(&self) -> bool {
195 let guard = self.inner.queue.lock().unwrap();
196 guard.is_empty()
197 }
198}
199
200pub struct PendingAsyncTask<Args, R> {
201 task: TaskHandle,
202 args: Args,
203 completer: AsyncCompleter<R>,
204}
205
206impl<Args, R> PendingAsyncTask<Args, R> {
207 pub(crate) fn new(task: TaskHandle, args: Args, completer: AsyncCompleter<R>) -> Self {
208 Self {
209 task,
210 args,
211 completer,
212 }
213 }
214
215 pub fn task(&self) -> TaskHandle {
216 self.task
217 }
218
219 pub fn args(&self) -> &Args {
220 &self.args
221 }
222
223 pub fn complete_ok(self, value: R) {
224 self.completer.complete_ok(value);
225 }
226
227 pub fn complete_err(self, message: impl Into<String>) {
228 self.completer.complete_err(message);
229 }
230}
231
232struct AsyncSignalState<T> {
233 result: Mutex<Option<std::result::Result<T, String>>>,
234 waker: Mutex<Option<Waker>>,
235}
236
237pub(crate) struct AsyncCompleter<T> {
238 inner: Arc<AsyncSignalState<T>>,
239}
240
241impl<T> Clone for AsyncCompleter<T> {
242 fn clone(&self) -> Self {
243 Self {
244 inner: Arc::clone(&self.inner),
245 }
246 }
247}
248
249impl<T> AsyncCompleter<T> {
250 fn complete_ok(&self, value: T) {
251 self.set_result(Ok(value));
252 }
253
254 fn complete_err(&self, message: impl Into<String>) {
255 self.set_result(Err(message.into()));
256 }
257
258 fn set_result(&self, value: std::result::Result<T, String>) {
259 {
260 let mut guard = self.inner.result.lock().unwrap();
261 if guard.is_some() {
262 return;
263 }
264 *guard = Some(value);
265 }
266
267 if let Some(waker) = self.inner.waker.lock().unwrap().take() {
268 waker.wake();
269 }
270 }
271}
272
273pub(crate) struct AsyncSignalFuture<T> {
274 inner: Arc<AsyncSignalState<T>>,
275}
276
277impl<T> Future for AsyncSignalFuture<T> {
278 type Output = std::result::Result<T, String>;
279
280 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
281 {
282 let mut guard = self.inner.result.lock().unwrap();
283 if let Some(result) = guard.take() {
284 return Poll::Ready(result);
285 }
286 }
287
288 let mut waker_slot = self.inner.waker.lock().unwrap();
289 *waker_slot = Some(cx.waker().clone());
290 Poll::Pending
291 }
292}
293
294pub(crate) fn signal_pair<T>() -> (AsyncCompleter<T>, AsyncSignalFuture<T>) {
295 let inner = Arc::new(AsyncSignalState {
296 result: Mutex::new(None),
297 waker: Mutex::new(None),
298 });
299 (
300 AsyncCompleter {
301 inner: Arc::clone(&inner),
302 },
303 AsyncSignalFuture { inner },
304 )
305}