1use std::{future::poll_fn, pin::Pin, task::Poll};
4
5use nonmax::NonMaxU32;
6use slab::Slab;
7
8use crate::{
9 context::{MAX_TASK_ID, TaskId, current_tasks, current_wake_sets},
10 sync::oneshot,
11};
12
13pub(crate) type Tasks = Slab<Pin<Box<dyn Future<Output = ()>>>>;
14
15#[derive(Debug, PartialEq, Eq)]
17pub struct Cancelled;
18
19pub struct JoinHandle<T> {
22 rx: Option<oneshot::Receiver<T>>,
23}
24
25impl<T> Future for JoinHandle<T> {
26 type Output = Result<T, Cancelled>;
27
28 fn poll(
29 mut self: std::pin::Pin<&mut Self>,
30 cx: &mut std::task::Context<'_>,
31 ) -> std::task::Poll<Self::Output> {
32 let rx = self
33 .rx
34 .as_mut()
35 .expect("JoinHandle polled after completion");
36 match rx.poll_recv(cx) {
37 Poll::Ready(result) => {
38 self.rx = None;
39 Poll::Ready(result.map_err(|_| Cancelled))
40 }
41 Poll::Pending => Poll::Pending,
42 }
43 }
44}
45
46pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
48 let (tx, rx) = oneshot::channel();
49 let handle = JoinHandle { rx: Some(rx) };
50
51 let wrapper = async move {
52 _ = tx.send(fut.await);
54 };
55
56 let (tasks, wake_sets) = unsafe { (current_tasks(), current_wake_sets()) };
58 let index = tasks.insert(Box::pin(wrapper));
59 assert!(index <= MAX_TASK_ID as usize);
60 let task_id = TaskId::Task(unsafe { NonMaxU32::new_unchecked(index as u32) });
62
63 wake_sets.staging.insert(task_id);
64
65 handle
66}
67
68pub async fn yield_now() {
71 let mut yielded = false;
72 poll_fn(|cx| {
73 if yielded {
74 Poll::Ready(())
75 } else {
76 yielded = true;
77 cx.waker().wake_by_ref();
78 Poll::Pending
79 }
80 })
81 .await
82}
83
84#[cfg(test)]
85mod tests {
86 use tempest_io::VirtualIo;
87
88 use crate::block_on;
89
90 use super::*;
91
92 #[test]
93 fn test_spawn_completes() {
94 block_on(VirtualIo::default(), async {
95 let handle = spawn(async { 42 });
96 assert_eq!(handle.await, Ok(42));
97 });
98 }
99
100 #[test]
101 fn test_spawn_cancelled() {
102 block_on(VirtualIo::default(), async {
103 let handle = spawn(async { 42 });
104 drop(handle);
105 });
107 }
108
109 #[test]
110 fn test_spawn_runs_concurrently() {
111 block_on(VirtualIo::default(), async {
112 let handle_a = spawn(async { 1 });
113 let handle_b = spawn(async { 2 });
114 assert_eq!(handle_a.await, Ok(1));
115 assert_eq!(handle_b.await, Ok(2));
116 });
117 }
118
119 #[test]
120 fn test_yield_now() {
121 block_on(VirtualIo::default(), yield_now());
122 }
123}