ic_cdk_executor/
machinery.rs1use std::{
2 cell::{Cell, RefCell},
3 collections::VecDeque,
4 future::Future,
5 mem::take,
6 pin::Pin,
7 sync::{Arc, Once},
8 task::{Context, Poll, Wake, Waker},
9};
10
11use slotmap::{Key, SecondaryMap, SlotMap, new_key_type};
12use smallvec::SmallVec;
13
14#[derive(Clone, Debug)]
16pub(crate) struct MethodContext {
17 pub(crate) kind: ContextKind,
19 pub(crate) handles: usize,
22 pub(crate) tasks: SmallVec<[TaskId; 4]>,
24}
25
26impl MethodContext {
27 pub(crate) fn new_update() -> Self {
28 Self {
29 kind: ContextKind::Update,
30 handles: 0,
31 tasks: SmallVec::new(),
32 }
33 }
34 pub(crate) fn new_query() -> Self {
35 Self {
36 kind: ContextKind::Query,
37 handles: 0,
38 tasks: SmallVec::new(),
39 }
40 }
41}
42
43#[derive(Copy, Clone, Eq, PartialEq, Debug)]
44pub(crate) enum ContextKind {
45 Update,
46 Query,
47}
48
49new_key_type! {
52 pub(crate) struct MethodId;
53 pub(crate) struct TaskId;
54}
55
56thread_local! {
57 pub(crate) static METHODS: RefCell<SlotMap<MethodId, MethodContext>> = RefCell::default();
59 pub(crate) static TASKS: RefCell<SlotMap<TaskId, Task>> = RefCell::default();
61 pub(crate) static PROTECTED_WAKEUPS: RefCell<SecondaryMap<MethodId, VecDeque<TaskId>>> = RefCell::default();
63 pub(crate) static MIGRATORY_WAKEUPS: RefCell<VecDeque<TaskId>> = const { RefCell::new(VecDeque::new()) };
65 pub(crate) static CURRENT_METHOD: Cell<Option<MethodId>> = const { Cell::new(None) };
68 pub(crate) static RECOVERING: Cell<bool> = const { Cell::new(false) };
70 pub(crate) static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
72}
73
74pub(crate) struct Task {
76 future: Pin<Box<dyn Future<Output = ()>>>,
78 method_binding: Option<MethodId>,
81 set_current_method_var: MethodId,
83}
84
85impl Default for Task {
87 fn default() -> Self {
88 Self {
89 future: Box::pin(std::future::pending()),
90 method_binding: None,
91 set_current_method_var: MethodId::null(),
92 }
93 }
94}
95
96pub fn in_tracking_executor_context<R>(f: impl FnOnce() -> R) -> R {
98 setup_panic_hook();
99 let method = METHODS.with_borrow_mut(|methods| methods.insert(MethodContext::new_update()));
100 let guard = MethodHandle::for_method(method);
101 enter_current_method(guard, || {
102 let res = f();
103 poll_all();
104 res
105 })
106}
107
108#[expect(dead_code)] pub(crate) fn in_null_context<R>(f: impl FnOnce() -> R) -> R {
112 setup_panic_hook();
113 let guard = MethodHandle::for_method(MethodId::null());
114 enter_current_method(guard, || {
115 let res = f();
116 poll_all();
117 res
118 })
119}
120
121pub fn in_tracking_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
123 setup_panic_hook();
124 let method = METHODS.with_borrow_mut(|methods| methods.insert(MethodContext::new_query()));
125 let guard = MethodHandle::for_method(method);
126 enter_current_method(guard, || {
127 let res = f();
128 poll_all();
129 res
130 })
131}
132
133pub fn in_callback_executor_context_for<R>(
135 method_handle: MethodHandle,
136 f: impl FnOnce() -> R,
137) -> R {
138 setup_panic_hook();
139 enter_current_method(method_handle, || {
140 let res = f();
141 poll_all();
142 res
143 })
144}
145
146pub fn in_trap_recovery_context_for<R>(method: MethodHandle, f: impl FnOnce() -> R) -> R {
148 setup_panic_hook();
149 enter_current_method(method, || {
150 RECOVERING.set(true);
151 let res = f();
152 RECOVERING.set(false);
153 res
154 })
155}
156
157pub fn cancel_all_tasks_attached_to_current_method() {
159 let Some(method_id) = CURRENT_METHOD.get() else {
160 panic!(
161 "`cancel_all_tasks_attached_to_current_method` can only be called within a method context"
162 );
163 };
164 cancel_all_tasks_attached_to_method(method_id);
165}
166
167fn cancel_all_tasks_attached_to_method(method_id: MethodId) {
169 let Some(to_cancel) = METHODS.with_borrow_mut(|methods| {
170 methods
171 .get_mut(method_id)
172 .map(|method| take(&mut method.tasks))
173 }) else {
174 return; };
176 let _tasks = TASKS.with(|tasks| {
177 let Ok(mut tasks) = tasks.try_borrow_mut() else {
178 panic!(
179 "`cancel_all_tasks_attached_to_current_method` cannot be called from an async task"
180 );
181 };
182 let mut canceled = Vec::with_capacity(to_cancel.len());
183 for task_id in to_cancel {
184 canceled.push(tasks.remove(task_id));
185 }
186 canceled
187 });
188 drop(_tasks); }
190
191pub(crate) fn delete_task(task_id: TaskId) {
193 let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
194 drop(_task); }
196
197pub fn cancel_task(task_handle: &TaskHandle) {
199 delete_task(task_handle.task_id);
200}
201
202pub fn is_recovering_from_trap() -> bool {
204 RECOVERING.get()
205}
206
207pub fn extend_current_method_context() -> MethodHandle {
211 setup_panic_hook();
212 let Some(method_id) = CURRENT_METHOD.get() else {
213 panic!("`extend_method_context` can only be called within a tracking executor context");
214 };
215 MethodHandle::for_method(method_id)
216}
217
218pub(crate) fn poll_all() {
223 let Some(method_id) = CURRENT_METHOD.get() else {
224 panic!("tasks can only be polled within an executor context");
225 };
226 let kind = METHODS
227 .with_borrow(|methods| methods.get(method_id).map(|m| m.kind))
228 .unwrap_or(ContextKind::Update);
229 fn pop_wakeup(method_id: MethodId, update: bool) -> Option<TaskId> {
230 if let Some(task_id) = PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
231 wakeups
232 .get_mut(method_id)
233 .and_then(|queue| queue.pop_front())
234 }) {
235 Some(task_id)
236 } else if update {
237 MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| unattached.pop_front())
238 } else {
239 None
240 }
241 }
242 while let Some(task_id) = pop_wakeup(method_id, kind == ContextKind::Update) {
243 let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(take)) else {
245 continue;
248 };
251 let waker = Waker::from(Arc::new(TaskWaker { task_id }));
252 let prev_current_method_var = CURRENT_METHOD.replace(Some(task.set_current_method_var));
253 CURRENT_TASK_ID.set(Some(task_id));
254 let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
255 CURRENT_TASK_ID.set(None);
256 CURRENT_METHOD.set(prev_current_method_var);
257 match poll {
258 Poll::Pending => {
259 TASKS.with_borrow_mut(|tasks| {
261 if let Some(t) = tasks.get_mut(task_id) {
262 *t = task;
263 }
264 });
265 }
266 Poll::Ready(()) => {
267 delete_task(task_id);
269 }
270 }
271 }
272}
273
274pub(crate) fn enter_current_method<R>(method_guard: MethodHandle, f: impl FnOnce() -> R) -> R {
276 CURRENT_METHOD.with(|context_var| {
277 assert!(
278 context_var.get().is_none(),
279 "in_*_context called within an existing async context"
280 );
281 context_var.set(Some(method_guard.method_id));
282 });
283 let r = f();
284 drop(method_guard); let method_id = CURRENT_METHOD.replace(None);
286 if let Some(method_id) = method_id {
287 let handles = METHODS.with_borrow_mut(|methods| methods.get(method_id).map(|m| m.handles));
288 if handles == Some(0) {
289 cancel_all_tasks_attached_to_method(method_id);
290 METHODS.with_borrow_mut(|methods| methods.remove(method_id));
291 }
292 }
293 r
294}
295
296#[derive(Debug)]
302pub struct MethodHandle {
303 method_id: MethodId,
304}
305
306impl MethodHandle {
307 pub(crate) fn for_method(method_id: MethodId) -> Self {
309 if method_id.is_null() {
310 return Self { method_id };
311 }
312 METHODS.with_borrow_mut(|methods| {
313 let Some(method) = methods.get_mut(method_id) else {
314 panic!("internal error: method context deleted while in use (for_method)");
315 };
316 method.handles += 1;
317 });
318 Self { method_id }
319 }
320}
321
322impl Drop for MethodHandle {
323 fn drop(&mut self) {
324 METHODS.with_borrow_mut(|methods| {
325 if let Some(method) = methods.get_mut(self.method_id) {
326 method.handles -= 1;
327 }
328 })
329 }
330}
331
332#[derive(Debug)]
334pub struct TaskHandle {
335 task_id: TaskId,
336}
337
338impl TaskHandle {
339 pub fn current() -> Option<Self> {
341 let task_id = CURRENT_TASK_ID.get()?;
342 Some(Self { task_id })
343 }
344}
345
346pub(crate) struct TaskWaker {
347 pub(crate) task_id: TaskId,
348}
349
350impl Wake for TaskWaker {
351 fn wake(self: Arc<Self>) {
352 TASKS.with_borrow_mut(|tasks| {
353 if let Some(task) = tasks.get(self.task_id) {
354 if let Some(method_id) = task.method_binding {
355 PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
356 if let Some(entry) = wakeups.entry(method_id) {
357 entry.or_default().push_back(self.task_id);
358 }
359 });
360 } else {
361 MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| {
362 unattached.push_back(self.task_id);
363 });
364 }
365 }
366 })
367 }
368}
369
370pub fn spawn_migratory(f: impl Future<Output = ()> + 'static) -> TaskHandle {
374 setup_panic_hook();
375 let Some(method_id) = CURRENT_METHOD.get() else {
376 panic!("`spawn_*` can only be called within an executor context");
377 };
378 if is_recovering_from_trap() {
379 panic!("tasks cannot be spawned while recovering from a trap");
380 }
381 let kind = METHODS
382 .with_borrow(|methods| methods.get(method_id).map(|m| m.kind))
383 .unwrap_or(ContextKind::Update);
384 if kind == ContextKind::Query {
385 panic!("unprotected spawns cannot be made within a query context");
386 }
387 let task = Task {
388 future: Box::pin(f),
389 method_binding: None,
390 set_current_method_var: MethodId::null(),
391 };
392 let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
393 MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| {
394 unattached.push_back(task_id);
395 });
396 TaskHandle { task_id }
397}
398
399pub fn spawn_protected(f: impl Future<Output = ()> + 'static) -> TaskHandle {
404 setup_panic_hook();
405 if is_recovering_from_trap() {
406 panic!("tasks cannot be spawned while recovering from a trap");
407 }
408 let Some(method_id) = CURRENT_METHOD.get() else {
409 panic!("`spawn_*` can only be called within an executor context");
410 };
411 if method_id.is_null() {
412 panic!("`spawn_protected` cannot be called outside of a tracked method context");
413 }
414 let task = Task {
415 future: Box::pin(f),
416 method_binding: Some(method_id),
417 set_current_method_var: method_id,
418 };
419 let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
420 METHODS.with_borrow_mut(|methods| {
421 let Some(method) = methods.get_mut(method_id) else {
422 panic!("internal error: method context deleted while in use (spawn_protected)");
423 };
424 method.tasks.push(task_id);
425 });
426 PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
427 let Some(entry) = wakeups.entry(method_id) else {
428 panic!("internal error: method context deleted while in use (spawn_protected)");
429 };
430 entry.or_default().push_back(task_id);
431 });
432 TaskHandle { task_id }
433}
434
435fn setup_panic_hook() {
436 static SETUP: Once = Once::new();
437 SETUP.call_once(|| {
438 std::panic::set_hook(Box::new(|info| {
439 let file = info.location().unwrap().file();
440 let line = info.location().unwrap().line();
441 let col = info.location().unwrap().column();
442
443 let msg = match info.payload().downcast_ref::<&'static str>() {
444 Some(s) => *s,
445 None => match info.payload().downcast_ref::<String>() {
446 Some(s) => &s[..],
447 None => "Box<Any>",
448 },
449 };
450
451 let err_info = format!("Panicked at '{msg}', {file}:{line}:{col}");
452 ic0::debug_print(err_info.as_bytes());
453 ic0::trap(err_info.as_bytes());
454 }));
455 });
456}