use crate::arena::picked::{Arena, ArenaBox};
use crate::common::Function;
use crate::Worker;
use std::mem::{replace, size_of, transmute};
pub struct Task<I, O> {
f: Function<I, O>,
io: TaskIO<I, O>,
execute_stolen: fn(&mut TypeErased, &mut Worker),
}
enum TaskIO<I, O> {
Input(ArenaBox<I>),
Output(ArenaBox<I>, Box<O>),
Empty,
}
impl<I, O> Task<I, O> {
#[inline(always)]
pub fn new(arena: &mut Arena, f: Function<I, O>, input: I) -> Self {
Self {
f,
io: TaskIO::Input(arena.alloc(input)),
execute_stolen: Self::execute_stolen,
}
}
#[inline(always)]
pub fn take_input(&mut self, arena: &mut Arena) -> I {
let TaskIO::Input(ipt) = replace(&mut self.io, TaskIO::Empty) else {
panic!("Task Input Taken Twice");
};
arena.take(ipt)
}
#[inline(always)]
pub fn take_output(&mut self, arena: &mut Arena) -> O {
let TaskIO::Output(abox, obox) = replace(&mut self.io, TaskIO::Empty) else {
panic!("Task Output Taken Twice or Not Written");
};
arena.dealloc(abox);
*obox
}
#[inline(always)]
pub fn execute(mut self, worker: &mut Worker) -> O {
let ipt = self.take_input(&mut worker.arena);
(self.f)(worker, ipt)
}
fn execute_stolen(task_erased: &mut TypeErased, worker: &mut Worker) {
let task: &mut Self = unsafe { transmute(task_erased) };
task.type_check::<I, O>();
let TaskIO::Input(mut abox) = replace(&mut task.io, TaskIO::Empty) else {
panic!("Task Input Taken Already While Stealing");
};
let ipt: I = abox.steal();
let out = (task.f)(worker, ipt);
task.io = TaskIO::Output(abox, Box::from(out));
}
pub fn erase(self) -> TypeErased {
unsafe { transmute(self) }
}
#[inline(always)]
fn type_check<X, Y>(&self) {
if !std::ptr::fn_addr_eq(self.execute_stolen, Task::<X, Y>::execute_stolen as for<'a, 'b> fn(&'a mut TypeErased, &'b mut Worker)) {
panic!("Task unerase()'d to incorrect base type")
}
}
}
pub struct TypeErased {
_fill: [u8; size_of::<Task<u8, u8>>()],
}
impl TypeErased {
#[inline(always)]
pub fn unerase<I, O>(self) -> Task<I, O> {
let task: Task<I, O> = unsafe { transmute(self) };
task.type_check::<I, O>();
task
}
#[inline(always)]
pub fn execute_stolen(&mut self, worker: &mut Worker) {
let f = unsafe { transmute::<&mut Self, &mut Task<(), ()>>(self) }.execute_stolen;
f(self, worker);
}
}