use std::{
any::TypeId,
cell::RefCell,
collections::BTreeSet,
task::{Context, RawWaker, RawWakerVTable, Waker},
};
use nonmax::NonMaxU32;
use tempest_io::{Io, OpHandle};
use crate::task::Tasks;
pub(crate) const OP_HANDLE_TASK_BITS: usize = 20;
pub(crate) const OP_HANDLE_OP_NUMBER_BITS: usize = 44;
pub(crate) const MAIN_TASK_ID: u32 = (1 << OP_HANDLE_TASK_BITS) - 1;
pub(crate) const MAX_TASK_ID: u32 = MAIN_TASK_ID - 1;
pub(crate) const MAX_OP_NUMBER: u64 = (1 << OP_HANDLE_OP_NUMBER_BITS) - 1;
const _: () = assert!(OP_HANDLE_TASK_BITS + OP_HANDLE_OP_NUMBER_BITS == 64);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) enum TaskId {
Main,
Task(NonMaxU32),
}
impl TaskId {
pub(crate) const fn to_bits(&self) -> u64 {
match self {
TaskId::Main => MAIN_TASK_ID as u64,
TaskId::Task(val) => val.get() as u64,
}
}
pub(crate) fn from_bits(bits: u64) -> Self {
match bits as u32 {
MAIN_TASK_ID => TaskId::Main,
other if (bits as u32) < MAIN_TASK_ID => {
TaskId::Task(unsafe { NonMaxU32::new_unchecked(other) })
}
_ => panic!("bits exceed maximum valid task ID"),
}
}
}
pub(crate) fn op_handle(task: TaskId, op_number: u64) -> OpHandle {
assert!(op_number <= MAX_OP_NUMBER);
OpHandle((task.to_bits() << OP_HANDLE_OP_NUMBER_BITS) | op_number)
}
pub(crate) fn parse_op_handle(handle: OpHandle) -> (TaskId, u64) {
let handle = handle.0;
let task_id_raw = handle >> OP_HANDLE_OP_NUMBER_BITS;
let task_id = TaskId::from_bits(task_id_raw);
let op_number = handle & MAX_OP_NUMBER;
(task_id, op_number)
}
unsafe fn wake_impl(data: *const ()) {
let bits = data as u64;
let task_id = TaskId::from_bits(bits);
unsafe { current_wake_sets().staging.insert(task_id) };
}
static STATIC_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
|ptr| RawWaker::new(ptr, &STATIC_WAKER_VTABLE),
wake_impl,
wake_impl,
|_| {},
);
pub(crate) const fn make_waker(task: TaskId) -> Waker {
unsafe {
Waker::from_raw(RawWaker::new(
task.to_bits() as usize as *const (),
&STATIC_WAKER_VTABLE,
))
}
}
#[derive(Default)]
pub(crate) struct WakeSets {
pub(crate) active: BTreeSet<TaskId>,
pub(crate) staging: BTreeSet<TaskId>,
}
impl WakeSets {
pub(crate) fn swap(&mut self) {
std::mem::swap(&mut self.active, &mut self.staging);
}
}
pub(crate) struct RuntimeContext {
pub(crate) type_id: TypeId,
pub(crate) io: *mut (),
pub(crate) tasks: *mut Tasks,
pub(crate) wake_sets: *mut WakeSets,
pub(crate) next_op: *mut u64,
}
thread_local! {
pub(crate) static CURRENT_CONTEXT: RefCell<Option<RuntimeContext>> = RefCell::new(None);
}
pub(crate) unsafe fn current_io<I: Io>() -> &'static mut I {
CURRENT_CONTEXT.with(|cell| {
let borrow = cell.borrow();
let ctx = borrow
.as_ref()
.expect("no active runtime on current thread");
assert_eq!(ctx.type_id, TypeId::of::<I>());
unsafe { &mut *(ctx.io as *mut I) }
})
}
pub(crate) unsafe fn current_tasks() -> &'static mut Tasks {
CURRENT_CONTEXT.with(|cell| {
let borrow = cell.borrow();
let ctx = borrow
.as_ref()
.expect("no active runtime on current thread");
unsafe { &mut *ctx.tasks }
})
}
pub(crate) unsafe fn current_wake_sets() -> &'static mut WakeSets {
CURRENT_CONTEXT.with(|cell| {
let borrow = cell.borrow();
let ctx = borrow
.as_ref()
.expect("no active runtime on current thread");
unsafe { &mut *ctx.wake_sets }
})
}
pub(crate) fn get_op_number() -> u64 {
CURRENT_CONTEXT.with(|cell| {
let borrow = cell.borrow();
let ctx = borrow
.as_ref()
.expect("no active runtime on current thread");
unsafe {
let op_number = *ctx.next_op;
assert!(op_number < MAX_OP_NUMBER);
*ctx.next_op = op_number + 1;
op_number
}
})
}
pub(crate) unsafe fn get_task_id(cx: &mut Context<'_>) -> TaskId {
let task_id_raw = cx.waker().data() as usize as u64;
let task_id = TaskId::from_bits(task_id_raw);
task_id
}
pub(crate) unsafe fn get_op_handle(cx: &mut Context<'_>) -> OpHandle {
let op_number = get_op_number();
let task = unsafe { get_task_id(cx) };
op_handle(task, op_number)
}