prepona/algo/mst/
kruskal.rs1use std::cell::RefCell;
2use std::collections::HashSet;
3use std::rc::Rc;
4
5use crate::graph::{subgraph::Subgraph, Edge, UndirectedEdge};
6use crate::provide;
7
8pub struct Kruskal {
62    sets: Vec<Rc<RefCell<HashSet<usize>>>>,
63}
64
65impl Kruskal {
66    pub fn init<G, W: Ord, E: Edge<W>>(graph: &G) -> Self
68    where
69        G: provide::Vertices + provide::Edges<W, E> + provide::Graph<W, E, UndirectedEdge>,
70    {
71        let vertex_count = graph.vertex_count();
72
73        let mut sets = vec![];
74        sets.resize_with(vertex_count, || Rc::new(RefCell::new(HashSet::new())));
75
76        for virt_id in 0..vertex_count {
77            sets[virt_id].borrow_mut().insert(virt_id);
78        }
79
80        Kruskal { sets }
81    }
82
83    pub fn execute<'a, G, W: Ord, E: Edge<W>>(
91        mut self,
92        graph: &'a G,
93    ) -> Subgraph<W, E, UndirectedEdge, G>
94    where
95        G: provide::Edges<W, E>
96            + provide::Neighbors
97            + provide::Vertices
98            + provide::Graph<W, E, UndirectedEdge>,
99    {
100        let mut mst = Vec::<(usize, usize, usize)>::new();
101
102        let id_map = graph.continuos_id_map();
103
104        let mut edges = graph.edges();
105
106        edges.sort_by(|(_, _, e1), (_, _, e2)| e1.get_weight().cmp(e2.get_weight()));
107
108        for (v_real_id, u_real_id, edge) in edges {
109            let v_virt_id = id_map.virt_id_of(v_real_id);
110            let u_virt_id = id_map.virt_id_of(u_real_id);
111
112            if !self.sets[v_virt_id]
113                .borrow()
114                .eq(&*self.sets[u_virt_id].borrow())
115            {
116                mst.push((v_real_id, u_real_id, edge.get_id()));
117
118                let union_set = self.sets[v_virt_id]
119                    .borrow()
120                    .union(&*self.sets[u_virt_id].borrow())
121                    .copied()
122                    .collect::<HashSet<usize>>();
123
124                let sharable_set = Rc::new(RefCell::new(union_set));
125
126                for member in sharable_set.borrow().iter() {
127                    self.sets[*member] = sharable_set.clone();
128                }
129            }
130        }
131
132        let vertices = mst
133            .iter()
134            .flat_map(|(src_id, dst_id, _)| vec![*src_id, *dst_id])
135            .collect::<HashSet<usize>>();
136
137        Subgraph::init(graph, mst, vertices)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::graph::MatGraph;
145    use crate::provide::*;
146    use crate::storage::Mat;
147
148    #[test]
149    fn empty_graph() {
150        let graph = MatGraph::init(Mat::<usize>::init());
151
152        let mst = Kruskal::init(&graph).execute(&graph);
153
154        assert_eq!(mst.vertex_count(), 0);
155    }
156
157    #[test]
158    fn trivial_directed_graph() {
159        let mut graph = MatGraph::init(Mat::<usize>::init());
169        let a = graph.add_vertex();
170        let b = graph.add_vertex();
171        let c = graph.add_vertex();
172        let d = graph.add_vertex();
173        let e = graph.add_vertex();
174        let f = graph.add_vertex();
175
176        let ab = graph.add_edge_unchecked(a, b, 1.into());
177        graph.add_edge_unchecked(a, c, 3.into());
178        let af = graph.add_edge_unchecked(a, f, 3.into());
179
180        graph.add_edge_unchecked(b, c, 5.into());
181        let bd = graph.add_edge_unchecked(b, d, 1.into());
182
183        let dc = graph.add_edge_unchecked(d, c, 2.into());
184        graph.add_edge_unchecked(d, e, 4.into());
185
186        let ec = graph.add_edge_unchecked(e, c, 1.into());
187        graph.add_edge_unchecked(e, f, 5.into());
188
189        let mut tags = std::collections::HashMap::<usize, &'static str>::new();
190        tags.insert(a, "a");
191        tags.insert(b, "b");
192        tags.insert(c, "c");
193        tags.insert(d, "d");
194        tags.insert(e, "e");
195        tags.insert(f, "f");
196
197        let mst = Kruskal::init(&graph).execute(&graph);
198
199        assert_eq!(mst.vertex_count(), 6);
200        assert_eq!(mst.edges_count(), 5);
201        assert!(vec![ab, af, bd, dc, ec]
202            .into_iter()
203            .all(|edge_id| mst.edge(edge_id).is_ok()))
204    }
205}