use fnv::FnvHashSet;
use std::hash::Hash;
use crate::network::Network;
#[derive(Clone)]
pub struct TopologicalIter<N>
where
N: Network,
{
network: N,
stack: Vec<N::NodeId>,
visited: FnvHashSet<N::NodeId>,
visit_primary_inputs: bool,
visit_constants: bool,
}
impl<N> TopologicalIter<N>
where
N: Network,
{
pub fn new(network: N, top_nodes: Vec<N::NodeId>) -> Self {
Self {
network,
stack: top_nodes,
visited: Default::default(),
visit_primary_inputs: true,
visit_constants: false,
}
}
pub fn visit_primary_inputs(mut self, visit_primary_inputs: bool) -> Self {
self.visit_primary_inputs = visit_primary_inputs;
self
}
pub fn visit_constants(mut self, visit_constants: bool) -> Self {
self.visit_constants = visit_constants;
self
}
}
impl<N> Iterator for TopologicalIter<N>
where
N: Network,
N::NodeId: Hash + Eq,
{
type Item = N::NodeId;
fn next(&mut self) -> Option<Self::Item> {
let net = &self.network;
loop {
if let Some(n) = self.stack.pop() {
let signal = net.get_node_output(&n);
if self.visited.contains(&n) {
continue;
} else if self.visit_constants && net.is_constant(signal) {
self.visited.insert(n.clone());
break Some(n);
} else if self.visit_primary_inputs && net.is_input(signal) {
self.visited.insert(n.clone());
break Some(n);
} else {
let operands = (0..net.num_node_inputs(&n)).map(|i| net.get_node_input(&n, i));
let mut values_complete = true;
for in_sig in operands {
if !self.visit_constants && net.get_constant_value(in_sig).is_some() {
continue;
}
if !self.visit_primary_inputs && net.is_input(in_sig) {
continue;
}
let node = net.get_source_node(&in_sig);
if self.visited.contains(&node) {
continue;
}
if values_complete {
self.stack.push(n.clone());
}
values_complete = false;
self.stack.push(node);
}
if values_complete {
self.visited.insert(n.clone());
break Some(n);
}
}
} else {
break None;
}
}
}
}
#[test]
fn test_topological_iter() {
use crate::network::*;
use crate::networks::mig::Mig;
let mut mig = Mig::new();
let [a, b, c, d] = mig.create_primary_inputs();
let anb = mig.create_and(a, b);
let bnc = mig.create_and(b, c);
let cnd = mig.create_and(c, d);
let or = mig.create_maj3(anb, bnc, cnd);
let out = mig.create_primary_output(or);
{
let iter = TopologicalIter::new(&mig, vec![out]);
let topo_sorted: Vec<_> = iter.collect();
assert_eq!(topo_sorted.last(), Some(&out));
assert_eq!(topo_sorted.len(), 8);
[a, b, c, d, anb, bnc, cnd, or]
.iter()
.for_each(|node| assert!(topo_sorted.contains(node)));
}
{
let iter = TopologicalIter::new(&mig, vec![out]).visit_primary_inputs(false);
let topo_sorted: Vec<_> = iter.collect();
assert_eq!(topo_sorted.last(), Some(&out));
assert_eq!(topo_sorted.len(), 4);
[anb, bnc, cnd, or]
.iter()
.for_each(|node| assert!(topo_sorted.contains(node)));
}
}