use std::collections::{BTreeSet, VecDeque};
use hugr_core::{
Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire,
hugr::patch::simple_replace::BoundaryMode,
};
use itertools::Itertools;
use crate::{CommitId, PatchNode, PersistentHugr, Walker, persistent_hugr::NodeStatus};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PersistentWire {
wires: BTreeSet<CommitWire>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct CommitWire(Wire<PatchNode>);
impl CommitWire {
fn from_connected_port(
PatchNode(commit_id, node): PatchNode,
port: impl Into<Port>,
hugr: &PersistentHugr,
) -> Self {
let commit_hugr = hugr.get_commit(commit_id).commit_hugr();
let wire = Wire::from_connected_port(node, port, commit_hugr);
Self(Wire::new(PatchNode(commit_id, wire.node()), wire.source()))
}
fn all_connected_ports<'h>(
&self,
hugr: &'h PersistentHugr,
) -> impl Iterator<Item = (PatchNode, Port)> + use<'h> {
let wire = Wire::new(self.0.node().1, self.0.source());
let commit_id = self.commit_id();
wire.all_connected_ports(hugr.get_commit(commit_id).commit_hugr())
.map(move |(node, port)| (hugr.to_persistent_node(node, commit_id), port))
}
fn commit_id(&self) -> CommitId {
self.0.node().0
}
delegate::delegate! {
to self.0 {
fn node(&self) -> PatchNode;
}
}
}
impl PersistentHugr {
pub fn get_wire(&self, node: PatchNode, port: impl Into<Port>) -> PersistentWire {
PersistentWire::from_port(node, port, self)
}
pub(crate) fn single_outgoing_port(
&self,
node: PatchNode,
port: impl Into<IncomingPort>,
) -> (PatchNode, OutgoingPort) {
let w = self.get_wire(node, port.into());
w.single_outgoing_port(self)
.expect("found invalid dfg wire")
}
pub(crate) fn all_incoming_ports(
&self,
out_node: PatchNode,
out_port: OutgoingPort,
) -> impl Iterator<Item = (PatchNode, IncomingPort)> {
let w = self.get_wire(out_node, out_port);
w.into_all_ports(self, Direction::Incoming)
.map(|(node, port)| (node, port.as_incoming().unwrap()))
}
}
impl PersistentWire {
fn from_port(node: PatchNode, port: impl Into<Port>, per_hugr: &PersistentHugr) -> Self {
debug_assert!(per_hugr.contains_node(node), "node not in hugr");
let mut commit_wires =
BTreeSet::from_iter([CommitWire::from_connected_port(node, port, per_hugr)]);
let mut queue = VecDeque::from_iter(commit_wires.iter().copied());
while let Some(wire) = queue.pop_front() {
let commit_id = wire.commit_id();
let commit = per_hugr.get_commit(commit_id);
let commit_hugr = commit.commit_hugr();
let all_ports = wire.all_connected_ports(per_hugr);
for (per_node @ PatchNode(_, node), port) in all_ports {
match per_hugr.node_status(per_node) {
NodeStatus::Deleted(deleted_by) => {
for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) {
for (child_node, child_port) in commit.linked_child_ports(
opp_node,
opp_port,
per_hugr.get_commit(deleted_by),
BoundaryMode::IncludeIO,
) {
debug_assert_eq!(child_node.owner(), deleted_by);
let w = CommitWire::from_connected_port(
child_node, child_port, per_hugr,
);
if commit_wires.insert(w) {
queue.push_back(w);
}
}
}
}
NodeStatus::ReplacementIO => {
for (opp_node, opp_port) in commit_hugr.linked_ports(node, port) {
for (parent_node, parent_port) in
commit.linked_parent_ports(opp_node, opp_port)
{
let w = CommitWire::from_connected_port(
parent_node,
parent_port,
per_hugr,
);
if commit_wires.insert(w) {
queue.push_back(w);
}
}
}
}
NodeStatus::Valid => {}
}
}
}
Self {
wires: commit_wires,
}
}
pub fn all_ports(
&self,
hugr: &PersistentHugr,
dir: impl Into<Option<Direction>>,
) -> impl Iterator<Item = (PatchNode, Port)> {
all_ports_impl(self.wires.iter().copied(), dir.into(), hugr)
}
pub fn owners(&self) -> impl Iterator<Item = CommitId> {
self.wires.iter().map(|w| w.node().owner()).unique()
}
pub fn into_all_ports(
self,
hugr: &PersistentHugr,
dir: impl Into<Option<Direction>>,
) -> impl Iterator<Item = (PatchNode, Port)> {
all_ports_impl(self.wires.into_iter(), dir.into(), hugr)
}
pub fn single_outgoing_port(&self, hugr: &PersistentHugr) -> Option<(PatchNode, OutgoingPort)> {
single_outgoing(self.all_ports(hugr, Direction::Outgoing))
}
pub fn all_incoming_ports(
&self,
hugr: &PersistentHugr,
) -> impl Iterator<Item = (PatchNode, IncomingPort)> {
self.all_ports(hugr, Direction::Incoming)
.map(|(node, port)| (node, port.as_incoming().unwrap()))
}
}
impl Walker<'_> {
pub(crate) fn wire_unpinned_ports(
&self,
wire: &PersistentWire,
dir: impl Into<Option<Direction>>,
) -> impl Iterator<Item = (PatchNode, Port)> {
let ports = wire.all_ports(self.as_hugr_view(), dir);
ports.filter(|(node, _)| !self.is_pinned(*node))
}
pub fn wire_pinned_ports(
&self,
wire: &PersistentWire,
dir: impl Into<Option<Direction>>,
) -> impl Iterator<Item = (PatchNode, Port)> {
let ports = wire.all_ports(self.as_hugr_view(), dir);
ports.filter(|(node, _)| self.is_pinned(*node))
}
pub fn wire_pinned_outport(&self, wire: &PersistentWire) -> Option<(PatchNode, OutgoingPort)> {
single_outgoing(self.wire_pinned_ports(wire, Direction::Outgoing))
}
pub fn wire_pinned_inports(
&self,
wire: &PersistentWire,
) -> impl Iterator<Item = (PatchNode, IncomingPort)> {
self.wire_pinned_ports(wire, Direction::Incoming)
.map(|(node, port)| (node, port.as_incoming().expect("incoming port")))
}
pub fn is_complete(&self, wire: &PersistentWire, dir: impl Into<Option<Direction>>) -> bool {
self.wire_unpinned_ports(wire, dir).next().is_none()
}
}
fn all_ports_impl(
wires: impl Iterator<Item = CommitWire>,
dir: Option<Direction>,
per_hugr: &PersistentHugr,
) -> impl Iterator<Item = (PatchNode, Port)> {
let all_ports = wires.flat_map(move |w| w.all_connected_ports(per_hugr));
all_ports
.filter(move |(_, port)| dir.is_none_or(|dir| port.direction() == dir))
.filter(|&(node, _)| per_hugr.node_status(node) == NodeStatus::Valid)
}
fn single_outgoing<N>(iter: impl Iterator<Item = (N, Port)>) -> Option<(N, OutgoingPort)> {
let (node, port) = iter.exactly_one().ok()?;
Some((node, port.as_outgoing().ok()?))
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use crate::{
PatchNode, PersistentHugr,
tests::{TestStateSpace, test_state_space},
};
use hugr_core::{HugrView, OutgoingPort};
use itertools::Itertools;
use rstest::rstest;
#[rstest]
fn test_all_ports(test_state_space: TestStateSpace) {
let [_, _, cm3, cm4] = test_state_space.commits();
let hugr = PersistentHugr::try_new([cm3.clone(), cm4.clone()]).unwrap();
let cm4_not = {
let hugr4 = cm4.commit_hugr();
let out = cm4.replacement().unwrap().get_replacement_io()[1];
let node = hugr4.input_neighbours(out).exactly_one().ok().unwrap();
PatchNode(cm4.id(), node)
};
let w = hugr.get_wire(cm4_not, OutgoingPort::from(0));
assert_eq!(
BTreeSet::from_iter(w.wires.iter().map(|w| w.0.node().0)),
BTreeSet::from_iter([cm3.id(), cm4.id(), hugr.base(),])
);
}
}