competitive_hpp/
dijkstra.rs1use crate::total_ord::Total;
2use num::traits::{Bounded, Num, PrimInt};
3use std::cmp::Reverse;
4use std::collections::BinaryHeap;
5
6#[derive(Clone, Debug)]
21pub struct Dijkstra<T, F>
22where
23 T: PrimInt,
24 F: Num + Bounded + Clone + Copy + PartialOrd,
25{
26 pub dist: Vec<F>,
27 pub adjacency_list: Vec<Vec<(usize, F)>>,
28 n: T,
29}
30
31impl<T, F> Dijkstra<T, F>
32where
33 T: PrimInt,
34 F: Num + Bounded + Clone + Copy + PartialOrd,
35{
36 pub fn new(n: T, edges: &[(usize, usize, F)], start: usize) -> Self {
37 let inf = F::max_value();
38
39 let mut dist: Vec<F> = vec![inf; n.to_usize().unwrap()];
40 let adjacency_list = Self::create_adjacency_list(n, &edges);
41
42 let mut heap: BinaryHeap<Total<Reverse<(F, usize)>>> = BinaryHeap::new();
44
45 dist[start] = F::zero();
46 heap.push(Total(Reverse((F::zero(), start))));
47
48 while !heap.is_empty() {
49 let Total(Reverse((d, v))) = heap.pop().unwrap();
50
51 if dist[v] < d {
52 continue;
53 }
54
55 for &(u, cost) in adjacency_list[v].iter() {
56 if dist[u] > dist[v] + cost {
57 dist[u] = dist[v] + cost;
58 heap.push(Total(Reverse((dist[u], u))));
59 }
60 }
61 }
62
63 Dijkstra {
64 dist,
65 adjacency_list,
66 n,
67 }
68 }
69
70 fn create_adjacency_list(n: T, edges: &[(usize, usize, F)]) -> Vec<Vec<(usize, F)>> {
71 let mut adjacency_list: Vec<Vec<(usize, F)>> = vec![vec![]; n.to_usize().unwrap()];
72
73 for &(from, to, cost) in edges {
74 adjacency_list[from].push((to, cost))
75 }
76
77 adjacency_list
78 }
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 #[test]
85 fn dijkstra_test() {
86 let edges = vec![
91 (0, 1, 1),
92 (0, 2, 6),
93 (1, 3, 2),
94 (2, 3, 2),
95 (0, 3, 2),
96 (3, 0, 4),
97 ];
98
99 let dijkstra = Dijkstra::new(4, &edges, 0);
100
101 assert_eq!(dijkstra.dist[0], 0);
102 assert_eq!(dijkstra.dist[1], 1);
103 assert_eq!(dijkstra.dist[2], 6);
104 assert_eq!(dijkstra.dist[3], 2);
105
106 let dijkstra_another = Dijkstra::new(4, &edges, 1);
107
108 assert_eq!(dijkstra_another.dist[0], 6);
109 assert_eq!(dijkstra_another.dist[1], 0);
110 assert_eq!(dijkstra_another.dist[2], 12);
111 assert_eq!(dijkstra_another.dist[3], 2);
112 }
113
114 #[test]
115 fn float_dijkstra_test() {
116 let float_edges = vec![(0, 1, 1.5f64), (1, 2, 6.2f64), (0, 2, 4.3f64)];
119
120 let float_dijkstra = Dijkstra::new(3, &float_edges, 0);
121
122 assert!(float_dijkstra.dist[0] - 0f64 < std::f64::EPSILON);
123 assert!(float_dijkstra.dist[1] - 1.5f64 < std::f64::EPSILON);
124 assert!(float_dijkstra.dist[2] - 4.3f64 < std::f64::EPSILON);
125 }
126}