use std::{collections::hash_map::Entry, fmt};
use ahash::AHashMap as HashMap;
#[cfg(feature = "parallel")]
use crate::dispatch::dispatcher::ThreadPoolWrapper;
use crate::{
dispatch::{
batch::BatchControllerSystem,
dispatcher::{SystemId, ThreadLocal},
stage::StagesBuilder,
BatchAccessor, BatchController, Dispatcher,
},
system::{RunNow, System, SystemData},
};
#[derive(Default)]
pub struct DispatcherBuilder<'a, 'b> {
current_id: usize,
map: HashMap<String, SystemId>,
pub(crate) stages_builder: StagesBuilder<'a>,
thread_local: ThreadLocal<'b>,
#[cfg(feature = "parallel")]
thread_pool: ::std::sync::Arc<::std::sync::RwLock<ThreadPoolWrapper>>,
}
impl<'a, 'b> DispatcherBuilder<'a, 'b> {
pub fn new() -> Self {
Default::default()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn num_systems(&self) -> usize {
self.map.len()
}
pub fn has_system(&self, system: &str) -> bool {
self.map.contains_key(system)
}
pub fn with<T>(mut self, system: T, name: &str, dep: &[&str]) -> Self
where
T: for<'c> System<'c> + Send + 'a,
{
self.add(system, name, dep);
self
}
pub fn add<T>(&mut self, system: T, name: &str, dep: &[&str])
where
T: for<'c> System<'c> + Send + 'a,
{
let id = self.next_id();
let dependencies = dep
.iter()
.map(|x| {
*self
.map
.get(*x)
.unwrap_or_else(|| panic!("No such system registered (\"{}\")", *x))
})
.collect();
if !name.is_empty() {
if let Entry::Vacant(e) = self.map.entry(name.to_owned()) {
e.insert(id);
} else {
panic!(
"Cannot insert multiple systems with the same name (\"{}\")",
name
);
}
}
self.stages_builder.insert(dependencies, id, system);
}
pub fn contains(&self, name: &str) -> bool {
self.map.contains_key(name)
}
pub fn with_batch<T>(
mut self,
controller: T,
dispatcher_builder: DispatcherBuilder<'a, 'b>,
name: &str,
dep: &[&str],
) -> Self
where
T: for<'c> BatchController<'a, 'b, 'c> + Send + 'a,
'b: 'a,
{
self.add_batch::<T>(controller, dispatcher_builder, name, dep);
self
}
pub fn add_batch<T>(
&mut self,
controller: T,
mut dispatcher_builder: DispatcherBuilder<'a, 'b>,
name: &str,
dep: &[&str],
) where
T: for<'c> BatchController<'a, 'b, 'c> + Send + 'a,
'b: 'a,
{
#[cfg(feature = "parallel")]
{
dispatcher_builder.thread_pool = self.thread_pool.clone();
}
let mut reads = dispatcher_builder.stages_builder.fetch_all_reads();
reads.extend(<T::BatchSystemData as SystemData>::reads());
reads.sort();
reads.dedup();
let mut writes = dispatcher_builder.stages_builder.fetch_all_writes();
writes.extend(<T::BatchSystemData as SystemData>::writes());
writes.sort();
writes.dedup();
let accessor = BatchAccessor::new(reads, writes);
let dispatcher: Dispatcher<'a, 'b> = dispatcher_builder.build();
let batch_system =
unsafe { BatchControllerSystem::<'a, 'b, T>::create(accessor, controller, dispatcher) };
self.add(batch_system, name, dep);
}
pub fn with_thread_local<T>(mut self, system: T) -> Self
where
T: for<'c> RunNow<'c> + 'b,
{
self.add_thread_local(system);
self
}
pub fn add_thread_local<T>(&mut self, system: T)
where
T: for<'c> RunNow<'c> + 'b,
{
self.thread_local.push(Box::new(system));
}
pub fn with_barrier(mut self) -> Self {
self.add_barrier();
self
}
pub fn add_barrier(&mut self) {
self.stages_builder.add_barrier();
}
#[cfg(feature = "parallel")]
pub fn with_pool(mut self, pool: ::std::sync::Arc<::rayon::ThreadPool>) -> Self {
self.add_pool(pool);
self
}
#[cfg(feature = "parallel")]
pub fn add_pool(&mut self, pool: ::std::sync::Arc<::rayon::ThreadPool>) {
*self.thread_pool.write().unwrap() = Some(pool);
}
pub fn print_par_seq(&self) {
println!("{:#?}", self);
}
pub fn build(self) -> Dispatcher<'a, 'b> {
use crate::dispatch::dispatcher::new_dispatcher;
#[cfg(feature = "parallel")]
self.thread_pool
.write()
.unwrap()
.get_or_insert_with(Self::create_thread_pool);
#[cfg(feature = "parallel")]
let d = new_dispatcher(
self.stages_builder.build(),
self.thread_local,
self.thread_pool,
);
#[cfg(not(feature = "parallel"))]
let d = new_dispatcher(self.stages_builder.build(), self.thread_local);
d
}
fn next_id(&mut self) -> SystemId {
let id = self.current_id;
self.current_id += 1;
SystemId(id)
}
#[cfg(feature = "parallel")]
fn create_thread_pool() -> ::std::sync::Arc<::rayon::ThreadPool> {
use rayon::ThreadPoolBuilder;
use std::sync::Arc;
Arc::new(
ThreadPoolBuilder::new()
.build()
.expect("Invalid configuration"),
)
}
}
#[cfg(feature = "parallel")]
impl<'b> DispatcherBuilder<'static, 'b> {
pub fn build_async<R>(
self,
world: R,
) -> crate::dispatch::async_dispatcher::AsyncDispatcher<'b, R> {
use crate::dispatch::async_dispatcher::new_async;
self.thread_pool
.write()
.unwrap()
.get_or_insert_with(Self::create_thread_pool);
new_async(
world,
self.stages_builder.build(),
self.thread_local,
self.thread_pool,
)
}
}
impl<'a, 'b> fmt::Debug for DispatcherBuilder<'a, 'b> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.stages_builder.write_par_seq(f, &self.map)
}
}