use crate::util::errors::Res;
use petgraph::{
self,
graph::NodeIndex,
visit::{Bfs, EdgeRef, IntoNodeReferences, Walker},
Direction,
};
use std::{
collections::HashMap,
ops::{Index, IndexMut},
};
#[derive(Debug, Clone)]
pub struct Graph<T, E = ()>
where
T: Eq,
{
pub inner: petgraph::Graph<T, E>,
}
impl<T: Eq, E> Graph<T, E> {
pub fn new(graph: petgraph::Graph<T, E>) -> Self {
Graph { inner: graph }
}
pub fn find_id(&self, node: &T) -> Option<NodeIndex> {
self.inner
.node_references()
.find(|(_, weight)| *weight == node)
.map(|(index, _)| index)
}
pub fn find_by<F>(&self, f: F) -> Option<&T>
where
F: Fn(&T) -> bool,
{
let node = self.inner.node_references().find(|(_, node)| f(node))?.1;
Some(node)
}
pub fn sub_tree<'a>(
&'a self,
root_id: NodeIndex,
) -> impl Iterator<Item = (NodeIndex, &T)> + 'a {
Bfs::new(&self.inner, root_id)
.iter(&self.inner)
.map(move |node_id| (node_id, &self.inner[node_id]))
}
pub fn children<'a>(
&'a self,
parent_id: NodeIndex,
) -> impl Iterator<Item = (NodeIndex, &T)> + 'a {
self.inner
.neighbors_directed(parent_id, Direction::Outgoing)
.map(move |node_id| (node_id, &self.inner[node_id]))
}
pub fn parents<'a>(
&'a self,
child_id: NodeIndex,
) -> impl Iterator<Item = (NodeIndex, &T)> + 'a {
self.inner
.neighbors_directed(child_id, Direction::Incoming)
.map(move |node_id| (node_id, &self.inner[node_id]))
}
pub fn map<U, V, F, G>(&self, mut f: F, mut g: G) -> Res<Graph<U, V>>
where
U: Eq,
F: FnMut(NodeIndex, &T) -> Res<U>,
G: FnMut(&E) -> Res<V>,
{
let mut tree = petgraph::Graph::new();
let mut node_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
for (idx, weight) in self.inner.node_references() {
let new_idx = tree.add_node(f(idx, weight)?);
node_map.insert(idx, new_idx);
}
for edge in self.inner.edge_references() {
tree.add_edge(
node_map[&edge.source()],
node_map[&edge.target()],
g(edge.weight())?,
);
}
Ok(Graph::new(tree))
}
pub fn filter_map<U, V, F, G>(&mut self, mut f: F, mut g: G) -> Graph<U, V>
where
U: Eq,
F: FnMut(&T) -> Option<U>,
G: FnMut(&E) -> Option<V>,
{
Graph {
inner: self.inner.filter_map(|_, i| f(i), |_, j| g(j)),
}
}
}
impl<T, E> Index<NodeIndex> for Graph<T, E>
where
T: Eq,
{
type Output = T;
fn index(&self, index: NodeIndex) -> &T {
&self.inner[index]
}
}
impl<T, E> IndexMut<NodeIndex> for Graph<T, E>
where
T: Eq,
{
fn index_mut(&mut self, index: NodeIndex) -> &mut T {
&mut self.inner[index]
}
}
impl<T> Default for Graph<T>
where
T: Eq,
{
fn default() -> Self {
Graph {
inner: petgraph::Graph::new(),
}
}
}