use std::collections::{HashMap, HashSet};
use crate::util::{ArrayMap};
use super::{Dataflow, Node, Out, flood, Cold, CFT};
#[derive(Debug, PartialEq, Eq)]
pub struct GuardFailure<L: Clone> {
pub cold: Cold<HotPathTree<L>>,
pub keep_alives: HashSet<Out>,
}
#[derive(Debug, PartialEq, Eq)]
pub struct HotPathTree<L: Clone> {
pub exit: Node,
pub leaf: L,
pub children: HashMap<Node, GuardFailure<L>>,
}
impl<L: Clone> HotPathTree<L> {
pub fn new(
exit: Node,
leaf: L,
children: impl IntoIterator<Item=GuardFailure<L>>,
) -> Self {
let children = HashMap::from_iter(children.into_iter().map(
|gf| (gf.cold.guard, gf)
));
HotPathTree {exit, leaf, children}
}
}
struct KeepAlive<'a> {
dataflow: &'a Dataflow,
marks: ArrayMap<Node, usize>,
}
impl<'a> KeepAlive<'a> {
fn new(dataflow: &'a Dataflow) -> Self {
let mut marks = dataflow.node_map();
marks[dataflow.entry_node()] = 1;
KeepAlive {dataflow, marks}
}
fn walk<L: Clone>(&mut self, cft: &'a CFT<L>, inputs: &mut HashSet<Out>, coldness: usize)
-> HotPathTree<L> {
let (colds, exit, leaf) = cft.hot_path();
let nodes = flood(&self.dataflow, &mut self.marks, coldness, inputs, exit);
let children: Vec<_> = colds.into_iter().map(|cold| {
let mut keep_alives = HashSet::new();
let cold = cold.map(|&c| self.walk(c, &mut keep_alives, coldness + 1));
for &out in &keep_alives {
let (node, _) = self.dataflow.out(out);
assert_ne!(self.marks[node], 0);
if self.marks[node] < coldness {
inputs.insert(out);
}
}
GuardFailure {cold, keep_alives}
}).collect();
for &node in &*nodes {
assert_eq!(self.marks[node], coldness);
self.marks[node] = 0;
}
HotPathTree::new(exit, leaf.clone(), children)
}
}
pub fn keep_alive_sets<L: Clone>(dataflow: &Dataflow, cft: &CFT<L>) -> HotPathTree<L> {
let mut ka = KeepAlive::new(dataflow);
ka.walk(cft, &mut HashSet::new(), 2)
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::{CFT, Dataflow, Op};
impl<L: Clone> GuardFailure<L> {
pub fn new(
guard: Node,
hot_index: usize,
keep_alives: impl IntoIterator<Item=Out>,
colds: impl Into<Box<[HotPathTree<L>]>>,
) -> Self {
let colds = colds.into();
let keep_alives = HashSet::from_iter(keep_alives);
GuardFailure {cold: Cold {guard, hot_index, colds}, keep_alives}
}
}
#[test]
fn binary_tree() {
#[derive(Debug, Clone, PartialEq)]
struct Leaf;
let mut dataflow = Dataflow::new(7);
let ins: Box<[_]> = dataflow.outs(dataflow.entry_node()).collect();
let a = ins[0];
let b = ins[1];
let c = ins[2];
let p = ins[3];
let q = ins[4];
let r = ins[5];
let s = ins[6];
let guard1 = dataflow.add_node(Op::Guard, &[], &[a], 0);
let guard2 = dataflow.add_node(Op::Guard, &[], &[b], 0);
let guard3 = dataflow.add_node(Op::Guard, &[], &[c], 0);
let hot_hot = dataflow.add_node(Op::Convention, &[guard1, guard2], &[p], 0);
let hot_cold = dataflow.add_node(Op::Convention, &[guard1, guard2], &[q], 0);
let cold_hot = dataflow.add_node(Op::Convention, &[guard1, guard3], &[r], 0);
let cold_cold = dataflow.add_node(Op::Convention, &[guard1, guard3], &[s], 0);
let merge4 = CFT::Merge {exit: hot_hot, leaf: Leaf};
let merge5 = CFT::Merge {exit: hot_cold, leaf: Leaf};
let merge6 = CFT::Merge {exit: cold_hot, leaf: Leaf};
let merge7 = CFT::Merge {exit: cold_cold, leaf: Leaf};
let switch2 = CFT::switch(guard2, [merge4], merge5, 0);
let switch3 = CFT::switch(guard3, [merge6], merge7, 0);
let switch1 = CFT::switch(guard1, [switch2], switch3, 0);
let expected = HotPathTree::new(hot_hot, Leaf, [
GuardFailure::new(guard1, 0, [c, r, s], [
HotPathTree::new(cold_hot, Leaf, [
GuardFailure::new(guard3, 0, [s], [
HotPathTree::new(cold_cold, Leaf, []),
]),
]),
]),
GuardFailure::new(guard2, 0, [q], [
HotPathTree::new(hot_cold, Leaf, []),
]),
]);
let observed = keep_alive_sets(&dataflow, &switch1);
assert_eq!(observed, expected);
}
}