tempest-rt 0.0.1

TempestDB Deterministic Async Runtime
Documentation
use std::{
    any::TypeId,
    cell::RefCell,
    collections::BTreeSet,
    task::{Context, RawWaker, RawWakerVTable, Waker},
};

use nonmax::NonMaxU32;
use tempest_io::{Io, OpHandle};

use crate::task::Tasks;

pub(crate) const OP_HANDLE_TASK_BITS: usize = 20;
pub(crate) const OP_HANDLE_OP_NUMBER_BITS: usize = 44;

pub(crate) const MAIN_TASK_ID: u32 = (1 << OP_HANDLE_TASK_BITS) - 1;
pub(crate) const MAX_TASK_ID: u32 = MAIN_TASK_ID - 1;
pub(crate) const MAX_OP_NUMBER: u64 = (1 << OP_HANDLE_OP_NUMBER_BITS) - 1;

const _: () = assert!(OP_HANDLE_TASK_BITS + OP_HANDLE_OP_NUMBER_BITS == 64);

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) enum TaskId {
    Main,
    Task(NonMaxU32),
}

impl TaskId {
    pub(crate) const fn to_bits(&self) -> u64 {
        match self {
            TaskId::Main => MAIN_TASK_ID as u64,
            TaskId::Task(val) => val.get() as u64,
        }
    }

    /// # Panics
    ///
    /// Panics if `bits > MAIN_TASK_ID`.
    pub(crate) fn from_bits(bits: u64) -> Self {
        match bits as u32 {
            MAIN_TASK_ID => TaskId::Main,
            // SAFETY: task_id_raw < MAIN_TASK_ID < u32::MAX, so NonMaxU32 invariant holds
            other if (bits as u32) < MAIN_TASK_ID => {
                TaskId::Task(unsafe { NonMaxU32::new_unchecked(other) })
            }
            _ => panic!("bits exceed maximum valid task ID"),
        }
    }
}

pub(crate) fn op_handle(task: TaskId, op_number: u64) -> OpHandle {
    assert!(op_number <= MAX_OP_NUMBER);
    OpHandle((task.to_bits() << OP_HANDLE_OP_NUMBER_BITS) | op_number)
}

pub(crate) fn parse_op_handle(handle: OpHandle) -> (TaskId, u64) {
    let handle = handle.0;
    let task_id_raw = handle >> OP_HANDLE_OP_NUMBER_BITS;
    let task_id = TaskId::from_bits(task_id_raw);
    let op_number = handle & MAX_OP_NUMBER;
    (task_id, op_number)
}

unsafe fn wake_impl(data: *const ()) {
    let bits = data as u64;
    let task_id = TaskId::from_bits(bits);
    // SAFETY: reference does not escape scope
    unsafe { current_wake_sets().staging.insert(task_id) };
}

static STATIC_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
    |ptr| RawWaker::new(ptr, &STATIC_WAKER_VTABLE),
    wake_impl,
    wake_impl,
    |_| {},
);

pub(crate) const fn make_waker(task: TaskId) -> Waker {
    unsafe {
        Waker::from_raw(RawWaker::new(
            task.to_bits() as usize as *const (),
            &STATIC_WAKER_VTABLE,
        ))
    }
}

/// Runtime drains from [`active`], while wakers insert into [`staging`].
/// After every tick, the runtime swaps them.
///
/// We use a [`BTreeSet`]s and not [`HashSet`]s, because poll order must stay deterministic.
#[derive(Default)]
pub(crate) struct WakeSets {
    pub(crate) active: BTreeSet<TaskId>,
    pub(crate) staging: BTreeSet<TaskId>,
}

impl WakeSets {
    /// Swap the [`active`] and [`staging`] wake sets.
    ///
    /// [`active`]: Self::active
    /// [`staging`]: Self::staging
    pub(crate) fn swap(&mut self) {
        std::mem::swap(&mut self.active, &mut self.staging);
    }
}

/// The context of the current runtime that is executing.
// TODO: use NonNull<T> pointers here, if we can prove that *mut T is never std::ptr::null()
pub(crate) struct RuntimeContext {
    pub(crate) type_id: TypeId,
    pub(crate) io: *mut (),
    pub(crate) tasks: *mut Tasks,
    pub(crate) wake_sets: *mut WakeSets,
    pub(crate) next_op: *mut u64,
}

