1use std::cell::{Cell, RefCell};
4use std::collections::VecDeque;
5use std::future::Future;
6use std::mem;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll, Wake, Waker};
10
11use slotmap::{new_key_type, SlotMap};
12
13pub fn spawn<F: 'static + Future<Output = ()>>(future: F) {
15 let in_query = match CONTEXT.get() {
16 AsyncContext::None => panic!("`spawn` can only be called from an executor context"),
17 AsyncContext::Query => true,
18 AsyncContext::Update => false,
19 AsyncContext::Cancel => panic!("`spawn` cannot be called during panic recovery"),
20 AsyncContext::FromTask => unreachable!("FromTask"),
21 };
22 let pinned_future = Box::pin(future);
23 let task = Task {
24 future: pinned_future,
25 query: in_query,
26 };
27 let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
28 WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(task_id));
29}
30
31pub fn in_executor_context<R>(f: impl FnOnce() -> R) -> R {
33 let _guard = ContextGuard::new(AsyncContext::Update);
34 let res = f();
35 poll_all();
36 res
37}
38
39pub fn in_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
41 let _guard = ContextGuard::new(AsyncContext::Query);
42 let res = f();
43 poll_all();
44 res
45}
46
47pub fn in_callback_executor_context(f: impl FnOnce()) {
49 let _guard = ContextGuard::new(AsyncContext::FromTask);
50 f();
51 poll_all();
52}
53
54pub fn in_callback_cancellation_context(f: impl FnOnce()) {
57 let _guard = ContextGuard::new(AsyncContext::Cancel);
58 f();
59}
60
61pub fn is_recovering_from_trap() -> bool {
63 matches!(CONTEXT.get(), AsyncContext::Cancel)
64}
65
66fn poll_all() {
67 let in_query = match CONTEXT.get() {
68 AsyncContext::Query => true,
69 AsyncContext::Update => false,
70 AsyncContext::None => panic!("tasks can only be polled in an executor context"),
71 AsyncContext::FromTask => unreachable!("FromTask"),
72 AsyncContext::Cancel => unreachable!("poll_all should not be called during panic recovery"),
73 };
74 let mut ineligible = vec![];
75 while let Some(task_id) = WAKEUP.with_borrow_mut(|queue| queue.pop_front()) {
76 let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(mem::take))
78 else {
79 continue;
82 };
85 if in_query && !task.query {
86 TASKS.with_borrow_mut(|tasks| tasks[task_id] = task);
87 ineligible.push(task_id);
88 continue;
89 }
90 let waker = Waker::from(Arc::new(TaskWaker {
91 task_id,
92 query: task.query,
93 }));
94 let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
95 match poll {
96 Poll::Pending => {
97 TASKS.with_borrow_mut(|tasks| {
99 if let Some(t) = tasks.get_mut(task_id) {
100 *t = task;
101 }
102 });
103 }
104 Poll::Ready(()) => {
105 TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
107 }
108 }
109 }
110 if !ineligible.is_empty() {
111 WAKEUP.with_borrow_mut(|wakeup| wakeup.extend(ineligible));
112 }
113}
114
115new_key_type! {
116 struct TaskId;
117}
118
119thread_local! {
120 static TASKS: RefCell<SlotMap<TaskId, Task>> = <_>::default();
121 static WAKEUP: RefCell<VecDeque<TaskId>> = <_>::default();
122 static CONTEXT: Cell<AsyncContext> = <_>::default();
123}
124
125#[derive(Default, Copy, Clone, PartialEq, Eq)]
126enum AsyncContext {
127 #[default]
128 None,
129 Update,
130 Query,
131 FromTask,
132 Cancel,
133}
134
135struct Task {
136 future: Pin<Box<dyn Future<Output = ()>>>,
137 query: bool,
138}
139
140impl Default for Task {
141 fn default() -> Self {
142 Self {
143 future: Box::pin(std::future::pending()),
144 query: false,
145 }
146 }
147}
148
149struct ContextGuard(());
150
151impl ContextGuard {
152 fn new(context: AsyncContext) -> Self {
153 CONTEXT.with(|context_var| {
154 assert!(
155 matches!(context_var.get(), AsyncContext::None),
156 "in_*_context called within an existing async context"
157 );
158 context_var.set(context);
159 Self(())
160 })
161 }
162}
163
164impl Drop for ContextGuard {
165 fn drop(&mut self) {
166 CONTEXT.set(AsyncContext::None);
167 }
168}
169
170#[derive(Clone)]
175struct TaskWaker {
176 task_id: TaskId,
177 query: bool,
178}
179
180impl Wake for TaskWaker {
181 fn wake(self: Arc<Self>) {
182 let context = CONTEXT.get();
183 assert!(
184 context != AsyncContext::None,
185 "wakers cannot be called outside an executor context"
186 );
187 if context == AsyncContext::Cancel {
188 let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(self.task_id));
190 } else {
192 WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(self.task_id));
193 if context == AsyncContext::FromTask {
194 if self.query {
195 CONTEXT.set(AsyncContext::Query)
196 } else {
197 CONTEXT.set(AsyncContext::Update)
198 }
199 }
200 }
201 }
202}