use std::{future::poll_fn, pin::Pin, task::Poll};
use nonmax::NonMaxU32;
use slab::Slab;
use crate::{
context::{MAX_TASK_ID, TaskId, current_tasks, current_wake_sets},
sync::oneshot,
};
pub(crate) type Tasks = Slab<Pin<Box<dyn Future<Output = ()>>>>;
#[derive(Debug, PartialEq, Eq)]
pub struct Cancelled;
pub struct JoinHandle<T> {
rx: Option<oneshot::Receiver<T>>,
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, Cancelled>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let rx = self
.rx
.as_mut()
.expect("JoinHandle polled after completion");
match rx.poll_recv(cx) {
Poll::Ready(result) => {
self.rx = None;
Poll::Ready(result.map_err(|_| Cancelled))
}
Poll::Pending => Poll::Pending,
}
}
}
pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
let (tx, rx) = oneshot::channel();
let handle = JoinHandle { rx: Some(rx) };
let wrapper = async move {
_ = tx.send(fut.await);
};
let (tasks, wake_sets) = unsafe { (current_tasks(), current_wake_sets()) };
let index = tasks.insert(Box::pin(wrapper));
assert!(index <= MAX_TASK_ID as usize);
let task_id = TaskId::Task(unsafe { NonMaxU32::new_unchecked(index as u32) });
wake_sets.staging.insert(task_id);
handle
}
pub async fn yield_now() {
let mut yielded = false;
poll_fn(|cx| {
if yielded {
Poll::Ready(())
} else {
yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
})
.await
}
#[cfg(test)]
mod tests {
use tempest_io::VirtualIo;
use crate::block_on;
use super::*;
#[test]
fn test_spawn_completes() {
block_on(VirtualIo::default(), async {
let handle = spawn(async { 42 });
assert_eq!(handle.await, Ok(42));
});
}
#[test]
fn test_spawn_cancelled() {
block_on(VirtualIo::default(), async {
let handle = spawn(async { 42 });
drop(handle);
});
}
#[test]
fn test_spawn_runs_concurrently() {
block_on(VirtualIo::default(), async {
let handle_a = spawn(async { 1 });
let handle_b = spawn(async { 2 });
assert_eq!(handle_a.await, Ok(1));
assert_eq!(handle_b.await, Ok(2));
});
}
#[test]
fn test_yield_now() {
block_on(VirtualIo::default(), yield_now());
}
}