extern crate std;
use std::cell::Cell;
use std::cell::RefCell;
pub fn barrier_wait() {
COROUTINE_YIELDER.with(|cell| {
let ptr = cell.get();
if !ptr.is_null() {
unsafe { (*ptr).suspend(()) };
}
});
}
pub fn set_barrier(_barrier: std::sync::Arc<std::sync::Barrier>) {}
pub fn clear_barrier() {}
pub fn dispatch_workgroups(num_workgroups: usize, f: impl Fn(u32) + Sync + Send) {
#[cfg(feature = "cpu-parallel")]
{
use rayon::prelude::*;
(0..num_workgroups as u32).into_par_iter().for_each(f);
}
#[cfg(not(feature = "cpu-parallel"))]
{
for i in 0..num_workgroups as u32 {
f(i);
}
}
}
const COROUTINE_STACK_SIZE: usize = 64 * 1024;
thread_local! {
static COROUTINE_YIELDER: Cell<*mut corosensei::Yielder<(), ()>> = const { Cell::new(std::ptr::null_mut()) };
static STACK_POOL: RefCell<Vec<corosensei::stack::DefaultStack>> = RefCell::new(Vec::new());
}
fn take_stacks(count: usize) -> Vec<corosensei::stack::DefaultStack> {
STACK_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
let reusable = count.min(pool.len());
let drain_start = pool.len() - reusable;
let mut stacks: Vec<corosensei::stack::DefaultStack> = pool.drain(drain_start..).collect();
for _ in stacks.len()..count {
stacks.push(
corosensei::stack::DefaultStack::new(COROUTINE_STACK_SIZE)
.expect("failed to allocate coroutine stack"),
);
}
stacks
})
}
fn return_stacks(stacks: impl IntoIterator<Item = corosensei::stack::DefaultStack>) {
STACK_POOL.with(|pool| {
pool.borrow_mut().extend(stacks);
});
}
pub fn dispatch_workgroup_threads(num_threads: usize, f: impl Fn(u32) + Sync) {
use corosensei::{Coroutine, CoroutineResult};
let f_ref: &'static (dyn Fn(u32) + Sync) =
unsafe { core::mem::transmute(&f as &(dyn Fn(u32) + Sync)) };
let stacks = take_stacks(num_threads);
let mut coroutines: Vec<Option<Coroutine<(), (), (), corosensei::stack::DefaultStack>>> =
stacks
.into_iter()
.enumerate()
.map(|(tid, stack)| {
Some(Coroutine::with_stack(stack, move |yielder, ()| {
COROUTINE_YIELDER.with(|cell| {
cell.set(yielder as *const _ as *mut _);
});
f_ref(tid as u32);
COROUTINE_YIELDER.with(|cell| cell.set(std::ptr::null_mut()));
}))
})
.collect();
let mut recovered_stacks = Vec::with_capacity(num_threads);
loop {
let mut all_done = true;
for i in 0..coroutines.len() {
let result = coroutines[i].as_mut().map(|c| c.resume(()));
match result {
Some(CoroutineResult::Yield(())) => {
all_done = false;
}
Some(CoroutineResult::Return(())) => {
recovered_stacks.push(coroutines[i].take().unwrap().into_stack());
}
None => {}
}
}
if all_done {
break;
}
}
return_stacks(recovered_stacks);
}