graphmind_graph_algorithms/
pathfinding.rs1use super::common::{GraphView, NodeId};
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap, VecDeque};
8
9#[derive(Debug, Clone)]
11pub struct PathResult {
12 pub source: NodeId,
13 pub target: NodeId,
14 pub path: Vec<NodeId>,
15 pub cost: f64,
16}
17
18pub fn bfs(view: &GraphView, source: NodeId, target: NodeId) -> Option<PathResult> {
20 let source_idx = *view.node_to_index.get(&source)?;
21 let target_idx = *view.node_to_index.get(&target)?;
22
23 let mut queue = VecDeque::new();
24 let mut visited = HashMap::new(); queue.push_back(source_idx);
27 visited.insert(source_idx, None);
28
29 while let Some(current_idx) = queue.pop_front() {
30 if current_idx == target_idx {
31 let mut path = Vec::new();
33 let mut curr = Some(target_idx);
34 while let Some(idx) = curr {
35 path.push(view.index_to_node[idx]);
36 if let Some(parent) = visited.get(&idx) {
37 curr = *parent;
38 } else {
39 curr = None;
40 }
41 }
42 path.reverse();
43 return Some(PathResult {
44 source,
45 target,
46 cost: (path.len() - 1) as f64,
47 path,
48 });
49 }
50
51 for &next_idx in view.successors(current_idx) {
52 if let std::collections::hash_map::Entry::Vacant(e) = visited.entry(next_idx) {
53 e.insert(Some(current_idx));
54 queue.push_back(next_idx);
55 }
56 }
57 }
58
59 None
60}
61
62#[derive(Copy, Clone, PartialEq)]
64struct State {
65 cost: f64,
66 node_idx: usize,
67}
68
69impl Eq for State {}
70
71impl Ord for State {
72 fn cmp(&self, other: &Self) -> Ordering {
73 other
75 .cost
76 .partial_cmp(&self.cost)
77 .unwrap_or(Ordering::Equal)
78 }
79}
80
81impl PartialOrd for State {
82 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
83 Some(self.cmp(other))
84 }
85}
86
87pub fn dijkstra(view: &GraphView, source: NodeId, target: NodeId) -> Option<PathResult> {
91 let source_idx = *view.node_to_index.get(&source)?;
92 let target_idx = *view.node_to_index.get(&target)?;
93
94 let mut dist = HashMap::new();
95 let mut parent = HashMap::new();
96 let mut heap = BinaryHeap::new();
97
98 dist.insert(source_idx, 0.0);
99 heap.push(State {
100 cost: 0.0,
101 node_idx: source_idx,
102 });
103
104 while let Some(State { cost, node_idx }) = heap.pop() {
105 if node_idx == target_idx {
106 let mut path = Vec::new();
108 let mut curr = Some(target_idx);
109 while let Some(idx) = curr {
110 path.push(view.index_to_node[idx]);
111 curr = parent.get(&idx).cloned().flatten();
112 }
113 path.reverse();
114 return Some(PathResult {
115 source,
116 target,
117 path,
118 cost,
119 });
120 }
121
122 if cost > *dist.get(&node_idx).unwrap_or(&f64::INFINITY) {
123 continue;
124 }
125
126 let edges = view.successors(node_idx);
127 let weights = view.weights(node_idx);
128
129 for (i, &next_idx) in edges.iter().enumerate() {
130 let weight = if let Some(w) = weights { w[i] } else { 1.0 };
131
132 if weight < 0.0 {
133 continue;
134 }
135
136 let next_cost = cost + weight;
137
138 if next_cost < *dist.get(&next_idx).unwrap_or(&f64::INFINITY) {
139 dist.insert(next_idx, next_cost);
140 parent.insert(next_idx, Some(node_idx));
141 heap.push(State {
142 cost: next_cost,
143 node_idx: next_idx,
144 });
145 }
146 }
147 }
148
149 None
150}
151
152pub fn bfs_all_shortest_paths(view: &GraphView, source: NodeId, target: NodeId) -> Vec<PathResult> {
154 let source_idx = match view.node_to_index.get(&source) {
155 Some(&idx) => idx,
156 None => return vec![],
157 };
158 let target_idx = match view.node_to_index.get(&target) {
159 Some(&idx) => idx,
160 None => return vec![],
161 };
162
163 if source_idx == target_idx {
164 return vec![PathResult {
165 source,
166 target,
167 path: vec![source],
168 cost: 0.0,
169 }];
170 }
171
172 let mut parents: HashMap<usize, Vec<usize>> = HashMap::new();
174 let mut distance: HashMap<usize, usize> = HashMap::new();
175 let mut queue = VecDeque::new();
176
177 queue.push_back(source_idx);
178 distance.insert(source_idx, 0);
179
180 let mut target_distance: Option<usize> = None;
181
182 while let Some(current) = queue.pop_front() {
183 let current_dist = distance[¤t];
184
185 if let Some(td) = target_distance {
187 if current_dist >= td {
188 continue;
189 }
190 }
191
192 for &next_idx in view.successors(current) {
193 let next_dist = current_dist + 1;
194
195 if let Some(&existing_dist) = distance.get(&next_idx) {
196 if next_dist == existing_dist {
197 parents.entry(next_idx).or_default().push(current);
199 }
200 } else {
202 distance.insert(next_idx, next_dist);
204 parents.insert(next_idx, vec![current]);
205 queue.push_back(next_idx);
206
207 if next_idx == target_idx {
208 target_distance = Some(next_dist);
209 }
210 }
211 }
212 }
213
214 if !distance.contains_key(&target_idx) {
216 return vec![];
217 }
218
219 let mut all_paths = Vec::new();
221 let mut stack: Vec<(usize, Vec<usize>)> = vec![(target_idx, vec![target_idx])];
222
223 while let Some((node, partial_path)) = stack.pop() {
224 if node == source_idx {
225 let path: Vec<NodeId> = partial_path
226 .iter()
227 .rev()
228 .map(|&idx| view.index_to_node[idx])
229 .collect();
230 all_paths.push(PathResult {
231 source,
232 target,
233 cost: (path.len() - 1) as f64,
234 path,
235 });
236 continue;
237 }
238
239 if let Some(parent_list) = parents.get(&node) {
240 for &parent in parent_list {
241 let mut new_path = partial_path.clone();
242 new_path.push(parent);
243 stack.push((parent, new_path));
244 }
245 }
246 }
247
248 all_paths
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::common::GraphView;
255 use std::collections::HashMap;
256
257 #[test]
258 fn test_bfs() {
259 let index_to_node = vec![1, 2, 3];
261 let mut node_to_index = HashMap::new();
262 node_to_index.insert(1, 0);
263 node_to_index.insert(2, 1);
264 node_to_index.insert(3, 2);
265
266 let mut outgoing = vec![vec![]; 3];
267 outgoing[0].push(1);
268 outgoing[1].push(2);
269
270 let view = GraphView::from_adjacency_list(
271 3,
272 index_to_node,
273 node_to_index,
274 outgoing,
275 vec![vec![]; 3],
276 None,
277 );
278
279 let result = bfs(&view, 1, 3).unwrap();
280 assert_eq!(result.path, vec![1, 2, 3]);
281 assert_eq!(result.cost, 2.0);
282 }
283
284 #[test]
285 fn test_dijkstra() {
286 let index_to_node = vec![1, 2, 3];
288 let mut node_to_index = HashMap::new();
289 node_to_index.insert(1, 0);
290 node_to_index.insert(2, 1);
291 node_to_index.insert(3, 2);
292
293 let mut outgoing = vec![vec![]; 3];
294 let mut weights = vec![vec![]; 3];
295
296 outgoing[0].push(1);
297 weights[0].push(10.0);
298 outgoing[0].push(2);
299 weights[0].push(50.0); outgoing[1].push(2);
301 weights[1].push(5.0);
302
303 let view = GraphView::from_adjacency_list(
304 3,
305 index_to_node,
306 node_to_index,
307 outgoing,
308 vec![vec![]; 3],
309 Some(weights),
310 );
311
312 let result = dijkstra(&view, 1, 3).unwrap();
313 assert_eq!(result.path, vec![1, 2, 3]);
314 assert_eq!(result.cost, 15.0);
315 }
316
317 #[test]
318 fn test_bfs_all_shortest_paths() {
319 let index_to_node = vec![1, 2, 3, 4];
321 let mut node_to_index = HashMap::new();
322 node_to_index.insert(1, 0);
323 node_to_index.insert(2, 1);
324 node_to_index.insert(3, 2);
325 node_to_index.insert(4, 3);
326
327 let mut outgoing = vec![vec![]; 4];
328 outgoing[0] = vec![1, 2]; outgoing[1] = vec![3]; outgoing[2] = vec![3]; let view = GraphView::from_adjacency_list(
333 4,
334 index_to_node,
335 node_to_index,
336 outgoing,
337 vec![vec![]; 4],
338 None,
339 );
340
341 let results = bfs_all_shortest_paths(&view, 1, 4);
342 assert_eq!(
343 results.len(),
344 2,
345 "Should find 2 shortest paths in diamond graph"
346 );
347 for r in &results {
348 assert_eq!(r.cost, 2.0);
349 assert_eq!(r.path.len(), 3);
350 assert_eq!(r.path[0], 1);
351 assert_eq!(r.path[2], 4);
352 }
353 }
354
355 #[test]
356 fn test_bfs_all_shortest_paths_no_path() {
357 let index_to_node = vec![1, 2];
358 let mut node_to_index = HashMap::new();
359 node_to_index.insert(1, 0);
360 node_to_index.insert(2, 1);
361
362 let view = GraphView::from_adjacency_list(
363 2,
364 index_to_node,
365 node_to_index,
366 vec![vec![], vec![]],
367 vec![vec![], vec![]],
368 None,
369 );
370
371 let results = bfs_all_shortest_paths(&view, 1, 2);
372 assert!(results.is_empty());
373 }
374}