Skip to main content

ac_lib/graph/
dijkstra.rs

1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3
4#[derive(Debug, Clone)]
5pub struct Edge {
6    pub node: usize,
7    pub cost: usize,
8}
9
10#[derive(Debug)]
11struct State {
12    cost: usize,
13    position: usize,
14}
15
16impl Ord for State {
17    fn cmp(&self, other: &Self) -> Ordering {
18        other.cost.cmp(&self.cost)
19    }
20}
21
22impl PartialOrd for State {
23    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
24        Some(self.cmp(other))
25    }
26}
27
28impl PartialEq for State {
29    fn eq(&self, other: &Self) -> bool {
30        self.cost == other.cost && self.position == other.position
31    }
32}
33
34impl Eq for State {}
35
36pub fn dijkstra(graph: &[Vec<Edge>], start: usize) -> Vec<usize> {
37    let n = graph.len();
38    let mut dist = vec![usize::MAX; n];
39    let mut heap = BinaryHeap::new();
40
41    dist[start] = 0;
42    heap.push(State {
43        cost: 0,
44        position: start,
45    });
46
47    while let Some(State { cost, position }) = heap.pop() {
48        if cost > dist[position] {
49            continue;
50        }
51
52        for edge in &graph[position] {
53            let next_cost = cost + edge.cost;
54
55            if next_cost < dist[edge.node] {
56                dist[edge.node] = next_cost;
57                heap.push(State {
58                    cost: next_cost,
59                    position: edge.node,
60                });
61            }
62        }
63    }
64
65    dist
66}
67
68pub fn dijkstra_with_path(graph: &[Vec<Edge>], start: usize) -> (Vec<usize>, Vec<Option<usize>>) {
69    let n = graph.len();
70    let mut dist = vec![usize::MAX; n];
71    let mut parent = vec![None; n];
72    let mut heap = BinaryHeap::new();
73
74    dist[start] = 0;
75    heap.push(State {
76        cost: 0,
77        position: start,
78    });
79
80    while let Some(State { cost, position }) = heap.pop() {
81        if cost > dist[position] {
82            continue;
83        }
84
85        for edge in &graph[position] {
86            let next_cost = cost + edge.cost;
87
88            if next_cost < dist[edge.node] {
89                dist[edge.node] = next_cost;
90                parent[edge.node] = Some(position);
91                heap.push(State {
92                    cost: next_cost,
93                    position: edge.node,
94                });
95            }
96        }
97    }
98
99    (dist, parent)
100}