use std::collections::{BTreeSet, HashSet};
use std::mem;
use itertools::Itertools;
use petgraph::visit::IntoNodeIdentifiers;
use portgraph::algorithms::convex::{LineIndex, LineIntervals, Position};
use portgraph::{PortView, algorithms::CreateConvexChecker, boundary::Boundary};
use rustc_hash::FxHashSet;
use thiserror::Error;
use crate::builder::{Container, FunctionBuilder};
use crate::core::HugrNode;
use crate::hugr::internal::{HugrInternals, PortgraphNodeMap};
use crate::hugr::{HugrMut, HugrView};
use crate::ops::dataflow::DataflowOpTrait;
use crate::ops::handle::{ContainerHandle, DataflowOpID};
use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
use crate::types::{Signature, Type};
use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
use super::{RootChecked, SchedulingGraph, SynEdgeWrapper};
mod convex;
pub trait HugrConvexChecker<N: HugrNode> {
fn region_parent(&self) -> N;
fn nodes_if_convex(
&self,
hugr: &impl HugrView<Node = N>,
inputs: &IncomingPorts<N>,
outputs: &OutgoingPorts<N>,
function_calls: &IncomingPorts<N>,
) -> Result<Vec<N>, InvalidSubgraph<N>>;
}
impl<'a, H, CC> HugrConvexChecker<H::Node> for PortgraphCheckerWithNodes<'a, H, CC>
where
H: HugrView,
CC: CreateConvexChecker<CheckerRegion<'a, H>, NodeIndexBase = u32, PortIndexBase = u32>,
{
fn region_parent(&self) -> H::Node {
self.region_parent
}
fn nodes_if_convex(
&self,
hugr: &impl HugrView<Node = H::Node>,
inputs: &IncomingPorts<H::Node>,
outputs: &OutgoingPorts<H::Node>,
function_calls: &IncomingPorts<H::Node>,
) -> Result<Vec<H::Node>, InvalidSubgraph<H::Node>> {
let subpg = make_pg_subgraph::<H>(
self.checker.graph().clone(),
inputs,
outputs,
&self.node_map,
);
let nodes = subpg
.nodes_iter()
.map(|index| self.node_map.from_portgraph(index))
.collect_vec();
validate_boundary(hugr, &nodes, inputs, outputs, function_calls)?;
if subpg.is_convex_with_checker(self) {
Ok(nodes)
} else {
Err(InvalidSubgraph::NotConvex)
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SiblingSubgraph<N = Node> {
nodes: Vec<N>,
inputs: IncomingPorts<N>,
outputs: OutgoingPorts<N>,
function_calls: IncomingPorts<N>,
}
pub type IncomingPorts<N = Node> = Vec<Vec<(N, IncomingPort)>>;
pub type OutgoingPorts<N = Node> = Vec<(N, OutgoingPort)>;
impl<N: HugrNode> SiblingSubgraph<N> {
pub fn try_new_dataflow_subgraph<'h, H, Root>(
dfg_graph: RootChecked<&'h H, Root>,
) -> Result<Self, InvalidSubgraph<N>>
where
H: 'h + Clone + HugrView<Node = N>,
Root: ContainerHandle<N, ChildrenHandle = DataflowOpID>,
{
let dfg_graph = dfg_graph.into_hugr();
let parent = HugrView::entrypoint(&dfg_graph);
let nodes = dfg_graph.children(parent).skip(2).collect_vec();
if nodes.is_empty() {
return Err(InvalidSubgraph::EmptySubgraph);
}
let (inputs, outputs) = get_input_output_ports(dfg_graph)?;
let non_local = get_non_local_edges(&nodes, &dfg_graph);
let function_calls = group_into_function_calls(non_local, &dfg_graph)?;
validate_boundary(dfg_graph, &nodes, &inputs, &outputs, &function_calls)?;
Ok(Self {
nodes,
inputs,
outputs,
function_calls,
})
}
pub fn try_new(
inputs: IncomingPorts<N>,
outputs: OutgoingPorts<N>,
hugr: &impl HugrView<Node = N>,
) -> Result<Self, InvalidSubgraph<N>> {
let (node, _) = iter_io(&inputs, &outputs)
.next()
.ok_or(InvalidSubgraph::EmptySubgraph)?;
let parent = hugr
.get_parent(node)
.ok_or(InvalidSubgraph::OrphanNode { orphan: node })?;
let checker = SchedGraphChecker::new(hugr.scheduling_graph(parent));
Self::try_new_with_checker(inputs, outputs, hugr, &checker)
}
pub fn new_unchecked(
inputs: IncomingPorts<N>,
outputs: OutgoingPorts<N>,
function_calls: IncomingPorts<N>,
nodes: Vec<N>,
) -> Self {
Self {
nodes,
inputs,
outputs,
function_calls,
}
}
pub fn try_new_with_checker<H: HugrView<Node = N>>(
mut inputs: IncomingPorts<N>,
outputs: OutgoingPorts<N>,
hugr: &H,
checker: &impl HugrConvexChecker<H::Node>,
) -> Result<Self, InvalidSubgraph<N>> {
let subgraph_parent = check_parent(hugr, &inputs, &outputs)?;
let checker_parent = checker.region_parent();
if subgraph_parent != checker_parent {
return Err(InvalidSubgraph::MismatchedCheckerParent {
checker_parent,
subgraph_parent,
});
}
let function_calls = drain_function_calls(&mut inputs, hugr);
let nodes = checker.nodes_if_convex(hugr, &inputs, &outputs, &function_calls)?;
Ok(Self {
nodes,
inputs,
outputs,
function_calls,
})
}
pub fn try_from_nodes(
nodes: impl Into<Vec<N>>,
hugr: &impl HugrView<Node = N>,
) -> Result<Self, InvalidSubgraph<N>> {
let nodes = nodes.into();
let Some(node) = nodes.first() else {
return Err(InvalidSubgraph::EmptySubgraph);
};
let parent = hugr
.get_parent(*node)
.ok_or(InvalidSubgraph::OrphanNode { orphan: *node })?;
let checker = SchedGraphChecker::new(hugr.scheduling_graph(parent));
Self::try_from_nodes_with_checker(nodes, hugr, &checker)
}
pub fn try_from_nodes_with_checker<H: HugrView<Node = N>>(
nodes: impl Into<Vec<N>>,
hugr: &H,
checker: &impl HugrConvexChecker<H::Node>,
) -> Result<Self, InvalidSubgraph<N>> {
let mut nodes: Vec<N> = nodes.into();
let num_nodes = nodes.len();
if nodes.is_empty() {
return Err(InvalidSubgraph::EmptySubgraph);
}
let (inputs, outputs) = get_boundary_from_nodes(hugr, &mut nodes);
if inputs.is_empty() && outputs.is_empty() {
return Ok(Self {
nodes,
inputs,
outputs,
function_calls: vec![],
});
}
let mut subgraph = Self::try_new_with_checker(inputs, outputs, hugr, checker)?;
if subgraph.node_count() < num_nodes {
subgraph.nodes = nodes;
}
Ok(subgraph)
}
pub fn try_from_nodes_with_intervals(
nodes: impl Into<Vec<N>>,
intervals: &LineIntervals,
line_checker: &LineConvexChecker<impl HugrView<Node = N>>,
) -> Result<Self, InvalidSubgraph<N>> {
if !line_checker.checker.is_convex_by_intervals(intervals) {
return Err(InvalidSubgraph::NotConvex);
}
let nodes: Vec<N> = nodes.into();
let hugr = line_checker.hugr();
if nodes.is_empty() {
return Err(InvalidSubgraph::EmptySubgraph);
}
let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
let incoming_edges = nodes
.iter()
.flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
let outgoing_edges = nodes
.iter()
.flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
let mut inputs = incoming_edges
.filter(|&(n, p)| {
if !hugr.is_linked(n, p) {
return false;
}
let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
!nodes_set.contains(&out_n)
})
.map(|p| vec![p])
.collect_vec();
let outputs = outgoing_edges
.filter(|&(n, p)| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !nodes_set.contains(&n1))
})
.collect_vec();
let function_calls = drain_function_calls(&mut inputs, hugr);
Ok(Self {
nodes,
inputs,
outputs,
function_calls,
})
}
pub fn from_node(node: N, hugr: &impl HugrView<Node = N>) -> Self {
let nodes = vec![node];
let mut inputs = hugr
.node_inputs(node)
.filter(|&p| hugr.is_linked(node, p))
.map(|p| vec![(node, p)])
.collect_vec();
let outputs = hugr
.node_outputs(node)
.filter_map(|p| {
{
hugr.is_linked(node, p)
|| HugrView::get_optype(&hugr, node)
.port_kind(p)
.is_some_and(|k| k.is_value())
}
.then_some((node, p))
})
.collect_vec();
let function_calls = drain_function_calls(&mut inputs, hugr);
let state_order_at_input = hugr
.get_optype(node)
.other_output_port()
.is_some_and(|p| hugr.is_linked(node, p));
let state_order_at_output = hugr
.get_optype(node)
.other_input_port()
.is_some_and(|p| hugr.is_linked(node, p));
if state_order_at_input || state_order_at_output {
unimplemented!("Order edges in {node:?} not supported");
}
Self {
nodes,
inputs,
outputs,
function_calls,
}
}
#[deprecated(
note = "Use `validate_with_checker`, `validate_default` or `validate_skip_convexity`",
since = "0.27.1"
)]
#[expect(deprecated)] pub fn validate<'h, H: HugrView<Node = N>>(
&self,
hugr: &'h H,
mode: ValidationMode<'_, 'h, H>,
) -> Result<(), InvalidSubgraph<N>> {
match mode {
ValidationMode::WithChecker(checker) => self.validate_with_checker(hugr, Some(checker)),
ValidationMode::CheckConvexity => self.validate_default(hugr),
ValidationMode::SkipConvexity => self.validate_skip_convexity(hugr),
}
}
pub fn validate_default(
&self,
hugr: &impl HugrView<Node = N>,
) -> Result<(), InvalidSubgraph<N>> {
let parent = check_parent(hugr, &self.inputs, &self.outputs)?;
self.validate_with_checker(
hugr,
Some(&SchedGraphChecker::new(hugr.scheduling_graph(parent))),
)
}
pub fn validate_skip_convexity(
&self,
hugr: &impl HugrView<Node = N>,
) -> Result<(), InvalidSubgraph<N>> {
enum NoChecker {}
impl<N: HugrNode> HugrConvexChecker<N> for NoChecker {
fn region_parent(&self) -> N {
match *self {}
}
fn nodes_if_convex(
&self,
_hugr: &impl HugrView<Node = N>,
_inputs: &IncomingPorts<N>,
_outputs: &OutgoingPorts<N>,
_function_calls: &IncomingPorts<N>,
) -> Result<Vec<N>, InvalidSubgraph<N>> {
match *self {}
}
}
let no_checker: Option<&NoChecker> = None;
self.validate_with_checker(hugr, no_checker)
}
pub fn validate_with_checker<H: HugrView<Node = N>>(
&self,
hugr: &H,
checker: Option<&impl HugrConvexChecker<N>>,
) -> Result<(), InvalidSubgraph<N>> {
let subgraph_parent = check_parent(hugr, &self.inputs, &self.outputs)?;
let mut exp_nodes = match checker {
Some(c) => {
if c.region_parent() != subgraph_parent {
return Err(InvalidSubgraph::MismatchedCheckerParent {
checker_parent: c.region_parent(),
subgraph_parent,
});
}
c.nodes_if_convex(hugr, &self.inputs, &self.outputs, &self.function_calls)?
}
None => {
let (region, node_map) = hugr
.scheduling_graph(subgraph_parent)
.portgraph_no_syn_edges();
make_pg_subgraph::<H>(region, &self.inputs, &self.outputs, &node_map)
.nodes_iter()
.map(|n| node_map.from_portgraph(n))
.collect_vec()
}
};
let mut nodes = self.nodes.clone();
exp_nodes.sort_unstable();
nodes.sort_unstable();
if exp_nodes != nodes {
return Err(InvalidSubgraph::InvalidNodeSet);
}
Ok(())
}
#[must_use]
pub fn nodes(&self) -> &[N] {
&self.nodes
}
#[must_use]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[must_use]
pub fn incoming_ports(&self) -> &IncomingPorts<N> {
&self.inputs
}
#[must_use]
pub fn outgoing_ports(&self) -> &OutgoingPorts<N> {
&self.outputs
}
#[must_use]
pub fn function_calls(&self) -> &IncomingPorts<N> {
&self.function_calls
}
pub fn signature(&self, hugr: &impl HugrView<Node = N>) -> Signature {
let input = self
.inputs
.iter()
.map(|part| {
let &(n, p) = part.iter().next().expect("is non-empty");
let sig = hugr.signature(n).expect("must have dataflow signature");
sig.port_type(p).cloned().expect("must be dataflow edge")
})
.collect_vec();
let output = self
.outputs
.iter()
.map(|&(n, p)| {
let sig = hugr.signature(n).expect("must have dataflow signature");
sig.port_type(p).cloned().expect("must be dataflow edge")
})
.collect_vec();
Signature::new(input, output)
}
pub fn get_parent(&self, hugr: &impl HugrView<Node = N>) -> N {
hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
}
pub(crate) fn map_nodes<N2: HugrNode>(
&self,
node_map: impl Fn(N) -> N2,
) -> SiblingSubgraph<N2> {
let nodes = self.nodes.iter().map(|&n| node_map(n)).collect_vec();
let inputs = self
.inputs
.iter()
.map(|part| part.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
.collect_vec();
let outputs = self
.outputs
.iter()
.map(|&(n, p)| (node_map(n), p))
.collect_vec();
let function_calls = self
.function_calls
.iter()
.map(|calls| calls.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
.collect_vec();
SiblingSubgraph {
nodes,
inputs,
outputs,
function_calls,
}
}
pub fn create_simple_replacement(
&self,
hugr: &impl HugrView<Node = N>,
replacement: Hugr,
) -> Result<SimpleReplacement<N>, InvalidReplacement> {
let rep_root = replacement.entrypoint();
let dfg_optype = replacement.get_optype(rep_root);
if !OpTag::DataflowParent.is_superset(dfg_optype.tag()) {
return Err(InvalidReplacement::InvalidDataflowGraph {
node: rep_root,
op: Box::new(dfg_optype.clone()),
});
}
let [rep_input, rep_output] = replacement
.get_io(rep_root)
.expect("DFG root in the replacement does not have input and output nodes.");
let state_order_at_input = replacement
.get_optype(rep_input)
.other_output_port()
.is_some_and(|p| replacement.is_linked(rep_input, p));
let state_order_at_output = replacement
.get_optype(rep_output)
.other_input_port()
.is_some_and(|p| replacement.is_linked(rep_output, p));
if state_order_at_input || state_order_at_output {
unimplemented!("Found state order edges in replacement graph");
}
SimpleReplacement::try_new(self.clone(), hugr, replacement)
}
pub fn extract_subgraph(
&self,
hugr: &impl HugrView<Node = N>,
name: impl Into<String>,
) -> Hugr {
let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
let mut extracted = mem::take(builder.hugr_mut());
let node_map = extracted.insert_subgraph(extracted.entrypoint(), hugr, self);
let [inp, out] = extracted.get_io(extracted.entrypoint()).unwrap();
let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
for (inp_port, repl_ports) in inputs {
for (repl_node, repl_port) in repl_ports {
connections.push((inp, inp_port, node_map[repl_node], *repl_port));
}
}
for (out_port, (repl_node, repl_port)) in outputs {
connections.push((node_map[repl_node], *repl_port, out, out_port));
}
for (src, src_port, dst, dst_port) in connections {
extracted.connect(src, src_port, dst, dst_port);
}
extracted
}
pub fn set_outgoing_ports(
&mut self,
ports: OutgoingPorts<N>,
host: &impl HugrView<Node = N>,
) -> Result<(), InvalidOutputPorts<N>> {
let old_boundary: HashSet<_> = iter_outgoing(&self.outputs).collect();
if let Some((node, port)) =
iter_outgoing(&ports).find(|(n, p)| !old_boundary.contains(&(*n, *p)))
{
return Err(InvalidOutputPorts::UnknownOutput { port, node });
}
if !has_unique_linear_ports(host, &ports) {
return Err(InvalidOutputPorts::NonUniqueLinear);
}
self.outputs = ports;
Ok(())
}
}
#[allow(deprecated)] #[deprecated(
note = "Call validate_with_checker or validate_default instead",
since = "0.27.1"
)]
#[derive(Default)]
pub enum ValidationMode<'t, 'h, H: HugrView> {
WithChecker(&'t TopoConvexChecker<'h, H>),
#[default]
CheckConvexity,
SkipConvexity,
}
fn make_pg_subgraph<'h, H: HugrView>(
region: CheckerRegion<'h, H>,
inputs: &IncomingPorts<H::Node>,
outputs: &OutgoingPorts<H::Node>,
node_map: &H::RegionPortgraphNodes,
) -> portgraph::view::Subgraph<CheckerRegion<'h, H>> {
let to_pg_index = |n: H::Node, p: Port| {
region
.port_index(node_map.to_portgraph(n), p.pg_offset())
.unwrap()
};
let boundary = Boundary::new(
iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
);
portgraph::view::Subgraph::new_subgraph(region, boundary)
}
fn get_boundary_from_nodes<N: HugrNode>(
hugr: &impl HugrView<Node = N>,
nodes: &mut Vec<N>,
) -> (IncomingPorts<N>, OutgoingPorts<N>) {
let mut nodes_set = FxHashSet::default();
nodes.retain(|&n| nodes_set.insert(n));
let incoming_edges = nodes
.iter()
.flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
let outgoing_edges = nodes
.iter()
.flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
let inputs = incoming_edges
.filter(|&(n, p)| {
if !hugr.is_linked(n, p) {
return false;
}
let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
!nodes_set.contains(&out_n)
})
.map(|p| vec![p])
.collect_vec();
let outputs = outgoing_edges
.filter(|&(n, p)| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !nodes_set.contains(&n1))
})
.collect_vec();
(inputs, outputs)
}
fn drain_function_calls<N: HugrNode, H: HugrView<Node = N>>(
inputs: &mut IncomingPorts<N>,
hugr: &H,
) -> IncomingPorts<N> {
let mut function_calls = Vec::new();
inputs.retain_mut(|calls| {
let Some(&(n, p)) = calls.first() else {
return true;
};
let op = hugr.get_optype(n);
if op.static_input_port() == Some(p)
&& op
.static_input()
.expect("static input exists")
.is_function()
{
function_calls.extend(mem::take(calls));
false
} else {
true
}
});
group_into_function_calls(function_calls.into_iter().map(|(n, p)| (n, p.into())), hugr)
.expect("valid function calls")
}
fn group_into_function_calls<N: HugrNode>(
ports: impl IntoIterator<Item = (N, Port)>,
hugr: &impl HugrView<Node = N>,
) -> Result<Vec<Vec<(N, IncomingPort)>>, InvalidSubgraph<N>> {
let incoming_ports: Vec<_> = ports
.into_iter()
.map(|(n, p)| {
let p = p
.as_incoming()
.map_err(|_| InvalidSubgraph::UnsupportedEdgeKind(n, p))?;
let op = hugr.get_optype(n);
if Some(p) != op.static_input_port() {
return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
}
if !op
.static_input()
.expect("static input exists")
.is_function()
{
return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
}
Ok::<_, InvalidSubgraph<N>>((n, p))
})
.try_collect()?;
let grouped_non_local = incoming_ports
.into_iter()
.into_group_map_by(|&(n, p)| hugr.single_linked_output(n, p).expect("valid dfg wire"));
Ok(grouped_non_local
.into_iter()
.sorted_unstable_by(|(n1, _), (n2, _)| n1.cmp(n2))
.map(|(_, v)| v)
.collect())
}
fn get_non_local_edges<'a, N: HugrNode>(
nodes: &'a [N],
hugr: &'a impl HugrView<Node = N>,
) -> impl Iterator<Item = (N, Port)> + 'a {
let parent = hugr.get_parent(nodes[0]);
let is_non_local = move |n, p| {
hugr.linked_ports(n, p)
.any(|(n, _)| hugr.get_parent(n) != parent)
};
nodes
.iter()
.flat_map(move |&n| hugr.all_node_ports(n).map(move |p| (n, p)))
.filter(move |&(n, p)| is_non_local(n, p))
}
fn iter_incoming<N: HugrNode>(
inputs: &IncomingPorts<N>,
) -> impl Iterator<Item = (N, IncomingPort)> + '_ {
inputs.iter().flat_map(|part| part.iter().copied())
}
fn iter_outgoing<N: HugrNode>(
outputs: &OutgoingPorts<N>,
) -> impl Iterator<Item = (N, OutgoingPort)> + '_ {
outputs.iter().copied()
}
fn iter_io<'a, N: HugrNode>(
inputs: &'a IncomingPorts<N>,
outputs: &'a OutgoingPorts<N>,
) -> impl Iterator<Item = (N, Port)> + 'a {
iter_incoming(inputs)
.map(|(n, p)| (n, Port::from(p)))
.chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
}
fn check_parent<'a, N: HugrNode>(
hugr: &impl HugrView<Node = N>,
inputs: &'a IncomingPorts<N>,
outputs: &'a OutgoingPorts<N>,
) -> Result<N, InvalidSubgraph<N>> {
let mut nodes = iter_io(inputs, outputs).map(|(n, _)| n);
let first_node = nodes.next().ok_or(InvalidSubgraph::EmptySubgraph)?;
let first_parent = hugr
.get_parent(first_node)
.ok_or(InvalidSubgraph::OrphanNode { orphan: first_node })?;
for other_node in nodes {
let other_parent = hugr
.get_parent(other_node)
.ok_or(InvalidSubgraph::OrphanNode { orphan: other_node })?;
if other_parent != first_parent {
return Err(InvalidSubgraph::NoSharedParent {
first_node,
first_parent,
other_node,
other_parent,
});
}
}
Ok(first_parent)
}
type CheckerRegion<'g, Base> =
portgraph::view::FlatRegion<'g, <Base as HugrInternals>::RegionPortgraph<'g>>;
#[deprecated(
note = "Use SchedGraphChecker or LineConvexChecker instead",
since = "0.27.1"
)]
pub type TopoConvexChecker<'g, Base> = PortgraphCheckerWithNodes<
'g,
Base,
portgraph::algorithms::TopoConvexChecker<CheckerRegion<'g, Base>>,
>;
pub type LineConvexChecker<'g, Base> = PortgraphCheckerWithNodes<
'g,
Base,
portgraph::algorithms::LineConvexChecker<CheckerRegion<'g, Base>>,
>;
#[derive(Clone)]
pub struct PortgraphCheckerWithNodes<'g, Base: HugrView, Checker> {
base: &'g Base,
region_parent: Base::Node,
checker: Checker,
node_map: Base::RegionPortgraphNodes,
}
#[deprecated(note = "Use PortgraphCheckerWithNodes instead", since = "0.27.1")]
pub type ConvexChecker<'g, Base, Checker> = PortgraphCheckerWithNodes<'g, Base, Checker>;
impl<'g, Base, Checker> PortgraphCheckerWithNodes<'g, Base, Checker>
where
Base: HugrView,
Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
{
pub fn new(base: &'g Base, region_parent: Base::Node) -> Self {
let (region, node_map) = base
.scheduling_graph(region_parent)
.portgraph_no_syn_edges();
let checker = Checker::new_convex_checker(region);
Self {
base,
region_parent,
checker,
node_map,
}
}
#[inline(always)]
pub fn from_entrypoint(base: &'g Base) -> Self {
let region_parent = base.entrypoint();
Self::new(base, region_parent)
}
pub fn hugr(&self) -> &'g Base {
self.base
}
}
impl<'g, Base, Checker> portgraph::algorithms::ConvexChecker
for PortgraphCheckerWithNodes<'g, Base, Checker>
where
Base: HugrView,
Checker: CreateConvexChecker<CheckerRegion<'g, Base>, NodeIndexBase = u32, PortIndexBase = u32>,
{
type NodeIndexBase = u32;
type PortIndexBase = u32;
fn is_convex(
&self,
nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
inputs: impl IntoIterator<Item = portgraph::PortIndex>,
outputs: impl IntoIterator<Item = portgraph::PortIndex>,
) -> bool {
let mut nodes = nodes.into_iter().multipeek();
if nodes.peek().is_none() || nodes.peek().is_none() {
return true;
}
self.checker.is_convex(nodes, inputs, outputs)
}
}
impl<'g, Base: HugrView> LineConvexChecker<'g, Base> {
pub fn get_intervals_from_nodes(
&self,
nodes: impl IntoIterator<Item = Base::Node>,
) -> Option<LineIntervals> {
let nodes = nodes
.into_iter()
.map(|n| self.node_map.to_portgraph(n))
.collect_vec();
self.checker.get_intervals_from_nodes(nodes)
}
pub fn get_intervals_from_boundary_ports(
&self,
ports: impl IntoIterator<Item = (Base::Node, Port)>,
) -> Option<LineIntervals> {
let ports = ports
.into_iter()
.map(|(n, p)| {
let node = self.node_map.to_portgraph(n);
self.checker
.graph()
.port_index(node, p.pg_offset())
.expect("valid port")
})
.collect_vec();
self.checker.get_intervals_from_boundary_ports(ports)
}
pub fn nodes_in_intervals<'a>(
&'a self,
intervals: &'a LineIntervals,
) -> impl Iterator<Item = Base::Node> + 'a {
self.checker
.nodes_in_intervals(intervals)
.map(|pg_node| self.node_map.from_portgraph(pg_node))
}
pub fn lines_at_port(&self, node: Base::Node, port: impl Into<Port>) -> &[LineIndex] {
let port = self
.checker
.graph()
.port_index(self.node_map.to_portgraph(node), port.into().pg_offset())
.expect("valid port");
self.checker.lines_at_port(port)
}
pub fn try_extend_intervals(&self, intervals: &mut LineIntervals, node: Base::Node) -> bool {
let node = self.node_map.to_portgraph(node);
self.checker.try_extend_intervals(intervals, node)
}
pub fn get_position(&self, node: Base::Node) -> Position {
let node = self.node_map.to_portgraph(node);
self.checker.get_position(node)
}
}
fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(
hugr: &H,
ports: &[(H::Node, P)],
) -> Option<Type> {
let &(n, p) = ports.first()?;
let edge_t = hugr.signature(n)?.port_type(p)?.clone();
ports
.iter()
.all(|&(n, p)| {
hugr.signature(n)
.is_some_and(|s| s.port_type(p) == Some(&edge_t))
})
.then_some(edge_t)
}
fn validate_boundary<H: HugrView>(
hugr: &H,
nodes: &[H::Node],
inputs: &IncomingPorts<H::Node>,
outputs: &OutgoingPorts<H::Node>,
function_calls: &IncomingPorts<H::Node>,
) -> Result<(), InvalidSubgraph<H::Node>> {
let node_set = nodes.iter().copied().collect::<HashSet<_>>();
if nodes.is_empty() {
return Err(InvalidSubgraph::EmptySubgraph);
}
if let Some((n, p)) = iter_io(inputs, outputs).find(|&(n, p)| is_non_value_edge(hugr, n, p)) {
return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p));
}
let boundary_ports = iter_io(inputs, outputs).collect_vec();
if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
}
if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
hugr.linked_ports(n, p)
.all(|(n1, _)| node_set.contains(&n1))
}) {
Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
}
let mut must_be_inputs = nodes
.iter()
.flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)))
.filter(|&(n, p)| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !node_set.contains(&n1))
});
if !must_be_inputs.all(|(n, p)| {
let mut all_inputs = inputs.iter().chain(function_calls);
all_inputs.any(|nps| nps.contains(&(n, p)))
}) {
return Err(InvalidSubgraph::NotConvex);
}
if nodes.iter().any(|&n| {
hugr.node_outputs(n).any(|p| {
hugr.linked_ports(n, p)
.any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
})
}) {
return Err(InvalidSubgraph::NotConvex);
}
if !inputs.iter().flatten().all_unique() {
return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
}
for inp in inputs {
let &(in_node, in_port) = inp.first().ok_or(InvalidSubgraphBoundary::EmptyPartition)?;
let exp_output_node_port = hugr
.single_linked_output(in_node, in_port)
.expect("valid dfg wire");
if let Some(output_node_port) = inp
.iter()
.map(|&(in_node, in_port)| {
hugr.single_linked_output(in_node, in_port)
.expect("valid dfg wire")
})
.find(|&p| p != exp_output_node_port)
{
return Err(InvalidSubgraphBoundary::MismatchedOutputPort(
(in_node, in_port),
exp_output_node_port,
output_node_port,
)
.into());
}
}
if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
let Some(edge_t) = get_edge_type(hugr, ports) else {
return true;
};
let require_copy = ports.len() > 1;
require_copy && !edge_t.copyable()
}) {
Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
}
for calls in function_calls {
if !calls
.iter()
.map(|&(n, p)| hugr.single_linked_output(n, p))
.all_equal()
{
let (n, p) = calls[0];
return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
}
for &(n, p) in calls {
let op = hugr.get_optype(n);
if op.static_input_port() != Some(p) {
return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
}
}
}
Ok(())
}
#[allow(clippy::type_complexity)]
fn get_input_output_ports<H: HugrView>(
hugr: &H,
) -> Result<(IncomingPorts<H::Node>, OutgoingPorts<H::Node>), InvalidSubgraph<H::Node>> {
let [inp, out] = hugr
.get_io(HugrView::entrypoint(&hugr))
.expect("invalid DFG");
if let Some(p) = hugr
.node_outputs(inp)
.find(|&p| is_non_value_edge(hugr, inp, p.into()))
{
return Err(InvalidSubgraph::UnsupportedEdgeKind(inp, p.into()));
}
if let Some(p) = hugr
.node_inputs(out)
.find(|&p| is_non_value_edge(hugr, out, p.into()))
{
return Err(InvalidSubgraph::UnsupportedEdgeKind(out, p.into()));
}
let dfg_inputs = HugrView::get_optype(&hugr, inp)
.as_input()
.unwrap()
.signature()
.output_ports();
let dfg_outputs = HugrView::get_optype(&hugr, out)
.as_output()
.unwrap()
.signature()
.input_ports();
let inputs = dfg_inputs
.into_iter()
.map(|p| {
hugr.linked_inputs(inp, p)
.filter(|&(n, _)| n != out)
.collect_vec()
})
.filter(|v| !v.is_empty())
.collect();
let outputs = dfg_outputs
.into_iter()
.filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
.collect();
Ok((inputs, outputs))
}
fn is_non_value_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
let op = hugr.get_optype(node);
let is_other = op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port);
let is_static = op.static_port(port.direction()) == Some(port) && hugr.is_linked(node, port);
is_other || is_static
}
#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
pub enum InvalidReplacement {
#[error("The root of the replacement {node} is a {}, but only dataflow parents are supported.", op.name())]
InvalidDataflowGraph {
node: Node,
op: Box<OpType>,
},
#[error(
"Replacement graph type mismatch. Expected {expected}, got {}.",
actual.clone().map_or("none".to_string(), |t| t.to_string()))
]
InvalidSignature {
expected: Box<Signature>,
actual: Option<Box<Signature>>,
},
#[error("SiblingSubgraph is not convex.")]
NonConvexSubgraph,
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum InvalidSubgraph<N: HugrNode = Node> {
#[error("The subgraph is not convex.")]
NotConvex,
#[error(
"Not a sibling subgraph. {first_node} has parent {first_parent}, but {other_node} has parent {other_parent}."
)]
NoSharedParent {
first_node: N,
first_parent: N,
other_node: N,
other_parent: N,
},
#[error("Not a sibling subgraph. {orphan} has no parent")]
OrphanNode {
orphan: N,
},
#[error("Empty subgraphs are not supported.")]
EmptySubgraph,
#[error("Invalid boundary port.")]
InvalidBoundary(#[from] InvalidSubgraphBoundary<N>),
#[error("SiblingSubgraphs may only be defined on dataflow regions.")]
NonDataflowRegion,
#[error("The subgraphs induced by the nodes and the boundary do not match.")]
InvalidNodeSet,
#[error("Unsupported edge kind at ({_0}, {_1:?}).")]
UnsupportedEdgeKind(N, Port),
#[error(
"ConvexChecker's region parent {checker_parent} did not match the subgraph parent {subgraph_parent}."
)]
#[allow(missing_docs)]
MismatchedCheckerParent {
checker_parent: N,
subgraph_parent: N,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum InvalidSubgraphBoundary<N: HugrNode = Node> {
#[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
PortNodeNotInSet(N, Port),
#[error(
"(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph."
)]
DisconnectedBoundaryPort(N, Port),
#[error("A port in the input boundary is used multiple times.")]
NonUniqueInput,
#[error("A partition in the input boundary is empty.")]
EmptyPartition,
#[error("expected port {0:?} to be linked to {1:?}, but is linked to {2:?}.")]
MismatchedOutputPort((N, IncomingPort), (N, OutgoingPort), (N, OutgoingPort)),
#[error("The partition {0} in the input boundary has ports with different types.")]
MismatchedTypes(usize),
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[error("Invalid output ports: {0:?}")]
pub enum InvalidOutputPorts<N: HugrNode = Node> {
#[error("{port} in {node} was not part of the original boundary.")]
UnknownOutput {
port: OutgoingPort,
node: N,
},
#[error("Linear ports must appear exactly once.")]
NonUniqueLinear,
}
fn has_unique_linear_ports<H: HugrView>(host: &H, ports: &OutgoingPorts<H::Node>) -> bool {
let linear_ports: Vec<_> = ports
.iter()
.filter(|&&(n, p)| {
host.get_optype(n)
.port_kind(p)
.is_some_and(|pk| pk.is_linear())
})
.collect();
let unique_ports: HashSet<_> = linear_ports.iter().collect();
linear_ports.len() == unique_ports.len()
}
pub struct SchedGraphChecker<'h, H: HugrView + 'h> {
node_map: H::RegionPortgraphNodes,
region_parent: H::Node,
checker: convex::TopoConvexChecker<
SynEdgeWrapper<portgraph::view::FlatRegion<'h, H::RegionPortgraph<'h>>>,
>,
}
impl<'h, H: HugrView> SchedGraphChecker<'h, H> {
pub fn new(graph: SchedulingGraph<'h, H>) -> Self {
let SchedulingGraph {
graph,
node_map,
region_parent,
} = graph;
let checker = convex::TopoConvexChecker::new(graph);
Self {
node_map,
region_parent,
checker,
}
}
}
impl<H: HugrView> HugrConvexChecker<H::Node> for SchedGraphChecker<'_, H> {
fn region_parent(&self) -> H::Node {
self.region_parent
}
fn nodes_if_convex(
&self,
hugr: &impl HugrView<Node = H::Node>,
inputs: &IncomingPorts<H::Node>,
outputs: &OutgoingPorts<H::Node>,
function_calls: &IncomingPorts<H::Node>,
) -> Result<Vec<H::Node>, InvalidSubgraph<H::Node>> {
let node_indices = make_pg_subgraph::<H>(
self.checker.graph().region_view.clone(),
inputs,
outputs,
&self.node_map,
)
.node_identifiers()
.collect_vec();
let nodes = node_indices
.iter()
.map(|&pg_node| self.node_map.from_portgraph(pg_node))
.collect_vec();
validate_boundary(hugr, &nodes, inputs, outputs, function_calls)?;
if nodes.len() <= 1 {
return Ok(nodes);
}
let post_outputs: BTreeSet<_> = outputs
.iter()
.flat_map(|(n, p)| hugr.linked_inputs(*n, *p))
.collect();
if inputs.iter().flatten().any(|p| post_outputs.contains(p)) {
return Err(InvalidSubgraph::NotConvex);
}
if self.checker.is_node_convex(node_indices) {
Ok(nodes)
} else {
Err(InvalidSubgraph::NotConvex)
}
}
}
#[cfg(test)]
mod test_traits_impld {
use crate::{Hugr, HugrView, builder::test::simple_dfg_hugr};
use portgraph::NodeIndex;
use rstest::rstest;
#[rstest]
fn test(simple_dfg_hugr: Hugr) {
let sg = simple_dfg_hugr.scheduling_graph(simple_dfg_hugr.module_root());
super::convex::TopoConvexChecker::new(sg.petgraph())
.is_node_convex([NodeIndex::new(0), NodeIndex::new(2)]);
}
}
#[cfg(test)]
mod tests;