mod job;
mod queues;
mod random;
mod thread;
use crate::job::JobHandlePrivate;
use crate::thread::JobThread;
use crate::thread::ThreadDataWrapper;
use std::marker::PhantomData;
use std::sync::Arc;
pub struct JobHandle<T>
where
T: Sized + FnMut() + Send,
{
h: JobHandlePrivate,
p: PhantomData<T>,
}
type ThreadDataList = Vec<ThreadDataWrapper>;
pub struct JobSystem {
thread_data: std::sync::Arc<ThreadDataList>,
threads: Vec<JobThread>,
}
#[derive(Debug)]
pub enum Error {
ThreadCreate,
CapacityNotPowerOf2,
StorageSizeExceeded,
QueueFull,
ChainCountExceeded,
}
impl JobSystem {
pub fn new(thread_count: usize, job_capacity: usize) -> Result<Self, Error> {
if !job_capacity.is_power_of_two() {
return Err(Error::CapacityNotPowerOf2);
}
let actual_thread_count = thread_count.max(1) + 1;
let mut job_sys = Self {
thread_data: std::sync::Arc::new(ThreadDataList::new()),
threads: vec![],
};
let mut data_vec: ThreadDataList = Vec::with_capacity(actual_thread_count);
data_vec.resize_with(actual_thread_count, || ThreadDataWrapper::new(job_capacity));
job_sys.thread_data = std::sync::Arc::new(data_vec);
job_sys
.threads
.resize_with(actual_thread_count - 1, JobThread::new);
for index in 1..actual_thread_count {
let thread = &mut job_sys.threads[index - 1];
thread.set_data(job_sys.thread_data.clone(), index);
if thread.start().is_err() {
return Err(Error::ThreadCreate);
}
}
Ok(job_sys)
}
fn shutdown(&mut self) {
for thread in &mut self.threads {
thread.finish().unwrap()
}
self.threads.clear();
self.thread_data = Arc::new(vec![]);
}
pub fn create<T>(&mut self, job: T) -> Result<JobHandle<T>, Error>
where
T: Sized + FnMut() + Send + Sync,
{
debug_assert!(std::mem::size_of::<T>() <= job::JOB_STORAGE_SIZE);
if std::mem::size_of::<T>() > job::JOB_STORAGE_SIZE {
return Err(Error::StorageSizeExceeded);
}
let mut handle = thread::alloc_job(&self.thread_data);
handle.store(job);
Ok(JobHandle {
h: handle,
p: PhantomData,
})
}
pub fn create_with_parent<T, Y>(
&mut self,
parent: &mut JobHandle<Y>,
job: T,
) -> Result<JobHandle<T>, Error>
where
T: Sized + FnMut() + Send,
Y: Sized + FnMut() + Send,
{
debug_assert!(std::mem::size_of::<T>() <= job::JOB_STORAGE_SIZE);
if std::mem::size_of::<T>() > job::JOB_STORAGE_SIZE {
return Err(Error::StorageSizeExceeded);
}
let mut handle = thread::alloc_job(&self.thread_data);
handle.store(job);
handle.set_parent_job(parent.h);
Ok(JobHandle {
h: handle,
p: PhantomData,
})
}
pub fn chain<T, Y>(&self, parent: &mut JobHandle<T>, child: &JobHandle<Y>) -> Result<(), Error>
where
T: Sized + FnMut() + Send,
Y: Sized + FnMut() + Send,
{
match parent.h.chain_job(child.h) {
Err(_) => Err(Error::ChainCountExceeded),
_ => Ok(()),
}
}
pub fn run<T>(&mut self, handle: &JobHandle<T>) -> Result<(), Error>
where
T: Sized + FnMut() + Send,
{
match thread::start_job(&self.thread_data, handle.h) {
true => Ok(()),
false => Err(Error::QueueFull),
}
}
pub fn wait<T>(&mut self, handle: &JobHandle<T>)
where
T: Sized + FnMut() + Send,
{
while !handle.h.is_finished() {
if let Some(job) = thread::get_job(&self.thread_data) {
debug_assert!(job.is_valid());
thread::run_job(&self.thread_data, job);
}
}
}
pub fn is_finished<T>(&mut self, handle: &JobHandle<T>) -> bool
where
T: Sized + FnMut() + Send,
{
handle.h.is_finished()
}
pub fn for_each<'env, T, Y>(&mut self, cb: T, slice: &'env mut [Y]) -> Result<(), Error>
where
T: Fn(&mut [Y], usize, usize) + 'env + Send + Sync,
Y: Send,
{
let mut parent_job = self.create(move || {})?;
let divisor = self.threads.len();
let group_size = (slice.len() / divisor).max(divisor);
let mut offset = 0_usize;
let mut remaining = slice.len();
while remaining != 0 {
let range = group_size.min(remaining);
let slice_ptr = unsafe { slice.as_mut_ptr().add(offset) };
let work_slice = unsafe { std::slice::from_raw_parts_mut(slice_ptr, range) };
let callback = &cb;
let child_job = self.create_with_parent(&mut parent_job, move || {
callback(work_slice, offset, offset + work_slice.len());
})?;
self.run(&child_job)?;
remaining -= range;
offset += range;
}
self.run(&parent_job)?;
self.wait(&parent_job);
Ok(())
}
pub fn for_each_with_result<'env, T, Y, Z>(
&mut self,
cb: T,
slice: &'env mut [Y],
) -> Result<Vec<Z>, Error>
where
T: Fn(&mut [Y], usize, usize) -> Z + 'env + Send + Sync,
Z: Sized + Default + Send,
Y: Send,
{
let mut parent_job = self.create(|| {})?;
let divisor = self.threads.len();
let group_size = (slice.len() / divisor).max(divisor);
let group_count = slice.len() % divisor;
let vec_size = if group_count == 0 {
divisor
} else {
divisor + 1
};
let mut offset = 0_usize;
let mut remaining = slice.len();
let mut group_index = 0_usize;
let mut result_vec = Vec::<Z>::with_capacity(vec_size);
for _ in 0..vec_size {
result_vec.push(Z::default());
}
while remaining != 0 {
let range = group_size.min(remaining);
let slice_ptr = unsafe { slice.as_mut_ptr().add(offset) };
let result_ref = unsafe { &mut *result_vec.as_mut_ptr().add(group_index) };
let work_slice = unsafe { std::slice::from_raw_parts_mut(slice_ptr, range) };
let callback = &cb;
let child_job = self.create_with_parent(&mut parent_job, move || {
*result_ref = callback(work_slice, offset, offset + work_slice.len());
})?;
self.run(&child_job)?;
remaining -= range;
offset += range;
group_index += 1;
}
self.run(&parent_job)?;
self.wait(&parent_job);
Ok(result_vec)
}
}
impl Drop for JobSystem {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use crate::JobSystem;
use std::cell::RefCell;
const THREAD_COUNT: usize = 4;
const JOB_CAPACITY: usize = 1024;
#[test]
fn start_stop() {
let r = JobSystem::new(THREAD_COUNT, JOB_CAPACITY);
assert!(r.is_ok());
}
#[test]
fn launch_jobs() {
let mut job_sys = JobSystem::new(THREAD_COUNT, 128).expect("Failed to init job system");
let mut _counter = 0_usize;
for _ in 0..8192_u32 {
const JOB_COUNT: usize = 100;
let mut jobs = Vec::<_>::with_capacity(JOB_COUNT);
for _ in 0..JOB_COUNT {
let handle = job_sys.create(|| {}).unwrap();
assert!(job_sys.run(&handle).is_ok());
jobs.push(handle);
}
for job in jobs {
job_sys.wait(&job);
}
_counter += JOB_COUNT;
}
}
#[test]
fn launch_jobs_with_ref() {
let mut job_sys =
JobSystem::new(THREAD_COUNT, JOB_CAPACITY).expect("Failed to init job system");
const JOB_COUNT: usize = 100;
let mut jobs = Vec::<_>::with_capacity(JOB_COUNT);
let val = std::sync::Arc::new(std::sync::atomic::AtomicU32::new(0));
for _ in 0..JOB_COUNT {
let val_copy = val.clone();
let handle = job_sys
.create(move || {
val_copy.fetch_add(10, std::sync::atomic::Ordering::Release);
})
.unwrap();
assert!(job_sys.run(&handle).is_ok());
jobs.push(handle);
}
for job in jobs {
job_sys.wait(&job);
}
assert_eq!(
val.load(std::sync::atomic::Ordering::Acquire),
10 * JOB_COUNT as u32
);
job_sys.shutdown();
}
#[test]
fn launch_jobs_chained() {
let mut job_sys =
JobSystem::new(THREAD_COUNT, JOB_CAPACITY).expect("Failed to init job system");
const JOB_COUNT: usize = 20;
let mut jobs = Vec::<_>::with_capacity(JOB_COUNT);
for i in 0..JOB_COUNT {
let handle = job_sys
.create(move || {
println!("Chained {:?}: Job {:02}", std::thread::current().id(), i);
})
.unwrap();
jobs.push(RefCell::new(handle));
if i > 0 {
let cur_handle = &jobs[i];
let prev_handle = &jobs[i - 1];
job_sys
.chain(&mut prev_handle.borrow_mut(), &cur_handle.borrow())
.expect("Failed to chain");
}
}
assert!(job_sys.run(&jobs.first().unwrap().borrow_mut()).is_ok());
job_sys.wait(&jobs.last().unwrap().borrow_mut());
job_sys.shutdown();
}
#[test]
fn parallel_for() {
let mut job_sys =
JobSystem::new(THREAD_COUNT, JOB_CAPACITY).expect("Failed to init job system");
let mut array = [0_u32; 100];
let r = job_sys.for_each(
|slice: &mut [u32], start, _end| {
for i in 0..slice.len() {
slice[i] = (start + i) as u32;
}
},
&mut array,
);
assert!(r.is_ok());
for i in 0..array.len() {
assert_eq!(array[i] as usize, i);
}
}
#[test]
fn launch_with_parent() {
let mut job_sys =
JobSystem::new(THREAD_COUNT, JOB_CAPACITY).expect("Failed to init job system");
const JOB_COUNT: usize = 20;
let mut parent = job_sys.create(|| {}).unwrap();
let mut jobs = Vec::<_>::with_capacity(JOB_COUNT);
for _i in 1..JOB_COUNT {
let handle = job_sys
.create_with_parent(&mut parent, move || {
})
.unwrap();
jobs.push(handle);
}
assert!(job_sys.run(&parent).is_ok());
for job in &jobs {
assert!(job_sys.run(job).is_ok());
}
job_sys.wait(&parent);
job_sys.shutdown();
}
#[test]
fn parallel_for_with_result() {
let mut job_sys =
JobSystem::new(THREAD_COUNT, JOB_CAPACITY).expect("Failed to init job system");
let mut array = [0_u32; 100];
let r = job_sys.for_each_with_result(
|slice: &mut [u32], start, end| -> u32 {
for i in 0..slice.len() {
slice[i] = (start + i) as u32;
}
(end - start) as u32
},
&mut array,
);
assert!(r.is_ok());
let result: u32 = r.unwrap().iter().sum();
assert_eq!(result, 100_u32);
}
}