Skip to main content

lust/embed/
async_runtime.rs

1use crate::bytecode::{TaskHandle, Value};
2use hashbrown::HashMap;
3use std::collections::VecDeque;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::sync::{Arc, Condvar, Mutex};
8use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
9
10pub(crate) type AsyncValueFuture =
11    Pin<Box<dyn Future<Output = std::result::Result<Value, String>>>>;
12
13pub(crate) struct AsyncRegistry {
14    pub(crate) next_id: u64,
15    pub(crate) pending: HashMap<u64, AsyncTaskEntry>,
16}
17
18impl AsyncRegistry {
19    pub(crate) fn new() -> Self {
20        Self {
21            next_id: 1,
22            pending: HashMap::new(),
23        }
24    }
25
26    pub(crate) fn register(&mut self, entry: AsyncTaskEntry) -> u64 {
27        let id = self.next_id;
28        self.next_id += 1;
29        self.pending.insert(id, entry);
30        id
31    }
32
33    pub(crate) fn has_pending_for(&self, handle: TaskHandle) -> bool {
34        self.pending.values().any(|entry| match entry.target {
35            AsyncTaskTarget::ScriptTask(existing) | AsyncTaskTarget::NativeTask(existing) => {
36                existing == handle
37            }
38        })
39    }
40
41    pub(crate) fn is_empty(&self) -> bool {
42        self.pending.is_empty()
43    }
44}
45
46pub(crate) struct AsyncTaskEntry {
47    pub(crate) target: AsyncTaskTarget,
48    pub(crate) future: AsyncValueFuture,
49    wake_flag: Arc<WakeFlag>,
50    immediate_poll: bool,
51}
52
53#[derive(Clone, Copy)]
54pub(crate) enum AsyncTaskTarget {
55    ScriptTask(TaskHandle),
56    NativeTask(TaskHandle),
57}
58
59impl AsyncTaskEntry {
60    pub(crate) fn new(target: AsyncTaskTarget, future: AsyncValueFuture) -> Self {
61        Self {
62            target,
63            future,
64            wake_flag: Arc::new(WakeFlag::new()),
65            immediate_poll: true,
66        }
67    }
68
69    pub(crate) fn take_should_poll(&mut self) -> bool {
70        if self.immediate_poll {
71            self.immediate_poll = false;
72            true
73        } else {
74            self.wake_flag.take()
75        }
76    }
77
78    pub(crate) fn make_waker(&self) -> Waker {
79        make_async_waker(&self.wake_flag)
80    }
81}
82
83struct WakeFlag {
84    pending: AtomicBool,
85}
86
87impl WakeFlag {
88    fn new() -> Self {
89        Self {
90            pending: AtomicBool::new(true),
91        }
92    }
93
94    fn take(&self) -> bool {
95        self.pending.swap(false, Ordering::SeqCst)
96    }
97
98    fn wake(&self) {
99        self.pending.store(true, Ordering::SeqCst);
100    }
101}
102
103fn make_async_waker(flag: &Arc<WakeFlag>) -> Waker {
104    unsafe {
105        Waker::from_raw(RawWaker::new(
106            Arc::into_raw(flag.clone()) as *const (),
107            &ASYNC_WAKER_VTABLE,
108        ))
109    }
110}
111
112unsafe fn async_waker_clone(ptr: *const ()) -> RawWaker {
113    let arc = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
114    let cloned = arc.clone();
115    std::mem::forget(arc);
116    RawWaker::new(Arc::into_raw(cloned) as *const (), &ASYNC_WAKER_VTABLE)
117}
118
119unsafe fn async_waker_wake(ptr: *const ()) {
120    let arc = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
121    arc.wake();
122}
123
124unsafe fn async_waker_wake_by_ref(ptr: *const ()) {
125    let arc = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
126    arc.wake();
127    std::mem::forget(arc);
128}
129
130unsafe fn async_waker_drop(ptr: *const ()) {
131    let _ = Arc::<WakeFlag>::from_raw(ptr as *const WakeFlag);
132}
133
134static ASYNC_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
135    async_waker_clone,
136    async_waker_wake,
137    async_waker_wake_by_ref,
138    async_waker_drop,
139);
140
141struct AsyncTaskQueueInner<Args, R> {
142    queue: Mutex<VecDeque<PendingAsyncTask<Args, R>>>,
143    condvar: Condvar,
144}
145
146pub struct AsyncTaskQueue<Args, R> {
147    inner: Arc<AsyncTaskQueueInner<Args, R>>,
148}
149
150impl<Args, R> Clone for AsyncTaskQueue<Args, R> {
151    fn clone(&self) -> Self {
152        Self {
153            inner: Arc::clone(&self.inner),
154        }
155    }
156}
157
158impl<Args, R> AsyncTaskQueue<Args, R> {
159    pub fn new() -> Self {
160        Self {
161            inner: Arc::new(AsyncTaskQueueInner {
162                queue: Mutex::new(VecDeque::new()),
163                condvar: Condvar::new(),
164            }),
165        }
166    }
167
168    pub fn push(&self, task: PendingAsyncTask<Args, R>) {
169        let mut guard = self.inner.queue.lock().unwrap();
170        guard.push_back(task);
171        self.inner.condvar.notify_one();
172    }
173
174    pub fn pop(&self) -> Option<PendingAsyncTask<Args, R>> {
175        let mut guard = self.inner.queue.lock().unwrap();
176        guard.pop_front()
177    }
178
179    pub fn pop_blocking(&self) -> PendingAsyncTask<Args, R> {
180        let mut guard = self.inner.queue.lock().unwrap();
181        loop {
182            if let Some(task) = guard.pop_front() {
183                return task;
184            }
185            guard = self.inner.condvar.wait(guard).unwrap();
186        }
187    }
188
189    pub fn len(&self) -> usize {
190        let guard = self.inner.queue.lock().unwrap();
191        guard.len()
192    }
193
194    pub fn is_empty(&self) -> bool {
195        let guard = self.inner.queue.lock().unwrap();
196        guard.is_empty()
197    }
198}
199
200pub struct PendingAsyncTask<Args, R> {
201    task: TaskHandle,
202    args: Args,
203    completer: AsyncCompleter<R>,
204}
205
206impl<Args, R> PendingAsyncTask<Args, R> {
207    pub(crate) fn new(task: TaskHandle, args: Args, completer: AsyncCompleter<R>) -> Self {
208        Self {
209            task,
210            args,
211            completer,
212        }
213    }
214
215    pub fn task(&self) -> TaskHandle {
216        self.task
217    }
218
219    pub fn args(&self) -> &Args {
220        &self.args
221    }
222
223    pub fn complete_ok(self, value: R) {
224        self.completer.complete_ok(value);
225    }
226
227    pub fn complete_err(self, message: impl Into<String>) {
228        self.completer.complete_err(message);
229    }
230}
231
232struct AsyncSignalState<T> {
233    result: Mutex<Option<std::result::Result<T, String>>>,
234    waker: Mutex<Option<Waker>>,
235}
236
237pub(crate) struct AsyncCompleter<T> {
238    inner: Arc<AsyncSignalState<T>>,
239}
240
241impl<T> Clone for AsyncCompleter<T> {
242    fn clone(&self) -> Self {
243        Self {
244            inner: Arc::clone(&self.inner),
245        }
246    }
247}
248
249impl<T> AsyncCompleter<T> {
250    fn complete_ok(&self, value: T) {
251        self.set_result(Ok(value));
252    }
253
254    fn complete_err(&self, message: impl Into<String>) {
255        self.set_result(Err(message.into()));
256    }
257
258    fn set_result(&self, value: std::result::Result<T, String>) {
259        {
260            let mut guard = self.inner.result.lock().unwrap();
261            if guard.is_some() {
262                return;
263            }
264            *guard = Some(value);
265        }
266
267        if let Some(waker) = self.inner.waker.lock().unwrap().take() {
268            waker.wake();
269        }
270    }
271}
272
273pub(crate) struct AsyncSignalFuture<T> {
274    inner: Arc<AsyncSignalState<T>>,
275}
276
277impl<T> Future for AsyncSignalFuture<T> {
278    type Output = std::result::Result<T, String>;
279
280    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
281        {
282            let mut guard = self.inner.result.lock().unwrap();
283            if let Some(result) = guard.take() {
284                return Poll::Ready(result);
285            }
286        }
287
288        let mut waker_slot = self.inner.waker.lock().unwrap();
289        *waker_slot = Some(cx.waker().clone());
290        Poll::Pending
291    }
292}
293
294pub(crate) fn signal_pair<T>() -> (AsyncCompleter<T>, AsyncSignalFuture<T>) {
295    let inner = Arc::new(AsyncSignalState {
296        result: Mutex::new(None),
297        waker: Mutex::new(None),
298    });
299    (
300        AsyncCompleter {
301            inner: Arc::clone(&inner),
302        },
303        AsyncSignalFuture { inner },
304    )
305}