use rustc_hash::FxHashSet;
use std::{iter::FlatMap, slice::Iter};
use crate::internal::Resource;
use super::{
resource_manager::{LoanManager, ResourceManager},
system::ShouldContinue,
ResourceId,
};
use self::solver::SolverSystem;
pub trait IsSystem: std::fmt::Debug {
fn name(&self) -> &str;
fn borrows(&self) -> &[Borrow];
fn dependencies(&self) -> &[Dependency];
fn set_barrier(&mut self, barrier: usize);
fn barrier(&self) -> usize;
fn prep(&self, loan_mngr: &mut LoanManager<'_>) -> anyhow::Result<Resource>;
fn run(&mut self, data: Resource) -> anyhow::Result<ShouldContinue>;
}
pub(crate) trait IsBatch: std::fmt::Debug + Default {
type System: IsSystem + Send + Sync;
type ExtraRunData<'a>: Send + Sync + Clone;
fn contains_system(&self, name: &str) -> bool {
for system in self.systems() {
if system.name() == name {
return true;
}
}
return false;
}
fn systems(&self) -> &[Self::System];
fn systems_mut(&mut self) -> &mut [Self::System];
fn trim_systems(&mut self, should_remove: FxHashSet<&str>);
fn borrows(&self) -> FlatMap<Iter<Self::System>, &[Borrow], fn(&Self::System) -> &[Borrow]> {
self.systems().iter().flat_map(|s| s.borrows())
}
fn add_system(&mut self, system: Self::System);
fn get_barrier(&self) -> usize;
fn set_barrier(&mut self, barrier: usize);
fn take_systems(&mut self) -> Vec<Self::System>;
fn set_systems(&mut self, systems: Vec<Self::System>);
fn run(
&mut self,
parallelism: u32,
_: Self::ExtraRunData<'_>,
resource_manager: &mut ResourceManager,
) -> anyhow::Result<()>;
}
pub(crate) trait IsSchedule: std::fmt::Debug {
type System: IsSystem;
type Batch: IsBatch<System = Self::System>;
fn contains_system(&self, name: &str) -> bool {
for batch in self.batches() {
if batch.contains_system(name) {
return true;
}
}
false
}
fn batches_mut(&mut self) -> &mut Vec<Self::Batch>;
fn batches(&self) -> &[Self::Batch];
fn add_batch(&mut self, batch: Self::Batch);
fn is_empty(&self) -> bool {
for batch in self.batches() {
if batch.systems().len() > 0 {
return false;
}
}
true
}
fn set_parallelism(&mut self, threads: u32);
fn get_parallelism(&self) -> u32;
fn add_system(&mut self, mut new_system: Self::System) {
new_system.set_barrier(self.current_barrier());
let batches = std::mem::take(self.batches_mut());
let mut systems = batches
.into_iter()
.flat_map(|mut batch| batch.take_systems())
.collect::<Vec<_>>();
systems.push(new_system);
let solver_systems = systems
.iter()
.map(|sys| SolverSystem {
name: sys.name().to_string(),
dependencies: sys.dependencies().to_vec(),
borrows: sys.borrows().to_vec(),
barrier: sys.barrier(),
})
.collect::<Vec<_>>();
let indices = solver::solve_order(&solver_systems).unwrap();
debug_assert_eq!(indices.len(), systems.len());
let mut indexed_systems = indices
.into_iter()
.zip(systems.into_iter())
.collect::<Vec<_>>();
indexed_systems.sort_by(|a, b| {
if a.0 == b.0 {
a.1.name().cmp(b.1.name())
} else {
a.0.total_cmp(&b.0)
}
});
log::trace!(
"pre-schedule: {:#?}",
indexed_systems
.iter()
.map(|(i, sys)| (i, sys.name()))
.collect::<Vec<_>>()
);
let mut batch = Self::Batch::default();
let mut current_index = indexed_systems.first().map(|(i, _)| *i).unwrap_or(0.0);
for (index, system) in indexed_systems.into_iter() {
let batch_borrows = batch.borrows().cloned().collect::<Vec<_>>();
if index > current_index || borrows_conflict(system.borrows(), &batch_borrows) {
if !batch.systems().is_empty() {
self.add_batch(std::mem::replace(&mut batch, Self::Batch::default()));
}
current_index = index;
}
batch.add_system(system);
}
if !batch.systems().is_empty() {
self.add_batch(batch);
}
}
fn current_barrier(&self) -> usize;
fn add_barrier(&mut self);
fn run(
&mut self,
extra: <Self::Batch as IsBatch>::ExtraRunData<'_>,
resource_manager: &mut ResourceManager,
) -> anyhow::Result<()> {
resource_manager.unify_resources("IsSchedule::run before all")?;
let parallelism = self.get_parallelism();
for batch in self.batches_mut() {
batch.run(parallelism, extra.clone(), resource_manager)?;
resource_manager.unify_resources("IsSchedule::run after one")?;
}
self.batches_mut()
.retain(|batch| !batch.systems().is_empty());
Ok(())
}
fn get_execution_order(&self) -> Vec<&str> {
self.batches()
.iter()
.flat_map(|batch| {
batch
.systems()
.iter()
.map(IsSystem::name)
.chain(vec!["---"])
})
.collect::<Vec<_>>()
}
fn get_schedule_names(&self) -> Vec<Vec<&str>> {
self.batches()
.iter()
.map(|batch| {
batch
.systems()
.iter()
.map(|sys| sys.name())
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
}
#[derive(Clone, Debug)]
pub struct Borrow {
pub id: ResourceId,
pub is_exclusive: bool,
}
impl Borrow {
pub fn rez_id(&self) -> ResourceId {
self.id.clone()
}
pub fn name(&self) -> &str {
self.id.name
}
pub fn is_exclusive(&self) -> bool {
self.is_exclusive
}
}
fn borrows_conflict<'a>(borrows_a: &[Borrow], borrows_b: &[Borrow]) -> bool {
for borrow_a in borrows_a {
for borrow_b in borrows_b {
if borrow_a.rez_id() == borrow_b.rez_id()
&& (borrow_a.is_exclusive() || borrow_b.is_exclusive())
{
return true;
}
}
}
false
}
#[derive(Clone, PartialEq)]
pub enum Dependency {
After(String),
Before(String),
}
mod solver {
use anyhow;
use casuarius::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
struct Sys(usize);
casuarius::derive_syntax_for!(Sys);
pub struct SolverSystem {
pub name: String,
pub dependencies: Vec<super::Dependency>,
pub borrows: Vec<super::Borrow>,
pub barrier: usize,
}
impl SolverSystem {
fn must_run_after(&self, system_b: &SolverSystem) -> bool {
self.dependencies
.contains(&super::Dependency::After(system_b.name.clone()))
}
fn must_run_before(&self, system_b: &SolverSystem) -> bool {
self.dependencies
.contains(&super::Dependency::Before(system_b.name.clone()))
}
}
pub fn solve_order(systems: &[SolverSystem]) -> anyhow::Result<Vec<f64>> {
log::trace!("solving schedule for {} systems", systems.len());
let mut systems = systems.iter().collect::<Vec<_>>();
systems.sort_by(|a, b| a.barrier.cmp(&b.barrier));
let max_barrier = systems.iter().fold(0, |b, sys| sys.barrier.max(b));
let barriers = (0..max_barrier).map(|b| Sys(b)).collect::<Vec<_>>();
log::trace!(" {} barriers", barriers.len());
let mut solver: Solver<Sys> = casuarius::Solver::default();
let mut constraints = vec![];
for barrier_a in barriers.iter() {
solver.add_constraint(barrier_a.is_ge(0.0)).unwrap();
constraints.push(format!("barrier {} >= 0", barrier_a.0));
log::trace!(" {}", constraints.last().unwrap());
for barrier_b in barriers.iter() {
if barrier_a.0 > barrier_b.0 {
solver
.add_constraint(barrier_a.is_ge(*barrier_b + 1.0))
.unwrap();
constraints.push(format!("barrier {} > barrier {}", barrier_a.0, barrier_b.0));
log::trace!(" {}", constraints.last().unwrap());
}
}
}
for (a, system_a) in systems.iter().enumerate() {
let sys_a = Sys(a + max_barrier);
if !barriers.is_empty() {
let barrier = Sys(system_a.barrier);
solver.add_constraint(sys_a.is_ge(barrier + 1.0)).unwrap();
constraints.push(format!("{} > barrier {}", system_a.name, barrier.0));
log::trace!(" {}", constraints.last().unwrap());
}
for (b, system_b) in systems.iter().enumerate() {
if system_a.name == system_b.name {
continue;
}
let sys_b = Sys(b + max_barrier);
let before_constraint = sys_b.is_ge(sys_a + 1.0);
let before_msg = format!("{} > {}", system_b.name, system_a.name);
let after_constraint = sys_a.is_ge(sys_b + 1.0);
let after_msg = format!("{} > {}", system_a.name, system_b.name);
if system_a.must_run_before(system_b) {
if !solver.has_constraint(&before_constraint) {
solver.add_constraint(before_constraint).map_err(|e| {
anyhow::anyhow!(
"can't make {:?} < {:?}: {:?}\nconstraints: {:#?}",
system_a.name,
system_b.name,
e,
constraints
)
})?;
log::trace!(" {}", before_msg);
constraints.push(before_msg);
}
} else if system_a.must_run_after(system_b) {
if !solver.has_constraint(&after_constraint) {
solver.add_constraint(after_constraint).map_err(|e| {
anyhow::anyhow!(
"can't make {:?} > {:?}: {:?}\nconstraints: {:#?}",
system_a.name,
system_b.name,
e,
constraints
)
})?;
log::trace!(" {}", after_msg);
constraints.push(after_msg);
}
}
}
}
let out = systems
.iter()
.enumerate()
.map(|(i, _)| {
let sys = Sys(i);
solver.get_value(sys)
})
.collect::<Vec<_>>();
Ok(out)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::system::*;
#[test]
fn negative_zero_sanity() {
let nzero: f64 = -0.0;
let zero: f64 = 0.0;
assert_eq!(std::cmp::Ordering::Less, nzero.total_cmp(&zero));
assert_eq!(-0.0f64, 0.0f64);
}
#[test]
fn schedule_with_dependencies() {
let _ = env_logger::builder()
.is_test(true)
.filter_level(log::LevelFilter::Trace)
.try_init();
let mut schedule = SyncSchedule::default();
schedule.add_system(SyncSystem::new("one", |()| ok(), vec![]));
schedule.add_system(SyncSystem::new(
"two",
|()| ok(),
vec![Dependency::After("one".to_string())],
));
schedule.add_system(SyncSystem::new(
"three",
|()| ok(),
vec![Dependency::After("two".to_string())],
));
schedule.add_system(SyncSystem::new(
"three-again",
|()| ok(),
vec![
Dependency::After("two".to_string()),
Dependency::Before("four".to_string()),
],
));
schedule.add_system(SyncSystem::new(
"four",
|()| ok(),
vec![Dependency::After("three".to_string())],
));
assert_eq!(
vec![
vec!["one"],
vec!["two"],
vec!["three", "three-again"],
vec!["four"],
],
schedule.get_schedule_names()
);
schedule.add_system(SyncSystem::new(
"zero",
|()| ok(),
vec![Dependency::Before("one".to_string())],
));
assert_eq!(
vec![
vec!["zero"],
vec!["one"],
vec!["two"],
vec!["three", "three-again"],
vec!["four"],
],
schedule.get_schedule_names()
);
}
#[test]
fn schedule_with_barrier() {
let _ = env_logger::builder()
.is_test(true)
.filter_level(log::LevelFilter::Trace)
.try_init();
let mut schedule = SyncSchedule::default();
schedule.add_system(SyncSystem::new("one", |()| ok(), vec![]));
schedule.add_barrier();
schedule.add_system(SyncSystem::new("two", |()| ok(), vec![]));
schedule.add_system(SyncSystem::new("three", |()| ok(), vec![]));
schedule.add_barrier();
schedule.add_system(SyncSystem::new("four", |()| ok(), vec![]));
assert_eq!(
vec![vec!["one"], vec!["three", "two"], vec!["four"]],
schedule.get_schedule_names()
);
}
#[test]
fn schedule_with_ephemeral() {
let mut schedule = SyncSchedule::default();
schedule.add_system(SyncSystem::new("one", |()| end(), vec![]));
schedule.add_system(SyncSystem::new("two", |()| ok(), vec![]));
schedule.add_system(SyncSystem::new("three", |()| ok(), vec![]));
assert_eq!(
vec![vec!["one", "three", "two"]],
schedule.get_schedule_names()
);
let mut manager = ResourceManager::default();
schedule.run((), &mut manager).unwrap();
assert_eq!(vec![vec!["three", "two"]], schedule.get_schedule_names());
}
}