use libc;
use num_cpus;
use std::cell::RefCell;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Mutex, RwLock};
use std::thread::{self, JoinHandle};
use std::time::Duration;
use crate::task::{Context, ISPCTaskFn};
pub trait TaskSystem {
unsafe fn alloc(
&self,
handle_ptr: *mut *mut libc::c_void,
size: i64,
align: i32,
) -> *mut libc::c_void;
unsafe fn launch(
&self,
handle_ptr: *mut *mut libc::c_void,
f: ISPCTaskFn,
data: *mut libc::c_void,
count0: i32,
count1: i32,
count2: i32,
);
unsafe fn sync(&self, handle: *mut libc::c_void);
}
thread_local!(static THREAD_ID: RefCell<usize> = const { RefCell::new(0) });
pub struct Parallel {
context_list: RwLock<Vec<Arc<Context>>>,
next_context_id: AtomicUsize,
threads: Mutex<Vec<JoinHandle<()>>>,
chunk_size: usize,
}
impl Parallel {
pub fn new() -> Arc<Parallel> {
Parallel::oversubscribed(1.0)
}
pub fn oversubscribed(oversubscribe: f32) -> Arc<Parallel> {
assert!(oversubscribe >= 1.0);
let par = Arc::new(Parallel {
context_list: RwLock::new(Vec::new()),
next_context_id: AtomicUsize::new(0),
threads: Mutex::new(Vec::new()),
chunk_size: 8,
});
{
let mut threads = par.threads.lock().unwrap();
let num_threads = (oversubscribe * num_cpus::get() as f32) as usize;
let chunk_size = par.chunk_size;
for i in 0..num_threads {
let task_sys = Arc::clone(&par);
threads.push(thread::spawn(move || {
Parallel::worker_thread(task_sys, i + 1, num_threads + 1, chunk_size)
}));
}
}
par
}
fn get_context(&self) -> Option<Arc<Context>> {
self.context_list
.read()
.unwrap()
.iter()
.find(|c| !c.current_tasks_done())
.cloned()
}
fn worker_thread(
task_sys: Arc<Parallel>,
thread: usize,
total_threads: usize,
chunk_size: usize,
) {
THREAD_ID.with(|f| *f.borrow_mut() = thread);
loop {
while let Some(c) = task_sys.get_context() {
for tg in c.iter() {
for chunk in tg.chunks(chunk_size) {
chunk.execute(thread as i32, total_threads as i32);
}
}
}
thread::park();
}
}
}
impl TaskSystem for Parallel {
unsafe fn alloc(
&self,
handle_ptr: *mut *mut libc::c_void,
size: i64,
align: i32,
) -> *mut libc::c_void {
if (*handle_ptr).is_null() {
let mut context_list = self.context_list.write().unwrap();
let c = Arc::new(Context::new(
self.next_context_id.fetch_add(1, atomic::Ordering::SeqCst),
));
{
let h = &*c;
*handle_ptr = h as *const Context as *mut libc::c_void;
}
context_list.push(c);
let ctx = context_list.last().unwrap();
ctx.alloc(size as usize, align as usize)
} else {
let context_list = self.context_list.read().unwrap();
let handle_ctx = *handle_ptr as *mut Context;
let ctx = context_list
.iter()
.find(|c| (*handle_ctx).id == c.id)
.unwrap();
ctx.alloc(size as usize, align as usize)
}
}
unsafe fn launch(
&self,
handle_ptr: *mut *mut libc::c_void,
f: ISPCTaskFn,
data: *mut libc::c_void,
count0: i32,
count1: i32,
count2: i32,
) {
let context: &mut Context = &mut *(*handle_ptr as *mut Context);
context.launch((count0, count1, count2), data, f);
let threads = self.threads.lock().unwrap();
for t in threads.iter() {
t.thread().unpark();
}
}
unsafe fn sync(&self, handle: *mut libc::c_void) {
let context: &mut Context = &mut *(handle as *mut Context);
let thread = THREAD_ID.with(|f| *f.borrow());
let total_threads = num_cpus::get();
for tg in context.iter() {
for chunk in tg.chunks(self.chunk_size) {
chunk.execute(thread as i32, total_threads as i32);
}
}
while !context.current_tasks_done() {
while let Some(c) = self.get_context() {
let mut ran_some = false;
for tg in c.iter() {
for chunk in tg.chunks(self.chunk_size) {
ran_some = true;
chunk.execute(thread as i32, total_threads as i32);
}
}
if !ran_some {
thread::sleep(Duration::from_millis(50));
}
}
}
let mut context_list = self.context_list.write().unwrap();
let pos = context_list
.iter()
.position(|c| context.id == c.id)
.unwrap();
context_list.remove(pos);
}
}