use crate::fx::{FxHashMap, FxHashSet};
use std::cell::Cell;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use std::hash;
use std::marker::PhantomData;
mod graphviz;
#[cfg(test)]
mod tests;
pub trait ForestObligation: Clone + Debug {
type CacheKey: Clone + hash::Hash + Eq + Debug;
fn as_cache_key(&self) -> Self::CacheKey;
}
pub trait ObligationProcessor {
type Obligation: ForestObligation;
type Error: Debug;
fn process_obligation(
&mut self,
obligation: &mut Self::Obligation,
) -> ProcessResult<Self::Obligation, Self::Error>;
fn process_backedge<'c, I>(&mut self, cycle: I, _marker: PhantomData<&'c Self::Obligation>)
where
I: Clone + Iterator<Item = &'c Self::Obligation>;
}
#[derive(Debug)]
pub enum ProcessResult<O, E> {
Unchanged,
Changed(Vec<O>),
Error(E),
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct ObligationTreeId(usize);
type ObligationTreeIdGenerator =
::std::iter::Map<::std::ops::RangeFrom<usize>, fn(usize) -> ObligationTreeId>;
pub struct ObligationForest<O: ForestObligation> {
nodes: Vec<Node<O>>,
done_cache: FxHashSet<O::CacheKey>,
active_cache: FxHashMap<O::CacheKey, usize>,
node_rewrites: Vec<usize>,
obligation_tree_id_generator: ObligationTreeIdGenerator,
error_cache: FxHashMap<ObligationTreeId, FxHashSet<O::CacheKey>>,
}
#[derive(Debug)]
struct Node<O> {
obligation: O,
state: Cell<NodeState>,
dependents: Vec<usize>,
has_parent: bool,
obligation_tree_id: ObligationTreeId,
}
impl<O> Node<O> {
fn new(parent: Option<usize>, obligation: O, obligation_tree_id: ObligationTreeId) -> Node<O> {
Node {
obligation,
state: Cell::new(NodeState::Pending),
dependents: if let Some(parent_index) = parent { vec![parent_index] } else { vec![] },
has_parent: parent.is_some(),
obligation_tree_id,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum NodeState {
Pending,
Success,
Waiting,
Done,
Error,
}
#[derive(Debug)]
pub struct Outcome<O, E> {
pub completed: Option<Vec<O>>,
pub errors: Vec<Error<O, E>>,
pub stalled: bool,
}
#[derive(PartialEq)]
pub enum DoCompleted {
No,
Yes,
}
#[derive(Debug, PartialEq, Eq)]
pub struct Error<O, E> {
pub error: E,
pub backtrace: Vec<O>,
}
impl<O: ForestObligation> ObligationForest<O> {
pub fn new() -> ObligationForest<O> {
ObligationForest {
nodes: vec![],
done_cache: Default::default(),
active_cache: Default::default(),
node_rewrites: vec![],
obligation_tree_id_generator: (0..).map(ObligationTreeId),
error_cache: Default::default(),
}
}
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn register_obligation(&mut self, obligation: O) {
let _ = self.register_obligation_at(obligation, None);
}
fn register_obligation_at(&mut self, obligation: O, parent: Option<usize>) -> Result<(), ()> {
if self.done_cache.contains(&obligation.as_cache_key()) {
debug!("register_obligation_at: ignoring already done obligation: {:?}", obligation);
return Ok(());
}
match self.active_cache.entry(obligation.as_cache_key()) {
Entry::Occupied(o) => {
let node = &mut self.nodes[*o.get()];
if let Some(parent_index) = parent {
if !node.dependents.contains(&parent_index) {
node.dependents.push(parent_index);
}
}
if let NodeState::Error = node.state.get() { Err(()) } else { Ok(()) }
}
Entry::Vacant(v) => {
let obligation_tree_id = match parent {
Some(parent_index) => self.nodes[parent_index].obligation_tree_id,
None => self.obligation_tree_id_generator.next().unwrap(),
};
let already_failed = parent.is_some()
&& self
.error_cache
.get(&obligation_tree_id)
.map(|errors| errors.contains(&obligation.as_cache_key()))
.unwrap_or(false);
if already_failed {
Err(())
} else {
let new_index = self.nodes.len();
v.insert(new_index);
self.nodes.push(Node::new(parent, obligation, obligation_tree_id));
Ok(())
}
}
}
}
pub fn to_errors<E: Clone>(&mut self, error: E) -> Vec<Error<O, E>> {
let errors = self
.nodes
.iter()
.enumerate()
.filter(|(_index, node)| node.state.get() == NodeState::Pending)
.map(|(index, _node)| Error { error: error.clone(), backtrace: self.error_at(index) })
.collect();
let successful_obligations = self.compress(DoCompleted::Yes);
assert!(successful_obligations.unwrap().is_empty());
errors
}
pub fn map_pending_obligations<P, F>(&self, f: F) -> Vec<P>
where
F: Fn(&O) -> P,
{
self.nodes
.iter()
.filter(|node| node.state.get() == NodeState::Pending)
.map(|node| f(&node.obligation))
.collect()
}
fn insert_into_error_cache(&mut self, index: usize) {
let node = &self.nodes[index];
self.error_cache
.entry(node.obligation_tree_id)
.or_default()
.insert(node.obligation.as_cache_key());
}
pub fn process_obligations<P>(
&mut self,
processor: &mut P,
do_completed: DoCompleted,
) -> Outcome<O, P::Error>
where
P: ObligationProcessor<Obligation = O>,
{
let mut errors = vec![];
let mut stalled = true;
let mut index = 0;
while index < self.nodes.len() {
let node = &mut self.nodes[index];
if node.state.get() != NodeState::Pending {
index += 1;
continue;
}
match processor.process_obligation(&mut node.obligation) {
ProcessResult::Unchanged => {
}
ProcessResult::Changed(children) => {
stalled = false;
node.state.set(NodeState::Success);
for child in children {
let st = self.register_obligation_at(child, Some(index));
if let Err(()) = st {
self.error_at(index);
}
}
}
ProcessResult::Error(err) => {
stalled = false;
errors.push(Error { error: err, backtrace: self.error_at(index) });
}
}
index += 1;
}
if stalled {
return Outcome {
completed: if do_completed == DoCompleted::Yes { Some(vec![]) } else { None },
errors,
stalled,
};
}
self.mark_successes();
self.process_cycles(processor);
let completed = self.compress(do_completed);
Outcome { completed, errors, stalled }
}
fn error_at(&self, mut index: usize) -> Vec<O> {
let mut error_stack: Vec<usize> = vec![];
let mut trace = vec![];
loop {
let node = &self.nodes[index];
node.state.set(NodeState::Error);
trace.push(node.obligation.clone());
if node.has_parent {
error_stack.extend(node.dependents.iter().skip(1));
index = node.dependents[0];
} else {
error_stack.extend(node.dependents.iter());
break;
}
}
while let Some(index) = error_stack.pop() {
let node = &self.nodes[index];
if node.state.get() != NodeState::Error {
node.state.set(NodeState::Error);
error_stack.extend(node.dependents.iter());
}
}
trace
}
fn mark_successes(&self) {
for node in &self.nodes {
if node.state.get() == NodeState::Waiting {
node.state.set(NodeState::Success);
}
}
for node in &self.nodes {
if node.state.get() == NodeState::Pending {
self.inlined_mark_dependents_as_waiting(node);
}
}
}
#[inline(always)]
fn inlined_mark_dependents_as_waiting(&self, node: &Node<O>) {
for &index in node.dependents.iter() {
let node = &self.nodes[index];
let state = node.state.get();
if state == NodeState::Success {
node.state.set(NodeState::Waiting);
self.uninlined_mark_dependents_as_waiting(node);
} else {
debug_assert!(state == NodeState::Waiting || state == NodeState::Error)
}
}
}
#[inline(never)]
fn uninlined_mark_dependents_as_waiting(&self, node: &Node<O>) {
self.inlined_mark_dependents_as_waiting(node)
}
fn process_cycles<P>(&self, processor: &mut P)
where
P: ObligationProcessor<Obligation = O>,
{
let mut stack = vec![];
for (index, node) in self.nodes.iter().enumerate() {
if node.state.get() == NodeState::Success {
self.find_cycles_from_node(&mut stack, processor, index);
}
}
debug_assert!(stack.is_empty());
}
fn find_cycles_from_node<P>(&self, stack: &mut Vec<usize>, processor: &mut P, index: usize)
where
P: ObligationProcessor<Obligation = O>,
{
let node = &self.nodes[index];
if node.state.get() == NodeState::Success {
match stack.iter().rposition(|&n| n == index) {
None => {
stack.push(index);
for &dep_index in node.dependents.iter() {
self.find_cycles_from_node(stack, processor, dep_index);
}
stack.pop();
node.state.set(NodeState::Done);
}
Some(rpos) => {
processor.process_backedge(
stack[rpos..].iter().map(GetObligation(&self.nodes)),
PhantomData,
);
}
}
}
}
#[inline(never)]
fn compress(&mut self, do_completed: DoCompleted) -> Option<Vec<O>> {
let orig_nodes_len = self.nodes.len();
let mut node_rewrites: Vec<_> = std::mem::take(&mut self.node_rewrites);
debug_assert!(node_rewrites.is_empty());
node_rewrites.extend(0..orig_nodes_len);
let mut dead_nodes = 0;
let mut removed_done_obligations: Vec<O> = vec![];
for index in 0..orig_nodes_len {
let node = &self.nodes[index];
match node.state.get() {
NodeState::Pending | NodeState::Waiting => {
if dead_nodes > 0 {
self.nodes.swap(index, index - dead_nodes);
node_rewrites[index] -= dead_nodes;
}
}
NodeState::Done => {
if let Some((predicate, _)) =
self.active_cache.remove_entry(&node.obligation.as_cache_key())
{
self.done_cache.insert(predicate);
} else {
self.done_cache.insert(node.obligation.as_cache_key().clone());
}
if do_completed == DoCompleted::Yes {
removed_done_obligations.push(node.obligation.clone());
}
node_rewrites[index] = orig_nodes_len;
dead_nodes += 1;
}
NodeState::Error => {
self.active_cache.remove(&node.obligation.as_cache_key());
self.insert_into_error_cache(index);
node_rewrites[index] = orig_nodes_len;
dead_nodes += 1;
}
NodeState::Success => unreachable!(),
}
}
if dead_nodes > 0 {
self.nodes.truncate(orig_nodes_len - dead_nodes);
self.apply_rewrites(&node_rewrites);
}
node_rewrites.truncate(0);
self.node_rewrites = node_rewrites;
if do_completed == DoCompleted::Yes { Some(removed_done_obligations) } else { None }
}
fn apply_rewrites(&mut self, node_rewrites: &[usize]) {
let orig_nodes_len = node_rewrites.len();
for node in &mut self.nodes {
let mut i = 0;
while i < node.dependents.len() {
let new_index = node_rewrites[node.dependents[i]];
if new_index >= orig_nodes_len {
node.dependents.swap_remove(i);
if i == 0 && node.has_parent {
node.has_parent = false;
}
} else {
node.dependents[i] = new_index;
i += 1;
}
}
}
self.active_cache.retain(|_predicate, index| {
let new_index = node_rewrites[*index];
if new_index >= orig_nodes_len {
false
} else {
*index = new_index;
true
}
});
}
}
#[derive(Clone)]
struct GetObligation<'a, O>(&'a [Node<O>]);
impl<'a, 'b, O> FnOnce<(&'b usize,)> for GetObligation<'a, O> {
type Output = &'a O;
extern "rust-call" fn call_once(self, args: (&'b usize,)) -> &'a O {
&self.0[*args.0].obligation
}
}
impl<'a, 'b, O> FnMut<(&'b usize,)> for GetObligation<'a, O> {
extern "rust-call" fn call_mut(&mut self, args: (&'b usize,)) -> &'a O {
&self.0[*args.0].obligation
}
}