1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3
4pub type DepGraph<T> = HashMap<T, HashSet<T>>;
7
8#[inline]
9fn add_edge<T>(graph: &mut DepGraph<T>, from: T, to: T)
10where
11 T: Eq + Hash + Clone,
12{
13 graph
14 .entry(from)
15 .and_modify(|deps| {
16 deps.insert(to.clone());
17 })
18 .or_insert_with(|| {
19 let mut deps = HashSet::new();
20 deps.insert(to);
21 deps
22 });
23}
24
25pub struct TopoSort<T> {
27 depends_on: DepGraph<T>,
28 dependents: DepGraph<T>,
29 no_deps: Vec<T>,
30}
31
32pub fn topo_sort<T, Id>(deps: T) -> Result<Vec<Id>, Box<dyn std::error::Error>>
50where
51 Id: Eq + Hash + Clone + std::fmt::Debug,
52 TopoSort<Id>: From<T>,
53{
54 let mut state = TopoSort::from(deps);
55 let mut sorted = vec![];
56
57 while let Some(node) = state.no_deps.pop() {
58 sorted.push(node.clone());
59
60 if let Some(dependents) = state.get_dependents(&node) {
61 for dependent in dependents.clone() {
62 state.resolve(&dependent, &node);
63 }
64 }
65 }
66
67 if state.is_resolved() {
68 Ok(sorted)
69 } else {
70 Err(format!(
71 "Cyclic reference detected for id: {:?}",
72 state.unresolved().collect::<Vec<_>>()
73 )
74 .into())
75 }
76}
77
78impl<T> Default for TopoSort<T> {
79 fn default() -> Self {
80 Self {
81 depends_on: HashMap::new(),
82 dependents: HashMap::new(),
83 no_deps: vec![],
84 }
85 }
86}
87
88impl<T: Eq + Hash + Clone> TopoSort<T> {
89 #[inline]
90 pub fn get_dependents(&self, dependency: &T) -> Option<&HashSet<T>> {
91 self.dependents.get(dependency)
92 }
93
94 #[inline]
95 pub fn is_resolved(&self) -> bool {
96 self.depends_on.is_empty()
97 }
98
99 #[inline]
100 pub fn resolve(&mut self, dependent: &T, dependency: &T) {
101 if let Some(dependencies) = self.depends_on.get_mut(dependent) {
102 dependencies.remove(dependency);
103 if dependencies.is_empty() {
104 self.no_deps.push(dependent.clone());
105
106 self.depends_on.remove(dependent);
107 }
108 }
109 }
110
111 #[inline]
112 pub fn unresolved(&self) -> impl Iterator<Item = &T> {
113 self.depends_on.keys()
114 }
115}
116
117impl From<&DepGraph<String>> for TopoSort<String> {
118 fn from(deps: &DepGraph<String>) -> TopoSort<String> {
119 let mut topo = TopoSort::default();
120 let mut nodes_being_depended_on = HashSet::new();
122
123 for (id, dependencies) in deps.iter() {
124 if dependencies.is_empty() {
125 topo.no_deps.push(id.clone());
126 } else {
127 for dependency_id in dependencies {
128 nodes_being_depended_on.insert(dependency_id.clone());
129 add_edge(&mut topo.depends_on, id.clone(), dependency_id.clone());
130 add_edge(&mut topo.dependents, dependency_id.clone(), id.clone());
131 }
132 }
133 }
134
135 for id in topo.depends_on.keys() {
137 nodes_being_depended_on.remove(id);
138 }
139
140 for id in nodes_being_depended_on {
142 topo.no_deps.push(id);
143 }
144
145 topo
146 }
147}