Skip to main content

converge_optimization/graph/
dijkstra.rs

1//! Dijkstra's shortest path algorithm
2//!
3//! Finds shortest paths from a source node to all other nodes
4//! in a graph with non-negative edge weights.
5//!
6//! Time complexity: O((V + E) log V) using a binary heap.
7
8use super::{Graph, NodeId, Path};
9use crate::{Cost, Error, Result};
10use petgraph::visit::EdgeRef;
11use std::cmp::Reverse;
12use std::collections::{BinaryHeap, HashMap};
13
14/// Find shortest path between two nodes
15pub fn shortest_path<N, E>(
16    graph: &Graph<N, E>,
17    source: NodeId,
18    target: NodeId,
19    edge_cost: impl Fn(&E) -> Cost,
20) -> Result<Option<Path>> {
21    let distances = dijkstra(graph, source, &edge_cost)?;
22
23    match distances.get(&target) {
24        Some(&cost) => {
25            // Reconstruct path (simplified - just returns cost for now)
26            // Full path reconstruction would need parent tracking
27            Ok(Some(Path {
28                nodes: vec![source, target],
29                cost,
30            }))
31        }
32        None => Ok(None),
33    }
34}
35
36/// Run Dijkstra's algorithm from a source node
37///
38/// Returns a map from node to shortest distance from source.
39pub fn dijkstra<N, E>(
40    graph: &Graph<N, E>,
41    source: NodeId,
42    edge_cost: impl Fn(&E) -> Cost,
43) -> Result<HashMap<NodeId, Cost>> {
44    let mut distances: HashMap<NodeId, Cost> = HashMap::new();
45    let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
46
47    distances.insert(source, 0);
48    heap.push(Reverse((0, source)));
49
50    while let Some(Reverse((cost, node))) = heap.pop() {
51        // Skip if we've found a better path
52        if let Some(&best) = distances.get(&node) {
53            if cost > best {
54                continue;
55            }
56        }
57
58        // Explore neighbors
59        for edge in graph.edges(node) {
60            let edge_weight = edge_cost(edge.weight());
61            if edge_weight < 0 {
62                return Err(Error::invalid_input(
63                    "Dijkstra requires non-negative edge weights",
64                ));
65            }
66
67            let next = edge.target();
68            let next_cost = cost + edge_weight;
69
70            let is_better = distances.get(&next).map_or(true, |&d| next_cost < d);
71
72            if is_better {
73                distances.insert(next, next_cost);
74                heap.push(Reverse((next_cost, next)));
75            }
76        }
77    }
78
79    Ok(distances)
80}
81
82/// Find shortest paths from source to all nodes, returning predecessors
83pub fn dijkstra_with_paths<N, E>(
84    graph: &Graph<N, E>,
85    source: NodeId,
86    edge_cost: impl Fn(&E) -> Cost,
87) -> Result<(HashMap<NodeId, Cost>, HashMap<NodeId, NodeId>)> {
88    let mut distances: HashMap<NodeId, Cost> = HashMap::new();
89    let mut predecessors: HashMap<NodeId, NodeId> = HashMap::new();
90    let mut heap: BinaryHeap<Reverse<(Cost, NodeId)>> = BinaryHeap::new();
91
92    distances.insert(source, 0);
93    heap.push(Reverse((0, source)));
94
95    while let Some(Reverse((cost, node))) = heap.pop() {
96        if let Some(&best) = distances.get(&node) {
97            if cost > best {
98                continue;
99            }
100        }
101
102        for edge in graph.edges(node) {
103            let edge_weight = edge_cost(edge.weight());
104            if edge_weight < 0 {
105                return Err(Error::invalid_input(
106                    "Dijkstra requires non-negative edge weights",
107                ));
108            }
109
110            let next = edge.target();
111            let next_cost = cost + edge_weight;
112
113            let is_better = distances.get(&next).map_or(true, |&d| next_cost < d);
114
115            if is_better {
116                distances.insert(next, next_cost);
117                predecessors.insert(next, node);
118                heap.push(Reverse((next_cost, next)));
119            }
120        }
121    }
122
123    Ok((distances, predecessors))
124}
125
126/// Reconstruct path from predecessors map
127pub fn reconstruct_path(
128    predecessors: &HashMap<NodeId, NodeId>,
129    source: NodeId,
130    target: NodeId,
131    total_cost: Cost,
132) -> Path {
133    let mut path = vec![target];
134    let mut current = target;
135
136    while current != source {
137        if let Some(&pred) = predecessors.get(&current) {
138            path.push(pred);
139            current = pred;
140        } else {
141            // No path exists
142            return Path::empty();
143        }
144    }
145
146    path.reverse();
147    Path {
148        nodes: path,
149        cost: total_cost,
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use petgraph::graph::DiGraph;
157
158    #[test]
159    fn test_simple_dijkstra() {
160        let mut graph: DiGraph<(), i64> = DiGraph::new();
161        let a = graph.add_node(());
162        let b = graph.add_node(());
163        let c = graph.add_node(());
164
165        graph.add_edge(a, b, 1);
166        graph.add_edge(b, c, 2);
167        graph.add_edge(a, c, 5);
168
169        let distances = dijkstra(&graph, a, |&w| w).unwrap();
170
171        assert_eq!(distances[&a], 0);
172        assert_eq!(distances[&b], 1);
173        assert_eq!(distances[&c], 3); // a->b->c = 1+2 < a->c = 5
174    }
175
176    #[test]
177    fn test_unreachable() {
178        let mut graph: DiGraph<(), i64> = DiGraph::new();
179        let a = graph.add_node(());
180        let b = graph.add_node(());
181
182        // No edge from a to b
183        let distances = dijkstra(&graph, a, |&w| w).unwrap();
184
185        assert_eq!(distances.get(&a), Some(&0));
186        assert_eq!(distances.get(&b), None);
187    }
188}