ic_cdk_executor/
lib.rs

1//! An async executor for [`ic-cdk`](https://docs.rs/ic-cdk). Most users should not use this crate directly.
2
3use std::cell::{Cell, RefCell};
4use std::collections::VecDeque;
5use std::future::Future;
6use std::mem;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll, Wake, Waker};
10
11use slotmap::{new_key_type, SlotMap};
12
13/// Spawn an asynchronous task to run in the background.
14pub fn spawn<F: 'static + Future<Output = ()>>(future: F) {
15    let in_query = match CONTEXT.get() {
16        AsyncContext::None => panic!("`spawn` can only be called from an executor context"),
17        AsyncContext::Query => true,
18        AsyncContext::Update => false,
19        AsyncContext::Cancel => panic!("`spawn` cannot be called during panic recovery"),
20        AsyncContext::FromTask => unreachable!("FromTask"),
21    };
22    let pinned_future = Box::pin(future);
23    let task = Task {
24        future: pinned_future,
25        query: in_query,
26    };
27    let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
28    WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(task_id));
29}
30
31/// Execute an update function in a context that allows calling [`spawn`] and notifying wakers.
32pub fn in_executor_context<R>(f: impl FnOnce() -> R) -> R {
33    let _guard = ContextGuard::new(AsyncContext::Update);
34    let res = f();
35    poll_all();
36    res
37}
38
39/// Execute a composite query function in a context that allows calling [`spawn`] and notifying wakers.
40pub fn in_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
41    let _guard = ContextGuard::new(AsyncContext::Query);
42    let res = f();
43    poll_all();
44    res
45}
46
47/// Execute an inter-canister-call callback in a context that allows calling [`spawn`] and notifying wakers.
48pub fn in_callback_executor_context(f: impl FnOnce()) {
49    let _guard = ContextGuard::new(AsyncContext::FromTask);
50    f();
51    poll_all();
52}
53
54/// Execute an inter-canister-call callback in a context that allows calling [`spawn`] and notifying wakers,
55/// but will cancel every awoken future.
56pub fn in_callback_cancellation_context(f: impl FnOnce()) {
57    let _guard = ContextGuard::new(AsyncContext::Cancel);
58    f();
59}
60
61/// Tells you whether the current async fn is being canceled due to a trap/panic.
62pub fn is_recovering_from_trap() -> bool {
63    matches!(CONTEXT.get(), AsyncContext::Cancel)
64}
65
66fn poll_all() {
67    let in_query = match CONTEXT.get() {
68        AsyncContext::Query => true,
69        AsyncContext::Update => false,
70        AsyncContext::None => panic!("tasks can only be polled in an executor context"),
71        AsyncContext::FromTask => unreachable!("FromTask"),
72        AsyncContext::Cancel => unreachable!("poll_all should not be called during panic recovery"),
73    };
74    let mut ineligible = vec![];
75    while let Some(task_id) = WAKEUP.with_borrow_mut(|queue| queue.pop_front()) {
76        // Temporarily remove the task from the table. We need to execute it while `TASKS` is not borrowed, because it may schedule more tasks.
77        let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(mem::take))
78        else {
79            // This waker handle appears to be dead. The most likely cause is that the method returned before
80            // a canceled call came back.
81            continue;
82            // In the case that a task panicked and that's why it's missing, but it was in an earlier callback so a later
83            // one tries to re-wake, the responsibility for re-trapping lies with CallFuture.
84        };
85        if in_query && !task.query {
86            TASKS.with_borrow_mut(|tasks| tasks[task_id] = task);
87            ineligible.push(task_id);
88            continue;
89        }
90        let waker = Waker::from(Arc::new(TaskWaker {
91            task_id,
92            query: task.query,
93        }));
94        let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
95        match poll {
96            Poll::Pending => {
97                // more to do, put the task back in the table
98                TASKS.with_borrow_mut(|tasks| {
99                    if let Some(t) = tasks.get_mut(task_id) {
100                        *t = task;
101                    }
102                });
103            }
104            Poll::Ready(()) => {
105                // task complete, remove its entry from the table fully
106                TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
107            }
108        }
109    }
110    if !ineligible.is_empty() {
111        WAKEUP.with_borrow_mut(|wakeup| wakeup.extend(ineligible));
112    }
113}
114
115new_key_type! {
116    struct TaskId;
117}
118
119thread_local! {
120    static TASKS: RefCell<SlotMap<TaskId, Task>> = <_>::default();
121    static WAKEUP: RefCell<VecDeque<TaskId>> = <_>::default();
122    static CONTEXT: Cell<AsyncContext> = <_>::default();
123}
124
125#[derive(Default, Copy, Clone, PartialEq, Eq)]
126enum AsyncContext {
127    #[default]
128    None,
129    Update,
130    Query,
131    FromTask,
132    Cancel,
133}
134
135struct Task {
136    future: Pin<Box<dyn Future<Output = ()>>>,
137    query: bool,
138}
139
140impl Default for Task {
141    fn default() -> Self {
142        Self {
143            future: Box::pin(std::future::pending()),
144            query: false,
145        }
146    }
147}
148
149struct ContextGuard(());
150
151impl ContextGuard {
152    fn new(context: AsyncContext) -> Self {
153        CONTEXT.with(|context_var| {
154            assert!(
155                matches!(context_var.get(), AsyncContext::None),
156                "in_*_context called within an existing async context"
157            );
158            context_var.set(context);
159            Self(())
160        })
161    }
162}
163
164impl Drop for ContextGuard {
165    fn drop(&mut self) {
166        CONTEXT.set(AsyncContext::None);
167    }
168}
169
170/// Waker implementation for executing futures produced by `call`/`call_raw`/etc.
171///
172/// *Almost* a straightforward executor, i.e. wakeups are addressed immediately for everything,
173/// except it attempts to clean up tasks whose execution has trapped - see `call::is_recovering_from_trap`.
174#[derive(Clone)]
175struct TaskWaker {
176    task_id: TaskId,
177    query: bool,
178}
179
180impl Wake for TaskWaker {
181    fn wake(self: Arc<Self>) {
182        let context = CONTEXT.get();
183        assert!(
184            context != AsyncContext::None,
185            "wakers cannot be called outside an executor context"
186        );
187        if context == AsyncContext::Cancel {
188            // This task is recovering from a trap. We cancel it to run destructors.
189            let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(self.task_id));
190            // _task must be dropped *outside* with_borrow_mut - its destructor may (inadvisably) schedule tasks
191        } else {
192            WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(self.task_id));
193            if context == AsyncContext::FromTask {
194                if self.query {
195                    CONTEXT.set(AsyncContext::Query)
196                } else {
197                    CONTEXT.set(AsyncContext::Update)
198                }
199            }
200        }
201    }
202}