use std::cell::RefCell;
use std::future::Future;
use std::marker::PhantomData;
use std::panic::AssertUnwindSafe;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Wake, Waker};
use std::thread::{self, LocalKey, ThreadId};
use crate::builtin::{Callable, Variant};
use crate::private::handle_panic;
#[doc(alias = "async")]
pub fn spawn(future: impl Future<Output = ()> + 'static) -> TaskHandle {
assert!(
crate::init::is_main_thread(),
"godot_task() can only be used on the main thread"
);
let (task_handle, godot_waker) = ASYNC_RUNTIME.with_runtime_mut(move |rt| {
let task_handle = rt.add_task(Box::pin(future));
let godot_waker = Arc::new(GodotWaker::new(
task_handle.index,
task_handle.id,
thread::current().id(),
));
(task_handle, godot_waker)
});
poll_future(godot_waker);
task_handle
}
pub struct TaskHandle {
index: usize,
id: u64,
_no_send_sync: PhantomData<*const ()>,
}
impl TaskHandle {
fn new(index: usize, id: u64) -> Self {
Self {
index,
id,
_no_send_sync: PhantomData,
}
}
pub fn cancel(self) {
ASYNC_RUNTIME.with_runtime_mut(|rt| {
let Some(task) = rt.tasks.get(self.index) else {
return;
};
let alive = match task.value {
FutureSlotState::Empty => {
panic!("Future slot is empty when canceling it! This is a bug!")
}
FutureSlotState::Gone => false,
FutureSlotState::Pending(_) => task.id == self.id,
FutureSlotState::Polling => panic!("Can not cancel future from inside it!"),
};
if !alive {
return;
}
rt.clear_task(self.index);
})
}
pub fn is_pending(&self) -> bool {
ASYNC_RUNTIME.with_runtime(|rt| {
let slot = rt
.tasks
.get(self.index)
.unwrap_or_else(|| unreachable!("missing future slot at index {}", self.index));
if slot.id != self.id {
return false;
}
matches!(
slot.value,
FutureSlotState::Pending(_) | FutureSlotState::Polling
)
})
}
}
const ASYNC_RUNTIME_DEINIT_PANIC_MESSAGE: &str = "The async runtime is being accessed after it has been deinitialized. This should not be possible and is most likely a bug.";
thread_local! {
static ASYNC_RUNTIME: RefCell<Option<AsyncRuntime>> = RefCell::new(Some(AsyncRuntime::new()));
}
pub(crate) fn cleanup() {
ASYNC_RUNTIME.set(None);
}
#[cfg(feature = "trace")] #[cfg_attr(published_docs, doc(cfg(feature = "trace")))]
pub fn has_godot_task_panicked(task_handle: TaskHandle) -> bool {
ASYNC_RUNTIME.with_runtime(|rt| rt.panicked_tasks.contains(&task_handle.id))
}
enum FutureSlotState<T> {
Empty,
Gone,
Pending(T),
Polling,
}
struct FutureSlot<T> {
value: FutureSlotState<T>,
id: u64,
}
impl<T> FutureSlot<T> {
fn pending(id: u64, value: T) -> Self {
Self {
value: FutureSlotState::Pending(value),
id,
}
}
fn is_empty(&self) -> bool {
matches!(self.value, FutureSlotState::Empty | FutureSlotState::Gone)
}
fn clear(&mut self) {
self.value = FutureSlotState::Gone;
}
fn take_for_polling(&mut self, id: u64) -> FutureSlotState<T> {
match self.value {
FutureSlotState::Empty => FutureSlotState::Empty,
FutureSlotState::Polling => FutureSlotState::Polling,
FutureSlotState::Gone => FutureSlotState::Gone,
FutureSlotState::Pending(_) if self.id != id => FutureSlotState::Gone,
FutureSlotState::Pending(_) => {
std::mem::replace(&mut self.value, FutureSlotState::Polling)
}
}
}
fn park(&mut self, value: T) {
match self.value {
FutureSlotState::Empty | FutureSlotState::Gone => {
panic!("cannot park future in slot which is unoccupied")
}
FutureSlotState::Pending(_) => {
panic!(
"cannot park future in slot, which is already occupied by a different future"
)
}
FutureSlotState::Polling => {
self.value = FutureSlotState::Pending(value);
}
}
}
}
#[derive(Default)]
struct AsyncRuntime {
tasks: Vec<FutureSlot<Pin<Box<dyn Future<Output = ()>>>>>,
next_task_id: u64,
#[cfg(feature = "trace")] #[cfg_attr(published_docs, doc(cfg(feature = "trace")))]
panicked_tasks: std::collections::HashSet<u64>,
}
impl AsyncRuntime {
fn new() -> Self {
Self {
tasks: Vec::with_capacity(16),
next_task_id: 0,
#[cfg(feature = "trace")] #[cfg_attr(published_docs, doc(cfg(feature = "trace")))]
panicked_tasks: std::collections::HashSet::default(),
}
}
fn next_id(&mut self) -> u64 {
let id = self.next_task_id;
self.next_task_id += 1;
id
}
fn add_task<F: Future<Output = ()> + 'static>(&mut self, future: F) -> TaskHandle {
let id = self.next_id();
let index_slot = self
.tasks
.iter_mut()
.enumerate()
.find(|(_, slot)| slot.is_empty());
let boxed = Box::pin(future);
let index = match index_slot {
Some((index, slot)) => {
*slot = FutureSlot::pending(id, boxed);
index
}
None => {
self.tasks.push(FutureSlot::pending(id, boxed));
self.tasks.len() - 1
}
};
TaskHandle::new(index, id)
}
fn take_task_for_polling(
&mut self,
index: usize,
id: u64,
) -> FutureSlotState<Pin<Box<dyn Future<Output = ()> + 'static>>> {
let slot = self.tasks.get_mut(index);
slot.map(|inner| inner.take_for_polling(id))
.unwrap_or(FutureSlotState::Empty)
}
fn clear_task(&mut self, index: usize) {
self.tasks[index].clear();
}
fn park_task(&mut self, index: usize, future: Pin<Box<dyn Future<Output = ()>>>) {
self.tasks[index].park(future);
}
#[cfg(feature = "trace")] #[cfg_attr(published_docs, doc(cfg(feature = "trace")))]
fn track_panic(&mut self, task_id: u64) {
self.panicked_tasks.insert(task_id);
}
}
trait WithRuntime {
fn with_runtime<R>(&'static self, f: impl FnOnce(&AsyncRuntime) -> R) -> R;
fn with_runtime_mut<R>(&'static self, f: impl FnOnce(&mut AsyncRuntime) -> R) -> R;
}
impl WithRuntime for LocalKey<RefCell<Option<AsyncRuntime>>> {
fn with_runtime<R>(&'static self, f: impl FnOnce(&AsyncRuntime) -> R) -> R {
self.with_borrow(|rt| {
let rt_ref = rt.as_ref().expect(ASYNC_RUNTIME_DEINIT_PANIC_MESSAGE);
f(rt_ref)
})
}
fn with_runtime_mut<R>(&'static self, f: impl FnOnce(&mut AsyncRuntime) -> R) -> R {
self.with_borrow_mut(|rt| {
let rt_ref = rt.as_mut().expect(ASYNC_RUNTIME_DEINIT_PANIC_MESSAGE);
f(rt_ref)
})
}
}
fn poll_future(godot_waker: Arc<GodotWaker>) {
let current_thread = thread::current().id();
assert_eq!(
godot_waker.thread_id, current_thread,
"trying to poll future on a different thread!\n Current thread: {:?}\n Future thread: {:?}",
current_thread, godot_waker.thread_id,
);
let waker = Waker::from(godot_waker.clone());
let mut ctx = Context::from_waker(&waker);
let future = ASYNC_RUNTIME.with_runtime_mut(|rt| {
match rt.take_task_for_polling(godot_waker.runtime_index, godot_waker.task_id) {
FutureSlotState::Empty => {
panic!("Future slot is empty when waking it! This is a bug!");
}
FutureSlotState::Gone => None,
FutureSlotState::Polling => {
unreachable!("the same GodotWaker has been called recursively");
}
FutureSlotState::Pending(future) => Some(future),
}
});
let Some(future) = future else {
return;
};
let error_context = || format!("async task #{}", godot_waker.task_id);
let mut future = AssertUnwindSafe(future);
let panic_result = handle_panic(error_context, move || {
(future.as_mut().poll(&mut ctx), future)
});
let Ok((poll_result, future)) = panic_result else {
ASYNC_RUNTIME.with_runtime_mut(|rt| {
#[cfg(feature = "trace")] #[cfg_attr(published_docs, doc(cfg(feature = "trace")))]
rt.track_panic(godot_waker.task_id);
rt.clear_task(godot_waker.runtime_index);
});
return;
};
ASYNC_RUNTIME.with_runtime_mut(|rt| match poll_result {
Poll::Pending => rt.park_task(godot_waker.runtime_index, future.0),
Poll::Ready(()) => rt.clear_task(godot_waker.runtime_index),
});
}
struct GodotWaker {
runtime_index: usize,
task_id: u64,
thread_id: ThreadId,
}
impl GodotWaker {
fn new(index: usize, task_id: u64, thread_id: ThreadId) -> Self {
Self {
runtime_index: index,
thread_id,
task_id,
}
}
}
impl Wake for GodotWaker {
fn wake(self: Arc<Self>) {
let mut waker = Some(self);
fn callback_type_hint<F>(f: F) -> F
where
F: for<'a> FnMut(&'a [&Variant]) -> Variant,
{
f
}
#[cfg(not(feature = "experimental-threads"))] #[cfg_attr(published_docs, doc(cfg(not(feature = "experimental-threads"))))]
let create_callable = Callable::from_fn;
#[cfg(feature = "experimental-threads")] #[cfg_attr(published_docs, doc(cfg(feature = "experimental-threads")))]
let create_callable = Callable::from_sync_fn;
let callable = create_callable(
"GodotWaker::wake",
callback_type_hint(move |_args| {
poll_future(waker.take().expect("Callable will never be called again"));
Variant::nil()
}),
);
callable.call_deferred(&[]);
}
}