use crate::algorithms::centrality::{pagerank, degree_centrality};
use crate::algorithms::community::connected_components;
use crate::algorithms::properties::{is_connected, has_cycle, density, is_dag};
use crate::algorithms::shortest_path::dijkstra;
use crate::algorithms::traversal::bfs;
use crate::export::to_dot;
use crate::graph::Graph;
use crate::graph::traits::{GraphBase, GraphOps, GraphQuery};
use crate::node::{NodeIndex, NodeRef};
use crate::vgi::traits::VirtualGraph;
use std::collections::HashMap;
use std::hash::Hash;
pub struct SimpleGraph<T, E> {
inner: Graph<T, E>,
}
impl<T, E> SimpleGraph<T, E>
where
T: Clone,
E: Clone,
{
pub fn directed() -> Self {
Self {
inner: Graph::<T, E>::directed(),
}
}
pub fn undirected() -> Self {
Self {
inner: Graph::<T, E>::undirected(),
}
}
pub fn add_node(&mut self, data: T) -> NodeIndex {
GraphOps::add_node(&mut self.inner, data)
.expect("节点索引溢出:无法添加超过 2^32 个节点")
}
pub fn add_edge(&mut self, from: NodeIndex, to: NodeIndex, data: E) -> Option<EdgeIndex> {
GraphOps::add_edge(&mut self.inner, from, to, data).ok()
}
pub fn node_count(&self) -> usize {
GraphBase::node_count(&self.inner)
}
pub fn edge_count(&self) -> usize {
GraphBase::edge_count(&self.inner)
}
pub fn get_node(&self, index: NodeIndex) -> Option<&T> {
GraphQuery::get_node(&self.inner, index)
}
pub fn get_node_mut(&mut self, index: NodeIndex) -> Option<&mut T> {
VirtualGraph::get_node_mut(&mut self.inner, index).ok()
}
pub fn contains_node(&self, index: NodeIndex) -> bool {
GraphQuery::get_node(&self.inner, index).is_some()
}
pub fn node_indices(&self) -> impl Iterator<Item = NodeIndex> + '_ {
GraphQuery::nodes(&self.inner).map(|n| n.index())
}
pub fn neighbors(&self, node: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
GraphQuery::neighbors(&self.inner, node)
}
}
impl<T, E> SimpleGraph<T, E>
where
T: Clone + Send + Sync,
E: Clone,
{
pub fn pagerank(&self, damping: f64, iterations: usize) -> HashMap<usize, f64> {
let ranks = pagerank(&self.inner, damping, iterations);
ranks
.into_iter()
.map(|(idx, rank)| (idx.index(), rank))
.collect()
}
pub fn degree_centrality(&self) -> HashMap<usize, f64> {
let centrality = degree_centrality(&self.inner);
centrality
.into_iter()
.enumerate()
.collect()
}
pub fn connected_components(&self) -> Vec<Vec<usize>> {
let components = connected_components(&self.inner);
components
.into_iter()
.map(|comp| comp.into_iter().map(|n| n.index()).collect())
.collect()
}
pub fn shortest_paths(&self, start: NodeIndex) -> HashMap<usize, f64>
where
E: Into<f64> + Copy,
{
let distances = dijkstra(&self.inner, start, |_, _, edge_data| {
(*edge_data).into()
}).unwrap_or_default();
distances
.iter()
.map(|(&idx, &dist)| (idx.index(), dist))
.collect()
}
pub fn bfs_order(&self, start: NodeIndex) -> Vec<usize> {
let mut order = Vec::new();
bfs(&self.inner, start, |node, _depth| {
order.push(node.index());
true
});
order
}
pub fn is_connected(&self) -> bool {
is_connected(&self.inner)
}
pub fn has_cycle(&self) -> bool {
has_cycle(&self.inner)
}
pub fn is_dag(&self) -> bool {
is_dag(&self.inner)
}
pub fn density(&self) -> f64 {
density(&self.inner)
}
pub fn to_dot(&self) -> String
where
T: std::fmt::Display,
E: std::fmt::Display,
{
to_dot(&self.inner)
}
}
impl<T, E> Default for SimpleGraph<T, E>
where
T: Clone,
E: Clone,
{
fn default() -> Self {
Self::directed()
}
}
impl<T, E> From<Graph<T, E>> for SimpleGraph<T, E> {
fn from(graph: Graph<T, E>) -> Self {
Self { inner: graph }
}
}
impl<T, E> From<SimpleGraph<T, E>> for Graph<T, E> {
fn from(simple: SimpleGraph<T, E>) -> Self {
simple.inner
}
}
impl<T, E> SimpleGraph<T, E>
where
T: Clone + Eq + Hash + Ord,
E: Clone,
{
pub fn find_node_by_data(&self, data: &T) -> Option<NodeIndex> {
GraphQuery::nodes(&self.inner)
.find(|n: &NodeRef<'_, T>| n.data() == data)
.map(|n: NodeRef<'_, T>| n.index())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_graph_creation() {
let mut graph = SimpleGraph::<String, f64>::directed();
let a = graph.add_node("A".to_string());
let b = graph.add_node("B".to_string());
graph.add_edge(a, b, 1.0);
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 1);
}
#[test]
fn test_simple_graph_pagerank() {
let mut graph = SimpleGraph::<String, ()>::directed();
let a = graph.add_node("A".to_string());
let b = graph.add_node("B".to_string());
let c = graph.add_node("C".to_string());
graph.add_edge(a, b, ());
graph.add_edge(b, c, ());
graph.add_edge(c, a, ());
let ranks = graph.pagerank(0.85, 20);
assert_eq!(ranks.len(), 3);
let values: Vec<f64> = ranks.values().copied().collect();
assert!((values[0] - values[1]).abs() < 0.1);
}
#[test]
fn test_simple_graph_properties() {
let mut graph = SimpleGraph::<String, ()>::undirected();
let a = graph.add_node("A".to_string());
let b = graph.add_node("B".to_string());
let c = graph.add_node("C".to_string());
graph.add_edge(a, b, ());
graph.add_edge(b, c, ());
assert!(graph.is_connected());
assert!(!graph.has_cycle());
assert!(graph.density() > 0.0);
}
#[test]
fn test_simple_graph_shortest_path() {
let mut graph = SimpleGraph::<String, f64>::directed();
let a = graph.add_node("A".to_string());
let b = graph.add_node("B".to_string());
let c = graph.add_node("C".to_string());
graph.add_edge(a, b, 1.0);
graph.add_edge(b, c, 2.0);
graph.add_edge(a, c, 5.0);
let distances = graph.shortest_paths(a);
assert_eq!(distances.get(&a.index()), Some(&0.0));
assert_eq!(distances.get(&b.index()), Some(&1.0));
assert_eq!(distances.get(&c.index()), Some(&3.0)); }
#[test]
fn test_simple_graph_export() {
let mut graph = SimpleGraph::<String, f64>::directed();
let a = graph.add_node("A".to_string());
let b = graph.add_node("B".to_string());
graph.add_edge(a, b, 1.0);
let dot = graph.to_dot();
assert!(dot.contains("digraph"));
assert!(dot.contains("A"));
assert!(dot.contains("B"));
}
}