use crate::types::NodeKind;
use rustc_hash::{FxHashMap, FxHashSet};
use std::collections::VecDeque;
pub struct NodesIter<'a> {
inner: std::collections::hash_map::Keys<'a, NodeKind, std::sync::Arc<dyn crate::node::Node>>,
}
impl<'a> NodesIter<'a> {
pub(super) fn new(
inner: std::collections::hash_map::Keys<
'a,
NodeKind,
std::sync::Arc<dyn crate::node::Node>,
>,
) -> Self {
Self { inner }
}
}
impl<'a> Iterator for NodesIter<'a> {
type Item = &'a NodeKind;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl<'a> ExactSizeIterator for NodesIter<'a> {}
pub struct EdgesIter<'a> {
outer: std::collections::hash_map::Iter<'a, NodeKind, Vec<NodeKind>>,
current_from: Option<&'a NodeKind>,
current_targets: std::slice::Iter<'a, NodeKind>,
}
impl<'a> EdgesIter<'a> {
pub(super) fn new(edges: &'a FxHashMap<NodeKind, Vec<NodeKind>>) -> Self {
let mut outer = edges.iter();
let (current_from, current_targets) = match outer.next() {
Some((from, targets)) => (Some(from), targets.iter()),
None => (None, [].iter()),
};
Self {
outer,
current_from,
current_targets,
}
}
}
impl<'a> Iterator for EdgesIter<'a> {
type Item = (&'a NodeKind, &'a NodeKind);
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(to) = self.current_targets.next() {
return Some((self.current_from.unwrap(), to));
}
match self.outer.next() {
Some((from, targets)) => {
self.current_from = Some(from);
self.current_targets = targets.iter();
}
None => return None,
}
}
}
}
pub(super) fn topological_sort(edges: &FxHashMap<NodeKind, Vec<NodeKind>>) -> Vec<NodeKind> {
let mut in_degree: FxHashMap<NodeKind, usize> = FxHashMap::default();
let mut all_nodes: FxHashSet<NodeKind> = FxHashSet::default();
for (from, tos) in edges {
all_nodes.insert(from.clone());
in_degree.entry(from.clone()).or_insert(0);
for to in tos {
all_nodes.insert(to.clone());
*in_degree.entry(to.clone()).or_insert(0) += 1;
}
}
let mut queue: VecDeque<NodeKind> = VecDeque::new();
let mut zero_in_degree: Vec<_> = in_degree
.iter()
.filter(|entry| *entry.1 == 0)
.map(|(node, _)| node.clone())
.collect();
zero_in_degree.sort_by(|a, b| match (a, b) {
(NodeKind::Start, _) => std::cmp::Ordering::Less,
(_, NodeKind::Start) => std::cmp::Ordering::Greater,
(NodeKind::End, _) => std::cmp::Ordering::Greater,
(_, NodeKind::End) => std::cmp::Ordering::Less,
(NodeKind::Custom(a_name), NodeKind::Custom(b_name)) => a_name.cmp(b_name),
});
queue.extend(zero_in_degree);
let mut result: Vec<NodeKind> = Vec::with_capacity(all_nodes.len());
while let Some(node) = queue.pop_front() {
result.push(node.clone());
if let Some(neighbors) = edges.get(&node) {
let mut new_zero: Vec<NodeKind> = Vec::new();
for neighbor in neighbors {
if let Some(deg) = in_degree.get_mut(neighbor) {
*deg = deg.saturating_sub(1);
if *deg == 0 {
new_zero.push(neighbor.clone());
}
}
}
new_zero.sort_by(|a, b| match (a, b) {
(NodeKind::Start, _) => std::cmp::Ordering::Less,
(_, NodeKind::Start) => std::cmp::Ordering::Greater,
(NodeKind::End, _) => std::cmp::Ordering::Greater,
(_, NodeKind::End) => std::cmp::Ordering::Less,
(NodeKind::Custom(a_name), NodeKind::Custom(b_name)) => a_name.cmp(b_name),
});
queue.extend(new_zero);
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topological_sort_linear() {
let mut edges: FxHashMap<NodeKind, Vec<NodeKind>> = FxHashMap::default();
edges.insert(NodeKind::Start, vec![NodeKind::Custom("A".into())]);
edges.insert(
NodeKind::Custom("A".into()),
vec![NodeKind::Custom("B".into())],
);
edges.insert(NodeKind::Custom("B".into()), vec![NodeKind::End]);
let sorted = topological_sort(&edges);
assert_eq!(sorted[0], NodeKind::Start);
assert_eq!(sorted[sorted.len() - 1], NodeKind::End);
let a_pos = sorted
.iter()
.position(|n| n == &NodeKind::Custom("A".into()))
.unwrap();
let b_pos = sorted
.iter()
.position(|n| n == &NodeKind::Custom("B".into()))
.unwrap();
assert!(a_pos < b_pos);
}
#[test]
fn test_topological_sort_diamond() {
let mut edges: FxHashMap<NodeKind, Vec<NodeKind>> = FxHashMap::default();
edges.insert(
NodeKind::Start,
vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())],
);
edges.insert(
NodeKind::Custom("A".into()),
vec![NodeKind::Custom("C".into())],
);
edges.insert(
NodeKind::Custom("B".into()),
vec![NodeKind::Custom("C".into())],
);
edges.insert(NodeKind::Custom("C".into()), vec![NodeKind::End]);
let sorted = topological_sort(&edges);
assert_eq!(sorted[0], NodeKind::Start);
assert_eq!(sorted[sorted.len() - 1], NodeKind::End);
let a_pos = sorted
.iter()
.position(|n| n == &NodeKind::Custom("A".into()))
.unwrap();
let b_pos = sorted
.iter()
.position(|n| n == &NodeKind::Custom("B".into()))
.unwrap();
let c_pos = sorted
.iter()
.position(|n| n == &NodeKind::Custom("C".into()))
.unwrap();
assert!(a_pos < c_pos);
assert!(b_pos < c_pos);
assert!(a_pos < b_pos);
}
#[test]
fn test_topological_sort_deterministic() {
let mut edges: FxHashMap<NodeKind, Vec<NodeKind>> = FxHashMap::default();
edges.insert(
NodeKind::Start,
vec![
NodeKind::Custom("X".into()),
NodeKind::Custom("Y".into()),
NodeKind::Custom("Z".into()),
],
);
edges.insert(NodeKind::Custom("X".into()), vec![NodeKind::End]);
edges.insert(NodeKind::Custom("Y".into()), vec![NodeKind::End]);
edges.insert(NodeKind::Custom("Z".into()), vec![NodeKind::End]);
let sorted1 = topological_sort(&edges);
let sorted2 = topological_sort(&edges);
assert_eq!(sorted1, sorted2);
}
}