Skip to main content

oxihuman_core/
topological_sort.rs

1// Copyright (C) 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Topological sort (Kahn's algorithm) for DAGs.
5
6use std::collections::HashMap;
7
8#[allow(dead_code)]
9#[derive(Debug, Clone)]
10pub struct TopoGraph {
11    pub nodes: Vec<u32>,
12    pub edges: Vec<(u32, u32)>,
13}
14
15#[allow(dead_code)]
16#[derive(Debug, Clone)]
17pub struct TopoResult {
18    pub order: Vec<u32>,
19    pub has_cycle: bool,
20}
21
22#[allow(dead_code)]
23pub fn new_topo_graph() -> TopoGraph {
24    TopoGraph {
25        nodes: Vec::new(),
26        edges: Vec::new(),
27    }
28}
29
30#[allow(dead_code)]
31pub fn topo_add_node(g: &mut TopoGraph, node: u32) {
32    if !g.nodes.contains(&node) {
33        g.nodes.push(node);
34    }
35}
36
37#[allow(dead_code)]
38pub fn topo_add_edge(g: &mut TopoGraph, from: u32, to: u32) {
39    g.edges.push((from, to));
40}
41
42#[allow(dead_code)]
43pub fn topo_sort(g: &TopoGraph) -> TopoResult {
44    let mut in_degree: HashMap<u32, usize> = HashMap::new();
45    for &n in &g.nodes {
46        in_degree.entry(n).or_insert(0);
47    }
48    for &(_, to) in &g.edges {
49        *in_degree.entry(to).or_insert(0) += 1;
50    }
51
52    let mut queue: std::collections::VecDeque<u32> = in_degree
53        .iter()
54        .filter(|(_, &d)| d == 0)
55        .map(|(&n, _)| n)
56        .collect();
57    queue.make_contiguous().sort();
58
59    let mut order = Vec::new();
60    while let Some(node) = queue.pop_front() {
61        order.push(node);
62        let mut neighbors: Vec<u32> = g
63            .edges
64            .iter()
65            .filter(|&&(from, _)| from == node)
66            .map(|&(_, to)| to)
67            .collect();
68        neighbors.sort();
69        for to in neighbors {
70            let d = in_degree.entry(to).or_insert(0);
71            *d -= 1;
72            if *d == 0 {
73                queue.push_back(to);
74            }
75        }
76    }
77
78    let has_cycle = order.len() < g.nodes.len();
79    TopoResult { order, has_cycle }
80}
81
82#[allow(dead_code)]
83pub fn topo_node_count(g: &TopoGraph) -> usize {
84    g.nodes.len()
85}
86
87#[allow(dead_code)]
88pub fn topo_edge_count(g: &TopoGraph) -> usize {
89    g.edges.len()
90}
91
92#[allow(dead_code)]
93pub fn topo_has_cycle(g: &TopoGraph) -> bool {
94    topo_sort(g).has_cycle
95}
96
97#[allow(dead_code)]
98pub fn topo_remove_node(g: &mut TopoGraph, node: u32) {
99    g.nodes.retain(|&n| n != node);
100    g.edges.retain(|&(f, t)| f != node && t != node);
101}
102
103#[allow(dead_code)]
104pub fn topo_clear(g: &mut TopoGraph) {
105    g.nodes.clear();
106    g.edges.clear();
107}
108
109#[allow(dead_code)]
110pub fn topo_sort_dag(n: usize, edges: &[(usize, usize)]) -> Option<Vec<usize>> {
111    let mut in_deg = vec![0usize; n];
112    for &(_, to) in edges {
113        if to < n {
114            in_deg[to] += 1;
115        }
116    }
117    let mut queue: std::collections::VecDeque<usize> = (0..n).filter(|&i| in_deg[i] == 0).collect();
118    let mut order = Vec::new();
119    while let Some(node) = queue.pop_front() {
120        order.push(node);
121        for &(from, to) in edges {
122            if from == node && to < n {
123                in_deg[to] -= 1;
124                if in_deg[to] == 0 {
125                    queue.push_back(to);
126                }
127            }
128        }
129    }
130    if order.len() == n {
131        Some(order)
132    } else {
133        None
134    }
135}
136
137#[allow(dead_code)]
138pub fn topo_has_cycle_dag(n: usize, edges: &[(usize, usize)]) -> bool {
139    topo_sort_dag(n, edges).is_none()
140}
141
142#[allow(dead_code)]
143pub fn topo_layer_count(n: usize, edges: &[(usize, usize)]) -> usize {
144    let mut dist = vec![0usize; n];
145    let order = match topo_sort_dag(n, edges) {
146        Some(o) => o,
147        None => return 0,
148    };
149    for &node in &order {
150        for &(from, to) in edges {
151            if from == node && to < n {
152                dist[to] = dist[to].max(dist[node] + 1);
153            }
154        }
155    }
156    dist.iter().max().copied().unwrap_or(0) + 1
157}
158
159#[allow(dead_code)]
160pub fn topo_sources(n: usize, edges: &[(usize, usize)]) -> Vec<usize> {
161    let mut in_deg = vec![0usize; n];
162    for &(_, to) in edges {
163        if to < n {
164            in_deg[to] += 1;
165        }
166    }
167    (0..n).filter(|&i| in_deg[i] == 0).collect()
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn test_new_empty() {
176        let g = new_topo_graph();
177        assert_eq!(topo_node_count(&g), 0);
178        assert_eq!(topo_edge_count(&g), 0);
179    }
180
181    #[test]
182    fn test_add_node() {
183        let mut g = new_topo_graph();
184        topo_add_node(&mut g, 1);
185        topo_add_node(&mut g, 1); // duplicate
186        assert_eq!(topo_node_count(&g), 1);
187    }
188
189    #[test]
190    fn test_simple_sort() {
191        let mut g = new_topo_graph();
192        topo_add_node(&mut g, 1);
193        topo_add_node(&mut g, 2);
194        topo_add_node(&mut g, 3);
195        topo_add_edge(&mut g, 1, 2);
196        topo_add_edge(&mut g, 2, 3);
197        let res = topo_sort(&g);
198        assert!(!res.has_cycle);
199        assert_eq!(res.order, vec![1, 2, 3]);
200    }
201
202    #[test]
203    fn test_cycle_detection() {
204        let mut g = new_topo_graph();
205        topo_add_node(&mut g, 1);
206        topo_add_node(&mut g, 2);
207        topo_add_edge(&mut g, 1, 2);
208        topo_add_edge(&mut g, 2, 1);
209        assert!(topo_has_cycle(&g));
210    }
211
212    #[test]
213    fn test_no_cycle() {
214        let mut g = new_topo_graph();
215        topo_add_node(&mut g, 10);
216        topo_add_node(&mut g, 20);
217        topo_add_edge(&mut g, 10, 20);
218        assert!(!topo_has_cycle(&g));
219    }
220
221    #[test]
222    fn test_remove_node() {
223        let mut g = new_topo_graph();
224        topo_add_node(&mut g, 1);
225        topo_add_node(&mut g, 2);
226        topo_add_edge(&mut g, 1, 2);
227        topo_remove_node(&mut g, 1);
228        assert_eq!(topo_node_count(&g), 1);
229        assert_eq!(topo_edge_count(&g), 0);
230    }
231
232    #[test]
233    fn test_clear() {
234        let mut g = new_topo_graph();
235        topo_add_node(&mut g, 1);
236        topo_add_edge(&mut g, 1, 2);
237        topo_clear(&mut g);
238        assert_eq!(topo_node_count(&g), 0);
239        assert_eq!(topo_edge_count(&g), 0);
240    }
241
242    #[test]
243    fn test_edge_count() {
244        let mut g = new_topo_graph();
245        topo_add_node(&mut g, 1);
246        topo_add_node(&mut g, 2);
247        topo_add_edge(&mut g, 1, 2);
248        topo_add_edge(&mut g, 2, 1);
249        assert_eq!(topo_edge_count(&g), 2);
250    }
251
252    #[test]
253    fn test_diamond_dag() {
254        let mut g = new_topo_graph();
255        for n in [1, 2, 3, 4] {
256            topo_add_node(&mut g, n);
257        }
258        topo_add_edge(&mut g, 1, 2);
259        topo_add_edge(&mut g, 1, 3);
260        topo_add_edge(&mut g, 2, 4);
261        topo_add_edge(&mut g, 3, 4);
262        let res = topo_sort(&g);
263        assert!(!res.has_cycle);
264        assert_eq!(res.order[0], 1);
265        assert_eq!(*res.order.last().expect("should succeed"), 4);
266    }
267
268    #[test]
269    fn test_single_node() {
270        let mut g = new_topo_graph();
271        topo_add_node(&mut g, 5);
272        let res = topo_sort(&g);
273        assert!(!res.has_cycle);
274        assert_eq!(res.order, vec![5]);
275    }
276}