ic_cdk_executor/
machinery.rs

1use std::{
2    cell::{Cell, RefCell},
3    collections::VecDeque,
4    future::Future,
5    mem::take,
6    pin::Pin,
7    sync::{Arc, Once},
8    task::{Context, Poll, Wake, Waker},
9};
10
11use slotmap::{Key, SecondaryMap, SlotMap, new_key_type};
12use smallvec::SmallVec;
13
14/// Represents an active canister method.
15#[derive(Clone, Debug)]
16pub(crate) struct MethodContext {
17    /// Whether this method is an update or a query.
18    pub(crate) kind: ContextKind,
19    /// The number of handles to this method context. When this drops to zero, the method context gets deleted.
20    /// The refcount is managed by `MethodHandle`.
21    pub(crate) handles: usize,
22    /// An index for Task.method_binding; all protected tasks attached to this method.
23    pub(crate) tasks: SmallVec<[TaskId; 4]>,
24}
25
26impl MethodContext {
27    pub(crate) fn new_update() -> Self {
28        Self {
29            kind: ContextKind::Update,
30            handles: 0,
31            tasks: SmallVec::new(),
32        }
33    }
34    pub(crate) fn new_query() -> Self {
35        Self {
36            kind: ContextKind::Query,
37            handles: 0,
38            tasks: SmallVec::new(),
39        }
40    }
41}
42
43#[derive(Copy, Clone, Eq, PartialEq, Debug)]
44pub(crate) enum ContextKind {
45    Update,
46    Query,
47}
48
49// Null method ID corresponds to 'null context', used for migratory tasks.
50// Null task ID is an error.
51new_key_type! {
52    pub(crate) struct MethodId;
53    pub(crate) struct TaskId;
54}
55
56thread_local! {
57    // global: list of all method contexts currently active
58    pub(crate) static METHODS: RefCell<SlotMap<MethodId, MethodContext>> = RefCell::default();
59    // global: list of all tasks currently spawned
60    pub(crate) static TASKS: RefCell<SlotMap<TaskId, Task>> = RefCell::default();
61    // global: map of methods to their protected tasks that have been woken up
62    pub(crate) static PROTECTED_WAKEUPS: RefCell<SecondaryMap<MethodId, VecDeque<TaskId>>> = RefCell::default();
63    // global: list of migratory tasks that have been woken up
64    pub(crate) static MIGRATORY_WAKEUPS: RefCell<VecDeque<TaskId>> = const { RefCell::new(VecDeque::new()) };
65    // dynamically scoped: the current method context. None means a context function was not called (which is a user error),
66    // vs null which means no method in particular.
67    pub(crate) static CURRENT_METHOD: Cell<Option<MethodId>> = const { Cell::new(None) };
68    // dynamically scoped: whether we are currently recovering from a trap
69    pub(crate) static RECOVERING: Cell<bool> = const { Cell::new(false) };
70    // dynamically scoped: the current task ID, or None if a task is not running
71    pub(crate) static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
72}
73
74/// A registered task in the executor.
75pub(crate) struct Task {
76    /// Should be `TaskFuture` in all cases
77    future: Pin<Box<dyn Future<Output = ()>>>,
78    /// If Some, this task will always resume during that method, regardless of where the waker is woken from.
79    /// If None, this task will resume wherever it is awoken from.
80    method_binding: Option<MethodId>,
81    // While this task is executing, `CURRENT_METHOD` will be set to this value.
82    set_current_method_var: MethodId,
83}
84
85// Actually using the default value would be a memory leak. This only exists for `take`.
86impl Default for Task {
87    fn default() -> Self {
88        Self {
89            future: Box::pin(std::future::pending()),
90            method_binding: None,
91            set_current_method_var: MethodId::null(),
92        }
93    }
94}
95
96/// Execute an update function in a context that allows calling [`spawn_protected`] and [`spawn_migratory`].
97pub fn in_tracking_executor_context<R>(f: impl FnOnce() -> R) -> R {
98    setup_panic_hook();
99    let method = METHODS.with_borrow_mut(|methods| methods.insert(MethodContext::new_update()));
100    let guard = MethodHandle::for_method(method);
101    enter_current_method(guard, || {
102        let res = f();
103        poll_all();
104        res
105    })
106}
107
108/// Execute a function in a context that is not tracked across callbacks, able to call [`spawn_migratory`]
109/// but not [`spawn_protected`].
110#[expect(dead_code)] // not used in current null context code but may be used for other things in the future
111pub(crate) fn in_null_context<R>(f: impl FnOnce() -> R) -> R {
112    setup_panic_hook();
113    let guard = MethodHandle::for_method(MethodId::null());
114    enter_current_method(guard, || {
115        let res = f();
116        poll_all();
117        res
118    })
119}
120
121/// Execute a query function in a context that allows calling [`spawn_protected`] but not [`spawn_migratory`].
122pub fn in_tracking_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
123    setup_panic_hook();
124    let method = METHODS.with_borrow_mut(|methods| methods.insert(MethodContext::new_query()));
125    let guard = MethodHandle::for_method(method);
126    enter_current_method(guard, || {
127        let res = f();
128        poll_all();
129        res
130    })
131}
132
133/// Execute an inter-canister call callback in the context of the method that made it.
134pub fn in_callback_executor_context_for<R>(
135    method_handle: MethodHandle,
136    f: impl FnOnce() -> R,
137) -> R {
138    setup_panic_hook();
139    enter_current_method(method_handle, || {
140        let res = f();
141        poll_all();
142        res
143    })
144}
145
146/// Enters a trap/panic recovery context for calling [`cancel_all_tasks_attached_to_current_method`] in.
147pub fn in_trap_recovery_context_for<R>(method: MethodHandle, f: impl FnOnce() -> R) -> R {
148    setup_panic_hook();
149    enter_current_method(method, || {
150        RECOVERING.set(true);
151        let res = f();
152        RECOVERING.set(false);
153        res
154    })
155}
156
157/// Cancels all tasks made with [`spawn_protected`] attached to the current method.
158pub fn cancel_all_tasks_attached_to_current_method() {
159    let Some(method_id) = CURRENT_METHOD.get() else {
160        panic!(
161            "`cancel_all_tasks_attached_to_current_method` can only be called within a method context"
162        );
163    };
164    cancel_all_tasks_attached_to_method(method_id);
165}
166
167/// Cancels all tasks made with [`spawn_protected`] attached to the given method.
168fn cancel_all_tasks_attached_to_method(method_id: MethodId) {
169    let Some(to_cancel) = METHODS.with_borrow_mut(|methods| {
170        methods
171            .get_mut(method_id)
172            .map(|method| take(&mut method.tasks))
173    }) else {
174        return; // method context null or already deleted
175    };
176    let _tasks = TASKS.with(|tasks| {
177        let Ok(mut tasks) = tasks.try_borrow_mut() else {
178            panic!(
179                "`cancel_all_tasks_attached_to_current_method` cannot be called from an async task"
180            );
181        };
182        let mut canceled = Vec::with_capacity(to_cancel.len());
183        for task_id in to_cancel {
184            canceled.push(tasks.remove(task_id));
185        }
186        canceled
187    });
188    drop(_tasks); // always run task destructors outside of a refcell borrow
189}
190
191/// Removes a specific task. Use this instead of `remove` for guaranteed drop order.
192pub(crate) fn delete_task(task_id: TaskId) {
193    let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
194    drop(_task); // always run task destructors outside of a refcell borrow
195}
196
197/// Cancels a specific task by its handle.
198pub fn cancel_task(task_handle: &TaskHandle) {
199    delete_task(task_handle.task_id);
200}
201
202/// Returns true if tasks are being canceled due to a trap or panic.
203pub fn is_recovering_from_trap() -> bool {
204    RECOVERING.get()
205}
206
207/// Produces a handle to the current method context.
208///
209/// The method is active as long as the handle is alive.
210pub fn extend_current_method_context() -> MethodHandle {
211    setup_panic_hook();
212    let Some(method_id) = CURRENT_METHOD.get() else {
213        panic!("`extend_method_context` can only be called within a tracking executor context");
214    };
215    MethodHandle::for_method(method_id)
216}
217
218/// Polls all tasks that have been woken up. Called after all context closures besides cancelation.
219///
220/// Should never be called inside a task, because it should only be called inside a context closure, and context closures
221/// should only be at the top level of an entrypoint.
222pub(crate) fn poll_all() {
223    let Some(method_id) = CURRENT_METHOD.get() else {
224        panic!("tasks can only be polled within an executor context");
225    };
226    let kind = METHODS
227        .with_borrow(|methods| methods.get(method_id).map(|m| m.kind))
228        .unwrap_or(ContextKind::Update);
229    fn pop_wakeup(method_id: MethodId, update: bool) -> Option<TaskId> {
230        if let Some(task_id) = PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
231            wakeups
232                .get_mut(method_id)
233                .and_then(|queue| queue.pop_front())
234        }) {
235            Some(task_id)
236        } else if update {
237            MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| unattached.pop_front())
238        } else {
239            None
240        }
241    }
242    while let Some(task_id) = pop_wakeup(method_id, kind == ContextKind::Update) {
243        // Temporarily remove the task from the table. We need to execute it while `TASKS` is not borrowed, because it may schedule more tasks.
244        let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(take)) else {
245            // This waker handle appears to be dead. The most likely cause is that the method returned before
246            // a canceled call came back.
247            continue;
248            // In the case that a task panicked and that's why it's missing, but it was in an earlier callback so a later
249            // one tries to re-wake, the responsibility for re-trapping lies with CallFuture.
250        };
251        let waker = Waker::from(Arc::new(TaskWaker { task_id }));
252        let prev_current_method_var = CURRENT_METHOD.replace(Some(task.set_current_method_var));
253        CURRENT_TASK_ID.set(Some(task_id));
254        let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
255        CURRENT_TASK_ID.set(None);
256        CURRENT_METHOD.set(prev_current_method_var);
257        match poll {
258            Poll::Pending => {
259                // more to do, put the task back in the table
260                TASKS.with_borrow_mut(|tasks| {
261                    if let Some(t) = tasks.get_mut(task_id) {
262                        *t = task;
263                    }
264                });
265            }
266            Poll::Ready(()) => {
267                // task complete, remove its entry from the table fully
268                delete_task(task_id);
269            }
270        }
271    }
272}
273
274/// Begin a context closure for the given method. Destroys the method afterwards if there are no outstanding handles.
275pub(crate) fn enter_current_method<R>(method_guard: MethodHandle, f: impl FnOnce() -> R) -> R {
276    CURRENT_METHOD.with(|context_var| {
277        assert!(
278            context_var.get().is_none(),
279            "in_*_context called within an existing async context"
280        );
281        context_var.set(Some(method_guard.method_id));
282    });
283    let r = f();
284    drop(method_guard); // drop the guard *before* the method freeing logic, but *after* the in-context code
285    let method_id = CURRENT_METHOD.replace(None);
286    if let Some(method_id) = method_id {
287        let handles = METHODS.with_borrow_mut(|methods| methods.get(method_id).map(|m| m.handles));
288        if handles == Some(0) {
289            cancel_all_tasks_attached_to_method(method_id);
290            METHODS.with_borrow_mut(|methods| methods.remove(method_id));
291        }
292    }
293    r
294}
295
296/// A handle to a method context. If the function returns and all handles have been dropped, the method is considered returned.
297///
298/// This should be created before performing an inter-canister call via [`extend_current_method_context`],
299/// threaded through the `env` parameter, and then used when calling [`in_callback_executor_context_for`] or
300/// [`in_trap_recovery_context_for`]. Failure to track this properly may result in unexpected cancellation of tasks.
301#[derive(Debug)]
302pub struct MethodHandle {
303    method_id: MethodId,
304}
305
306impl MethodHandle {
307    /// Creates a live handle for the given method.
308    pub(crate) fn for_method(method_id: MethodId) -> Self {
309        if method_id.is_null() {
310            return Self { method_id };
311        }
312        METHODS.with_borrow_mut(|methods| {
313            let Some(method) = methods.get_mut(method_id) else {
314                panic!("internal error: method context deleted while in use (for_method)");
315            };
316            method.handles += 1;
317        });
318        Self { method_id }
319    }
320}
321
322impl Drop for MethodHandle {
323    fn drop(&mut self) {
324        METHODS.with_borrow_mut(|methods| {
325            if let Some(method) = methods.get_mut(self.method_id) {
326                method.handles -= 1;
327            }
328        })
329    }
330}
331
332/// A handle to a spawned task.
333#[derive(Debug)]
334pub struct TaskHandle {
335    task_id: TaskId,
336}
337
338impl TaskHandle {
339    /// A handle to the task currently executing, or None if no task is executing.
340    pub fn current() -> Option<Self> {
341        let task_id = CURRENT_TASK_ID.get()?;
342        Some(Self { task_id })
343    }
344}
345
346pub(crate) struct TaskWaker {
347    pub(crate) task_id: TaskId,
348}
349
350impl Wake for TaskWaker {
351    fn wake(self: Arc<Self>) {
352        TASKS.with_borrow_mut(|tasks| {
353            if let Some(task) = tasks.get(self.task_id) {
354                if let Some(method_id) = task.method_binding {
355                    PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
356                        if let Some(entry) = wakeups.entry(method_id) {
357                            entry.or_default().push_back(self.task_id);
358                        }
359                    });
360                } else {
361                    MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| {
362                        unattached.push_back(self.task_id);
363                    });
364                }
365            }
366        })
367    }
368}
369
370/// Spawns a task that can migrate between methods.
371///
372/// When the task is awoken, it will run in the context of the method that woke it.
373pub fn spawn_migratory(f: impl Future<Output = ()> + 'static) -> TaskHandle {
374    setup_panic_hook();
375    let Some(method_id) = CURRENT_METHOD.get() else {
376        panic!("`spawn_*` can only be called within an executor context");
377    };
378    if is_recovering_from_trap() {
379        panic!("tasks cannot be spawned while recovering from a trap");
380    }
381    let kind = METHODS
382        .with_borrow(|methods| methods.get(method_id).map(|m| m.kind))
383        .unwrap_or(ContextKind::Update);
384    if kind == ContextKind::Query {
385        panic!("unprotected spawns cannot be made within a query context");
386    }
387    let task = Task {
388        future: Box::pin(f),
389        method_binding: None,
390        set_current_method_var: MethodId::null(),
391    };
392    let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
393    MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| {
394        unattached.push_back(task_id);
395    });
396    TaskHandle { task_id }
397}
398
399/// Spawns a task attached to the current method.
400///
401/// When the task is awoken, if a different method is currently running, the task will not run until the method
402/// it is attached to continues. If the attached method returns before the task completes, the task will be canceled.
403pub fn spawn_protected(f: impl Future<Output = ()> + 'static) -> TaskHandle {
404    setup_panic_hook();
405    if is_recovering_from_trap() {
406        panic!("tasks cannot be spawned while recovering from a trap");
407    }
408    let Some(method_id) = CURRENT_METHOD.get() else {
409        panic!("`spawn_*` can only be called within an executor context");
410    };
411    if method_id.is_null() {
412        panic!("`spawn_protected` cannot be called outside of a tracked method context");
413    }
414    let task = Task {
415        future: Box::pin(f),
416        method_binding: Some(method_id),
417        set_current_method_var: method_id,
418    };
419    let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
420    METHODS.with_borrow_mut(|methods| {
421        let Some(method) = methods.get_mut(method_id) else {
422            panic!("internal error: method context deleted while in use (spawn_protected)");
423        };
424        method.tasks.push(task_id);
425    });
426    PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
427        let Some(entry) = wakeups.entry(method_id) else {
428            panic!("internal error: method context deleted while in use (spawn_protected)");
429        };
430        entry.or_default().push_back(task_id);
431    });
432    TaskHandle { task_id }
433}
434
435fn setup_panic_hook() {
436    static SETUP: Once = Once::new();
437    SETUP.call_once(|| {
438        std::panic::set_hook(Box::new(|info| {
439            let file = info.location().unwrap().file();
440            let line = info.location().unwrap().line();
441            let col = info.location().unwrap().column();
442
443            let msg = match info.payload().downcast_ref::<&'static str>() {
444                Some(s) => *s,
445                None => match info.payload().downcast_ref::<String>() {
446                    Some(s) => &s[..],
447                    None => "Box<Any>",
448                },
449            };
450
451            let err_info = format!("Panicked at '{msg}', {file}:{line}:{col}");
452            ic0::debug_print(err_info.as_bytes());
453            ic0::trap(err_info.as_bytes());
454        }));
455    });
456}