tempest-rt 0.0.1

TempestDB Deterministic Async Runtime
Documentation
//! Async task primitives: spawning concurrent tasks and yielding to the scheduler.

use std::{future::poll_fn, pin::Pin, task::Poll};

use nonmax::NonMaxU32;
use slab::Slab;

use crate::{
    context::{MAX_TASK_ID, TaskId, current_tasks, current_wake_sets},
    sync::oneshot,
};

pub(crate) type Tasks = Slab<Pin<Box<dyn Future<Output = ()>>>>;

/// Error returned by [`JoinHandle`] when the task's result was not collected before being dropped.
#[derive(Debug, PartialEq, Eq)]
pub struct Cancelled;

/// Handle to a spawned task. Awaiting it returns the task's output, or [`Cancelled`] if the
/// handle was dropped before the task completed.
pub struct JoinHandle<T> {
    rx: Option<oneshot::Receiver<T>>,
}

impl<T> Future for JoinHandle<T> {
    type Output = Result<T, Cancelled>;

    fn poll(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Self::Output> {
        let rx = self
            .rx
            .as_mut()
            .expect("JoinHandle polled after completion");
        match rx.poll_recv(cx) {
            Poll::Ready(result) => {
                self.rx = None;
                Poll::Ready(result.map_err(|_| Cancelled))
            }
            Poll::Pending => Poll::Pending,
        }
    }
}

/// Spawns `fut` as a concurrent task, returning a [`JoinHandle`] to collect its result.
pub fn spawn<T: 'static>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T> {
    let (tx, rx) = oneshot::channel();
    let handle = JoinHandle { rx: Some(rx) };

    let wrapper = async move {
        // we can ignore this error, since tasks do not have to be joined
        _ = tx.send(fut.await);
    };

    // SAFETY: we do not hold on to the references outside of this function
    let (tasks, wake_sets) = unsafe { (current_tasks(), current_wake_sets()) };
    let index = tasks.insert(Box::pin(wrapper));
    assert!(index <= MAX_TASK_ID as usize);
    // SAFETY: index < MAX_TASK_ID < u32::Max, so NonMaxU32 invariant holds
    let task_id = TaskId::Task(unsafe { NonMaxU32::new_unchecked(index as u32) });

    wake_sets.staging.insert(task_id);

    handle
}

/// Yields control back to the runtime for one tick, allowing other tasks and I/O completions
/// to be processed before this task resumes.
pub async fn yield_now() {
    let mut yielded = false;
    poll_fn(|cx| {
        if yielded {
            Poll::Ready(())
        } else {
            yielded = true;
            cx.waker().wake_by_ref();
            Poll::Pending
        }
    })
    .await
}

#[cfg(test)]
mod tests {
    use tempest_io::VirtualIo;

    use crate::block_on;

    use super::*;

    #[test]
    fn test_spawn_completes() {
        block_on(VirtualIo::default(), async {
            let handle = spawn(async { 42 });
            assert_eq!(handle.await, Ok(42));
        });
    }

    #[test]
    fn test_spawn_cancelled() {
        block_on(VirtualIo::default(), async {
            let handle = spawn(async { 42 });
            drop(handle);
            // task still runs to completion, just result is discarded
        });
    }

    #[test]
    fn test_spawn_runs_concurrently() {
        block_on(VirtualIo::default(), async {
            let handle_a = spawn(async { 1 });
            let handle_b = spawn(async { 2 });
            assert_eq!(handle_a.await, Ok(1));
            assert_eq!(handle_b.await, Ok(2));
        });
    }

    #[test]
    fn test_yield_now() {
        block_on(VirtualIo::default(), yield_now());
    }
}