use crate::{Direction, LinkView, NodeIndex, PortIndex, PortView, SecondaryMap};
use bitvec::prelude::BitVec;
use smallvec::SmallVec;
use std::{collections::VecDeque, fmt::Debug, iter::FusedIterator};
pub fn toposort<G, Map>(
graph: G,
source: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>,
direction: Direction,
) -> TopoSort<'static, G, Map>
where
Map: SecondaryMap<PortIndex<G::PortIndexBase>, bool>,
G: LinkView,
{
TopoSort::new(graph, source, direction, None, None)
}
pub fn toposort_filtered<'f, G, Map>(
graph: G,
source: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>,
direction: Direction,
node_filter: impl FnMut(NodeIndex<G::NodeIndexBase>) -> bool + 'f,
port_filter: impl FnMut(NodeIndex<G::NodeIndexBase>, PortIndex<G::PortIndexBase>) -> bool + 'f,
) -> TopoSort<'f, G, Map>
where
Map: SecondaryMap<PortIndex<G::PortIndexBase>, bool>,
G: LinkView,
{
TopoSort::new(
graph,
source,
direction,
Some(Box::new(node_filter)),
Some(Box::new(port_filter)),
)
}
pub struct TopoSort<'f, G, Map = BitVec>
where
G: PortView,
{
graph: G,
visited_ports: Map,
candidate_nodes: VecDeque<NodeIndex<G::NodeIndexBase>>,
direction: Direction,
nodes_seen: usize,
#[allow(clippy::type_complexity)]
node_filter: Option<Box<dyn FnMut(NodeIndex<G::NodeIndexBase>) -> bool + 'f>>,
#[allow(clippy::type_complexity)]
port_filter: Option<
Box<dyn FnMut(NodeIndex<G::NodeIndexBase>, PortIndex<G::PortIndexBase>) -> bool + 'f>,
>,
}
impl<'f, Map, G> TopoSort<'f, G, Map>
where
Map: SecondaryMap<PortIndex<G::PortIndexBase>, bool>,
G: LinkView,
{
#[allow(clippy::type_complexity)]
pub fn new(
graph: G,
source: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>,
direction: Direction,
mut node_filter: Option<Box<dyn FnMut(NodeIndex<G::NodeIndexBase>) -> bool + 'f>>,
port_filter: Option<
Box<dyn FnMut(NodeIndex<G::NodeIndexBase>, PortIndex<G::PortIndexBase>) -> bool + 'f>,
>,
) -> Self {
let mut visited_ports: Map = SecondaryMap::new();
let candidate_nodes: VecDeque<_> = if let Some(node_filter) = node_filter.as_mut() {
source.into_iter().filter(|&n| node_filter(n)).collect()
} else {
source.into_iter().collect()
};
if visited_ports.default_value() {
for port in graph.ports_iter() {
visited_ports.set(port, false);
}
}
for node in candidate_nodes.iter() {
for port in graph.ports(*node, direction.reverse()) {
visited_ports.set(port, true);
}
}
Self {
graph,
visited_ports,
candidate_nodes,
direction,
nodes_seen: 0,
node_filter,
port_filter,
}
}
pub fn ports_remaining(&self) -> impl Iterator<Item = PortIndex<G::PortIndexBase>> + '_ {
self.graph
.ports_iter()
.filter(move |&p| !self.visited_ports.get(p))
}
pub fn add_sources(&mut self, sources: impl IntoIterator<Item = NodeIndex<G::NodeIndexBase>>) {
for node in sources.into_iter() {
if self.ignore_node(node) {
continue;
};
let mut new_candidate = false;
for port in self.graph.ports(node, self.direction.reverse()) {
new_candidate |= !self.visited_ports.get(port);
self.visited_ports.set(port, true);
}
if self.graph.num_ports(node, self.direction.reverse()) == 0 {
new_candidate = !self.candidate_nodes.contains(&node)
&& self
.graph
.ports(node, self.direction)
.any(|p| !self.visited_ports.get(p));
}
if new_candidate {
self.candidate_nodes.push_back(node);
}
}
}
fn becomes_ready(
&mut self,
node: NodeIndex<G::NodeIndexBase>,
from_port: impl Into<PortIndex<G::PortIndexBase>>,
) -> bool {
let from_port = from_port.into();
if self.ignore_node(node) {
return false;
}
let ports: Vec<_> = self.graph.ports(node, self.direction.reverse()).collect();
ports.into_iter().all(|p| {
if p == from_port {
!self.visited_ports.get(p)
} else if *self.visited_ports.get(p) {
true
} else if self.graph.port_link(p).is_none() || self.ignore_port(node, p) {
self.visited_ports.set(p, true);
true
} else {
false
}
})
}
#[inline]
fn ignore_node(&mut self, node: NodeIndex<G::NodeIndexBase>) -> bool {
!self
.node_filter
.as_mut()
.map_or(true, |filter| filter(node))
}
#[inline]
fn ignore_port(
&mut self,
node: NodeIndex<G::NodeIndexBase>,
port: PortIndex<G::PortIndexBase>,
) -> bool {
!self
.port_filter
.as_mut()
.map_or(true, |filter| filter(node, port))
}
}
impl<Map, G> Iterator for TopoSort<'_, G, Map>
where
Map: SecondaryMap<PortIndex<G::PortIndexBase>, bool>,
G: LinkView,
{
type Item = NodeIndex<G::NodeIndexBase>;
fn next(&mut self) -> Option<Self::Item> {
let node = self.candidate_nodes.pop_front()?;
let ports = self.graph.ports(node, self.direction).collect::<Vec<_>>();
for port in ports {
self.visited_ports.set(port, true);
if self.ignore_port(node, port) {
continue;
}
let linked_ports: SmallVec<[PortIndex<G::PortIndexBase>; 2]> = self
.graph
.port_links(port)
.map(|(_, next_port)| next_port.into())
.collect();
for port in linked_ports {
let target = self.graph.port_node(port).unwrap();
if self.becomes_ready(target, port) {
self.candidate_nodes.push_back(target);
}
self.visited_ports.set(port, true);
}
}
self.nodes_seen += 1;
Some(node)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(
self.candidate_nodes.len(),
Some(self.graph.node_count() - self.nodes_seen),
)
}
}
impl<Map, G> FusedIterator for TopoSort<'_, G, Map>
where
Map: SecondaryMap<PortIndex<G::PortIndexBase>, bool>,
G: LinkView,
{
}
impl<Map, G> Debug for TopoSort<'_, G, Map>
where
Map: Debug,
G: Debug + PortView,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TopoSort")
.field("graph", &self.graph)
.field("visited_ports", &self.visited_ports)
.field("candidate_nodes", &self.candidate_nodes)
.field("direction", &self.direction)
.field("nodes_seen", &self.nodes_seen)
.finish()
}
}
#[cfg(test)]
mod test {
use std::collections::BTreeSet;
use super::*;
use crate::{Direction, LinkMut, PortMut, PortView};
type PortGraph = crate::PortGraph<u32, u32, u16>;
type MultiPortGraph = crate::MultiPortGraph<u32, u32, u16>;
#[test]
fn small_toposort() {
let mut graph: PortGraph = PortGraph::new();
let node_a = graph.add_node(2, 3);
let node_b = graph.add_node(3, 2);
let node_c = graph.add_node(3, 2);
let node_d = graph.add_node(3, 2);
let node_e = graph.add_node(3, 2);
graph.link_nodes(node_a, 0, node_b, 0).unwrap();
graph.link_nodes(node_a, 1, node_b, 1).unwrap();
graph.link_nodes(node_a, 2, node_e, 0).unwrap();
graph.link_nodes(node_b, 0, node_c, 0).unwrap();
graph.link_nodes(node_c, 0, node_d, 0).unwrap();
let topo: TopoSort<_> = toposort(&graph, [node_a, node_d], Direction::Outgoing);
assert_eq!(
topo.collect::<Vec<_>>(),
[node_a, node_d, node_b, node_e, node_c]
);
let topo_filtered: TopoSort<_> = toposort_filtered(
&graph,
[node_a, node_d],
Direction::Outgoing,
|n| ![node_d, node_e].contains(&n),
|_, p| Some(p) != graph.output(node_b, 0),
);
assert_eq!(topo_filtered.collect::<Vec<_>>(), [node_a, node_b]);
}
#[test]
fn test_toposort_multi_port() {
let mut graph: MultiPortGraph = MultiPortGraph::new();
let node_a = graph.add_node(0, 1);
let node_b = graph.add_node(1, 0);
let node_c = graph.add_node(1, 0);
graph.link_nodes(node_a, 0, node_b, 0).unwrap();
graph.link_nodes(node_a, 0, node_c, 0).unwrap();
let topo: TopoSort<_> = toposort(&graph, [node_a], Direction::Outgoing);
assert_eq!(
topo.collect::<BTreeSet<_>>(),
BTreeSet::from_iter([node_a, node_b, node_c])
);
}
}