walky/
mst.rs

1//! Compute a minimum spanning tree
2
3use core::panic;
4use std::{
5    cmp::Reverse,
6    ops::{Deref, DerefMut},
7};
8
9use crate::{
10    computation_mode::*,
11    datastructures::{AdjacencyMatrix, Edge, Graph, NAMatrix},
12};
13
14use delegate::delegate;
15use nalgebra::{Dyn, U1};
16use ordered_float::OrderedFloat;
17use priority_queue::PriorityQueue;
18use rayon::prelude::*;
19
20/// Prims algorithm for computing an MST of the given `graph`.
21///
22/// `MODE`: constant parameter, choose one of the values from [`crate::computation_mode`]
23///
24/// See [`prim_with_excluded_node_single_threaded`] for more details.
25pub fn prim<const MODE: usize>(graph: &NAMatrix) -> Graph {
26    match MODE {
27        SEQ_COMPUTATION => prim_with_excluded_node_single_threaded(graph, &[]),
28        PAR_COMPUTATION => prim_with_excluded_node_multi_threaded(graph, &[]),
29        #[cfg(feature = "mpi")]
30        MPI_COMPUTATION => {
31            eprintln!("Warning: defaulting to sequential implementation of prims algorithm");
32            prim::<SEQ_COMPUTATION>(graph)
33        }
34        _ => panic_on_invaid_mode::<MODE>(),
35    }
36}
37
38/// multithreaded version of [`prim_with_excluded_node_single_threaded`].
39///
40/// If you have multiple calls to prims algorithm, use a single threaded version
41/// and make the calls in parallel.
42pub fn prim_with_excluded_node_multi_threaded(
43    graph: &NAMatrix,
44    excluded_vertices: &[usize],
45) -> Graph {
46    prim_with_excluded_node::<MultiThreadedVecWrapper>(graph, excluded_vertices)
47}
48
49/// greedy algorithm:
50/// start at the first vertex in the graph and build an MST step by step.
51///
52/// `excluded_vertices`: option to exclude vertices from the graph and thus the MST computation.
53///     If you do not want to exclude a vertex from the computation, chose
54///     `excluded_vertex = &[])` (the function [`prim`] does this for you).
55///
56/// naive version using only vectors as data structures.
57/// For small enough (might not have to be very small) inputs
58/// this is faster than a priority queue due to
59/// less branching and better auto-vectorization potential.
60/// Asymptotic performance: O(N^2)
61pub fn prim_with_excluded_node_single_threaded(
62    graph: &NAMatrix,
63    excluded_vertices: &[usize],
64) -> Graph {
65    prim_with_excluded_node::<Vec<(Edge, bool)>>(graph, excluded_vertices)
66}
67
68/// improve asymptotic performance (compared to [`prim_with_excluded_node_single_threaded`])
69/// by using a priority queue
70pub fn prim_with_excluded_node_priority_queue(
71    graph: &NAMatrix,
72    excluded_vertices: &[usize],
73) -> Graph {
74    prim_with_excluded_node::<VerticesInPriorityQueue>(graph, excluded_vertices)
75}
76
77/// greedy algorithm:
78/// start at the first vertex in the graph and build an MST step by step.
79///
80/// `excluded_vertices`: option to exclude vertices from the graph and thus the MST computation.
81///     If you do not want to exclude a vertex from the computation, chose
82///     `excluded_vertex = &[])` (the function [`prim`] does this for you).
83fn prim_with_excluded_node<D: FindMinCostEdge>(
84    graph: &NAMatrix,
85    excluded_vertices: &[usize],
86) -> Graph {
87    let num_vertices = graph.dim();
88    let unconnected_node = num_vertices;
89
90    // stores our current MST
91    let mut mst_adj_list: Vec<Vec<Edge>> = vec![Vec::new(); num_vertices + 1];
92
93    // `dist_from_mst[i]` stores the edge with that the vertex i can be connected to the MST
94    // with minimal cost.
95    let mut dist_from_mst = D::from_default_value(
96        // base case: every vertex is "connected" to the unconnected node with cost f64::INFINITY
97        Edge {
98            cost: f64::INFINITY,
99            to: unconnected_node,
100        },
101        num_vertices + 1,
102    );
103
104    // Vertex at index unconnected_node is special: it is not connected to the rest of the graph,
105    // and has distance INFINITY to every other node.
106    // It is used as a base case.
107
108    // start with vertex 0, or with vertex 1 if vertex 0 shall be excluded
109    let start_index = {
110        let mut idx = 0;
111        while excluded_vertices.contains(&idx) {
112            idx += 1;
113        }
114        if idx >= num_vertices {
115            // all vertices are excluded --> empty MST
116            return vec![].into();
117        }
118        idx
119    };
120
121    dist_from_mst.set_cost(
122        start_index,
123        Edge {
124            to: start_index,
125            cost: 0.,
126        },
127    );
128    for &vertex in excluded_vertices {
129        dist_from_mst.set_excluded_vertex(vertex)
130    }
131
132    // iterate over maximally `num_vertices` many iterations (for every vertex one)
133    for _ in 0..=num_vertices {
134        let (next_vertex, next_edge) = dist_from_mst.find_edge_with_minimal_cost();
135
136        // when we reach an unreachable vertex (like index num_vertices),
137        // we are finished
138        if next_edge.cost == f64::INFINITY {
139            break;
140        }
141
142        // add next_vertex to the mst
143        dist_from_mst.mark_vertex_as_used(next_vertex);
144        if next_vertex != start_index {
145            //let connecting_edge = dist_from_mst[next_vertex].clone();
146            let reverse_edge = Edge {
147                to: next_vertex,
148                cost: next_edge.cost,
149            };
150            let connection_from = next_edge.to;
151            let connection_to = next_vertex;
152            mst_adj_list[connection_to].push(next_edge);
153            mst_adj_list[connection_from].push(reverse_edge);
154        }
155
156        // update the minimal connection costs foll all newly adjacent vertices
157        //for edge in graph[next_vertex].iter() {
158        //for (to, &cost) in graph.row(next_vertex).iter().enumerate() {
159        //    dist_from_mst.update_minimal_cost(next_vertex, Edge { to, cost });
160        //}
161        dist_from_mst.update_minimal_cost(next_vertex, graph.row(next_vertex))
162    }
163
164    // remove the last entry (for unreachable_vertex) as it is only relevant for the algorithm
165    mst_adj_list.pop();
166    Graph::from(mst_adj_list)
167}
168
169type NAMatrixRowView<'a> =
170    nalgebra::Matrix<f64, U1, Dyn, nalgebra::ViewStorage<'a, f64, U1, Dyn, U1, Dyn>>;
171
172/// This trait reflects a datastructure,
173/// that holds Edges and can give back the edge with minimal cost,
174/// as well as update the cost of edges.
175trait FindMinCostEdge {
176    fn from_default_value(default_val: Edge, size: usize) -> Self;
177
178    /// Get the index of the vertex that is currently not in the MST
179    /// and has minimal cost to connect to the mst, as well as the
180    /// corresponding connecting edge to the MST.
181    fn find_edge_with_minimal_cost(&self) -> (usize, Edge);
182    /// update the connection cost of `edge_to.to`.
183    /// If `edge_to.cost` is less than the current cost, the cost decreases to
184    /// `edge_to.cost` and `from` gets saved as the connecting vertex.
185    /// If it is higher, the cost does *not* increase.
186    /// If provided with the edge `from --> edge_to.to`,
187    /// the structure will then possibly remember the reverse edge `from <-- edge_to.to`
188    fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView);
189
190    /// sets the cost of connecting from `from` to `edge_to.to` to the value `edge_to.cost`.
191    fn set_cost(&mut self, from: usize, edge_to: Edge);
192
193    /// sets which vertex to exclude/ignore in the computations
194    fn set_excluded_vertex(&mut self, excluded_vertex: usize);
195
196    fn mark_vertex_as_used(&mut self, used_vertex: usize);
197}
198
199#[derive(Clone, Debug, PartialEq)]
200struct VerticesInPriorityQueue {
201    /// stores the vertices that are not currently in the MST,
202    /// can efficiently find the vertex with minimal connection cost to the MST
203    cost_queue: PriorityQueue<usize, Reverse<OrderedFloat<f64>>>,
204    /// implements the following map:
205    /// given a vertex `i`, the minimal cost edge to the
206    /// MST is to the vertex `j == connection_to_mst[i]`
207    connection_to_mst: Vec<usize>,
208    /// `used[i]`: vertex `i` is already part of the MST
209    used: Vec<bool>,
210}
211impl FindMinCostEdge for VerticesInPriorityQueue {
212    fn from_default_value(default_val: Edge, size: usize) -> Self {
213        VerticesInPriorityQueue {
214            cost_queue: PriorityQueue::from(
215                (0..size)
216                    .map(|i| (i, Reverse(OrderedFloat(default_val.cost))))
217                    .collect::<Vec<(usize, Reverse<OrderedFloat<f64>>)>>(),
218            ),
219            connection_to_mst: vec![default_val.to; size],
220            used: vec![false; size],
221        }
222    }
223
224    fn find_edge_with_minimal_cost(&self) -> (usize, Edge) {
225        let base_case = Edge {
226            to: self.connection_to_mst.len(),
227            cost: f64::INFINITY,
228        };
229        let (&next_vertex, &Reverse(OrderedFloat(cost))) = self
230            .cost_queue
231            .peek()
232            .unwrap_or((&base_case.to, &Reverse(OrderedFloat(base_case.cost))));
233        let to = self.connection_to_mst[next_vertex];
234
235        (next_vertex, Edge { to, cost })
236    }
237
238    fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView) {
239        for (to, &cost) in new_neighbours.iter().enumerate() {
240            if self.used[to] {
241                continue;
242            }
243            let Reverse(OrderedFloat(old_cost)) = self.cost_queue
244            .push_increase(to, Reverse(OrderedFloat(cost)))
245            .unwrap_or_else(|| panic!("Every unused unused vertex shall be contained in the queue from the beginning. Missing vertex: {}", to));
246            if cost <= old_cost {
247                self.connection_to_mst[to] = from;
248            }
249        }
250    }
251
252    fn set_cost(&mut self, from: usize, edge_to: Edge) {
253        self.cost_queue
254            .change_priority(&from, Reverse(OrderedFloat(edge_to.cost)));
255
256        self.connection_to_mst[from] = edge_to.to;
257    }
258
259    fn set_excluded_vertex(&mut self, excluded_vertex: usize) {
260        self.mark_vertex_as_used(excluded_vertex);
261    }
262
263    fn mark_vertex_as_used(&mut self, used_vertex: usize) {
264        self.cost_queue.remove(&used_vertex);
265        self.used[used_vertex] = true;
266    }
267}
268
269/// Edge: holds the (currently minimal) connection cost,
270/// and the vertex to which to connect to the MST
271///
272/// bool: true, if the Vertex is in the MST, false if the vertex is not in the MST.
273impl FindMinCostEdge for Vec<(Edge, bool)> {
274    fn from_default_value(default_val: Edge, size: usize) -> Self {
275        vec![(default_val, false); size]
276    }
277
278    fn find_edge_with_minimal_cost(&self) -> (usize, Edge) {
279        let base_case = Edge {
280            to: self.len(),
281            cost: f64::INFINITY,
282        };
283        let (next_vertex, reverse_edge) = self
284            .iter()
285            .enumerate()
286            // skip all used vertices
287            .filter_map(
288                |(i, &(edge, used_in_mst))| if used_in_mst { None } else { Some((i, edge)) },
289            )
290            // find the next vertex via the corresponding edge with minimal cost
291            .min_by(|&(_, edg_i), &(_, edg_j)| {
292                OrderedFloat(edg_i.cost).cmp(&OrderedFloat(edg_j.cost))
293            })
294            // unwrap, or give back the base case
295            .unwrap_or((base_case.to, base_case));
296        (next_vertex, reverse_edge)
297    }
298
299    fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView) {
300        //self[to] = f64::min(self[to], edge.cost);
301        for (to, &cost) in new_neighbours.iter().enumerate() {
302            if cost < self[to].0.cost {
303                self[to].0 = Edge { to: from, cost };
304            }
305        }
306    }
307
308    fn set_cost(&mut self, from: usize, edge_to: Edge) {
309        self[from].0 = edge_to;
310    }
311
312    fn mark_vertex_as_used(&mut self, used_vertex: usize) {
313        self[used_vertex].1 = true;
314    }
315
316    fn set_excluded_vertex(&mut self, excluded_vertex: usize) {
317        self.mark_vertex_as_used(excluded_vertex);
318    }
319}
320
321#[derive(Debug, PartialEq)]
322struct MultiThreadedVecWrapper(Vec<(Edge, bool)>);
323
324impl Deref for MultiThreadedVecWrapper {
325    type Target = Vec<(Edge, bool)>;
326    fn deref(&self) -> &Self::Target {
327        &self.0
328    }
329}
330impl DerefMut for MultiThreadedVecWrapper {
331    fn deref_mut(&mut self) -> &mut Self::Target {
332        &mut self.0
333    }
334}
335
336impl FindMinCostEdge for MultiThreadedVecWrapper {
337    fn from_default_value(default_val: Edge, size: usize) -> Self {
338        MultiThreadedVecWrapper(Vec::from_default_value(default_val, size))
339    }
340    delegate! {
341        to self.0 {
342            fn set_cost(&mut self, from: usize, edge_to: Edge);
343            fn set_excluded_vertex(&mut self, excluded_vertex: usize);
344            fn mark_vertex_as_used(&mut self, used_vertex: usize);
345        }
346    }
347
348    fn update_minimal_cost(&mut self, from: usize, new_neighbours: NAMatrixRowView) {
349        //self[to] = f64::min(self[to], edge.cost);
350        let dim = new_neighbours.shape().1;
351        //for (to, &cost) in new_neighbours.par_iter().enumerate()
352        (0..dim).into_par_iter().for_each(|to| {
353            let neighbour_prt = new_neighbours.as_ptr() as *mut f64;
354            // safety: the data exists, we do not leave the range
355            // of the underlying NAMatrix (we add at most dim*(dim-1),
356            // and the pointer to the row has at most offset dim-1 from the cell at index (0,0).
357            // Therefore we stay within an offset of (dim*dim)-1
358            let cost = unsafe { *neighbour_prt.add(dim * to) };
359            let to_dist_ptr = self.as_ptr() as *mut (Edge, bool);
360            if cost < self[to].0.cost {
361                // safety:
362                //  - no race conditions, since the parallel iterator visits each value of to
363                //    exactly once
364                //  - we do not exeed the length of the vector self.0
365                unsafe {
366                    (*to_dist_ptr.add(to)).0 = Edge { to: from, cost };
367                }
368            }
369        });
370    }
371
372    fn find_edge_with_minimal_cost(&self) -> (usize, Edge) {
373        let base_case = Edge {
374            to: self.0.len(),
375            cost: f64::INFINITY,
376        };
377        let (next_vertex, reverse_edge) = self
378            .0
379            .par_iter()
380            .enumerate()
381            // skip all used vertices
382            .filter_map(
383                |(i, &(edge, used_in_mst))| if used_in_mst { None } else { Some((i, edge)) },
384            )
385            // find the next vertex via the corresponding edge with minimal cost
386            .min_by(|&(_, edg_i), &(_, edg_j)| {
387                OrderedFloat(edg_i.cost).cmp(&OrderedFloat(edg_j.cost))
388            })
389            // unwrap, or give back the base case
390            .unwrap_or((base_case.to, base_case));
391        (next_vertex, reverse_edge)
392    }
393}
394
395#[cfg(test)]
396mod test {
397    use std::assert_eq;
398
399    use nalgebra::DMatrix;
400
401    use super::*;
402
403    #[test]
404    fn easy_prim() {
405        let graph = Graph::from(vec![
406            vec![Edge { to: 1, cost: 1.0 }],
407            vec![Edge { to: 0, cost: 1.0 }],
408        ]);
409
410        let mst = prim::<SEQ_COMPUTATION>(&(&graph).into());
411        assert_eq!(graph, mst);
412    }
413
414    /// graph:
415    /// 0 ----- 1
416    /// |\     /|
417    /// | \   / |
418    /// |  \ /  |
419    /// |   X   |
420    /// |  / \  |
421    /// | /   \ |
422    /// |/     \|
423    /// 3 ----- 2
424    ///
425    /// MST:
426    /// 0       1
427    ///  \     /
428    ///   \   /  
429    ///    \ /   
430    ///     X    
431    ///    / \   
432    ///   /   \  
433    ///  /     \
434    /// 3 ----- 2
435    #[test]
436    fn four_vertices_mst_prim() {
437        let graph = Graph::from(vec![
438            //vertex 0
439            vec![
440                Edge { to: 1, cost: 1.0 },
441                Edge { to: 2, cost: 0.1 },
442                Edge { to: 3, cost: 2.0 },
443            ],
444            //vertex 1
445            vec![
446                Edge { to: 0, cost: 1.0 },
447                Edge { to: 2, cost: 5.0 },
448                Edge { to: 3, cost: 0.1 },
449            ],
450            //vertex 2
451            vec![
452                Edge { to: 0, cost: 0.1 },
453                Edge { to: 1, cost: 1.1 },
454                Edge { to: 3, cost: 0.1 },
455            ],
456            //vertex 3
457            vec![
458                Edge { to: 0, cost: 2.0 },
459                Edge { to: 1, cost: 0.1 },
460                Edge { to: 2, cost: 0.1 },
461            ],
462        ]);
463
464        let expected = Graph::from(vec![
465            //vertex 0
466            vec![Edge { to: 2, cost: 0.1 }],
467            //vertex 1
468            vec![Edge { to: 3, cost: 0.1 }],
469            //vertex 2
470            vec![Edge { to: 0, cost: 0.1 }, Edge { to: 3, cost: 0.1 }],
471            //vertex 3
472            vec![Edge { to: 2, cost: 0.1 }, Edge { to: 1, cost: 0.1 }],
473        ]);
474
475        assert_eq!(expected, prim::<SEQ_COMPUTATION>(&(&graph).into()));
476    }
477
478    /// graph:
479    /// 0 ----- 1
480    /// |\     /|
481    /// | \   / |
482    /// |  \ /  |
483    /// |   X   |
484    /// |  / \  |
485    /// | /   \ |
486    /// |/     \|
487    /// 3 ----- 2
488    ///
489    /// exclude vertex 0 from MST computation
490    ///
491    /// MST:
492    ///         1
493    ///        /
494    ///       /  
495    ///      /   
496    ///     /    
497    ///    /     
498    ///   /      
499    ///  /      
500    /// 3 ----- 2
501    #[test]
502    fn exclude_one_vertex_from_mst() {
503        let graph = Graph::from(vec![
504            //vertex 0
505            vec![
506                Edge { to: 1, cost: 1.0 },
507                Edge { to: 2, cost: 0.1 },
508                Edge { to: 3, cost: 2.0 },
509            ],
510            //vertex 1
511            vec![
512                Edge { to: 0, cost: 1.0 },
513                Edge { to: 2, cost: 5.0 },
514                Edge { to: 3, cost: 0.1 },
515            ],
516            //vertex 2
517            vec![
518                Edge { to: 0, cost: 0.1 },
519                Edge { to: 1, cost: 1.1 },
520                Edge { to: 3, cost: 0.1 },
521            ],
522            //vertex 3
523            vec![
524                Edge { to: 0, cost: 2.0 },
525                Edge { to: 1, cost: 0.1 },
526                Edge { to: 2, cost: 0.1 },
527            ],
528        ]);
529
530        let expected = Graph::from(vec![
531            //vertex 0 not in the MST
532            vec![],
533            //vertex 1
534            vec![Edge { to: 3, cost: 0.1 }],
535            //vertex 2
536            vec![Edge { to: 3, cost: 0.1 }],
537            //vertex 3
538            vec![Edge { to: 1, cost: 0.1 }, Edge { to: 2, cost: 0.1 }],
539        ]);
540
541        assert_eq!(
542            expected,
543            prim_with_excluded_node_multi_threaded(&(&graph).into(), &[0])
544        );
545    }
546
547    #[test]
548    fn prim_all_versions_agree() {
549        let graph = Graph::from(vec![
550            //vertex 0
551            vec![
552                Edge { to: 1, cost: 1.0 },
553                Edge { to: 2, cost: 0.1 },
554                Edge { to: 3, cost: 2.0 },
555            ],
556            //vertex 1
557            vec![
558                Edge { to: 0, cost: 1.0 },
559                Edge { to: 2, cost: 5.0 },
560                Edge { to: 3, cost: 0.1 },
561            ],
562            //vertex 2
563            vec![
564                Edge { to: 0, cost: 0.1 },
565                Edge { to: 1, cost: 1.1 },
566                Edge { to: 3, cost: 0.1 },
567            ],
568            //vertex 3
569            vec![
570                Edge { to: 0, cost: 2.0 },
571                Edge { to: 1, cost: 0.1 },
572                Edge { to: 2, cost: 0.1 },
573            ],
574        ]);
575        let excluded_vertex = &[0];
576        let res_st = prim_with_excluded_node_single_threaded(&(&graph).into(), excluded_vertex);
577        let res_mt = prim_with_excluded_node_multi_threaded(&(&graph).into(), excluded_vertex);
578        let res_prio = prim_with_excluded_node_priority_queue(&(&graph).into(), excluded_vertex);
579        assert_eq!(
580            res_st, res_mt,
581            "single_threaded should agree with multi_threaded"
582        );
583        assert_eq!(
584            res_st, res_prio,
585            "single_threaded should agree with priority queue version"
586        );
587    }
588
589    #[test]
590    fn test_vertices_in_priority_queue_from_default_value() {
591        let default_val = Edge {
592            to: 3,
593            cost: f64::INFINITY,
594        };
595
596        let size = 5;
597
598        let vert = VerticesInPriorityQueue::from_default_value(default_val, size);
599
600        let mut queue = PriorityQueue::new();
601        for i in 0..size {
602            queue.push(i, Reverse(OrderedFloat(f64::INFINITY)));
603        }
604
605        assert_eq!(vert.cost_queue, queue);
606        assert_eq!(vert.cost_queue.into_vec(), vec![0, 1, 2, 3, 4]);
607        assert_eq!(vert.connection_to_mst, vec![3; 5])
608    }
609
610    #[test]
611    fn test_vertices_in_priority_queue_increase_priority() {
612        let default_val = Edge {
613            to: 4,
614            cost: f64::INFINITY,
615        };
616
617        let size = 5;
618
619        let mut vert = VerticesInPriorityQueue::from_default_value(default_val, size);
620
621        let res = vert.cost_queue.push_increase(0, Reverse(OrderedFloat(1.0)));
622        assert_eq!(res, Some(Reverse(OrderedFloat(f64::INFINITY))));
623    }
624
625    #[test]
626    fn test_vertices_in_priority_queue_update_priority_does_not_panic() {
627        let default_val = Edge {
628            to: 4,
629            cost: f64::INFINITY,
630        };
631
632        let size = 5;
633
634        let mut vert = VerticesInPriorityQueue::from_default_value(default_val, size);
635        let mat = DMatrix::from_row_slice(1, size, &[1.0; 5]);
636
637        vert.update_minimal_cost(0, mat.row(0));
638    }
639
640    #[test]
641    fn test_vertices_in_priority_queue_update_priority_works() {
642        let default_val = Edge {
643            to: 4,
644            cost: f64::INFINITY,
645        };
646
647        let size = 5;
648
649        let mut vert = VerticesInPriorityQueue::from_default_value(default_val, size);
650        let mat = DMatrix::from_row_slice(1, size, &[0.0, 1.0, 0.0, 0.0, 0.0]);
651
652        vert.update_minimal_cost(0, mat.row(0));
653        assert_eq!(vert.connection_to_mst[1], 0);
654        assert_eq!(
655            vert.cost_queue.get_priority(&1),
656            Some(&Reverse(OrderedFloat(1.0f64)))
657        );
658    }
659}