use std::collections::BTreeSet;
use hugr_core::Node;
use hugr_core::hugr::patch::simple_replace::BoundaryMode;
use hugr_core::ops::handle::DataflowParentID;
use itertools::{Either, Itertools};
use thiserror::Error;
use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootChecked};
use crate::{Commit, PersistentReplacement, PinnedSubgraph};
use crate::PersistentWire;
use super::{CommitStateSpace, InvalidCommit, PatchNode, PersistentHugr, state_space::CommitId};
#[derive(Debug, Clone)]
pub struct Walker<'a> {
state_space: &'a CommitStateSpace,
selected_commits: PersistentHugr,
pinned_nodes: BTreeSet<PatchNode>,
}
impl<'a> Walker<'a> {
pub fn new(state_space: &'a CommitStateSpace) -> Self {
let base = state_space.base_commit().expect("non-empty state space");
let selected_commits: PersistentHugr = PersistentHugr::from_commit(base);
Self {
state_space,
selected_commits,
pinned_nodes: BTreeSet::new(),
}
}
pub fn from_pinned_node(node: PatchNode, state_space: &'a CommitStateSpace) -> Self {
let mut walker = Self::new(state_space);
walker
.try_pin_node(node)
.expect("node is valid and not deleted");
walker
}
pub fn try_pin_node(&mut self, node: PatchNode) -> Result<bool, PinNodeError> {
let commit_id = node.0;
if self.selected_commits.contains_id(commit_id) {
if !self.selected_commits.contains_node(node) {
return Err(PinNodeError::AlreadyDeleted(node));
}
} else {
let commit = self
.state_space
.try_upgrade(commit_id)
.ok_or(PinNodeError::UnknownCommitId(commit_id))?;
self.try_select_commit(commit)?;
}
Ok(self.pinned_nodes.insert(node))
}
pub fn try_select_commit(&mut self, commit: Commit) -> Result<CommitId, PinNodeError> {
let backup = self.selected_commits.clone();
let commit_id = self.selected_commits.try_add_commit(commit)?;
if let Some(&pinned_node) = self
.pinned_nodes
.iter()
.find(|&&n| !self.selected_commits.contains_node(n))
{
self.selected_commits = backup;
return Err(PinNodeError::AlreadyPinned(pinned_node));
}
Ok(commit_id)
}
pub fn expand<'b>(
&'b self,
wire: &'b PersistentWire,
dir: impl Into<Option<Direction>>,
) -> impl Iterator<Item = Walker<'a>> + 'b {
let dir = dir.into();
if self.is_complete(wire, dir) {
return Either::Left(std::iter::once(self.clone()));
}
let unpinned_ports = self.wire_unpinned_ports(wire, dir);
let pinnable_nodes = unpinned_ports
.flat_map(|(node, port)| self.equivalent_descendant_ports(node, port))
.map(|(n, _, commits)| (n, commits))
.unique();
let new_walkers = pinnable_nodes.filter_map(|(pinnable_node, new_commits)| {
let contains_new_commit = || {
new_commits
.iter()
.any(|&cm| !self.selected_commits.contains_id(cm))
};
debug_assert!(
!self.is_pinned(pinnable_node) || contains_new_commit(),
"trying to pin already pinned node and no new commit is selected"
);
let new_commits = new_commits
.iter()
.map(|&id| self.state_space.try_upgrade(id))
.collect::<Option<Vec<_>>>()?;
let new_selected_commits = {
let mut phugr = self.selected_commits.clone();
phugr.try_add_commits(new_commits).ok()?;
phugr
};
if self
.pinned_nodes
.iter()
.any(|&pnode| !new_selected_commits.contains_node(pnode))
{
return None;
}
let mut new_walker = Walker {
state_space: self.state_space,
selected_commits: new_selected_commits,
pinned_nodes: self.pinned_nodes.clone(),
};
new_walker.try_pin_node(pinnable_node).ok()?;
Some(new_walker)
});
Either::Right(new_walkers)
}
pub fn try_create_commit(
&self,
subgraph: impl Into<PinnedSubgraph>,
mut repl: RootChecked<Hugr, DataflowParentID>,
map_boundary: impl Fn(PatchNode, Port) -> Port,
) -> Result<Commit<'a>, InvalidCommit> {
let pinned_subgraph = subgraph.into();
let subgraph = pinned_subgraph.to_sibling_subgraph(self.as_hugr_view())?;
let selected_commits = pinned_subgraph
.selected_commits()
.map(|id| self.selected_commits.get_commit(id).clone());
let repl = {
let new_inputs = subgraph
.incoming_ports()
.iter()
.flatten() .map(|&(n, p)| {
map_boundary(n, p.into())
.as_outgoing()
.expect("unexpected port direction returned by map_boundary")
.index()
})
.collect_vec();
let new_outputs = subgraph
.outgoing_ports()
.iter()
.map(|&(n, p)| {
map_boundary(n, p.into())
.as_incoming()
.expect("unexpected port direction returned by map_boundary")
.index()
})
.collect_vec();
repl.map_function_type(&new_inputs, &new_outputs)?;
PersistentReplacement::try_new(subgraph, self.as_hugr_view(), repl.into_hugr())?
};
Commit::try_new(repl, selected_commits, self.state_space)
}
pub fn get_wire(&self, node: PatchNode, port: impl Into<Port>) -> PersistentWire {
assert!(self.is_pinned(node), "node must be pinned");
self.selected_commits.get_wire(node, port)
}
pub fn into_persistent_hugr(self) -> PersistentHugr {
self.selected_commits
}
pub fn as_hugr_view(&self) -> &PersistentHugr {
&self.selected_commits
}
pub fn is_pinned(&self, node: PatchNode) -> bool {
self.pinned_nodes.contains(&node)
}
pub fn pinned_nodes(&self) -> impl Iterator<Item = PatchNode> + '_ {
self.pinned_nodes.iter().copied()
}
fn equivalent_descendant_ports(
&self,
node: PatchNode,
port: Port,
) -> Vec<(PatchNode, Port, BTreeSet<CommitId>)> {
let mut all_ports = vec![(node, port, BTreeSet::new())];
let mut index = 0;
while index < all_ports.len() {
let (node, port, empty_commits) = all_ports[index].clone();
let Some(commit) = self.state_space.try_upgrade(node.owner()) else {
continue;
};
index += 1;
for (child, (opp_node, opp_port)) in
commit.children_at_boundary_port(node.1, port, self.state_space)
{
for (node, port) in
commit.linked_child_ports(opp_node, opp_port, &child, BoundaryMode::SnapToHost)
{
let mut empty_commits = empty_commits.clone();
if node.owner() != child.id() {
empty_commits.insert(child.id());
}
all_ports.push((node, port, empty_commits));
}
}
}
all_ports
}
}
#[cfg(test)]
impl Walker<'_> {
fn component_wise_ptr_eq(&self, other: &Self) -> bool {
self.state_space == other.state_space
&& self.pinned_nodes == other.pinned_nodes
&& BTreeSet::from_iter(self.selected_commits.all_commit_ids())
== BTreeSet::from_iter(other.selected_commits.all_commit_ids())
}
fn no_more_expansion(&self, wire: &PersistentWire, dir: impl Into<Option<Direction>>) -> bool {
let Some([new_walker]) = self.expand(wire, dir).collect_array() else {
return false;
};
new_walker.component_wise_ptr_eq(self)
}
}
impl<'a> Commit<'a> {
fn children_at_boundary_port(
&self,
node: Node,
port: Port,
state_space: &'a CommitStateSpace,
) -> impl Iterator<Item = (Commit<'a>, (Node, Port))> + '_ {
let linked_ports = self.commit_hugr().linked_ports(node, port).collect_vec();
self.children(state_space).flat_map(move |child| {
let deleted_nodes: BTreeSet<_> = child.deleted_parent_nodes().collect();
if !deleted_nodes.contains(&self.to_patch_node(node)) {
vec![]
} else {
linked_ports
.iter()
.filter_map(move |&(linked_node, linked_port)| {
(!deleted_nodes.contains(&self.to_patch_node(linked_node)))
.then_some((child.clone(), (linked_node, linked_port)))
})
.collect_vec()
}
})
}
}
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum PinNodeError {
#[error("cannot add commit to pin node: {0}")]
InvalidNewCommit(InvalidCommit),
#[error("cannot pin deleted node: {0}")]
AlreadyDeleted(PatchNode),
#[error("cannot delete already pinned node: {0}")]
AlreadyPinned(PatchNode),
#[error("unknown commit ID: {0:?}")]
UnknownCommitId(CommitId),
}
impl From<InvalidCommit> for PinNodeError {
fn from(value: InvalidCommit) -> Self {
PinNodeError::InvalidNewCommit(value)
}
}
impl<'w> hugr_core::hugr::views::NodesIter for Walker<'w> {
type Node = PatchNode;
fn nodes(&self) -> impl Iterator<Item = Self::Node> + '_ {
<PersistentHugr as HugrView>::nodes(self.as_hugr_view())
}
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use hugr_core::{
Direction, HugrView, IncomingPort, OutgoingPort,
builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig},
extension::prelude::bool_t,
std_extensions::logic::LogicOp,
};
use itertools::Itertools;
use rstest::rstest;
use super::*;
use crate::{
PersistentHugr, Walker,
state_space::CommitId,
tests::{TestStateSpace, persistent_hugr_empty_child, test_state_space},
};
#[rstest]
fn test_walker_base_or_child_expansion(test_state_space: TestStateSpace) {
let [commit1, _commit2, _commit3, _commit4] = test_state_space.commits();
let state_space = commit1.state_space();
let base_commit = commit1.base_commit();
let base_and_node = {
let base_hugr = base_commit.commit_hugr();
let and_node = base_hugr
.nodes()
.find(|&n| base_hugr.get_optype(n) == &LogicOp::And.into())
.unwrap();
base_commit.to_patch_node(and_node)
};
let walker = Walker::from_pinned_node(base_and_node, &state_space);
assert!(walker.is_pinned(base_and_node));
let in0 = walker.get_wire(base_and_node, IncomingPort::from(0));
assert!(walker.no_more_expansion(&in0, Direction::Incoming));
let out_walkers = walker.expand(&in0, Direction::Outgoing).collect_vec();
assert_eq!(out_walkers.len(), 2);
for new_walker in out_walkers {
let in0 = new_walker.get_wire(base_and_node, IncomingPort::from(0));
assert!(new_walker.is_complete(&in0, None));
assert!(new_walker.no_more_expansion(&in0, None));
let (not_node, _) = in0.single_outgoing_port(new_walker.as_hugr_view()).unwrap();
assert!(new_walker.is_pinned(base_and_node));
assert!(new_walker.is_pinned(not_node));
assert!([commit1.id(), base_commit.id()].contains(¬_node.0));
assert_eq!(
new_walker.as_hugr_view().get_optype(not_node),
&LogicOp::Not.into()
);
let persistent_hugr = new_walker.into_persistent_hugr();
let hugr = persistent_hugr.get_commit(not_node.owner()).commit_hugr();
assert_eq!(hugr.get_optype(not_node.1), &LogicOp::Not.into());
}
}
#[rstest]
fn test_walker_disjoint_nephew_expansion(test_state_space: TestStateSpace) {
let [commit1, commit2, commit3, commit4] = test_state_space.commits();
let base_commit = commit1.base_commit();
let state_space = commit4.state_space();
let not4_node = {
let repl4 = commit4.replacement().unwrap();
let hugr4 = commit4.commit_hugr();
let [_, output] = repl4.get_replacement_io();
let (second_not_node, _) = hugr4.single_linked_output(output, 0).unwrap();
commit4.to_patch_node(second_not_node)
};
let walker = Walker::from_pinned_node(not4_node, &state_space);
assert!(walker.is_pinned(not4_node));
let not4_out = walker.get_wire(not4_node, OutgoingPort::from(0));
assert!(walker.no_more_expansion(¬4_out, Direction::Outgoing));
let mut exp_options = BTreeSet::from_iter([
BTreeSet::from_iter([base_commit.id(), commit4.id()]),
BTreeSet::from_iter([base_commit.id(), commit3.id(), commit4.id()]),
BTreeSet::from_iter([base_commit.id(), commit1.id(), commit2.id(), commit4.id()]),
]);
for new_walker in walker.expand(¬4_out, None) {
let commit_ids = new_walker
.as_hugr_view()
.all_commit_ids()
.collect::<BTreeSet<_>>();
assert!(
exp_options.remove(&commit_ids),
"{commit_ids:?} not an expected set of commit IDs (or duplicate)"
);
let not4_out = new_walker.get_wire(not4_node, OutgoingPort::from(0));
assert!(new_walker.is_complete(¬4_out, None));
assert!(new_walker.no_more_expansion(¬4_out, None));
let (next_node, _) = not4_out
.all_incoming_ports(new_walker.as_hugr_view())
.exactly_one()
.ok()
.unwrap();
assert!(new_walker.is_pinned(not4_node));
assert!(new_walker.is_pinned(next_node));
let persistent_hugr = new_walker.into_persistent_hugr();
let expected_optype = match next_node.0 {
commit_id if commit_id == base_commit.id() => LogicOp::And,
commit_id if [commit2.id(), commit3.id()].contains(&commit_id) => LogicOp::Xor,
_ => panic!("neighbour of not4 must be in base, commit2 or commit3"),
};
assert_eq!(
persistent_hugr.get_optype(next_node),
&expected_optype.into()
);
}
assert!(
exp_options.is_empty(),
"missing expected options: {exp_options:?}"
);
}
#[rstest]
fn test_get_wire_endpoints(test_state_space: TestStateSpace) {
let [commit1, commit2, _commit3, commit4] = test_state_space.commits();
let base_commit = commit1.base_commit();
let base_and_node = {
let base_hugr = base_commit.commit_hugr();
let and_node = base_hugr
.nodes()
.find(|&n| base_hugr.get_optype(n) == &LogicOp::And.into())
.unwrap();
base_commit.to_patch_node(and_node)
};
let hugr = PersistentHugr::try_new([commit4.clone()]).unwrap();
let (second_not_node, out_port) =
hugr.single_outgoing_port(base_and_node, IncomingPort::from(1));
assert_eq!(second_not_node.0, commit4.id());
assert_eq!(out_port, OutgoingPort::from(0));
let hugr =
PersistentHugr::try_new([commit1.clone(), commit2.clone(), commit4.clone()]).unwrap();
let (new_and_node, in_port) = hugr
.all_incoming_ports(second_not_node, out_port)
.exactly_one()
.ok()
.unwrap();
assert_eq!(new_and_node.0, commit2.id());
assert_eq!(in_port, 1.into());
}
#[rstest]
fn test_walk_over_empty_repls(
persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]),
) {
let (hugr, [base_commit, empty_commit], [not0, not1, not2]) = persistent_hugr_empty_child;
let state_space = hugr.state_space();
let walker = Walker::from_pinned_node(not0, state_space);
let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0));
let expanded_wires = walker
.expand(¬0_outwire, Direction::Incoming)
.collect_vec();
assert_eq!(expanded_wires.len(), 2);
let connected_inports: BTreeSet<_> = expanded_wires
.iter()
.map(|new_walker| {
let wire = new_walker.get_wire(not0, OutgoingPort::from(0));
wire.all_incoming_ports(new_walker.as_hugr_view())
.exactly_one()
.ok()
.unwrap()
})
.collect();
assert_eq!(
connected_inports,
BTreeSet::from_iter([(not1, IncomingPort::from(0)), (not2, IncomingPort::from(0))])
);
let traversed_commits: BTreeSet<BTreeSet<_>> = expanded_wires
.iter()
.map(|new_walker| {
let wire = new_walker.get_wire(not0, OutgoingPort::from(0));
wire.owners().collect()
})
.collect();
assert_eq!(
traversed_commits,
BTreeSet::from_iter([
BTreeSet::from_iter([base_commit]),
BTreeSet::from_iter([base_commit, empty_commit])
])
);
}
#[rstest]
fn test_create_commit_over_empty(
persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]),
) {
let (mut hugr, [base_commit, empty_commit], [not0, _not1, not2]) =
persistent_hugr_empty_child;
let state_space = hugr.state_space().clone();
let mut walker = Walker {
state_space: &state_space,
selected_commits: hugr.clone(),
pinned_nodes: BTreeSet::from_iter([not0]),
};
let wire = walker.get_wire(not0, OutgoingPort::from(0));
walker = walker.expand(&wire, None).exactly_one().ok().unwrap();
let wire = walker.get_wire(not0, OutgoingPort::from(0));
assert!(walker.is_complete(&wire, None));
let empty_hugr = {
let dfg_builder = DFGBuilder::new(endo_sig([bool_t()])).unwrap();
let inputs = dfg_builder.input_wires();
dfg_builder.finish_hugr_with_outputs(inputs).unwrap()
};
let commit = walker
.try_create_commit(
PinnedSubgraph::try_from_pinned(std::iter::empty(), [wire], &walker).unwrap(),
RootChecked::try_new(empty_hugr).expect("Root should be DFG."),
|node, port| {
assert_eq!(port.index(), 0);
assert!([not0, not2].contains(&node));
match port.direction() {
Direction::Incoming => OutgoingPort::from(0).into(),
Direction::Outgoing => IncomingPort::from(0).into(),
}
},
)
.unwrap();
let commit_id = hugr.try_add_commit(commit.clone()).unwrap();
assert_eq!(
hugr.parent_commits(commit_id).collect::<BTreeSet<_>>(),
BTreeSet::from_iter([base_commit, empty_commit])
);
let res_hugr: PersistentHugr = PersistentHugr::from_commit(commit);
assert!(res_hugr.validate().is_ok());
assert_eq!(res_hugr.num_nodes(), 1 + 1 + 2 + 1 + 2);
}
#[rstest]
fn test_walk_over_two_pinned_nodes(
persistent_hugr_empty_child: (PersistentHugr, [CommitId; 2], [PatchNode; 3]),
) {
let (hugr, [base_commit, empty_commit], [not0, _not1, not2]) = persistent_hugr_empty_child;
let mut walker = Walker::from_pinned_node(not0, hugr.state_space());
assert!(walker.try_pin_node(not2).unwrap());
let not0_outwire = walker.get_wire(not0, OutgoingPort::from(0));
let expanded_walkers = walker.expand(¬0_outwire, Direction::Incoming);
let expanded_wires: BTreeSet<BTreeSet<_>> = expanded_walkers
.map(|new_walker| {
new_walker
.get_wire(not0, OutgoingPort::from(0))
.owners()
.collect()
})
.collect();
assert_eq!(
expanded_wires,
BTreeSet::from_iter([
BTreeSet::from_iter([base_commit]),
BTreeSet::from_iter([base_commit, empty_commit])
])
);
}
}