rk_utils/
topo_sort.rs

1use std::collections::{HashMap, HashSet};
2use std::hash::Hash;
3
4/// DepGraph is a dependency graph.
5/// It is represented as a HashMap where the key is the node and the value is a HashSet of its dependencies.
6pub type DepGraph<T> = HashMap<T, HashSet<T>>;
7
8#[inline]
9fn add_edge<T>(graph: &mut DepGraph<T>, from: T, to: T)
10where
11    T: Eq + Hash + Clone,
12{
13    graph
14        .entry(from)
15        .and_modify(|deps| {
16            deps.insert(to.clone());
17        })
18        .or_insert_with(|| {
19            let mut deps = HashSet::new();
20            deps.insert(to);
21            deps
22        });
23}
24
25/// TopoSort is a topological sort implementation.
26pub struct TopoSort<T> {
27    depends_on: DepGraph<T>,
28    dependents: DepGraph<T>,
29    no_deps: Vec<T>,
30}
31
32/// topo_sort sorts the dependencies topologically.
33/// It returns a Vec of the sorted nodes.
34/// If a cyclic reference is detected, it returns an error.
35/// 
36/// # Example
37/// 
38/// ```rust
39/// use rk_utils::topo_sort;
40/// use std::collections::{ HashMap, HashSet };
41/// 
42/// let mut deps = HashMap::new();
43/// deps.insert("b".to_string(), HashSet::from(["a".to_string()]));
44/// deps.insert("c".to_string(), HashSet::from(["b".to_string()]));
45/// 
46/// let sorted = topo_sort(&deps).unwrap();
47/// assert_eq!(sorted, ["a", "b", "c"]);
48/// ```
49pub fn topo_sort<T, Id>(deps: T) -> Result<Vec<Id>, Box<dyn std::error::Error>>
50where
51    Id: Eq + Hash + Clone + std::fmt::Debug,
52    TopoSort<Id>: From<T>,
53{
54    let mut state = TopoSort::from(deps);
55    let mut sorted = vec![];
56
57    while let Some(node) = state.no_deps.pop() {
58        sorted.push(node.clone());
59
60        if let Some(dependents) = state.get_dependents(&node) {
61            for dependent in dependents.clone() {
62                state.resolve(&dependent, &node);
63            }
64        }
65    }
66
67    if state.is_resolved() {
68        Ok(sorted)
69    } else {
70        Err(format!(
71            "Cyclic reference detected for id: {:?}",
72            state.unresolved().collect::<Vec<_>>()
73        )
74        .into())
75    }
76}
77
78impl<T> Default for TopoSort<T> {
79    fn default() -> Self {
80        Self {
81            depends_on: HashMap::new(),
82            dependents: HashMap::new(),
83            no_deps: vec![],
84        }
85    }
86}
87
88impl<T: Eq + Hash + Clone> TopoSort<T> {
89    #[inline]
90    pub fn get_dependents(&self, dependency: &T) -> Option<&HashSet<T>> {
91        self.dependents.get(dependency)
92    }
93
94    #[inline]
95    pub fn is_resolved(&self) -> bool {
96        self.depends_on.is_empty()
97    }
98
99    #[inline]
100    pub fn resolve(&mut self, dependent: &T, dependency: &T) {
101        if let Some(dependencies) = self.depends_on.get_mut(dependent) {
102            dependencies.remove(dependency);
103            if dependencies.is_empty() {
104                self.no_deps.push(dependent.clone());
105
106                self.depends_on.remove(dependent);
107            }
108        }
109    }
110
111    #[inline]
112    pub fn unresolved(&self) -> impl Iterator<Item = &T> {
113        self.depends_on.keys()
114    }
115}
116
117impl From<&DepGraph<String>> for TopoSort<String> {
118    fn from(deps: &DepGraph<String>) -> TopoSort<String> {
119        let mut topo = TopoSort::default();
120        // only track nodes that are being depended on but not appearing in id field of the deps
121        let mut nodes_being_depended_on = HashSet::new();
122
123        for (id, dependencies) in deps.iter() {
124            if dependencies.is_empty() {
125                topo.no_deps.push(id.clone());
126            } else {
127                for dependency_id in dependencies {
128                    nodes_being_depended_on.insert(dependency_id.clone());
129                    add_edge(&mut topo.depends_on, id.clone(), dependency_id.clone());
130                    add_edge(&mut topo.dependents, dependency_id.clone(), id.clone());
131                }
132            }
133        }
134
135        // remove those id that already on topo.depends_on from nodes_being_depended_on
136        for id in topo.depends_on.keys() {
137            nodes_being_depended_on.remove(id);
138        }
139
140        // move the rest of the nodes to topo.no_deps
141        for id in nodes_being_depended_on {
142            topo.no_deps.push(id);
143        }
144
145        topo
146    }
147}