1use std::cell::{Cell, RefCell};
4use std::collections::VecDeque;
5use std::future::Future;
6use std::mem;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll, Wake, Waker};
10
11use slotmap::{new_key_type, SlotMap};
12
13pub fn spawn<F: 'static + Future<Output = ()>>(future: F) {
15 let in_query = match CONTEXT.get() {
16 AsyncContext::None => panic!("`spawn` can only be called from an executor context"),
17 AsyncContext::Query => true,
18 AsyncContext::Update | AsyncContext::Cancel => false,
19 AsyncContext::FromTask => unreachable!("FromTask"),
20 };
21 let pinned_future = Box::pin(future);
22 let task = Task {
23 future: pinned_future,
24 query: in_query,
25 };
26 let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
27 WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(task_id));
28}
29
30pub fn in_executor_context<R>(f: impl FnOnce() -> R) -> R {
32 let _guard = ContextGuard::new(AsyncContext::Update);
33 let res = f();
34 poll_all();
35 res
36}
37
38pub fn in_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
40 let _guard = ContextGuard::new(AsyncContext::Query);
41 let res = f();
42 poll_all();
43 res
44}
45
46pub fn in_callback_executor_context(f: impl FnOnce()) {
48 let _guard = ContextGuard::new(AsyncContext::FromTask);
49 f();
50 poll_all();
51}
52
53pub fn in_callback_cancellation_context(f: impl FnOnce()) {
56 let _guard = ContextGuard::new(AsyncContext::Cancel);
57 f();
58}
59
60pub fn is_recovering_from_trap() -> bool {
62 matches!(CONTEXT.get(), AsyncContext::Cancel)
63}
64
65fn poll_all() {
66 let in_query = match CONTEXT.get() {
67 AsyncContext::Query => true,
68 AsyncContext::Update | AsyncContext::Cancel => false,
69 AsyncContext::None => panic!("tasks can only be polled in an executor context"),
70 AsyncContext::FromTask => unreachable!("FromTask"),
71 };
72 let mut ineligible = vec![];
73 while let Some(task_id) = WAKEUP.with_borrow_mut(|queue| queue.pop_front()) {
74 let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(mem::take))
76 else {
77 panic!("Call already trapped");
81 };
83 if in_query && !task.query {
84 TASKS.with_borrow_mut(|tasks| tasks[task_id] = task);
85 ineligible.push(task_id);
86 continue;
87 }
88 let waker = Waker::from(Arc::new(TaskWaker {
89 task_id,
90 query: task.query,
91 }));
92 let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
93 match poll {
94 Poll::Pending => {
95 TASKS.with_borrow_mut(|tasks| {
97 if let Some(t) = tasks.get_mut(task_id) {
98 *t = task;
99 }
100 });
101 }
102 Poll::Ready(()) => {
103 TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
105 }
106 }
107 }
108 if !ineligible.is_empty() {
109 WAKEUP.with_borrow_mut(|wakeup| wakeup.extend(ineligible));
110 }
111}
112
113new_key_type! {
114 struct TaskId;
115}
116
117thread_local! {
118 static TASKS: RefCell<SlotMap<TaskId, Task>> = <_>::default();
119 static WAKEUP: RefCell<VecDeque<TaskId>> = <_>::default();
120 static CONTEXT: Cell<AsyncContext> = <_>::default();
121}
122
123#[derive(Default, Copy, Clone)]
124enum AsyncContext {
125 #[default]
126 None,
127 Update,
128 Query,
129 FromTask,
130 Cancel,
131}
132
133struct Task {
134 future: Pin<Box<dyn Future<Output = ()>>>,
135 query: bool,
136}
137
138impl Default for Task {
139 fn default() -> Self {
140 Self {
141 future: Box::pin(std::future::pending()),
142 query: false,
143 }
144 }
145}
146
147struct ContextGuard(());
148
149impl ContextGuard {
150 fn new(context: AsyncContext) -> Self {
151 CONTEXT.with(|context_var| {
152 assert!(
153 matches!(context_var.get(), AsyncContext::None),
154 "in_*_context called within an existing async context"
155 );
156 context_var.set(context);
157 Self(())
158 })
159 }
160}
161
162impl Drop for ContextGuard {
163 fn drop(&mut self) {
164 CONTEXT.set(AsyncContext::None);
165 }
166}
167
168#[derive(Clone)]
173struct TaskWaker {
174 task_id: TaskId,
175 query: bool,
176}
177
178impl Wake for TaskWaker {
179 fn wake(self: Arc<Self>) {
180 if matches!(CONTEXT.get(), AsyncContext::Cancel) {
181 let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(self.task_id));
183 } else {
185 WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(self.task_id));
186 CONTEXT.with(|context| {
187 if matches!(context.get(), AsyncContext::FromTask) {
188 if self.query {
189 context.set(AsyncContext::Query)
190 } else {
191 context.set(AsyncContext::Update)
192 }
193 }
194 })
195 }
196 }
197}