thread_local! {
    /// Thread local which gets set during every execution tick and unset after, which points to
    /// the context of the current runtime, or `None` if not called from inside a runtime.
    pub(crate) static CURRENT_CONTEXT: RefCell<Option<RuntimeContext>> = RefCell::new(None);
}

/// Retrieve the [`Io`] reference from the current executor context.
/// The type parameter `I` must match the [`Io`] implementation this runtime was initialized with.
///
/// # Safety
///
/// The returned reference must not be stored or moved out of the calling scope.
/// This constraint cannot be expressed as a lifetime.
///
/// # Panics
///
/// Panics if called outside of an active runtime or if `I` does not match the runtime's I/O type.
pub(crate) unsafe fn current_io<I: Io>() -> &'static mut I {
    CURRENT_CONTEXT.with(|cell| {
        let borrow = cell.borrow();
        let ctx = borrow
            .as_ref()
            .expect("no active runtime on current thread");
        assert_eq!(ctx.type_id, TypeId::of::<I>());
        unsafe { &mut *(ctx.io as *mut I) }
    })
}

/// Retrieve the [`Tasks`] reference from the current executor context.
///
/// # Safety
///
/// The returned reference must not be stored or moved out of the calling scope.
/// This constraint cannot be expressed as a lifetime.
///
/// # Panics
///
/// Panics if called outside of an active runtime.
pub(crate) unsafe fn current_tasks() -> &'static mut Tasks {
    CURRENT_CONTEXT.with(|cell| {
        let borrow = cell.borrow();
        let ctx = borrow
            .as_ref()
            .expect("no active runtime on current thread");
        unsafe { &mut *ctx.tasks }
    })
}

/// Retrieve the [`WakeSets`] reference from the current executor context.
///
/// # Safety
///
/// The returned reference must not be stored or moved out of the calling scope.
/// This constraint cannot be expressed as a lifetime.
///
/// # Panics
///
/// Panics if called outside of an active runtime.
pub(crate) unsafe fn current_wake_sets() -> &'static mut WakeSets {
    CURRENT_CONTEXT.with(|cell| {
        let borrow = cell.borrow();
        let ctx = borrow
            .as_ref()
            .expect("no active runtime on current thread");
        unsafe { &mut *ctx.wake_sets }
    })
}

/// Retrieve a unique I/O operation number from the current execturo context.
///
/// # Panics
///
/// Panics if called outside of an active runtime.
pub(crate) fn get_op_number() -> u64 {
    CURRENT_CONTEXT.with(|cell| {
        let borrow = cell.borrow();
        let ctx = borrow
            .as_ref()
            .expect("no active runtime on current thread");

        unsafe {
            let op_number = *ctx.next_op;
            assert!(op_number < MAX_OP_NUMBER);
            *ctx.next_op = op_number + 1;
            op_number
        }
    })
}

/// Retrieve the [`TaskId`] from the [`Waker`].
///
/// # Safety
///
/// Caller must ensure that the context comes from this runtime.
///
/// # Panics
///
/// Panics if called outside of an active runtime.
// TODO: we could get the task ID not from the waker but from the thread local, but that's one
// addition pointer dereference. However, using the thread local would be a bit safer I think.
// => is the safety worth it?
pub(crate) unsafe fn get_task_id(cx: &mut Context<'_>) -> TaskId {
    let task_id_raw = cx.waker().data() as usize as u64;
    let task_id = TaskId::from_bits(task_id_raw);
    task_id
}

/// Retrieve a new unique [`OpHandle`] from the [`Waker`], which can be linked back to the current
/// task.
///
/// # Safety
///
/// Caller must ensure that the context comes from this runtime.
///
/// # Panics
///
/// Panics if called outside of an active runtime.
pub(crate) unsafe fn get_op_handle(cx: &mut Context<'_>) -> OpHandle {
    let op_number = get_op_number();
    let task = unsafe { get_task_id(cx) };
    op_handle(task, op_number)
}