ic_cdk_executor/
lib.rs

1//! An async executor for [`ic-cdk`](https://docs.rs/ic-cdk). Most users should not use this crate directly.
2
3use std::cell::{Cell, RefCell};
4use std::collections::VecDeque;
5use std::future::Future;
6use std::mem;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll, Wake, Waker};
10
11use slotmap::{new_key_type, SlotMap};
12
13/// Spawn an asynchronous task to run in the background.
14pub fn spawn<F: 'static + Future<Output = ()>>(future: F) {
15    let in_query = match CONTEXT.get() {
16        AsyncContext::None => panic!("`spawn` can only be called from an executor context"),
17        AsyncContext::Query => true,
18        AsyncContext::Update | AsyncContext::Cancel => false,
19        AsyncContext::FromTask => unreachable!("FromTask"),
20    };
21    let pinned_future = Box::pin(future);
22    let task = Task {
23        future: pinned_future,
24        query: in_query,
25    };
26    let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
27    WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(task_id));
28}
29
30/// Execute an update function in a context that allows calling [`spawn`] and notifying wakers.
31pub fn in_executor_context<R>(f: impl FnOnce() -> R) -> R {
32    let _guard = ContextGuard::new(AsyncContext::Update);
33    let res = f();
34    poll_all();
35    res
36}
37
38/// Execute a composite query function in a context that allows calling [`spawn`] and notifying wakers.
39pub fn in_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
40    let _guard = ContextGuard::new(AsyncContext::Query);
41    let res = f();
42    poll_all();
43    res
44}
45
46/// Execute an inter-canister-call callback in a context that allows calling [`spawn`] and notifying wakers.
47pub fn in_callback_executor_context(f: impl FnOnce()) {
48    let _guard = ContextGuard::new(AsyncContext::FromTask);
49    f();
50    poll_all();
51}
52
53/// Execute an inter-canister-call callback in a context that allows calling [`spawn`] and notifying wakers,
54/// but will cancel every awoken future.
55pub fn in_callback_cancellation_context(f: impl FnOnce()) {
56    let _guard = ContextGuard::new(AsyncContext::Cancel);
57    f();
58}
59
60/// Tells you whether the current async fn is being canceled due to a trap/panic.
61pub fn is_recovering_from_trap() -> bool {
62    matches!(CONTEXT.get(), AsyncContext::Cancel)
63}
64
65fn poll_all() {
66    let in_query = match CONTEXT.get() {
67        AsyncContext::Query => true,
68        AsyncContext::Update | AsyncContext::Cancel => false,
69        AsyncContext::None => panic!("tasks can only be polled in an executor context"),
70        AsyncContext::FromTask => unreachable!("FromTask"),
71    };
72    let mut ineligible = vec![];
73    while let Some(task_id) = WAKEUP.with_borrow_mut(|queue| queue.pop_front()) {
74        // Temporarily remove the task from the table. We need to execute it while `TASKS` is not borrowed, because it may schedule more tasks.
75        let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(mem::take))
76        else {
77            // The task is dropped on the first callback that panics, but the last callback is the one that sets the flag.
78            // So if multiple calls are sent concurrently, the waker will be asked to wake a future that no longer exists.
79            // This should be the only possible case in which this happens.
80            panic!("Call already trapped");
81            // This also should not happen because the CallFuture handles this itself. But FuturesUnordered introduces some chaos.
82        };
83        if in_query && !task.query {
84            TASKS.with_borrow_mut(|tasks| tasks[task_id] = task);
85            ineligible.push(task_id);
86            continue;
87        }
88        let waker = Waker::from(Arc::new(TaskWaker {
89            task_id,
90            query: task.query,
91        }));
92        let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
93        match poll {
94            Poll::Pending => {
95                // more to do, put the task back in the table
96                TASKS.with_borrow_mut(|tasks| {
97                    if let Some(t) = tasks.get_mut(task_id) {
98                        *t = task;
99                    }
100                });
101            }
102            Poll::Ready(()) => {
103                // task complete, remove its entry from the table fully
104                TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
105            }
106        }
107    }
108    if !ineligible.is_empty() {
109        WAKEUP.with_borrow_mut(|wakeup| wakeup.extend(ineligible));
110    }
111}
112
113new_key_type! {
114    struct TaskId;
115}
116
117thread_local! {
118    static TASKS: RefCell<SlotMap<TaskId, Task>> = <_>::default();
119    static WAKEUP: RefCell<VecDeque<TaskId>> = <_>::default();
120    static CONTEXT: Cell<AsyncContext> = <_>::default();
121}
122
123#[derive(Default, Copy, Clone)]
124enum AsyncContext {
125    #[default]
126    None,
127    Update,
128    Query,
129    FromTask,
130    Cancel,
131}
132
133struct Task {
134    future: Pin<Box<dyn Future<Output = ()>>>,
135    query: bool,
136}
137
138impl Default for Task {
139    fn default() -> Self {
140        Self {
141            future: Box::pin(std::future::pending()),
142            query: false,
143        }
144    }
145}
146
147struct ContextGuard(());
148
149impl ContextGuard {
150    fn new(context: AsyncContext) -> Self {
151        CONTEXT.with(|context_var| {
152            assert!(
153                matches!(context_var.get(), AsyncContext::None),
154                "in_*_context called within an existing async context"
155            );
156            context_var.set(context);
157            Self(())
158        })
159    }
160}
161
162impl Drop for ContextGuard {
163    fn drop(&mut self) {
164        CONTEXT.set(AsyncContext::None);
165    }
166}
167
168/// Waker implementation for executing futures produced by `call`/`call_raw`/etc.
169///
170/// *Almost* a straightforward executor, i.e. wakeups are addressed immediately for everything,
171/// except it attempts to clean up tasks whose execution has trapped - see `call::is_recovering_from_trap`.
172#[derive(Clone)]
173struct TaskWaker {
174    task_id: TaskId,
175    query: bool,
176}
177
178impl Wake for TaskWaker {
179    fn wake(self: Arc<Self>) {
180        if matches!(CONTEXT.get(), AsyncContext::Cancel) {
181            // This task is recovering from a trap. We cancel it to run destructors.
182            let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(self.task_id));
183            // _task must be dropped *outside* with_borrow_mut - its destructor may (inadvisably) schedule tasks
184        } else {
185            WAKEUP.with_borrow_mut(|wakeup| wakeup.push_back(self.task_id));
186            CONTEXT.with(|context| {
187                if matches!(context.get(), AsyncContext::FromTask) {
188                    if self.query {
189                        context.set(AsyncContext::Query)
190                    } else {
191                        context.set(AsyncContext::Update)
192                    }
193                }
194            })
195        }
196    }
197}