use std::collections::BTreeMap;
pub fn topo_sort<Id, NodeIds, PredsFn, PredsIter>(
node_ids: NodeIds,
mut preds_fn: PredsFn,
) -> Result<Vec<Id>, Vec<Id>>
where
Id: Copy + Eq + Ord,
NodeIds: IntoIterator<Item = Id>,
PredsFn: FnMut(Id) -> PredsIter,
PredsIter: IntoIterator<Item = Id>,
{
let (mut marked, mut order) = Default::default();
fn pred_dfs_postorder<Id, PredsFn, PredsIter>(
node_id: Id,
preds_fn: &mut PredsFn,
marked: &mut BTreeMap<Id, bool>, order: &mut Vec<Id>,
) -> Result<(), ()>
where
Id: Copy + Eq + Ord,
PredsFn: FnMut(Id) -> PredsIter,
PredsIter: IntoIterator<Item = Id>,
{
match marked.get(&node_id) {
Some(_permanent @ true) => Ok(()),
Some(_temporary @ false) => {
order.clear();
order.push(node_id);
Err(())
}
None => {
marked.insert(node_id, false);
for next_pred in (preds_fn)(node_id) {
pred_dfs_postorder(next_pred, preds_fn, marked, order).map_err(|()| {
if order.len() == 1 || order.first().unwrap() != order.last().unwrap() {
order.push(node_id);
}
})?;
}
order.push(node_id);
marked.insert(node_id, true);
Ok(())
}
}
}
for node_id in node_ids {
if pred_dfs_postorder(node_id, &mut preds_fn, &mut marked, &mut order).is_err() {
let end = order.last().unwrap();
let beg = order.iter().position(|n| n == end).unwrap();
order.drain(0..=beg);
return Err(order);
}
}
Ok(order)
}
#[cfg(test)]
mod test {
use std::collections::{BTreeMap, BTreeSet};
use itertools::Itertools;
use super::*;
#[test]
pub fn test_toposort() {
let edges = [
(5, 11),
(11, 2),
(11, 9),
(11, 10),
(7, 11),
(7, 8),
(8, 9),
(3, 8),
(3, 10),
];
let sort = topo_sort([2, 3, 5, 7, 8, 9, 10, 11], |v| {
edges
.iter()
.filter(move |&&(_, dst)| v == dst)
.map(|&(src, _)| src)
});
assert!(
sort.is_ok(),
"Did not expect cycle: {:?}",
sort.unwrap_err()
);
let sort = sort.unwrap();
println!("{:?}", sort);
let position: BTreeMap<_, _> = sort.iter().enumerate().map(|(i, &x)| (x, i)).collect();
for (src, dst) in edges.iter() {
assert!(position[src] < position[dst]);
}
}
#[test]
pub fn test_toposort_cycle() {
let edges = [
('A', 'B'),
('B', 'C'),
('C', 'E'),
('D', 'B'),
('E', 'F'),
('E', 'D'),
];
let ids = edges
.iter()
.flat_map(|&(a, b)| [a, b])
.collect::<BTreeSet<_>>();
let cycle_rotations = BTreeSet::from_iter([
['B', 'C', 'E', 'D'],
['C', 'E', 'D', 'B'],
['E', 'D', 'B', 'C'],
['D', 'B', 'C', 'E'],
]);
let permutations = ids.iter().copied().permutations(ids.len());
for permutation in permutations {
let result = topo_sort(permutation.iter().copied(), |v| {
edges
.iter()
.filter(move |&&(_, dst)| v == dst)
.map(|&(src, _)| src)
});
assert!(result.is_err());
let cycle = result.unwrap_err();
assert!(
cycle_rotations.contains(&*cycle),
"cycle: {:?}, vertex order: {:?}",
cycle,
permutation
);
}
}
}