use std::cmp;
use std::collections::BTreeMap;
use std::fs::File;
use std::io::prelude::*;
use std::io::BufReader;
use std::ops::{Index, IndexMut};
use std::str;
use hashbrown::{HashMap, HashSet};
use pyo3::class::PyMappingProtocol;
use pyo3::exceptions::PyIndexError;
use pyo3::gc::{PyGCProtocol, PyVisit};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyList, PyLong, PyString, PyTuple};
use pyo3::PyTraverseError;
use pyo3::Python;
use ndarray::prelude::*;
use numpy::PyReadonlyArray2;
use super::dot_utils::build_dot;
use super::iterators::{EdgeList, NodeIndices, WeightedEdgeList};
use super::{NoEdgeBetweenNodes, NodesRemoved};
use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;
use petgraph::stable_graph::StableUnGraph;
use petgraph::visit::{
GetAdjacencyMatrix, GraphBase, GraphProp, IntoEdgeReferences, IntoEdges,
IntoNeighbors, IntoNeighborsDirected, IntoNodeIdentifiers,
IntoNodeReferences, NodeCompactIndexable, NodeCount, NodeFiltered,
NodeIndexable, Visitable,
};
#[pyclass(module = "retworkx", subclass, gc)]
#[text_signature = "(/, multigraph=True)"]
pub struct PyGraph {
pub graph: StableUnGraph<PyObject, PyObject>,
pub node_removed: bool,
pub multigraph: bool,
}
pub type Edges<'a, E> =
petgraph::stable_graph::Edges<'a, E, petgraph::Undirected>;
impl GraphBase for PyGraph {
type NodeId = NodeIndex;
type EdgeId = EdgeIndex;
}
impl<'a> NodesRemoved for &'a PyGraph {
fn nodes_removed(&self) -> bool {
self.node_removed
}
}
impl NodeCount for PyGraph {
fn node_count(&self) -> usize {
self.graph.node_count()
}
}
impl GraphProp for PyGraph {
type EdgeType = petgraph::Undirected;
fn is_directed(&self) -> bool {
false
}
}
impl petgraph::visit::Visitable for PyGraph {
type Map = <StableUnGraph<PyObject, PyObject> as Visitable>::Map;
fn visit_map(&self) -> Self::Map {
self.graph.visit_map()
}
fn reset_map(&self, map: &mut Self::Map) {
self.graph.reset_map(map)
}
}
impl petgraph::visit::Data for PyGraph {
type NodeWeight = PyObject;
type EdgeWeight = PyObject;
}
impl petgraph::data::DataMap for PyGraph {
fn node_weight(&self, id: Self::NodeId) -> Option<&Self::NodeWeight> {
self.graph.node_weight(id)
}
fn edge_weight(&self, id: Self::EdgeId) -> Option<&Self::EdgeWeight> {
self.graph.edge_weight(id)
}
}
impl petgraph::data::DataMapMut for PyGraph {
fn node_weight_mut(
&mut self,
id: Self::NodeId,
) -> Option<&mut Self::NodeWeight> {
self.graph.node_weight_mut(id)
}
fn edge_weight_mut(
&mut self,
id: Self::EdgeId,
) -> Option<&mut Self::EdgeWeight> {
self.graph.edge_weight_mut(id)
}
}
impl<'a> IntoNeighbors for &'a PyGraph {
type Neighbors = petgraph::stable_graph::Neighbors<'a, PyObject>;
fn neighbors(self, n: NodeIndex) -> Self::Neighbors {
self.graph.neighbors(n)
}
}
impl<'a> IntoNeighborsDirected for &'a PyGraph {
type NeighborsDirected = petgraph::stable_graph::Neighbors<'a, PyObject>;
fn neighbors_directed(
self,
n: NodeIndex,
d: petgraph::Direction,
) -> Self::Neighbors {
self.graph.neighbors_directed(n, d)
}
}
impl<'a> IntoEdgeReferences for &'a PyGraph {
type EdgeRef = petgraph::stable_graph::EdgeReference<'a, PyObject>;
type EdgeReferences = petgraph::stable_graph::EdgeReferences<'a, PyObject>;
fn edge_references(self) -> Self::EdgeReferences {
self.graph.edge_references()
}
}
impl<'a> IntoEdges for &'a PyGraph {
type Edges = Edges<'a, PyObject>;
fn edges(self, a: Self::NodeId) -> Self::Edges {
self.graph.edges(a)
}
}
impl<'a> IntoNodeIdentifiers for &'a PyGraph {
type NodeIdentifiers = petgraph::stable_graph::NodeIndices<'a, PyObject>;
fn node_identifiers(self) -> Self::NodeIdentifiers {
self.graph.node_identifiers()
}
}
impl<'a> IntoNodeReferences for &'a PyGraph {
type NodeRef = (NodeIndex, &'a PyObject);
type NodeReferences = petgraph::stable_graph::NodeReferences<'a, PyObject>;
fn node_references(self) -> Self::NodeReferences {
self.graph.node_references()
}
}
impl NodeIndexable for PyGraph {
fn node_bound(&self) -> usize {
self.graph.node_bound()
}
fn to_index(&self, ix: NodeIndex) -> usize {
self.graph.to_index(ix)
}
fn from_index(&self, ix: usize) -> Self::NodeId {
self.graph.from_index(ix)
}
}
impl NodeCompactIndexable for PyGraph {}
impl Index<NodeIndex> for PyGraph {
type Output = PyObject;
fn index(&self, index: NodeIndex) -> &PyObject {
&self.graph[index]
}
}
impl IndexMut<NodeIndex> for PyGraph {
fn index_mut(&mut self, index: NodeIndex) -> &mut PyObject {
&mut self.graph[index]
}
}
impl Index<EdgeIndex> for PyGraph {
type Output = PyObject;
fn index(&self, index: EdgeIndex) -> &PyObject {
&self.graph[index]
}
}
impl IndexMut<EdgeIndex> for PyGraph {
fn index_mut(&mut self, index: EdgeIndex) -> &mut PyObject {
&mut self.graph[index]
}
}
impl GetAdjacencyMatrix for PyGraph {
type AdjMatrix =
<StableUnGraph<PyObject, PyObject> as GetAdjacencyMatrix>::AdjMatrix;
fn adjacency_matrix(&self) -> Self::AdjMatrix {
self.graph.adjacency_matrix()
}
fn is_adjacent(
&self,
matrix: &Self::AdjMatrix,
a: NodeIndex,
b: NodeIndex,
) -> bool {
self.graph.is_adjacent(matrix, a, b)
}
}
#[pymethods]
impl PyGraph {
#[new]
#[args(multigraph = "true")]
fn new(multigraph: bool) -> Self {
PyGraph {
graph: StableUnGraph::<PyObject, PyObject>::default(),
node_removed: false,
multigraph,
}
}
fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let out_dict = PyDict::new(py);
let node_dict = PyDict::new(py);
let mut out_list: Vec<PyObject> =
Vec::with_capacity(self.graph.edge_count());
out_dict.set_item("nodes", node_dict)?;
out_dict.set_item("nodes_removed", self.node_removed)?;
out_dict.set_item("multigraph", self.multigraph)?;
for node_index in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_index).unwrap();
node_dict.set_item(node_index.index(), node_data)?;
}
for edge in self.graph.edge_indices() {
let edge_w = self.graph.edge_weight(edge);
let endpoints = self.graph.edge_endpoints(edge).unwrap();
let triplet = (endpoints.0.index(), endpoints.1.index(), edge_w)
.to_object(py);
out_list.push(triplet);
}
let py_out_list: PyObject = PyList::new(py, out_list).into();
out_dict.set_item("edges", py_out_list)?;
Ok(out_dict.into())
}
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
self.graph = StableUnGraph::<PyObject, PyObject>::default();
let dict_state = state.cast_as::<PyDict>(py)?;
let nodes_dict =
dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list =
dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let nodes_removed_raw = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?;
self.node_removed = nodes_removed_raw.extract()?;
let multigraph_raw = dict_state
.get_item("multigraph")
.unwrap()
.downcast::<PyBool>()?;
self.multigraph = multigraph_raw.extract()?;
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
node_indices.push(tmp_index.extract()?);
}
if node_indices.is_empty() {
return Ok(());
}
let max_index: usize = *node_indices.iter().max().unwrap();
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0).downcast::<PyLong>()?;
let parent: usize = raw_p_index.extract()?;
let p_index = NodeIndex::new(parent);
let raw_c_index = edge.get_item(1).downcast::<PyLong>()?;
let child: usize = raw_c_index.extract()?;
let c_index = NodeIndex::new(child);
let edge_data = edge.get_item(2);
self.graph.add_edge(p_index, c_index, edge_data.into());
}
Ok(())
}
#[getter]
fn multigraph(&self) -> bool {
self.multigraph
}
#[text_signature = "(self)"]
pub fn edges(&self) -> Vec<&PyObject> {
self.graph
.edge_indices()
.map(|edge| self.graph.edge_weight(edge).unwrap())
.collect()
}
#[text_signature = "(self)"]
pub fn nodes(&self) -> Vec<&PyObject> {
self.graph
.node_indices()
.map(|node| self.graph.node_weight(node).unwrap())
.collect()
}
#[text_signature = "(self)"]
pub fn node_indexes(&self) -> NodeIndices {
NodeIndices {
nodes: self.graph.node_indices().map(|node| node.index()).collect(),
}
}
#[text_signature = "(self, node_a, node_b, /)"]
pub fn has_edge(&self, node_a: usize, node_b: usize) -> bool {
let index_a = NodeIndex::new(node_a);
let index_b = NodeIndex::new(node_b);
self.graph.find_edge(index_a, index_b).is_some()
}
#[text_signature = "(self, node_a, node_b, /)"]
pub fn get_edge_data(
&self,
node_a: usize,
node_b: usize,
) -> PyResult<&PyObject> {
let index_a = NodeIndex::new(node_a);
let index_b = NodeIndex::new(node_b);
let edge_index = match self.graph.find_edge(index_a, index_b) {
Some(edge_index) => edge_index,
None => {
return Err(NoEdgeBetweenNodes::new_err(
"No edge found between nodes",
))
}
};
let data = self.graph.edge_weight(edge_index).unwrap();
Ok(data)
}
#[text_signature = "(self, source, target, edge /)"]
pub fn update_edge(
&mut self,
source: usize,
target: usize,
edge: PyObject,
) -> PyResult<()> {
let index_a = NodeIndex::new(source);
let index_b = NodeIndex::new(target);
let edge_index = match self.graph.find_edge(index_a, index_b) {
Some(edge_index) => edge_index,
None => {
return Err(NoEdgeBetweenNodes::new_err(
"No edge found between nodes",
))
}
};
let data = self.graph.edge_weight_mut(edge_index).unwrap();
*data = edge;
Ok(())
}
#[text_signature = "(self, source, target, edge /)"]
pub fn update_edge_by_index(
&mut self,
edge_index: usize,
edge: PyObject,
) -> PyResult<()> {
match self.graph.edge_weight_mut(EdgeIndex::new(edge_index)) {
Some(data) => *data = edge,
None => {
return Err(PyIndexError::new_err("No edge found for index"))
}
};
Ok(())
}
#[text_signature = "(self, node, /)"]
pub fn get_node_data(&self, node: usize) -> PyResult<&PyObject> {
let index = NodeIndex::new(node);
let node = match self.graph.node_weight(index) {
Some(node) => node,
None => {
return Err(PyIndexError::new_err("No node found for index"))
}
};
Ok(node)
}
#[text_signature = "(self, node_a, node_b, /)"]
pub fn get_all_edge_data(
&self,
node_a: usize,
node_b: usize,
) -> PyResult<Vec<&PyObject>> {
let index_a = NodeIndex::new(node_a);
let index_b = NodeIndex::new(node_b);
let out: Vec<&PyObject> = self
.graph
.edges(index_a)
.filter(|edge| edge.target() == index_b)
.map(|edge| edge.weight())
.collect();
if out.is_empty() {
Err(NoEdgeBetweenNodes::new_err("No edge found between nodes"))
} else {
Ok(out)
}
}
#[text_signature = "(self)"]
pub fn edge_list(&self) -> EdgeList {
EdgeList {
edges: self
.edge_references()
.map(|edge| (edge.source().index(), edge.target().index()))
.collect(),
}
}
#[text_signature = "(self)"]
pub fn weighted_edge_list(&self, py: Python) -> WeightedEdgeList {
WeightedEdgeList {
edges: self
.edge_references()
.map(|edge| {
(
edge.source().index(),
edge.target().index(),
edge.weight().clone_ref(py),
)
})
.collect(),
}
}
#[text_signature = "(self, node, /)"]
pub fn remove_node(&mut self, node: usize) -> PyResult<()> {
let index = NodeIndex::new(node);
self.graph.remove_node(index);
self.node_removed = true;
Ok(())
}
#[text_signature = "(self, node_a, node_b, edge, /)"]
pub fn add_edge(
&mut self,
node_a: usize,
node_b: usize,
edge: PyObject,
) -> PyResult<usize> {
let p_index = NodeIndex::new(node_a);
let c_index = NodeIndex::new(node_b);
if !self.multigraph {
let exists = self.graph.find_edge(p_index, c_index);
if let Some(index) = exists {
let edge_weight = self.graph.edge_weight_mut(index).unwrap();
*edge_weight = edge;
return Ok(index.index());
}
}
let edge = self.graph.add_edge(p_index, c_index, edge);
Ok(edge.index())
}
#[text_signature = "(self, obj_list, /)"]
pub fn add_edges_from(
&mut self,
obj_list: Vec<(usize, usize, PyObject)>,
) -> PyResult<Vec<usize>> {
let mut out_list: Vec<usize> = Vec::with_capacity(obj_list.len());
for obj in obj_list {
let p_index = NodeIndex::new(obj.0);
let c_index = NodeIndex::new(obj.1);
if !self.multigraph {
let exists = self.graph.find_edge(p_index, c_index);
if let Some(index) = exists {
let edge_weight =
self.graph.edge_weight_mut(index).unwrap();
*edge_weight = obj.2;
out_list.push(index.index());
continue;
}
}
let edge = self.graph.add_edge(p_index, c_index, obj.2);
out_list.push(edge.index());
}
Ok(out_list)
}
#[text_signature = "(self, obj_list, /)"]
pub fn add_edges_from_no_data(
&mut self,
py: Python,
obj_list: Vec<(usize, usize)>,
) -> PyResult<Vec<usize>> {
let mut out_list: Vec<usize> = Vec::with_capacity(obj_list.len());
for obj in obj_list {
let p_index = NodeIndex::new(obj.0);
let c_index = NodeIndex::new(obj.1);
if !self.multigraph {
let exists = self.graph.find_edge(p_index, c_index);
if let Some(index) = exists {
let edge_weight =
self.graph.edge_weight_mut(index).unwrap();
*edge_weight = py.None();
out_list.push(index.index());
continue;
}
}
let edge = self.graph.add_edge(p_index, c_index, py.None());
out_list.push(edge.index());
}
Ok(out_list)
}
#[text_signature = "(self, edge_list, /)"]
pub fn extend_from_edge_list(
&mut self,
py: Python,
edge_list: Vec<(usize, usize)>,
) {
for (source, target) in edge_list {
let max_index = cmp::max(source, target);
while max_index >= self.node_count() {
self.graph.add_node(py.None());
}
let source_index = NodeIndex::new(source);
let target_index = NodeIndex::new(target);
if !self.multigraph {
let exists = self.graph.find_edge(source_index, target_index);
if let Some(index) = exists {
let edge_weight =
self.graph.edge_weight_mut(index).unwrap();
*edge_weight = py.None();
continue;
}
}
self.graph.add_edge(source_index, target_index, py.None());
}
}
#[text_signature = "(self, edge_lsit, /)"]
pub fn extend_from_weighted_edge_list(
&mut self,
py: Python,
edge_list: Vec<(usize, usize, PyObject)>,
) {
for (source, target, weight) in edge_list {
let max_index = cmp::max(source, target);
while max_index >= self.node_count() {
self.graph.add_node(py.None());
}
let source_index = NodeIndex::new(source);
let target_index = NodeIndex::new(target);
if !self.multigraph {
let exists = self.graph.find_edge(source_index, target_index);
if let Some(index) = exists {
let edge_weight =
self.graph.edge_weight_mut(index).unwrap();
*edge_weight = weight;
continue;
}
}
self.graph.add_edge(source_index, target_index, weight);
}
}
#[text_signature = "(self, node_a, node_b, /)"]
pub fn remove_edge(
&mut self,
node_a: usize,
node_b: usize,
) -> PyResult<()> {
let p_index = NodeIndex::new(node_a);
let c_index = NodeIndex::new(node_b);
let edge_index = match self.graph.find_edge(p_index, c_index) {
Some(edge_index) => edge_index,
None => {
return Err(NoEdgeBetweenNodes::new_err(
"No edge found between nodes",
))
}
};
self.graph.remove_edge(edge_index);
Ok(())
}
#[text_signature = "(self, edge, /)"]
pub fn remove_edge_from_index(&mut self, edge: usize) -> PyResult<()> {
let edge_index = EdgeIndex::new(edge);
self.graph.remove_edge(edge_index);
Ok(())
}
#[text_signature = "(self, index_list, /)"]
pub fn remove_edges_from(
&mut self,
index_list: Vec<(usize, usize)>,
) -> PyResult<()> {
for (p_index, c_index) in index_list
.iter()
.map(|(x, y)| (NodeIndex::new(*x), NodeIndex::new(*y)))
{
let edge_index = match self.graph.find_edge(p_index, c_index) {
Some(edge_index) => edge_index,
None => {
return Err(NoEdgeBetweenNodes::new_err(
"No edge found between nodes",
))
}
};
self.graph.remove_edge(edge_index);
}
Ok(())
}
#[text_signature = "(self, obj, /)"]
pub fn add_node(&mut self, obj: PyObject) -> PyResult<usize> {
let index = self.graph.add_node(obj);
Ok(index.index())
}
#[text_signature = "(self, obj_list, /)"]
pub fn add_nodes_from(&mut self, obj_list: Vec<PyObject>) -> NodeIndices {
let out_list: Vec<usize> = obj_list
.into_iter()
.map(|obj| self.graph.add_node(obj).index())
.collect();
NodeIndices { nodes: out_list }
}
#[text_signature = "(self, index_list, /)"]
pub fn remove_nodes_from(
&mut self,
index_list: Vec<usize>,
) -> PyResult<()> {
for node in index_list.iter().map(|x| NodeIndex::new(*x)) {
self.graph.remove_node(node);
}
Ok(())
}
#[text_signature = "(self, node, /)"]
pub fn adj(&mut self, node: usize) -> PyResult<HashMap<usize, &PyObject>> {
let index = NodeIndex::new(node);
let neighbors = self.graph.neighbors(index);
let mut out_map: HashMap<usize, &PyObject> = HashMap::new();
for neighbor in neighbors {
let edge = self.graph.find_edge(index, neighbor);
let edge_w = self.graph.edge_weight(edge.unwrap());
out_map.insert(neighbor.index(), edge_w.unwrap());
}
Ok(out_map)
}
#[text_signature = "(self, node, /)"]
pub fn neighbors(&self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors(NodeIndex::new(node))
.map(|node| node.index())
.collect::<HashSet<usize>>()
.drain()
.collect(),
}
}
#[text_signature = "(self, node, /)"]
pub fn degree(&self, node: usize) -> usize {
let index = NodeIndex::new(node);
let neighbors = self.graph.edges(index);
neighbors.count()
}
#[text_signature = "(self, /, node_attr=None, edge_attr=None, graph_attr=None, filename=None)"]
pub fn to_dot(
&self,
py: Python,
node_attr: Option<PyObject>,
edge_attr: Option<PyObject>,
graph_attr: Option<BTreeMap<String, String>>,
filename: Option<String>,
) -> PyResult<Option<PyObject>> {
if filename.is_some() {
let mut file = File::create(filename.unwrap())?;
build_dot(py, self, &mut file, graph_attr, node_attr, edge_attr)?;
Ok(None)
} else {
let mut file = Vec::<u8>::new();
build_dot(py, self, &mut file, graph_attr, node_attr, edge_attr)?;
Ok(Some(
PyString::new(py, str::from_utf8(&file)?).to_object(py),
))
}
}
#[staticmethod]
#[text_signature = "(path, /, comment=None, deliminator=None)"]
pub fn read_edge_list(
py: Python,
path: &str,
comment: Option<String>,
deliminator: Option<String>,
) -> PyResult<PyGraph> {
let file = File::open(path)?;
let buf_reader = BufReader::new(file);
let mut out_graph = StableUnGraph::<PyObject, PyObject>::default();
for line_raw in buf_reader.lines() {
let line = line_raw?;
let skip = match &comment {
Some(comm) => line.trim().starts_with(comm),
None => line.trim().is_empty(),
};
if skip {
continue;
}
let line_no_comments = match &comment {
Some(comm) => line
.find(comm)
.map(|idx| &line[..idx])
.unwrap_or(&line)
.trim()
.to_string(),
None => line,
};
let pieces: Vec<&str> = match &deliminator {
Some(del) => line_no_comments.split(del).collect(),
None => line_no_comments.split_whitespace().collect(),
};
let src = pieces[0].parse::<usize>()?;
let target = pieces[1].parse::<usize>()?;
let max_index = cmp::max(src, target);
while max_index >= out_graph.node_count() {
out_graph.add_node(py.None());
}
let weight = if pieces.len() > 2 {
let weight_str = match &deliminator {
Some(del) => pieces[2..].join(del),
None => pieces[2..].join(&' '.to_string()),
};
PyString::new(py, &weight_str).into()
} else {
py.None()
};
out_graph.add_edge(
NodeIndex::new(src),
NodeIndex::new(target),
weight,
);
}
Ok(PyGraph {
graph: out_graph,
node_removed: false,
multigraph: true,
})
}
#[staticmethod]
#[text_signature = "(matrix, /)"]
pub fn from_adjacency_matrix<'p>(
py: Python<'p>,
matrix: PyReadonlyArray2<'p, f64>,
) -> PyGraph {
let array = matrix.as_array();
let shape = array.shape();
let mut out_graph = StableUnGraph::<PyObject, PyObject>::default();
let _node_indices: Vec<NodeIndex> = (0..shape[0])
.map(|node| out_graph.add_node(node.to_object(py)))
.collect();
array
.axis_iter(Axis(0))
.enumerate()
.for_each(|(index, row)| {
let source_index = NodeIndex::new(index);
for target_index in 0..row.len() {
if target_index < index {
continue;
}
if row[[target_index]] > 0.0 {
out_graph.add_edge(
source_index,
NodeIndex::new(target_index),
row[[target_index]].to_object(py),
);
}
}
});
PyGraph {
graph: out_graph,
node_removed: false,
multigraph: true,
}
}
#[text_signature = "(self, other, node_map, /, node_map_func=None, edge_map_func=None)"]
pub fn compose(
&mut self,
py: Python,
other: &PyGraph,
node_map: HashMap<usize, (usize, PyObject)>,
node_map_func: Option<PyObject>,
edge_map_func: Option<PyObject>,
) -> PyResult<PyObject> {
let mut new_node_map: HashMap<NodeIndex, NodeIndex> =
HashMap::with_capacity(other.node_count());
for node in other.graph.node_indices() {
let new_index = self.graph.add_node(weight_transform_callable(
py,
&node_map_func,
&other.graph[node],
)?);
new_node_map.insert(node, new_index);
}
for edge in other.graph.edge_references() {
let new_p_index = new_node_map.get(&edge.source()).unwrap();
let new_c_index = new_node_map.get(&edge.target()).unwrap();
let weight =
weight_transform_callable(py, &edge_map_func, edge.weight())?;
self.graph.add_edge(*new_p_index, *new_c_index, weight);
}
for (this_index, (index, weight)) in node_map.iter() {
let new_index = new_node_map.get(&NodeIndex::new(*index)).unwrap();
self.graph.add_edge(
NodeIndex::new(*this_index),
*new_index,
weight.clone_ref(py),
);
}
let out_dict = PyDict::new(py);
for (orig_node, new_node) in new_node_map.iter() {
out_dict.set_item(orig_node.index(), new_node.index())?;
}
Ok(out_dict.into())
}
#[text_signature = "(self, nodes, /)"]
pub fn subgraph(&self, py: Python, nodes: Vec<usize>) -> PyGraph {
let node_set: HashSet<usize> = nodes.iter().cloned().collect();
let mut node_map: HashMap<NodeIndex, NodeIndex> =
HashMap::with_capacity(nodes.len());
let node_filter =
|node: NodeIndex| -> bool { node_set.contains(&node.index()) };
let mut out_graph = StableUnGraph::<PyObject, PyObject>::default();
let filtered = NodeFiltered(self, node_filter);
for node in filtered.node_references() {
let new_node = out_graph.add_node(node.1.clone_ref(py));
node_map.insert(node.0, new_node);
}
for edge in filtered.edge_references() {
let new_source = *node_map.get(&edge.source()).unwrap();
let new_target = *node_map.get(&edge.target()).unwrap();
out_graph.add_edge(
new_source,
new_target,
edge.weight().clone_ref(py),
);
}
PyGraph {
graph: out_graph,
node_removed: false,
multigraph: self.multigraph,
}
}
}
#[pyproto]
impl PyMappingProtocol for PyGraph {
fn __len__(&self) -> PyResult<usize> {
Ok(self.graph.node_count())
}
fn __getitem__(&'p self, idx: usize) -> PyResult<&'p PyObject> {
match self.graph.node_weight(NodeIndex::new(idx)) {
Some(data) => Ok(data),
None => Err(PyIndexError::new_err("No node found for index")),
}
}
fn __setitem__(&'p mut self, idx: usize, value: PyObject) -> PyResult<()> {
let data = match self.graph.node_weight_mut(NodeIndex::new(idx)) {
Some(node_data) => node_data,
None => {
return Err(PyIndexError::new_err("No node found for index"))
}
};
*data = value;
Ok(())
}
fn __delitem__(&'p mut self, idx: usize) -> PyResult<()> {
match self.graph.remove_node(NodeIndex::new(idx as usize)) {
Some(_) => Ok(()),
None => Err(PyIndexError::new_err("No node found for index")),
}
}
}
#[pyproto]
impl PyGCProtocol for PyGraph {
fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
for node in self
.graph
.node_indices()
.map(|node| self.graph.node_weight(node).unwrap())
{
visit.call(node)?;
}
for edge in self
.graph
.edge_indices()
.map(|edge| self.graph.edge_weight(edge).unwrap())
{
visit.call(edge)?;
}
Ok(())
}
fn __clear__(&mut self) {
self.graph = StableUnGraph::<PyObject, PyObject>::default();
self.node_removed = false;
}
}
fn weight_transform_callable(
py: Python,
map_fn: &Option<PyObject>,
value: &PyObject,
) -> PyResult<PyObject> {
match map_fn {
Some(map_fn) => {
let res = map_fn.call1(py, (value,))?;
Ok(res.to_object(py))
}
None => Ok(value.clone_ref(py)),
}
}