use std::collections::BTreeSet;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PassPhase {
Normal,
Deferred,
}
pub struct PassDescriptor<T: InvalidationTag> {
pub name: &'static str,
pub phase: PassPhase,
pub depends_on: &'static [T],
pub invalidates: &'static [T],
}
pub trait InvalidationTag: Copy + Eq + Ord + fmt::Debug + 'static {
fn all() -> &'static [Self];
}
pub fn run_invalidation_loop<T, F>(
passes: &[PassDescriptor<T>],
mut run_pass: F,
max_rounds: usize,
) -> bool
where
T: InvalidationTag,
F: FnMut(usize, &str) -> bool,
{
let mut dirty: BTreeSet<T> = T::all().iter().copied().collect();
let mut any_change_overall = false;
let mut rounds = 0;
loop {
let normal_changed = run_phase_until_converged(
passes,
PassPhase::Normal,
&mut dirty,
&mut run_pass,
max_rounds,
&mut rounds,
);
any_change_overall |= normal_changed;
dirty = T::all().iter().copied().collect();
let deferred_changed = run_single_round(
passes,
PassPhase::Deferred,
&mut dirty,
&mut run_pass,
);
any_change_overall |= deferred_changed;
if deferred_changed {
rounds += 1;
}
if !deferred_changed || rounds >= max_rounds {
break;
}
}
any_change_overall
}
fn run_phase_until_converged<T, F>(
passes: &[PassDescriptor<T>],
phase: PassPhase,
dirty: &mut BTreeSet<T>,
run_pass: &mut F,
max_rounds: usize,
rounds: &mut usize,
) -> bool
where
T: InvalidationTag,
F: FnMut(usize, &str) -> bool,
{
let mut any_change = false;
loop {
if *rounds >= max_rounds {
break;
}
let round_changed = run_single_round(passes, phase, dirty, run_pass);
any_change |= round_changed;
if round_changed {
*rounds += 1;
} else {
break;
}
}
any_change
}
fn run_single_round<T, F>(
passes: &[PassDescriptor<T>],
phase: PassPhase,
dirty: &mut BTreeSet<T>,
run_pass: &mut F,
) -> bool
where
T: InvalidationTag,
F: FnMut(usize, &str) -> bool,
{
let snapshot = dirty.clone();
let mut newly_dirty: BTreeSet<T> = BTreeSet::new();
let mut round_changed = false;
for (index, desc) in passes.iter().enumerate() {
if desc.phase != phase {
continue;
}
let relevant = desc
.depends_on
.iter()
.any(|tag| snapshot.contains(tag) || newly_dirty.contains(tag));
if !relevant {
continue;
}
let changed = run_pass(index, desc.name);
if changed {
round_changed = true;
for tag in desc.invalidates {
newly_dirty.insert(*tag);
}
}
}
*dirty = newly_dirty;
round_changed
}