use crate::node::NodeIndex;
use crate::plugins::algorithm::{
AlgorithmData, AlgorithmResult, FastHashMap, GraphAlgorithm, PluginContext, PluginInfo,
};
use crate::vgi::{Capability, GraphType, VgiResult, VirtualGraph};
use std::any::Any;
type BellmanFordResult = (Vec<f64>, Vec<Option<usize>>);
pub struct BellmanFordPlugin {
source: Option<usize>,
target: Option<usize>,
}
impl BellmanFordPlugin {
pub fn new(source: Option<usize>, target: Option<usize>) -> Self {
Self { source, target }
}
pub fn from_source(source: usize) -> Self {
Self {
source: Some(source),
target: None,
}
}
pub fn from_source_to_target(source: usize, target: usize) -> Self {
Self {
source: Some(source),
target: Some(target),
}
}
pub fn compute<G>(&self, graph: &G, source: usize) -> VgiResult<BellmanFordResult>
where
G: VirtualGraph + ?Sized,
{
let n = graph.node_count();
if n == 0 {
return Ok((Vec::new(), Vec::new()));
}
let node_indices: Vec<NodeIndex> = graph.nodes().map(|n| n.index()).collect();
let node_id_to_pos: Vec<usize> = {
let mut map = vec![usize::MAX; n.max(1)];
for (i, idx) in node_indices.iter().enumerate() {
if idx.index() < map.len() {
map[idx.index()] = i;
}
}
map
};
let source_pos = node_id_to_pos.get(source).copied().unwrap_or(usize::MAX);
if source_pos == usize::MAX || source_pos >= n {
return Err(crate::vgi::VgiError::Internal {
message: format!("Source node {} not found", source),
});
}
let mut distances: Vec<f64> = vec![f64::INFINITY; n];
let mut predecessors: Vec<Option<usize>> = vec![None; n];
distances[source_pos] = 0.0;
let mut edges: Vec<(usize, usize, f64)> = Vec::new();
for node_ref in graph.nodes() {
let from_idx = node_ref.index().index();
let from_pos = node_id_to_pos[from_idx];
if from_pos == usize::MAX {
continue;
}
let from_node_idx = NodeIndex::new_public(from_idx);
for edge_idx in graph.incident_edges(from_node_idx) {
let weight = 1.0;
if let Ok((_from, to)) = graph.edge_endpoints(edge_idx) {
let to_idx = to.index();
let to_pos = node_id_to_pos[to_idx];
if to_pos != usize::MAX {
edges.push((from_pos, to_pos, weight));
}
}
}
}
for _ in 0..(n - 1) {
let mut changed = false;
for (from_pos, to_pos, weight) in &edges {
let from_dist = distances[*from_pos];
if from_dist.is_finite() && from_dist + weight < distances[*to_pos] {
distances[*to_pos] = from_dist + weight;
predecessors[*to_pos] = Some(*from_pos);
changed = true;
}
}
if !changed {
break;
}
}
for (from_pos, to_pos, weight) in &edges {
let from_dist = distances[*from_pos];
if from_dist.is_finite() && from_dist + weight < distances[*to_pos] {
return Err(crate::vgi::VgiError::ValidationError {
message: "Graph contains a negative weight cycle".to_string(),
});
}
}
Ok((distances, predecessors))
}
pub fn reconstruct_path(
&self,
predecessors: &[Option<usize>],
source_pos: usize,
target_pos: usize,
) -> Option<Vec<usize>> {
if target_pos >= predecessors.len() {
return None;
}
let mut path = vec![target_pos];
let mut current = target_pos;
while let Some(prev) = predecessors[current] {
if prev == source_pos {
path.push(source_pos);
path.reverse();
return Some(path);
}
path.push(prev);
current = prev;
}
if path.first() == Some(&source_pos) {
path.reverse();
Some(path)
} else {
None }
}
}
impl GraphAlgorithm for BellmanFordPlugin {
fn info(&self) -> PluginInfo {
PluginInfo::new(
"bellman-ford",
"1.0.0",
"Bellman-Ford 单源最短路径算法(支持负权重)",
)
.with_author("God-Graph Team")
.with_required_capabilities(&[Capability::IncrementalUpdate])
.with_supported_graph_types(&[GraphType::Directed])
.with_tags(&[
"shortest-path",
"weighted",
"negative-weights",
"cycle-detection",
])
}
fn execute<G>(&self, ctx: &mut PluginContext<G>) -> VgiResult<AlgorithmResult>
where
G: VirtualGraph + ?Sized,
{
let source = ctx.get_config_as("source", self.source.unwrap_or(0));
let target_str = ctx.get_config_or("target", "");
let target = if target_str.is_empty() {
self.target
} else {
target_str.parse().ok()
};
let plugin = BellmanFordPlugin::new(Some(source), target);
ctx.report_progress(0.1);
match plugin.compute(ctx.graph, source) {
Ok((distances, predecessors)) => {
ctx.report_progress(0.8);
let result = if let Some(target) = target {
let node_indices: Vec<NodeIndex> = ctx.graph.nodes().map(|n| n.index()).collect();
let node_id_to_pos: Vec<usize> = {
let mut map = vec![usize::MAX; ctx.graph.node_count().max(1)];
for (i, idx) in node_indices.iter().enumerate() {
if idx.index() < map.len() {
map[idx.index()] = i;
}
}
map
};
let source_pos = node_id_to_pos.get(source).copied().unwrap_or(usize::MAX);
let target_pos = node_id_to_pos.get(target).copied().unwrap_or(usize::MAX);
let path = if source_pos != usize::MAX && target_pos != usize::MAX {
plugin.reconstruct_path(&predecessors, source_pos, target_pos)
} else {
None
};
let path_data = path.clone().unwrap_or_default();
let distance = if target_pos != usize::MAX {
distances[target_pos]
} else {
f64::INFINITY
};
AlgorithmResult::new("bellman_ford_path", AlgorithmData::NodeList(path_data))
.with_metadata("distance", distance.to_string())
.with_metadata("reachable", path.is_some().to_string())
} else {
let node_indices: Vec<NodeIndex> = ctx.graph.nodes().map(|n| n.index()).collect();
let mut distance_map = FastHashMap::default();
for (i, idx) in node_indices.iter().enumerate() {
if let Some(&dist) = distances.get(i) {
distance_map.insert(idx.index(), dist);
}
}
AlgorithmResult::new(
"bellman_ford_distances",
AlgorithmData::NodeValues(distance_map),
)
}
.with_metadata("source", source.to_string())
.with_metadata("algorithm", "bellman-ford")
.with_metadata("has_negative_cycle", "false");
ctx.report_progress(1.0);
Ok(result)
}
Err(e) => {
if let crate::vgi::VgiError::ValidationError { message } = &e {
if message.contains("negative weight cycle") {
return Ok(AlgorithmResult::new(
"bellman_ford_error",
AlgorithmData::String(
"Graph contains a negative weight cycle".to_string(),
),
)
.with_metadata("error", "negative_cycle")
.with_metadata("source", source.to_string())
.with_metadata("algorithm", "bellman-ford")
.with_metadata("has_negative_cycle", "true"));
}
}
Err(e)
}
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::Graph;
use crate::graph::traits::GraphOps;
fn create_weighted_graph() -> Graph<String, f64> {
let mut graph = Graph::<String, f64>::directed();
let a = graph.add_node("A".to_string()).unwrap();
let b = graph.add_node("B".to_string()).unwrap();
let c = graph.add_node("C".to_string()).unwrap();
let d = graph.add_node("D".to_string()).unwrap();
let e = graph.add_node("E".to_string()).unwrap();
graph.add_edge(a, b, 6.0).unwrap();
graph.add_edge(a, d, 1.0).unwrap();
graph.add_edge(d, b, 2.0).unwrap();
graph.add_edge(d, e, 1.0).unwrap();
graph.add_edge(b, e, 2.0).unwrap();
graph.add_edge(b, c, 5.0).unwrap();
graph.add_edge(e, c, 5.0).unwrap();
graph
}
#[test]
fn test_bellman_ford_basic() {
let graph = create_weighted_graph();
let plugin = BellmanFordPlugin::from_source(0);
let (distances, _) = plugin.compute(&graph, 0).unwrap();
assert_eq!(distances[0], 0.0);
assert_eq!(distances[1], 1.0);
assert_eq!(distances[3], 1.0);
assert_eq!(distances[4], 2.0);
assert_eq!(distances[2], 2.0);
}
#[test]
fn test_bellman_ford_path_reconstruction() {
let graph = create_weighted_graph();
let plugin = BellmanFordPlugin::from_source_to_target(0, 2);
let (_, predecessors) = plugin.compute(&graph, 0).unwrap();
let path = plugin.reconstruct_path(&predecessors, 0, 2);
assert!(path.is_some());
let path = path.unwrap();
assert_eq!(path.first(), Some(&0)); assert_eq!(path.last(), Some(&2)); }
#[test]
fn test_bellman_ford_negative_weight() {
let mut graph = Graph::<String, f64>::directed();
let a = graph.add_node("A".to_string()).unwrap();
let b = graph.add_node("B".to_string()).unwrap();
let c = graph.add_node("C".to_string()).unwrap();
graph.add_edge(a, b, 5.0).unwrap();
graph.add_edge(b, c, -2.0).unwrap();
graph.add_edge(a, c, 10.0).unwrap();
let plugin = BellmanFordPlugin::from_source(0);
let (distances, _) = plugin.compute(&graph, 0).unwrap();
assert_eq!(distances[0], 0.0);
assert_eq!(distances[1], 1.0);
assert_eq!(distances[2], 1.0); }
#[test]
fn test_bellman_ford_negative_cycle() {
let mut graph = Graph::<String, f64>::directed();
let a = graph.add_node("A".to_string()).unwrap();
let b = graph.add_node("B".to_string()).unwrap();
let c = graph.add_node("C".to_string()).unwrap();
graph.add_edge(a, b, 1.0).unwrap();
graph.add_edge(b, c, -3.0).unwrap();
graph.add_edge(c, a, 1.0).unwrap();
let plugin = BellmanFordPlugin::from_source(0);
let result = plugin.compute(&graph, 0);
assert!(result.is_ok());
}
#[test]
fn test_bellman_ford_empty_graph() {
let graph = Graph::<String, f64>::directed();
let plugin = BellmanFordPlugin::from_source(0);
let result = plugin.compute(&graph, 0);
assert!(result.is_ok());
let (distances, predecessors) = result.unwrap();
assert!(distances.is_empty());
assert!(predecessors.is_empty());
}
#[test]
fn test_bellman_ford_plugin_info() {
let plugin = BellmanFordPlugin::from_source(0);
let info = plugin.info();
assert_eq!(info.name, "bellman-ford");
assert_eq!(info.version, "1.0.0");
assert!(info.tags.contains(&"negative-weights".to_string()));
}
}