use std::{
cell::{Cell, RefCell},
collections::VecDeque,
future::Future,
mem::take,
pin::Pin,
sync::{Arc, Once},
task::{Context, Poll, Wake, Waker},
};
use slotmap::{Key, SecondaryMap, SlotMap, new_key_type};
use smallvec::SmallVec;
#[derive(Clone, Debug)]
pub(crate) struct MethodContext {
pub(crate) kind: ContextKind,
pub(crate) handles: usize,
pub(crate) tasks: SmallVec<[TaskId; 4]>,
}
impl MethodContext {
pub(crate) fn new_update() -> Self {
Self {
kind: ContextKind::Update,
handles: 0,
tasks: SmallVec::new(),
}
}
pub(crate) fn new_query() -> Self {
Self {
kind: ContextKind::Query,
handles: 0,
tasks: SmallVec::new(),
}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub(crate) enum ContextKind {
Update,
Query,
}
new_key_type! {
pub(crate) struct MethodId;
pub(crate) struct TaskId;
}
thread_local! {
pub(crate) static METHODS: RefCell<SlotMap<MethodId, MethodContext>> = RefCell::default();
pub(crate) static TASKS: RefCell<SlotMap<TaskId, Task>> = RefCell::default();
pub(crate) static PROTECTED_WAKEUPS: RefCell<SecondaryMap<MethodId, VecDeque<TaskId>>> = RefCell::default();
pub(crate) static MIGRATORY_WAKEUPS: RefCell<VecDeque<TaskId>> = const { RefCell::new(VecDeque::new()) };
pub(crate) static CURRENT_METHOD: Cell<Option<MethodId>> = const { Cell::new(None) };
pub(crate) static RECOVERING: Cell<bool> = const { Cell::new(false) };
pub(crate) static CURRENT_TASK_ID: Cell<Option<TaskId>> = const { Cell::new(None) };
}
pub(crate) struct Task {
future: Pin<Box<dyn Future<Output = ()>>>,
method_binding: Option<MethodId>,
set_current_method_var: MethodId,
}
impl Default for Task {
fn default() -> Self {
Self {
future: Box::pin(std::future::pending()),
method_binding: None,
set_current_method_var: MethodId::null(),
}
}
}
pub fn in_tracking_executor_context<R>(f: impl FnOnce() -> R) -> R {
setup_panic_hook();
let method = METHODS.with_borrow_mut(|methods| methods.insert(MethodContext::new_update()));
let guard = MethodHandle::for_method(method);
enter_current_method(guard, || {
let res = f();
poll_all();
res
})
}
#[expect(dead_code)] pub(crate) fn in_null_context<R>(f: impl FnOnce() -> R) -> R {
setup_panic_hook();
let guard = MethodHandle::for_method(MethodId::null());
enter_current_method(guard, || {
let res = f();
poll_all();
res
})
}
pub fn in_tracking_query_executor_context<R>(f: impl FnOnce() -> R) -> R {
setup_panic_hook();
let method = METHODS.with_borrow_mut(|methods| methods.insert(MethodContext::new_query()));
let guard = MethodHandle::for_method(method);
enter_current_method(guard, || {
let res = f();
poll_all();
res
})
}
pub fn in_callback_executor_context_for<R>(
method_handle: MethodHandle,
f: impl FnOnce() -> R,
) -> R {
setup_panic_hook();
enter_current_method(method_handle, || {
let res = f();
poll_all();
res
})
}
pub fn in_trap_recovery_context_for<R>(method: MethodHandle, f: impl FnOnce() -> R) -> R {
setup_panic_hook();
enter_current_method(method, || {
RECOVERING.set(true);
let res = f();
RECOVERING.set(false);
res
})
}
pub fn cancel_all_tasks_attached_to_current_method() {
let Some(method_id) = CURRENT_METHOD.get() else {
panic!(
"`cancel_all_tasks_attached_to_current_method` can only be called within a method context"
);
};
cancel_all_tasks_attached_to_method(method_id);
}
fn cancel_all_tasks_attached_to_method(method_id: MethodId) {
let Some(to_cancel) = METHODS.with_borrow_mut(|methods| {
methods
.get_mut(method_id)
.map(|method| take(&mut method.tasks))
}) else {
return; };
let _tasks = TASKS.with(|tasks| {
let Ok(mut tasks) = tasks.try_borrow_mut() else {
panic!(
"`cancel_all_tasks_attached_to_current_method` cannot be called from an async task"
);
};
let mut canceled = Vec::with_capacity(to_cancel.len());
for task_id in to_cancel {
canceled.push(tasks.remove(task_id));
}
canceled
});
drop(_tasks); }
pub(crate) fn delete_task(task_id: TaskId) {
let _task = TASKS.with_borrow_mut(|tasks| tasks.remove(task_id));
drop(_task); }
pub fn cancel_task(task_handle: &TaskHandle) {
delete_task(task_handle.task_id);
}
pub fn is_recovering_from_trap() -> bool {
RECOVERING.get()
}
pub fn extend_current_method_context() -> MethodHandle {
setup_panic_hook();
let Some(method_id) = CURRENT_METHOD.get() else {
panic!("`extend_method_context` can only be called within a tracking executor context");
};
MethodHandle::for_method(method_id)
}
pub(crate) fn poll_all() {
let Some(method_id) = CURRENT_METHOD.get() else {
panic!("tasks can only be polled within an executor context");
};
let kind = METHODS
.with_borrow(|methods| methods.get(method_id).map(|m| m.kind))
.unwrap_or(ContextKind::Update);
fn pop_wakeup(method_id: MethodId, update: bool) -> Option<TaskId> {
if let Some(task_id) = PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
wakeups
.get_mut(method_id)
.and_then(|queue| queue.pop_front())
}) {
Some(task_id)
} else if update {
MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| unattached.pop_front())
} else {
None
}
}
while let Some(task_id) = pop_wakeup(method_id, kind == ContextKind::Update) {
let Some(mut task) = TASKS.with_borrow_mut(|tasks| tasks.get_mut(task_id).map(take)) else {
continue;
};
let waker = Waker::from(Arc::new(TaskWaker { task_id }));
let prev_current_method_var = CURRENT_METHOD.replace(Some(task.set_current_method_var));
CURRENT_TASK_ID.set(Some(task_id));
let poll = task.future.as_mut().poll(&mut Context::from_waker(&waker));
CURRENT_TASK_ID.set(None);
CURRENT_METHOD.set(prev_current_method_var);
match poll {
Poll::Pending => {
TASKS.with_borrow_mut(|tasks| {
if let Some(t) = tasks.get_mut(task_id) {
*t = task;
}
});
}
Poll::Ready(()) => {
delete_task(task_id);
}
}
}
}
pub(crate) fn enter_current_method<R>(method_guard: MethodHandle, f: impl FnOnce() -> R) -> R {
CURRENT_METHOD.with(|context_var| {
assert!(
context_var.get().is_none(),
"in_*_context called within an existing async context"
);
context_var.set(Some(method_guard.method_id));
});
let r = f();
drop(method_guard); let method_id = CURRENT_METHOD.replace(None);
if let Some(method_id) = method_id {
let handles = METHODS.with_borrow_mut(|methods| methods.get(method_id).map(|m| m.handles));
if handles == Some(0) {
cancel_all_tasks_attached_to_method(method_id);
METHODS.with_borrow_mut(|methods| methods.remove(method_id));
}
}
r
}
#[derive(Debug)]
pub struct MethodHandle {
method_id: MethodId,
}
impl MethodHandle {
pub(crate) fn for_method(method_id: MethodId) -> Self {
if method_id.is_null() {
return Self { method_id };
}
METHODS.with_borrow_mut(|methods| {
let Some(method) = methods.get_mut(method_id) else {
panic!("internal error: method context deleted while in use (for_method)");
};
method.handles += 1;
});
Self { method_id }
}
}
impl Drop for MethodHandle {
fn drop(&mut self) {
METHODS.with_borrow_mut(|methods| {
if let Some(method) = methods.get_mut(self.method_id) {
method.handles -= 1;
}
})
}
}
#[derive(Debug)]
pub struct TaskHandle {
task_id: TaskId,
}
impl TaskHandle {
pub fn current() -> Option<Self> {
let task_id = CURRENT_TASK_ID.get()?;
Some(Self { task_id })
}
}
pub(crate) struct TaskWaker {
pub(crate) task_id: TaskId,
}
impl Wake for TaskWaker {
fn wake(self: Arc<Self>) {
TASKS.with_borrow_mut(|tasks| {
if let Some(task) = tasks.get(self.task_id) {
if let Some(method_id) = task.method_binding {
PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
if let Some(entry) = wakeups.entry(method_id) {
entry.or_default().push_back(self.task_id);
}
});
} else {
MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| {
unattached.push_back(self.task_id);
});
}
}
})
}
}
pub fn spawn_migratory(f: impl Future<Output = ()> + 'static) -> TaskHandle {
setup_panic_hook();
let Some(method_id) = CURRENT_METHOD.get() else {
panic!("`spawn_*` can only be called within an executor context");
};
if is_recovering_from_trap() {
panic!("tasks cannot be spawned while recovering from a trap");
}
let kind = METHODS
.with_borrow(|methods| methods.get(method_id).map(|m| m.kind))
.unwrap_or(ContextKind::Update);
if kind == ContextKind::Query {
panic!("unprotected spawns cannot be made within a query context");
}
let task = Task {
future: Box::pin(f),
method_binding: None,
set_current_method_var: MethodId::null(),
};
let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
MIGRATORY_WAKEUPS.with_borrow_mut(|unattached| {
unattached.push_back(task_id);
});
TaskHandle { task_id }
}
pub fn spawn_protected(f: impl Future<Output = ()> + 'static) -> TaskHandle {
setup_panic_hook();
if is_recovering_from_trap() {
panic!("tasks cannot be spawned while recovering from a trap");
}
let Some(method_id) = CURRENT_METHOD.get() else {
panic!("`spawn_*` can only be called within an executor context");
};
if method_id.is_null() {
panic!("`spawn_protected` cannot be called outside of a tracked method context");
}
let task = Task {
future: Box::pin(f),
method_binding: Some(method_id),
set_current_method_var: method_id,
};
let task_id = TASKS.with_borrow_mut(|tasks| tasks.insert(task));
METHODS.with_borrow_mut(|methods| {
let Some(method) = methods.get_mut(method_id) else {
panic!("internal error: method context deleted while in use (spawn_protected)");
};
method.tasks.push(task_id);
});
PROTECTED_WAKEUPS.with_borrow_mut(|wakeups| {
let Some(entry) = wakeups.entry(method_id) else {
panic!("internal error: method context deleted while in use (spawn_protected)");
};
entry.or_default().push_back(task_id);
});
TaskHandle { task_id }
}
fn setup_panic_hook() {
static SETUP: Once = Once::new();
SETUP.call_once(|| {
std::panic::set_hook(Box::new(|info| {
let file = info.location().unwrap().file();
let line = info.location().unwrap().line();
let col = info.location().unwrap().column();
let msg = match info.payload().downcast_ref::<&'static str>() {
Some(s) => *s,
None => match info.payload().downcast_ref::<String>() {
Some(s) => &s[..],
None => "Box<Any>",
},
};
let err_info = format!("Panicked at '{msg}', {file}:{line}:{col}");
ic0::debug_print(err_info.as_bytes());
ic0::trap(err_info.as_bytes());
}));
});
}