use std::borrow::Borrow;
use rayon::{join, ThreadPool};
use crate::{
dispatch::util::check_intersection,
system::{RunNow, System},
world::{ResourceId, World},
};
pub struct Nil;
#[macro_export]
macro_rules! par {
($head:expr, $( $tail:expr ,)*) => {
{
$crate::Par::new($head)
$( .with($tail) )*
}
};
}
#[macro_export]
macro_rules! seq {
($head:expr, $( $tail:expr ,)*) => {
{
$crate::Seq::new($head)
$( .with($tail) )*
}
};
}
impl<'a> System<'a> for Nil {
type SystemData = ();
fn run(&mut self, _: Self::SystemData) {}
}
pub struct Par<H, T> {
head: H,
tail: T,
}
impl<H> Par<H, Nil> {
pub fn new(head: H) -> Self {
Par { head, tail: Nil }
}
pub fn with<T>(self, sys: T) -> Par<Par<H, T>, Nil>
where
H: for<'a> RunWithPool<'a>,
T: for<'a> RunWithPool<'a>,
{
if cfg!(debug_assertions) {
let mut reads = Vec::new();
let mut writes = Vec::new();
self.head.reads(&mut reads);
self.head.writes(&mut writes);
let mut sys_reads = Vec::new();
let mut sys_writes = Vec::new();
sys.reads(&mut sys_reads);
sys.writes(&mut sys_writes);
let read_write_intersections_safe =
!(check_intersection(writes.iter(), sys_reads.iter())
|| check_intersection(writes.iter(), sys_writes.iter())
|| check_intersection(reads.iter(), sys_writes.iter()));
debug_assert!(
read_write_intersections_safe,
"Tried to add system with conflicting reads / writes"
);
}
Par {
head: Par {
head: self.head,
tail: sys,
},
tail: Nil,
}
}
}
pub struct ParSeq<P, T> {
run: T,
pool: P,
}
impl<P, T> ParSeq<P, T>
where
P: Borrow<ThreadPool>,
T: for<'a> RunWithPool<'a>,
{
pub fn new(run: T, pool: P) -> Self {
ParSeq { run, pool }
}
pub fn setup(&mut self, world: &mut World) {
self.run.setup(world);
}
pub fn dispatch(&mut self, world: &World) {
self.run.run(world, self.pool.borrow());
}
}
impl<'a, P, T> RunNow<'a> for ParSeq<P, T>
where
P: Borrow<ThreadPool>,
T: for<'b> RunWithPool<'b>,
{
fn run_now(&mut self, world: &World) {
RunWithPool::run(&mut self.run, world, self.pool.borrow());
}
fn setup(&mut self, world: &mut World) {
RunWithPool::setup(&mut self.run, world);
}
}
pub trait RunWithPool<'a> {
fn setup(&mut self, world: &mut World);
fn run(&mut self, world: &'a World, pool: &ThreadPool);
fn reads(&self, reads: &mut Vec<ResourceId>);
fn writes(&self, writes: &mut Vec<ResourceId>);
}
impl<'a, T> RunWithPool<'a> for T
where
T: System<'a>,
{
fn setup(&mut self, world: &mut World) {
T::setup(self, world);
}
fn run(&mut self, world: &'a World, _: &ThreadPool) {
RunNow::run_now(self, world);
}
fn reads(&self, reads: &mut Vec<ResourceId>) {
use crate::system::Accessor;
reads.extend(self.accessor().reads())
}
fn writes(&self, writes: &mut Vec<ResourceId>) {
use crate::system::Accessor;
writes.extend(self.accessor().writes())
}
}
impl<'a, H, T> RunWithPool<'a> for Par<H, T>
where
H: RunWithPool<'a> + Send,
T: RunWithPool<'a> + Send,
{
fn setup(&mut self, world: &mut World) {
self.head.setup(world);
self.tail.setup(world);
}
fn run(&mut self, world: &'a World, pool: &ThreadPool) {
let head = &mut self.head;
let tail = &mut self.tail;
let head = move || head.run(world, pool);
let tail = move || tail.run(world, pool);
if pool.current_thread_index().is_none() {
pool.join(head, tail);
} else {
join(head, tail);
}
}
fn reads(&self, reads: &mut Vec<ResourceId>) {
self.head.reads(reads);
self.tail.reads(reads);
}
fn writes(&self, writes: &mut Vec<ResourceId>) {
self.head.writes(writes);
self.tail.writes(writes);
}
}
pub struct Seq<H, T> {
head: H,
tail: T,
}
impl<H> Seq<H, Nil> {
pub fn new(head: H) -> Self {
Seq { head, tail: Nil }
}
pub fn with<T>(self, sys: T) -> Seq<Seq<H, T>, Nil> {
Seq {
head: Seq {
head: self.head,
tail: sys,
},
tail: Nil,
}
}
}
impl<'a, H, T> RunWithPool<'a> for Seq<H, T>
where
H: RunWithPool<'a>,
T: RunWithPool<'a>,
{
fn setup(&mut self, world: &mut World) {
self.head.setup(world);
self.tail.setup(world);
}
fn run(&mut self, world: &'a World, pool: &ThreadPool) {
self.head.run(world, pool);
self.tail.run(world, pool);
}
fn reads(&self, reads: &mut Vec<ResourceId>) {
self.head.reads(reads);
self.tail.reads(reads);
}
fn writes(&self, writes: &mut Vec<ResourceId>) {
self.head.writes(writes);
self.tail.writes(writes);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{atomic::*, Arc};
fn new_tp() -> ThreadPool {
use rayon::ThreadPoolBuilder;
ThreadPoolBuilder::new().build().unwrap()
}
#[test]
fn nested_joins() {
let pool = new_tp();
pool.join(|| join(|| join(|| join(|| (), || ()), || ()), || ()), || ());
}
#[test]
fn build_par() {
let pool = new_tp();
struct A(Arc<AtomicUsize>);
impl<'a> System<'a> for A {
type SystemData = ();
fn run(&mut self, _: Self::SystemData) {
self.0.fetch_add(1, Ordering::AcqRel);
}
}
let nr = Arc::new(AtomicUsize::new(0));
Par::new(A(nr.clone()))
.with(A(nr.clone()))
.with(A(nr.clone()))
.run(&World::empty(), &pool);
assert_eq!(nr.load(Ordering::Acquire), 3);
par![A(nr.clone()), A(nr.clone()),].run(&World::empty(), &pool);
assert_eq!(nr.load(Ordering::Acquire), 5);
}
#[test]
fn build_seq() {
let pool = new_tp();
struct A(Arc<AtomicUsize>);
impl<'a> System<'a> for A {
type SystemData = ();
fn run(&mut self, _: Self::SystemData) {
self.0.fetch_add(1, Ordering::AcqRel);
}
}
let nr = Arc::new(AtomicUsize::new(0));
Seq::new(A(nr.clone()))
.with(A(nr.clone()))
.with(A(nr.clone()))
.run(&World::empty(), &pool);
assert_eq!(nr.load(Ordering::Acquire), 3);
}
}