use crate::node::NodeIndex;
use crate::plugins::algorithm::{
AlgorithmData, AlgorithmResult, GraphAlgorithm, PluginContext, PluginInfo,
};
use crate::vgi::{Capability, GraphType, VgiResult, VirtualGraph};
use std::any::Any;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
pub struct DijkstraPlugin {
source: Option<usize>,
target: Option<usize>,
}
pub type DijkstraResult = (Vec<(usize, f64)>, Vec<(usize, Option<usize>)>);
impl DijkstraPlugin {
pub fn new(source: Option<usize>, target: Option<usize>) -> Self {
Self { source, target }
}
pub fn from_source(source: usize) -> Self {
Self {
source: Some(source),
target: None,
}
}
pub fn from_source_to_target(source: usize, target: usize) -> Self {
Self {
source: Some(source),
target: Some(target),
}
}
pub fn compute<G>(&self, graph: &G, source: usize) -> VgiResult<DijkstraResult>
where
G: VirtualGraph + ?Sized,
{
let n = graph.node_count();
if n == 0 {
return Ok((Vec::new(), Vec::new()));
}
let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let mut node_id_to_pos: Vec<usize> = vec![usize::MAX; n];
for (pos, idx) in node_indices.iter().enumerate() {
node_id_to_pos[idx.index()] = pos;
}
let mut distances: Vec<f64> = vec![f64::INFINITY; n];
let mut predecessors: Vec<Option<usize>> = vec![None; n];
let source_pos = node_id_to_pos[source];
if source_pos == usize::MAX {
return Ok((Vec::new(), Vec::new()));
}
distances[source_pos] = 0.0;
let mut heap = BinaryHeap::new();
heap.push(DijkstraNode {
pos: source_pos,
distance: 0.0,
});
while let Some(DijkstraNode { pos, distance }) = heap.pop() {
if distance > distances[pos] {
continue;
}
if let Some(target) = self.target {
let target_pos = node_id_to_pos[target];
if target_pos != usize::MAX && pos == target_pos {
break;
}
}
let node_id = node_id_to_pos[pos];
let node_idx = NodeIndex::new_public(node_id);
for neighbor_idx in graph.neighbors(node_idx) {
let neighbor_id = neighbor_idx.index();
let neighbor_pos = node_id_to_pos[neighbor_id];
if neighbor_pos == usize::MAX {
continue;
}
let weight = 1.0;
if weight < 0.0 {
continue;
}
let new_dist = distance + weight;
if new_dist < distances[neighbor_pos] {
distances[neighbor_pos] = new_dist;
predecessors[neighbor_pos] = Some(pos);
heap.push(DijkstraNode {
pos: neighbor_pos,
distance: new_dist,
});
}
}
}
let distances_result: Vec<(usize, f64)> = node_indices
.iter()
.zip(distances.iter())
.map(|(idx, &dist)| (idx.index(), dist))
.collect();
let predecessors_result: Vec<(usize, Option<usize>)> = node_indices
.iter()
.zip(predecessors.iter())
.map(|(idx, &pred)| (idx.index(), pred.map(|p| node_id_to_pos[p])))
.collect();
Ok((distances_result, predecessors_result))
}
pub fn reconstruct_path(
&self,
predecessors: &[(usize, Option<usize>)],
source: usize,
target: usize,
) -> Option<Vec<usize>> {
let n = predecessors.len();
if n == 0 {
return None;
}
let max_id = predecessors.iter().map(|(id, _)| *id).max().unwrap_or(0);
let mut id_to_pred: Vec<Option<usize>> = vec![None; max_id + 1];
for &(id, pred) in predecessors {
id_to_pred[id] = pred;
}
if target > max_id || id_to_pred[target].is_none() && target != source {
return None;
}
let mut path = vec![target];
let mut current = target;
while let Some(prev) = id_to_pred[current] {
if prev == source {
path.push(source);
path.reverse();
return Some(path);
}
path.push(prev);
current = prev;
}
if path.first() == Some(&source) {
path.reverse();
Some(path)
} else {
None }
}
}
#[derive(Clone, Copy, Debug)]
struct DijkstraNode {
pos: usize,
distance: f64,
}
impl PartialEq for DijkstraNode {
fn eq(&self, other: &Self) -> bool {
self.pos == other.pos && (self.distance - other.distance).abs() < 1e-10
}
}
impl Eq for DijkstraNode {}
impl PartialOrd for DijkstraNode {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DijkstraNode {
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
impl GraphAlgorithm for DijkstraPlugin {
fn info(&self) -> PluginInfo {
PluginInfo::new("dijkstra", "1.0.0", "Dijkstra 单源最短路径算法")
.with_author("God-Graph Team")
.with_required_capabilities(&[Capability::IncrementalUpdate])
.with_supported_graph_types(&[GraphType::Directed, GraphType::Undirected])
.with_tags(&["shortest-path", "weighted", "single-source"])
}
fn execute<G>(&self, ctx: &mut PluginContext<G>) -> VgiResult<AlgorithmResult>
where
G: VirtualGraph + ?Sized,
{
let source = ctx.get_config_as("source", self.source.unwrap_or(0));
let target_str = ctx.get_config_or("target", "");
let target = if target_str.is_empty() {
self.target
} else {
target_str.parse().ok()
};
let plugin = DijkstraPlugin::new(Some(source), target);
ctx.report_progress(0.1);
let (distances, predecessors) = plugin.compute(ctx.graph, source)?;
ctx.report_progress(0.8);
let result = if let Some(target) = target {
let path = plugin.reconstruct_path(&predecessors, source, target);
let path_data = path.clone().unwrap_or_default();
let target_dist = distances.iter().find(|(id, _)| *id == target).map(|(_, d)| *d).unwrap_or(f64::INFINITY);
AlgorithmResult::new("dijkstra_path", AlgorithmData::NodeList(path_data))
.with_metadata(
"distance",
target_dist.to_string(),
)
.with_metadata("reachable", path.is_some().to_string())
} else {
AlgorithmResult::new(
"dijkstra_distances",
AlgorithmData::NodeValues(distances.into_iter().collect()),
)
}
.with_metadata("source", source.to_string())
.with_metadata("algorithm", "dijkstra");
ctx.report_progress(1.0);
Ok(result)
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::graph::traits::GraphOps;
fn create_weighted_graph() -> Graph<String, f64> {
let mut graph = Graph::<String, f64>::directed();
let a = graph.add_node("A".to_string()).unwrap();
let b = graph.add_node("B".to_string()).unwrap();
let c = graph.add_node("C".to_string()).unwrap();
let d = graph.add_node("D".to_string()).unwrap();
let e = graph.add_node("E".to_string()).unwrap();
graph.add_edge(a, b, 4.0).unwrap();
graph.add_edge(a, c, 2.0).unwrap();
graph.add_edge(b, c, 1.0).unwrap();
graph.add_edge(b, d, 5.0).unwrap();
graph.add_edge(c, d, 8.0).unwrap();
graph.add_edge(c, e, 10.0).unwrap();
graph.add_edge(d, e, 2.0).unwrap();
graph
}
#[test]
fn test_dijkstra_basic() {
let graph = create_weighted_graph();
let plugin = DijkstraPlugin::from_source(0);
let (distances, _) = plugin.compute(&graph, 0).unwrap();
assert_eq!(distances.iter().find(|(id, _)| *id == 0), Some(&(0, 0.0)));
assert_eq!(distances.iter().find(|(id, _)| *id == 1), Some(&(1, 1.0)));
assert_eq!(distances.iter().find(|(id, _)| *id == 2), Some(&(2, 1.0)));
assert_eq!(distances.iter().find(|(id, _)| *id == 3), Some(&(3, 2.0)));
assert_eq!(distances.iter().find(|(id, _)| *id == 4), Some(&(4, 2.0)));
}
#[test]
fn test_dijkstra_path_reconstruction() {
let graph = create_weighted_graph();
let plugin = DijkstraPlugin::from_source_to_target(0, 4);
let (_, predecessors) = plugin.compute(&graph, 0).unwrap();
let path = plugin.reconstruct_path(&predecessors, 0, 4);
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.first(), Some(&0)); assert_eq!(path.last(), Some(&4)); }
#[test]
fn test_dijkstra_empty_graph() {
let graph = Graph::<String, f64>::directed();
let plugin = DijkstraPlugin::from_source(0);
let result = plugin.compute(&graph, 0);
assert!(result.is_ok());
let (distances, predecessors) = result.unwrap();
assert!(distances.is_empty());
assert!(predecessors.is_empty());
}
#[test]
fn test_dijkstra_disconnected() {
let mut graph = Graph::<String, f64>::directed();
let a = graph.add_node("A".to_string()).unwrap();
let b = graph.add_node("B".to_string()).unwrap();
let _c = graph.add_node("C".to_string()).unwrap();
graph.add_edge(a, b, 1.0).unwrap();
let plugin = DijkstraPlugin::from_source(0);
let (distances, _) = plugin.compute(&graph, 0).unwrap();
assert_eq!(distances.iter().find(|(id, _)| *id == 0), Some(&(0, 0.0)));
assert_eq!(distances.iter().find(|(id, _)| *id == 1), Some(&(1, 1.0)));
assert_eq!(distances.iter().find(|(id, _)| *id == 2), Some(&(2, f64::INFINITY))); }
#[test]
fn test_dijkstra_plugin_info() {
let plugin = DijkstraPlugin::from_source(0);
let info = plugin.info();
assert_eq!(info.name, "dijkstra");
assert_eq!(info.version, "1.0.0");
assert!(info.tags.contains(&"shortest-path".to_string()));
}
}