1use crate::errors::{GraphError, GraphResult};
6use crate::graph::traits::{GraphBase, GraphQuery};
7use crate::graph::Graph;
8use crate::node::NodeIndex;
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap};
11
12pub fn dijkstra<T, E, F>(
31 graph: &Graph<T, E>,
32 source: NodeIndex,
33 mut get_weight: F,
34) -> GraphResult<HashMap<NodeIndex, f64>>
35where
36 F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
37{
38 for edge in graph.edges() {
40 let u = edge.source();
41 let v = edge.target();
42 let weight = get_weight(u, v, edge.data());
43 if weight < 0.0 {
44 return Err(GraphError::NegativeWeight {
45 from: u.index(),
46 to: v.index(),
47 weight,
48 });
49 }
50 }
51
52 struct State {
54 node: NodeIndex,
55 distance: f64,
56 }
57
58 impl PartialEq for State {
59 fn eq(&self, other: &Self) -> bool {
60 self.distance == other.distance
61 }
62 }
63
64 impl Eq for State {}
65
66 impl PartialOrd for State {
67 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
68 Some(self.cmp(other))
69 }
70 }
71
72 impl Ord for State {
73 fn cmp(&self, other: &Self) -> Ordering {
74 other.distance.total_cmp(&self.distance)
75 }
76 }
77
78 let mut distances: HashMap<NodeIndex, f64> = HashMap::new();
79 let mut heap = BinaryHeap::new();
80
81 distances.insert(source, 0.0);
82 heap.push(State {
83 node: source,
84 distance: 0.0,
85 });
86
87 while let Some(State { node, distance }) = heap.pop() {
88 if distance > *distances.get(&node).unwrap_or(&f64::INFINITY) {
90 continue;
91 }
92
93 for neighbor in graph.neighbors(node) {
94 let edge_data = graph.get_edge_by_nodes(node, neighbor)?;
95 let weight = get_weight(node, neighbor, edge_data);
96 let new_distance = distance + weight;
97
98 if new_distance < *distances.get(&neighbor).unwrap_or(&f64::INFINITY) {
99 distances.insert(neighbor, new_distance);
100 heap.push(State {
101 node: neighbor,
102 distance: new_distance,
103 });
104 }
105 }
106 }
107
108 Ok(distances)
109}
110
111pub fn bellman_ford<T, E, F>(
120 graph: &Graph<T, E>,
121 source: NodeIndex,
122 mut get_weight: F,
123) -> Result<HashMap<NodeIndex, f64>, GraphError>
124where
125 F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
126{
127 let mut distances: HashMap<NodeIndex, f64> = HashMap::new();
128
129 for node in graph.nodes() {
131 distances.insert(node.index(), f64::INFINITY);
132 }
133 distances.insert(source, 0.0);
134
135 let n = graph.node_count();
136
137 for _ in 0..n - 1 {
139 for edge in graph.edges() {
140 let u = edge.source();
141 let v = edge.target();
142 let w = get_weight(u, v, edge.data());
143
144 if distances.get(&u) != Some(&f64::INFINITY) {
145 let new_dist = distances[&u] + w;
146 if new_dist < *distances.get(&v).unwrap_or(&f64::INFINITY) {
147 distances.insert(v, new_dist);
148 }
149 }
150 }
151 }
152
153 for edge in graph.edges() {
155 let u = edge.source();
156 let v = edge.target();
157 let w = get_weight(u, v, edge.data());
158
159 if distances.get(&u) != Some(&f64::INFINITY)
160 && distances[&u] + w < *distances.get(&v).unwrap_or(&f64::INFINITY)
161 {
162 return Err(GraphError::NegativeCycle);
163 }
164 }
165
166 Ok(distances)
167}
168
169pub fn astar<T, E, F, H>(
180 graph: &Graph<T, E>,
181 start: NodeIndex,
182 goal: NodeIndex,
183 mut get_weight: F,
184 mut heuristic: H,
185) -> GraphResult<(f64, Vec<NodeIndex>)>
186where
187 F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
188 H: FnMut(NodeIndex) -> f64,
189{
190 #[derive(Debug)]
191 struct State {
192 node: NodeIndex,
193 f_score: f64,
194 }
195
196 impl PartialEq for State {
197 fn eq(&self, other: &Self) -> bool {
198 self.f_score == other.f_score
199 }
200 }
201
202 impl Eq for State {}
203
204 impl PartialOrd for State {
205 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
206 Some(self.cmp(other))
207 }
208 }
209
210 impl Ord for State {
211 fn cmp(&self, other: &Self) -> Ordering {
212 other.f_score.total_cmp(&self.f_score)
213 }
214 }
215
216 let mut g_scores: HashMap<NodeIndex, f64> = HashMap::new();
217 let mut came_from: HashMap<NodeIndex, NodeIndex> = HashMap::new();
218 let mut heap = BinaryHeap::new();
219
220 g_scores.insert(start, 0.0);
221 heap.push(State {
222 node: start,
223 f_score: heuristic(start),
224 });
225
226 while let Some(State { node, .. }) = heap.pop() {
227 if node == goal {
228 let mut path = vec![goal];
230 let mut current = goal;
231 while let Some(&prev) = came_from.get(¤t) {
232 path.push(prev);
233 current = prev;
234 }
235 path.reverse();
236 return Ok((*g_scores.get(&goal).unwrap_or(&0.0), path));
237 }
238
239 let current_g = *g_scores.get(&node).unwrap_or(&f64::INFINITY);
240
241 for neighbor in graph.neighbors(node) {
242 let edge_data = graph.get_edge_by_nodes(node, neighbor)?;
243 let weight = get_weight(node, neighbor, edge_data);
244 let tentative_g = current_g + weight;
245
246 if tentative_g < *g_scores.get(&neighbor).unwrap_or(&f64::INFINITY) {
247 came_from.insert(neighbor, node);
248 g_scores.insert(neighbor, tentative_g);
249 let f_score = tentative_g + heuristic(neighbor);
250 heap.push(State {
251 node: neighbor,
252 f_score,
253 });
254 }
255 }
256 }
257
258 Err(GraphError::NodeNotFound {
259 index: goal.index(),
260 })
261}
262
263pub fn floyd_warshall<T, E, F>(
291 graph: &Graph<T, E>,
292 mut get_weight: F,
293) -> Result<HashMap<(NodeIndex, NodeIndex), f64>, GraphError>
294where
295 F: FnMut(NodeIndex, NodeIndex, &E) -> f64,
296{
297 let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
298 let n = node_indices.len();
299
300 if n == 0 {
301 return Ok(HashMap::new());
302 }
303
304 let index_to_node = &node_indices;
306 let node_to_index: std::collections::HashMap<usize, usize> = node_indices
307 .iter()
308 .enumerate()
309 .map(|(i, ni)| (ni.index(), i))
310 .collect();
311
312 const INF: f64 = f64::INFINITY;
314 let mut dist = vec![vec![INF; n]; n];
315
316 for (i, row) in dist.iter_mut().enumerate().take(n) {
318 row[i] = 0.0;
319 }
320
321 for edge in graph.edges() {
323 let u = edge.source();
324 let v = edge.target();
325 if let (Some(&i), Some(&j)) = (node_to_index.get(&u.index()), node_to_index.get(&v.index()))
326 {
327 let weight = get_weight(u, v, edge.data());
328 dist[i][j] = dist[i][j].min(weight);
329 }
330 }
331
332 for k in 0..n {
334 for i in 0..n {
335 for j in 0..n {
336 if dist[i][k] != INF && dist[k][j] != INF {
337 let new_dist = dist[i][k] + dist[k][j];
338 if new_dist < dist[i][j] {
339 dist[i][j] = new_dist;
340 }
341 }
342 }
343 }
344 }
345
346 #[allow(clippy::needless_range_loop)]
348 for i in 0..n {
349 if dist[i][i] < 0.0 {
350 return Err(GraphError::NegativeCycle);
351 }
352 }
353
354 let mut result = HashMap::with_capacity(n * n);
356 for (i, row) in dist.iter().enumerate().take(n) {
357 for (j, &value) in row.iter().enumerate().take(n) {
358 if value != INF {
359 result.insert((index_to_node[i], index_to_node[j]), value);
360 }
361 }
362 }
363
364 Ok(result)
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use crate::graph::builders::GraphBuilder;
371
372 #[test]
373 fn test_dijkstra_basic() {
374 let graph = GraphBuilder::directed()
375 .with_nodes(vec!["A", "B", "C", "D"])
376 .with_edges(vec![
377 (0, 1, 1.0),
378 (0, 2, 4.0),
379 (1, 2, 2.0),
380 (1, 3, 5.0),
381 (2, 3, 1.0),
382 ])
383 .build()
384 .unwrap();
385
386 let start = NodeIndex::new(0, 1);
387 let distances = dijkstra(&graph, start, |_, _, _| 1.0).unwrap();
388
389 assert!(distances.contains_key(&start));
390 }
391
392 #[test]
393 fn test_dijkstra_negative_weights() {
394 let graph = GraphBuilder::directed()
395 .with_nodes(vec!["A", "B", "C"])
396 .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (0, 2, 3.0)])
397 .build()
398 .unwrap();
399
400 let start = NodeIndex::new(0, 1);
401 let result = dijkstra(&graph, start, |_, _, w| *w);
402
403 assert!(matches!(result, Err(GraphError::NegativeWeight { .. })));
404 }
405
406 #[test]
407 fn test_bellman_ford_basic() {
408 let graph = GraphBuilder::directed()
409 .with_nodes(vec!["A", "B", "C", "D"])
410 .with_edges(vec![
411 (0, 1, 4.0),
412 (0, 2, 2.0),
413 (1, 2, 1.0),
414 (1, 3, 5.0),
415 (2, 3, 3.0),
416 ])
417 .build()
418 .unwrap();
419
420 let source = NodeIndex::new(0, 1);
421 let distances = bellman_ford(&graph, source, |_, _, w| *w).unwrap();
422
423 assert_eq!(distances.get(&source), Some(&0.0));
424 assert!(distances.len() == 4);
425 }
426
427 #[test]
428 fn test_bellman_ford_negative_weights() {
429 let graph = GraphBuilder::directed()
430 .with_nodes(vec!["A", "B", "C"])
431 .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (0, 2, 3.0)])
432 .build()
433 .unwrap();
434
435 let source = NodeIndex::new(0, 1);
436 let distances = bellman_ford(&graph, source, |_, _, w| *w).unwrap();
437
438 assert_eq!(distances.get(&NodeIndex::new(2, 1)), Some(&-1.0));
440 }
441
442 #[test]
443 fn test_bellman_ford_negative_cycle() {
444 let graph = GraphBuilder::directed()
445 .with_nodes(vec!["A", "B", "C"])
446 .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (2, 0, -3.0)])
447 .build()
448 .unwrap();
449
450 let source = NodeIndex::new(0, 1);
451 let result = bellman_ford(&graph, source, |_, _, w| *w);
452
453 assert!(matches!(result, Err(GraphError::NegativeCycle)));
454 }
455
456 #[test]
457 fn test_astar_basic() {
458 let graph = GraphBuilder::directed()
459 .with_nodes(vec!["A", "B", "C", "D"])
460 .with_edges(vec![
461 (0, 1, 1.0),
462 (0, 2, 4.0),
463 (1, 2, 2.0),
464 (1, 3, 5.0),
465 (2, 3, 1.0),
466 ])
467 .build()
468 .unwrap();
469
470 let start = NodeIndex::new(0, 1);
471 let goal = NodeIndex::new(3, 1);
472
473 let (distance, path) = astar(&graph, start, goal, |_, _, _| 1.0, |_| 0.0).unwrap();
475
476 assert!(distance > 0.0);
477 assert!(!path.is_empty());
478 assert_eq!(path.first(), Some(&start));
479 assert_eq!(path.last(), Some(&goal));
480 }
481
482 #[test]
483 fn test_floyd_warshall_basic() {
484 let graph = GraphBuilder::directed()
485 .with_nodes(vec!["A", "B", "C", "D"])
486 .with_edges(vec![
487 (0, 1, 1.0),
488 (0, 2, 4.0),
489 (1, 2, 2.0),
490 (1, 3, 5.0),
491 (2, 3, 1.0),
492 ])
493 .build()
494 .unwrap();
495
496 let distances = floyd_warshall(&graph, |_, _, w| *w).unwrap();
497
498 let nodes: Vec<_> = graph.nodes().collect();
500 assert_eq!(
501 distances.get(&(nodes[0].index(), nodes[3].index())),
502 Some(&4.0)
503 ); }
505
506 #[test]
507 fn test_floyd_warshall_negative_weights() {
508 let graph = GraphBuilder::directed()
509 .with_nodes(vec!["A", "B", "C"])
510 .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (0, 2, 3.0)])
511 .build()
512 .unwrap();
513
514 let distances = floyd_warshall(&graph, |_, _, w| *w).unwrap();
515
516 let nodes: Vec<_> = graph.nodes().collect();
517 assert_eq!(
519 distances.get(&(nodes[0].index(), nodes[2].index())),
520 Some(&-1.0)
521 );
522 }
523
524 #[test]
525 fn test_floyd_warshall_negative_cycle() {
526 let graph = GraphBuilder::directed()
527 .with_nodes(vec!["A", "B", "C"])
528 .with_edges(vec![(0, 1, 1.0), (1, 2, -2.0), (2, 0, -3.0)])
529 .build()
530 .unwrap();
531
532 let result = floyd_warshall(&graph, |_, _, w| *w);
533 assert!(matches!(result, Err(GraphError::NegativeCycle)));
534 }
535
536 #[test]
537 fn test_floyd_warshall_empty_graph() {
538 let graph: Graph<i32, f64> = GraphBuilder::directed().build().unwrap();
539 let distances = floyd_warshall(&graph, |_, _, _: &f64| 1.0).unwrap();
540 assert!(distances.is_empty());
541 }
542}