use crate::revisionid::RevisionId;
use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;
use pyo3::types::{PyFrozenSet, PyIterator, PyTuple};
use std::collections::HashMap;
use std::hash::Hash;
pub trait GraphNode: Eq + Hash + Clone {
fn to_pyobject<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>>;
fn from_pyobject(obj: &Bound<PyAny>) -> PyResult<Self>;
}
pub struct Graph(Py<PyAny>);
impl<'py> IntoPyObject<'py> for Graph {
type Target = PyAny;
type Output = Bound<'py, Self::Target>;
type Error = std::convert::Infallible;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
Ok(self.0.into_bound(py))
}
}
impl<'a, 'py> FromPyObject<'a, 'py> for Graph {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
Ok(Graph(ob.to_owned().unbind()))
}
}
impl From<Py<PyAny>> for Graph {
fn from(ob: Py<PyAny>) -> Self {
Graph(ob)
}
}
impl GraphNode for RevisionId {
fn to_pyobject<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
Ok(self.as_bytes().into_pyobject(py)?.into_any())
}
fn from_pyobject(obj: &Bound<PyAny>) -> PyResult<Self> {
let bytes: Vec<u8> = obj.extract()?;
Ok(RevisionId::from(bytes))
}
}
struct NodeIter<T: GraphNode>(Py<PyAny>, std::marker::PhantomData<T>);
impl<T: GraphNode> Iterator for NodeIter<T> {
type Item = Result<T, crate::error::Error>;
fn next(&mut self) -> Option<Self::Item> {
Python::attach(|py| match self.0.call_method0(py, "__next__") {
Ok(item) => match T::from_pyobject(item.bind(py)) {
Ok(node) => Some(Ok(node)),
Err(e) => Some(Err(e.into())),
},
Err(e) if e.is_instance_of::<PyStopIteration>(py) => None,
Err(e) => Some(Err(e.into())),
})
}
}
struct TopoOrderIter<T: GraphNode>(Py<PyAny>, std::marker::PhantomData<T>);
impl<T: GraphNode> Iterator for TopoOrderIter<T> {
type Item = Result<(usize, T, usize, bool), crate::error::Error>;
fn next(&mut self) -> Option<Self::Item> {
Python::attach(|py| match self.0.call_method0(py, "__next__") {
Ok(item) => {
let tuple = match item.bind(py).cast::<PyTuple>() {
Ok(t) => t,
Err(e) => return Some(Err(PyErr::from(e).into())),
};
if tuple.len() != 4 {
return Some(Err(pyo3::exceptions::PyValueError::new_err(
"Expected 4-tuple from iter_topo_order",
)
.into()));
}
match (
tuple.get_item(0).and_then(|i| i.extract::<usize>()),
tuple.get_item(1).and_then(|i| T::from_pyobject(&i)),
tuple.get_item(2).and_then(|i| i.extract::<usize>()),
tuple.get_item(3).and_then(|i| i.extract::<bool>()),
) {
(Ok(seq), Ok(node), Ok(depth), Ok(eom)) => Some(Ok((seq, node, depth, eom))),
_ => Some(Err(pyo3::exceptions::PyValueError::new_err(
"Failed to extract values from topo_order tuple",
)
.into())),
}
}
Err(e) if e.is_instance_of::<PyStopIteration>(py) => None,
Err(e) => Some(Err(e.into())),
})
}
}
impl Graph {
pub(crate) fn as_pyobject(&self) -> &Py<PyAny> {
&self.0
}
pub fn is_ancestor<T: GraphNode>(
&self,
node1: &T,
node2: &T,
) -> Result<bool, crate::error::Error> {
Python::attach(|py| {
let result = self.0.call_method1(
py,
"is_ancestor",
(node1.to_pyobject(py)?, node2.to_pyobject(py)?),
)?;
Ok(result.extract(py)?)
})
}
pub fn iter_lefthand_ancestry<T: GraphNode>(
&self,
node: &T,
stop_nodes: Option<&[T]>,
) -> Result<impl Iterator<Item = Result<T, crate::error::Error>>, crate::error::Error> {
Python::attach(|py| {
let stop_py = if let Some(nodes) = stop_nodes {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
Some(py_nodes?)
} else {
None
};
let iter = self.0.call_method1(
py,
"iter_lefthand_ancestry",
(node.to_pyobject(py)?, stop_py),
)?;
Ok(NodeIter(iter, std::marker::PhantomData))
})
}
pub fn find_lca<T: GraphNode>(&self, nodes: &[T]) -> Result<Vec<T>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self.0.call_method1(py, "find_lca", (py_nodes?,))?;
let py_set = result
.cast_bound::<pyo3::types::PySet>(py)
.map_err(PyErr::from)?;
let mut lca_nodes = Vec::new();
for item in py_set {
lca_nodes.push(T::from_pyobject(&item)?)
}
Ok(lca_nodes)
})
}
pub fn heads<T: GraphNode>(&self, nodes: &[T]) -> Result<Vec<T>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self.0.call_method1(py, "heads", (py_nodes?,))?;
let py_set = result
.cast_bound::<pyo3::types::PySet>(py)
.map_err(PyErr::from)?;
let mut head_nodes = Vec::new();
for item in py_set {
head_nodes.push(T::from_pyobject(&item)?)
}
Ok(head_nodes)
})
}
pub fn find_unique_ancestors<T: GraphNode>(
&self,
nodes: &[T],
common_nodes: &[T],
) -> Result<Vec<T>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let py_common: Result<Vec<_>, _> =
common_nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result =
self.0
.call_method1(py, "find_unique_ancestors", (py_nodes?, py_common?))?;
let py_list = result
.cast_bound::<pyo3::types::PyList>(py)
.map_err(PyErr::from)?;
let mut unique_ancestors = Vec::new();
for item in py_list {
unique_ancestors.push(T::from_pyobject(&item)?)
}
Ok(unique_ancestors)
})
}
pub fn find_difference<T: GraphNode>(
&self,
left_nodes: &[T],
right_nodes: &[T],
) -> Result<(Vec<T>, Vec<T>), crate::error::Error> {
Python::attach(|py| {
let py_left: Result<Vec<_>, _> = left_nodes.iter().map(|n| n.to_pyobject(py)).collect();
let py_right: Result<Vec<_>, _> =
right_nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self
.0
.call_method1(py, "find_difference", (py_left?, py_right?))?;
let tuple = result.cast_bound::<PyTuple>(py).map_err(PyErr::from)?;
let left_only = tuple.get_item(0)?;
let right_only = tuple.get_item(1)?;
let mut left_result = Vec::new();
for item in left_only
.cast::<pyo3::types::PySet>()
.map_err(PyErr::from)?
{
left_result.push(T::from_pyobject(&item)?);
}
let mut right_result = Vec::new();
for item in right_only
.cast::<pyo3::types::PySet>()
.map_err(PyErr::from)?
{
right_result.push(T::from_pyobject(&item)?);
}
Ok((left_result, right_result))
})
}
pub fn iter_ancestry<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<impl Iterator<Item = Result<T, crate::error::Error>>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let iter = self.0.call_method1(py, "iter_ancestry", (py_nodes?,))?;
Ok(NodeIter(iter, std::marker::PhantomData))
})
}
pub fn get_parent_map<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<HashMap<T, Vec<T>>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self.0.call_method1(py, "get_parent_map", (py_nodes?,))?;
let py_dict = result
.cast_bound::<pyo3::types::PyDict>(py)
.map_err(PyErr::from)?;
let mut parent_map = HashMap::new();
for (key, value) in py_dict {
let key_node = T::from_pyobject(&key)?;
let mut parents = Vec::new();
for parent in value.cast::<pyo3::types::PyTuple>().map_err(PyErr::from)? {
parents.push(T::from_pyobject(&parent)?);
}
parent_map.insert(key_node, parents);
}
Ok(parent_map)
})
}
pub fn is_between<T: GraphNode>(
&self,
candidate: &T,
ancestor: &T,
descendant: &T,
) -> Result<bool, crate::error::Error> {
Python::attach(|py| {
let result = self.0.call_method1(
py,
"is_between",
(
candidate.to_pyobject(py)?,
ancestor.to_pyobject(py)?,
descendant.to_pyobject(py)?,
),
)?;
Ok(result.extract(py)?)
})
}
pub fn iter_topo_order<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<
impl Iterator<Item = Result<(usize, T, usize, bool), crate::error::Error>>,
crate::error::Error,
> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let iter = self.0.call_method1(py, "iter_topo_order", (py_nodes?,))?;
Ok(TopoOrderIter(iter, std::marker::PhantomData))
})
}
pub fn find_descendants<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<Vec<T>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self.0.call_method1(py, "find_descendants", (py_nodes?,))?;
let py_set = result
.cast_bound::<pyo3::types::PySet>(py)
.map_err(PyErr::from)?;
let mut descendants = Vec::new();
for item in py_set {
descendants.push(T::from_pyobject(&item)?);
}
Ok(descendants)
})
}
pub fn find_distance_to_null<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<HashMap<T, usize>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self
.0
.call_method1(py, "find_distance_to_null", (py_nodes?,))?;
let py_dict = result
.cast_bound::<pyo3::types::PyDict>(py)
.map_err(PyErr::from)?;
let mut distance_map = HashMap::new();
for (key, value) in py_dict {
let key_node = T::from_pyobject(&key)?;
let distance: usize = value.extract()?;
distance_map.insert(key_node, distance);
}
Ok(distance_map)
})
}
pub fn find_unique_lca<T: GraphNode>(
&self,
nodes: &[T],
count: Option<usize>,
) -> Result<Option<T>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = if let Some(c) = count {
self.0.call_method1(py, "find_unique_lca", (py_nodes?, c))?
} else {
self.0.call_method1(py, "find_unique_lca", (py_nodes?,))?
};
if result.is_none(py) {
Ok(None)
} else {
Ok(Some(T::from_pyobject(result.bind(py))?))
}
})
}
pub fn find_merge_order<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<Vec<T>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self.0.call_method1(py, "find_merge_order", (py_nodes?,))?;
let py_list = result
.cast_bound::<pyo3::types::PyList>(py)
.map_err(PyErr::from)?;
let mut merge_order = Vec::new();
for item in py_list {
merge_order.push(T::from_pyobject(&item)?);
}
Ok(merge_order)
})
}
pub fn find_lefthand_merger<T: GraphNode>(
&self,
node: &T,
tip: Option<&T>,
) -> Result<Option<T>, crate::error::Error> {
Python::attach(|py| {
let args = if let Some(t) = tip {
(node.to_pyobject(py)?, t.to_pyobject(py)?)
} else {
(node.to_pyobject(py)?, py.None().into_bound(py))
};
let result = self.0.call_method1(py, "find_lefthand_merger", args)?;
if result.is_none(py) {
Ok(None)
} else {
Ok(Some(T::from_pyobject(result.bind(py))?))
}
})
}
pub fn find_lefthand_distances<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<HashMap<T, usize>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self
.0
.call_method1(py, "find_lefthand_distances", (py_nodes?,))?;
let py_dict = result
.cast_bound::<pyo3::types::PyDict>(py)
.map_err(PyErr::from)?;
let mut distance_map = HashMap::new();
for (key, value) in py_dict {
let key_node = T::from_pyobject(&key)?;
let distance: usize = value.extract()?;
distance_map.insert(key_node, distance);
}
Ok(distance_map)
})
}
pub fn get_child_map<T: GraphNode>(
&self,
nodes: &[T],
) -> Result<HashMap<T, Vec<T>>, crate::error::Error> {
Python::attach(|py| {
let py_nodes: Result<Vec<_>, _> = nodes.iter().map(|n| n.to_pyobject(py)).collect();
let result = self.0.call_method1(py, "get_child_map", (py_nodes?,))?;
let py_dict = result
.cast_bound::<pyo3::types::PyDict>(py)
.map_err(PyErr::from)?;
let mut child_map = HashMap::new();
for (key, value) in py_dict {
let key_node = T::from_pyobject(&key)?;
let mut children = Vec::new();
for child in value.cast::<pyo3::types::PyList>().map_err(PyErr::from)? {
children.push(T::from_pyobject(&child)?);
}
child_map.insert(key_node, children);
}
Ok(child_map)
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Key(Vec<String>);
impl From<Vec<String>> for Key {
fn from(v: Vec<String>) -> Self {
Key(v)
}
}
impl From<Key> for Vec<String> {
fn from(k: Key) -> Self {
k.0
}
}
impl<'py> IntoPyObject<'py> for Key {
type Target = PyTuple;
type Output = Bound<'py, Self::Target>;
type Error = PyErr;
fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
PyTuple::new(py, self.0)
}
}
impl<'a, 'py> FromPyObject<'a, 'py> for Key {
type Error = PyErr;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
let tuple = ob.cast::<PyTuple>()?;
let mut items = Vec::new();
for item in tuple.iter() {
items.push(item.extract::<String>()?);
}
Ok(Key(items))
}
}
impl GraphNode for Key {
fn to_pyobject<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
Ok(PyTuple::new(py, &self.0)?.into_any())
}
fn from_pyobject(obj: &Bound<PyAny>) -> PyResult<Self> {
obj.extract::<Key>()
}
}
pub struct KnownGraph(Py<PyAny>);
impl KnownGraph {
pub fn new(py_obj: Py<PyAny>) -> Self {
Self(py_obj)
}
pub fn heads<T: GraphNode>(&self, nodes: Vec<T>) -> Result<Vec<T>, crate::error::Error> {
Python::attach(|py| {
let nodes_py: Vec<_> = nodes
.into_iter()
.map(|n| n.to_pyobject(py))
.collect::<Result<Vec<_>, _>>()?;
let nodes_frozenset = PyFrozenSet::new(py, &nodes_py)?;
let result = self.0.call_method1(py, "heads", (nodes_frozenset,))?;
let mut heads = Vec::new();
for head_py in result
.cast_bound::<PyIterator>(py)
.map_err(|_| pyo3::exceptions::PyTypeError::new_err("Expected iterator"))?
{
let head = T::from_pyobject(&head_py?)?;
heads.push(head);
}
Ok(heads)
})
}
}
impl Clone for KnownGraph {
fn clone(&self) -> Self {
Python::attach(|py| KnownGraph(self.0.clone_ref(py)))
}
}
#[cfg(test)]
mod tests;