use super::common::{GraphView, NodeId};
use std::collections::{HashMap, VecDeque, BinaryHeap};
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct PathResult {
pub source: NodeId,
pub target: NodeId,
pub path: Vec<NodeId>,
pub cost: f64,
}
pub fn bfs(
view: &GraphView,
source: NodeId,
target: NodeId,
) -> Option<PathResult> {
let source_idx = *view.node_to_index.get(&source)?;
let target_idx = *view.node_to_index.get(&target)?;
let mut queue = VecDeque::new();
let mut visited = HashMap::new();
queue.push_back(source_idx);
visited.insert(source_idx, None);
while let Some(current_idx) = queue.pop_front() {
if current_idx == target_idx {
let mut path = Vec::new();
let mut curr = Some(target_idx);
while let Some(idx) = curr {
path.push(view.index_to_node[idx]);
if let Some(parent) = visited.get(&idx) {
curr = *parent;
} else {
curr = None;
}
}
path.reverse();
return Some(PathResult {
source,
target,
cost: (path.len() - 1) as f64,
path,
});
}
for &next_idx in view.successors(current_idx) {
if !visited.contains_key(&next_idx) {
visited.insert(next_idx, Some(current_idx));
queue.push_back(next_idx);
}
}
}
None
}
#[derive(Copy, Clone, PartialEq)]
struct State {
cost: f64,
node_idx: usize,
}
impl Eq for State {}
impl Ord for State {
fn cmp(&self, other: &Self) -> Ordering {
other.cost.partial_cmp(&self.cost).unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for State {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub fn dijkstra(
view: &GraphView,
source: NodeId,
target: NodeId,
) -> Option<PathResult> {
let source_idx = *view.node_to_index.get(&source)?;
let target_idx = *view.node_to_index.get(&target)?;
let mut dist = HashMap::new();
let mut parent = HashMap::new();
let mut heap = BinaryHeap::new();
dist.insert(source_idx, 0.0);
heap.push(State { cost: 0.0, node_idx: source_idx });
while let Some(State { cost, node_idx }) = heap.pop() {
if node_idx == target_idx {
let mut path = Vec::new();
let mut curr = Some(target_idx);
while let Some(idx) = curr {
path.push(view.index_to_node[idx]);
curr = parent.get(&idx).cloned().flatten();
}
path.reverse();
return Some(PathResult {
source,
target,
path,
cost,
});
}
if cost > *dist.get(&node_idx).unwrap_or(&f64::INFINITY) {
continue;
}
let edges = view.successors(node_idx);
let weights = view.weights(node_idx);
for (i, &next_idx) in edges.iter().enumerate() {
let weight = if let Some(w) = weights {
w[i]
} else {
1.0
};
if weight < 0.0 { continue; }
let next_cost = cost + weight;
if next_cost < *dist.get(&next_idx).unwrap_or(&f64::INFINITY) {
dist.insert(next_idx, next_cost);
parent.insert(next_idx, Some(node_idx));
heap.push(State { cost: next_cost, node_idx: next_idx });
}
}
}
None
}
pub fn bfs_all_shortest_paths(
view: &GraphView,
source: NodeId,
target: NodeId,
) -> Vec<PathResult> {
let source_idx = match view.node_to_index.get(&source) {
Some(&idx) => idx,
None => return vec![],
};
let target_idx = match view.node_to_index.get(&target) {
Some(&idx) => idx,
None => return vec![],
};
if source_idx == target_idx {
return vec![PathResult {
source,
target,
path: vec![source],
cost: 0.0,
}];
}
let mut parents: HashMap<usize, Vec<usize>> = HashMap::new();
let mut distance: HashMap<usize, usize> = HashMap::new();
let mut queue = VecDeque::new();
queue.push_back(source_idx);
distance.insert(source_idx, 0);
let mut target_distance: Option<usize> = None;
while let Some(current) = queue.pop_front() {
let current_dist = distance[¤t];
if let Some(td) = target_distance {
if current_dist >= td {
continue;
}
}
for &next_idx in view.successors(current) {
let next_dist = current_dist + 1;
if let Some(&existing_dist) = distance.get(&next_idx) {
if next_dist == existing_dist {
parents.entry(next_idx).or_default().push(current);
}
} else {
distance.insert(next_idx, next_dist);
parents.insert(next_idx, vec![current]);
queue.push_back(next_idx);
if next_idx == target_idx {
target_distance = Some(next_dist);
}
}
}
}
if !distance.contains_key(&target_idx) {
return vec![];
}
let mut all_paths = Vec::new();
let mut stack: Vec<(usize, Vec<usize>)> = vec![(target_idx, vec![target_idx])];
while let Some((node, partial_path)) = stack.pop() {
if node == source_idx {
let mut path: Vec<NodeId> = partial_path.iter()
.rev()
.map(|&idx| view.index_to_node[idx])
.collect();
all_paths.push(PathResult {
source,
target,
cost: (path.len() - 1) as f64,
path,
});
continue;
}
if let Some(parent_list) = parents.get(&node) {
for &parent in parent_list {
let mut new_path = partial_path.clone();
new_path.push(parent);
stack.push((parent, new_path));
}
}
}
all_paths
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use crate::common::GraphView;
#[test]
fn test_bfs() {
let index_to_node = vec![1, 2, 3];
let mut node_to_index = HashMap::new();
node_to_index.insert(1, 0);
node_to_index.insert(2, 1);
node_to_index.insert(3, 2);
let mut outgoing = vec![vec![]; 3];
outgoing[0].push(1);
outgoing[1].push(2);
let view = GraphView::from_adjacency_list(
3,
index_to_node,
node_to_index,
outgoing,
vec![vec![]; 3],
None,
);
let result = bfs(&view, 1, 3).unwrap();
assert_eq!(result.path, vec![1, 2, 3]);
assert_eq!(result.cost, 2.0);
}
#[test]
fn test_dijkstra() {
let index_to_node = vec![1, 2, 3];
let mut node_to_index = HashMap::new();
node_to_index.insert(1, 0);
node_to_index.insert(2, 1);
node_to_index.insert(3, 2);
let mut outgoing = vec![vec![]; 3];
let mut weights = vec![vec![]; 3];
outgoing[0].push(1); weights[0].push(10.0);
outgoing[0].push(2); weights[0].push(50.0); outgoing[1].push(2); weights[1].push(5.0);
let view = GraphView::from_adjacency_list(
3,
index_to_node,
node_to_index,
outgoing,
vec![vec![]; 3],
Some(weights),
);
let result = dijkstra(&view, 1, 3).unwrap();
assert_eq!(result.path, vec![1, 2, 3]);
assert_eq!(result.cost, 15.0);
}
#[test]
fn test_bfs_all_shortest_paths() {
let index_to_node = vec![1, 2, 3, 4];
let mut node_to_index = HashMap::new();
node_to_index.insert(1, 0);
node_to_index.insert(2, 1);
node_to_index.insert(3, 2);
node_to_index.insert(4, 3);
let mut outgoing = vec![vec![]; 4];
outgoing[0] = vec![1, 2]; outgoing[1] = vec![3]; outgoing[2] = vec![3];
let view = GraphView::from_adjacency_list(
4,
index_to_node,
node_to_index,
outgoing,
vec![vec![]; 4],
None,
);
let results = bfs_all_shortest_paths(&view, 1, 4);
assert_eq!(results.len(), 2, "Should find 2 shortest paths in diamond graph");
for r in &results {
assert_eq!(r.cost, 2.0);
assert_eq!(r.path.len(), 3);
assert_eq!(r.path[0], 1);
assert_eq!(r.path[2], 4);
}
}
#[test]
fn test_bfs_all_shortest_paths_no_path() {
let index_to_node = vec![1, 2];
let mut node_to_index = HashMap::new();
node_to_index.insert(1, 0);
node_to_index.insert(2, 1);
let view = GraphView::from_adjacency_list(
2,
index_to_node,
node_to_index,
vec![vec![], vec![]],
vec![vec![], vec![]],
None,
);
let results = bfs_all_shortest_paths(&view, 1, 2);
assert!(results.is_empty());
}
}