use serde::*;
use std::collections::HashMap;
use petgraph::prelude::*;
pub use petgraph::prelude::NodeIndex;
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct NxGraph<N, E>
where
N: Default,
E: Default,
{
mapping: HashMap<(NodeIndex, NodeIndex), EdgeIndex>,
graph: StableUnGraph<N, E>,
}
fn node_pair_key(n1: NodeIndex, n2: NodeIndex) -> (NodeIndex, NodeIndex) {
if n1 > n2 {
(n2, n1)
} else {
(n1, n2)
}
}
impl<N, E> NxGraph<N, E>
where
N: Default,
E: Default,
{
fn edge_index_between(&self, n1: NodeIndex, n2: NodeIndex) -> Option<EdgeIndex> {
let (n1, n2) = if n1 > n2 { (n2, n1) } else { (n1, n2) };
self.mapping.get(&node_pair_key(n1, n2)).map(|v| *v)
}
fn get_node_data(&self, n: NodeIndex) -> &N {
self.graph.node_weight(n).expect("no node")
}
fn get_node_data_mut(&mut self, n: NodeIndex) -> &mut N {
self.graph.node_weight_mut(n).expect("no node")
}
fn get_edge_data(&self, node1: NodeIndex, node2: NodeIndex) -> &E {
let edge_index = self.edge_index_between(node1, node2).expect("no edge index");
self.graph.edge_weight(edge_index).expect("no edge")
}
fn get_edge_data_mut(&mut self, node1: NodeIndex, node2: NodeIndex) -> &mut E {
let edge_index = self.edge_index_between(node1, node2).expect("no edge index");
self.graph.edge_weight_mut(edge_index).expect("no edge")
}
}
impl<N, E> NxGraph<N, E>
where
N: Default,
E: Default,
{
pub fn new() -> Self {
Self { ..Default::default() }
}
pub fn neighbors(&self, n: NodeIndex) -> impl Iterator<Item = NodeIndex> + '_ {
self.graph.neighbors(n)
}
pub fn node_indices(&self) -> impl Iterator<Item = NodeIndex> + '_ {
self.graph.node_indices()
}
pub fn has_node(&self, n: NodeIndex) -> bool {
self.graph.contains_node(n)
}
pub fn has_edge(&self, u: NodeIndex, v: NodeIndex) -> bool {
self.graph.find_edge(u, v).is_some()
}
pub fn number_of_nodes(&self) -> usize {
self.graph.node_count()
}
pub fn number_of_edges(&self) -> usize {
self.graph.edge_count()
}
pub fn add_node(&mut self, data: N) -> NodeIndex {
self.graph.add_node(data)
}
pub fn add_nodes_from<M: IntoIterator<Item = N>>(&mut self, nodes: M) -> Vec<NodeIndex> {
nodes.into_iter().map(|node| self.add_node(node)).collect()
}
pub fn add_edge(&mut self, u: NodeIndex, v: NodeIndex, data: E) {
let e = self.graph.update_edge(u, v, data);
self.mapping.insert(node_pair_key(u, v), e);
}
pub fn add_edges_from<M: IntoIterator<Item = (NodeIndex, NodeIndex, E)>>(&mut self, edges: M) {
for (u, v, d) in edges {
self.add_edge(u, v, d);
}
}
pub fn remove_edge(&mut self, node1: NodeIndex, node2: NodeIndex) -> Option<E> {
if let Some(e) = self.mapping.remove(&node_pair_key(node1, node2)) {
self.graph.remove_edge(e)
} else {
None
}
}
pub fn remove_node(&mut self, n: NodeIndex) -> Option<N> {
self.graph.remove_node(n)
}
pub fn clear(&mut self) {
self.graph.clear();
}
pub fn clear_edges(&mut self) {
self.graph.clear_edges()
}
}
#[cfg(feature = "adhoc")]
impl<N, E> NxGraph<N, E>
where
N: Default,
E: Default,
{
pub fn raw_graph(&self) -> &StableUnGraph<N, E> {
&self.graph
}
pub fn raw_graph_mut(&mut self) -> &mut StableUnGraph<N, E> {
&mut self.graph
}
pub fn from_raw_graph(graph: StableUnGraph<N, E>) -> Self {
todo!()
}
}
impl NxGraph<usize, usize> {
pub fn path_graph(n: usize) -> Self {
let mut g = Self::new();
let nodes = g.add_nodes_from(1..=n);
for p in nodes.windows(2) {
g.add_edge(p[0], p[1], 0)
}
g
}
}
#[test]
fn test_path_graph() {
let g = NxGraph::path_graph(5);
assert_eq!(g.number_of_nodes(), 5);
assert_eq!(g.number_of_edges(), 4);
}
impl<N, E> std::ops::Index<NodeIndex> for NxGraph<N, E>
where
N: Default,
E: Default,
{
type Output = N;
fn index(&self, n: NodeIndex) -> &Self::Output {
self.get_node_data(n)
}
}
impl<N, E> std::ops::IndexMut<NodeIndex> for NxGraph<N, E>
where
N: Default,
E: Default,
{
fn index_mut(&mut self, n: NodeIndex) -> &mut Self::Output {
self.get_node_data_mut(n)
}
}
impl<N, E> std::ops::Index<(NodeIndex, NodeIndex)> for NxGraph<N, E>
where
N: Default,
E: Default,
{
type Output = E;
fn index(&self, e: (NodeIndex, NodeIndex)) -> &Self::Output {
self.get_edge_data(e.0, e.1)
}
}
impl<N, E> std::ops::IndexMut<(NodeIndex, NodeIndex)> for NxGraph<N, E>
where
N: Default,
E: Default,
{
fn index_mut(&mut self, e: (NodeIndex, NodeIndex)) -> &mut Self::Output {
self.get_edge_data_mut(e.0, e.1)
}
}
pub struct Nodes<'a, N, E>
where
N: Default,
E: Default,
{
nodes: std::vec::IntoIter<NodeIndex>,
parent: &'a NxGraph<N, E>,
}
impl<'a, N, E> Nodes<'a, N, E>
where
N: Default,
E: Default,
{
fn new(g: &'a NxGraph<N, E>) -> Self {
let nodes: Vec<_> = g.graph.node_indices().collect();
Self {
parent: g,
nodes: nodes.into_iter(),
}
}
}
impl<'a, N, E> Iterator for Nodes<'a, N, E>
where
N: Default,
E: Default,
{
type Item = (NodeIndex, &'a N);
fn next(&mut self) -> Option<Self::Item> {
if let Some(cur) = self.nodes.next() {
Some((cur, &self.parent.graph[cur]))
} else {
None
}
}
}
impl<'a, N, E> std::ops::Index<NodeIndex> for Nodes<'a, N, E>
where
N: Default,
E: Default,
{
type Output = N;
fn index(&self, n: NodeIndex) -> &Self::Output {
&self.parent[n]
}
}
pub struct Edges<'a, N, E>
where
N: Default,
E: Default,
{
parent: &'a NxGraph<N, E>,
edges: std::vec::IntoIter<EdgeIndex>,
}
impl<'a, N, E> Edges<'a, N, E>
where
N: Default,
E: Default,
{
fn new(g: &'a NxGraph<N, E>) -> Self {
let edges: Vec<_> = g.graph.edge_indices().collect();
Self {
parent: g,
edges: edges.into_iter(),
}
}
}
impl<'a, N, E> Iterator for Edges<'a, N, E>
where
N: Default,
E: Default,
{
type Item = (NodeIndex, NodeIndex, &'a E);
fn next(&mut self) -> Option<Self::Item> {
if let Some(cur) = self.edges.next() {
let (u, v) = self
.parent
.graph
.edge_endpoints(cur)
.expect("no graph endpoints");
let edge_data = &self.parent.graph[cur];
Some((u, v, edge_data))
} else {
None
}
}
}
impl<'a, N, E> std::ops::Index<(NodeIndex, NodeIndex)> for Edges<'a, N, E>
where
N: Default,
E: Default,
{
type Output = E;
fn index(&self, e: (NodeIndex, NodeIndex)) -> &Self::Output {
&self.parent[e]
}
}
impl<N, E> NxGraph<N, E>
where
N: Default,
E: Default,
{
pub fn nodes(&self) -> Nodes<N, E> {
Nodes::new(self)
}
pub fn edges(&self) -> Edges<N, E> {
Edges::new(self)
}
}
#[cfg(test)]
mod test {
use super::*;
#[derive(Clone, Default, Debug, PartialEq)]
struct Edge {
weight: f64,
}
#[derive(Clone, Default, Debug, PartialEq)]
struct Node {
position: [f64; 3],
}
#[test]
fn test_graph() {
let mut g = NxGraph::new();
let n1 = g.add_node(Node::default());
let n2 = g.add_node(Node::default());
let n3 = g.add_node(Node::default());
g.add_edge(n1, n2, Edge { weight: 1.0 });
assert_eq!(1, g.number_of_edges());
g.add_edge(n1, n2, Edge { weight: 2.0 });
assert_eq!(1, g.number_of_edges());
assert_eq!(g[(n1, n2)].weight, 2.0);
g.add_edge(n1, n3, Edge::default());
let n4 = g.add_node(Node::default());
let _ = g.remove_node(n4);
assert_eq!(g.number_of_nodes(), 3);
assert_eq!(g.number_of_edges(), 2);
let node = Node { position: [1.0; 3] };
let n4 = g.add_node(node.clone());
let edge = Edge { weight: 2.2 };
g.add_edge(n1, n4, edge.clone());
let x = g.remove_edge(n2, n4);
assert_eq!(x, None);
let x = g.remove_edge(n1, n4);
assert_eq!(x, Some(edge));
let x = g.remove_node(n4);
assert_eq!(x, Some(node));
assert!(g.has_node(n1));
assert!(g.has_edge(n1, n2));
assert!(!g.has_edge(n2, n3));
let _ = g.remove_edge(n1, n3);
assert_eq!(g.number_of_edges(), 1);
assert!(!g.has_edge(n1, n3));
g[n1].position = [1.9; 3];
let nodes = g.nodes();
assert_eq!(nodes[n1].position, [1.9; 3]);
g[(n1, n2)].weight = 0.3;
let edges = g.edges();
assert_eq!(edges[(n1, n2)].weight, 0.3);
assert_eq!(edges[(n2, n1)].weight, 0.3);
for (u, node_data) in g.nodes() {
dbg!(u, node_data);
}
for (u, v, edge_data) in g.edges() {
dbg!(u, v, edge_data);
}
for u in g.neighbors(n1) {
dbg!(&g[u]);
}
g.clear();
assert_eq!(g.number_of_nodes(), 0);
assert_eq!(g.number_of_edges(), 0);
}
}