1use std::cell::OnceCell;
8use std::collections::HashSet;
9use std::mem;
10
11use itertools::Itertools;
12use portgraph::LinkView;
13use portgraph::algorithms::ConvexChecker;
14use portgraph::boundary::Boundary;
15use portgraph::{Direction, PortView, view::Subgraph};
16use thiserror::Error;
17
18use crate::builder::{Container, FunctionBuilder};
19use crate::core::HugrNode;
20use crate::hugr::internal::{HugrInternals, PortgraphNodeMap};
21use crate::hugr::{HugrMut, HugrView};
22use crate::ops::dataflow::DataflowOpTrait;
23use crate::ops::handle::{ContainerHandle, DataflowOpID};
24use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
25use crate::types::{Signature, Type};
26use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
27
28use super::root_checked::RootCheckable;
29
30#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
54pub struct SiblingSubgraph<N = Node> {
55 nodes: Vec<N>,
57 inputs: Vec<Vec<(N, IncomingPort)>>,
62 outputs: Vec<(N, OutgoingPort)>,
67}
68
69pub type IncomingPorts<N = Node> = Vec<Vec<(N, IncomingPort)>>;
76pub type OutgoingPorts<N = Node> = Vec<(N, OutgoingPort)>;
78
79impl<N: HugrNode> SiblingSubgraph<N> {
80 pub fn try_new_dataflow_subgraph<'h, H, Root>(
92 dfg_graph: impl RootCheckable<&'h H, Root>,
93 ) -> Result<Self, InvalidSubgraph<N>>
94 where
95 H: 'h + Clone + HugrView<Node = N>,
96 Root: ContainerHandle<N, ChildrenHandle = DataflowOpID>,
97 {
98 let Ok(dfg_graph) = dfg_graph.try_into_checked() else {
99 return Err(InvalidSubgraph::NonDataflowRegion);
100 };
101 let dfg_graph = dfg_graph.into_hugr();
102
103 let parent = HugrView::entrypoint(&dfg_graph);
104 let nodes = dfg_graph.children(parent).skip(2).collect_vec();
105 let (inputs, outputs) = get_input_output_ports(dfg_graph);
106
107 validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?;
108
109 if nodes.is_empty() {
110 Err(InvalidSubgraph::EmptySubgraph)
111 } else {
112 Ok(Self {
113 nodes,
114 inputs,
115 outputs,
116 })
117 }
118 }
119
120 pub fn try_new(
160 incoming: IncomingPorts<N>,
161 outgoing: OutgoingPorts<N>,
162 hugr: &impl HugrView<Node = N>,
163 ) -> Result<Self, InvalidSubgraph<N>> {
164 let parent = pick_parent(hugr, &incoming, &outgoing)?;
165 let checker = TopoConvexChecker::new(hugr, parent);
166 Self::try_new_with_checker(incoming, outgoing, hugr, &checker)
167 }
168
169 pub fn try_new_with_checker<H: HugrView<Node = N>>(
178 inputs: IncomingPorts<N>,
179 outputs: OutgoingPorts<N>,
180 hugr: &H,
181 checker: &TopoConvexChecker<H>,
182 ) -> Result<Self, InvalidSubgraph<N>> {
183 let (region, node_map) = checker.region_portgraph();
184
185 let boundary = make_boundary::<H>(®ion, node_map, &inputs, &outputs);
187 let subpg = Subgraph::new_subgraph(region, boundary);
188 let nodes = subpg
189 .nodes_iter()
190 .map(|index| node_map.from_portgraph(index))
191 .collect_vec();
192 validate_subgraph(hugr, &nodes, &inputs, &outputs)?;
193
194 if nodes.len() > 1 && !subpg.is_convex_with_checker(checker) {
195 return Err(InvalidSubgraph::NotConvex);
196 }
197
198 Ok(Self {
199 nodes,
200 inputs,
201 outputs,
202 })
203 }
204
205 pub fn try_from_nodes(
221 nodes: impl Into<Vec<N>>,
222 hugr: &impl HugrView<Node = N>,
223 ) -> Result<Self, InvalidSubgraph<N>> {
224 let nodes = nodes.into();
225 let Some(node) = nodes.first() else {
226 return Err(InvalidSubgraph::EmptySubgraph);
227 };
228 let parent = hugr
229 .get_parent(*node)
230 .ok_or(InvalidSubgraph::OrphanNode { orphan: *node })?;
231
232 let checker = TopoConvexChecker::new(hugr, parent);
233 Self::try_from_nodes_with_checker(nodes, hugr, &checker)
234 }
235
236 pub fn try_from_nodes_with_checker<H: HugrView<Node = N>>(
245 nodes: impl Into<Vec<N>>,
246 hugr: &H,
247 checker: &TopoConvexChecker<H>,
248 ) -> Result<Self, InvalidSubgraph<N>> {
249 let nodes = nodes.into();
250
251 if nodes.is_empty() {
252 return Err(InvalidSubgraph::EmptySubgraph);
253 }
254
255 let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
256 let incoming_edges = nodes
257 .iter()
258 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
259 let outgoing_edges = nodes
260 .iter()
261 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
262 let inputs = incoming_edges
263 .filter(|&(n, p)| {
264 if !hugr.is_linked(n, p) {
265 return false;
266 }
267 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
268 !nodes_set.contains(&out_n)
269 })
270 .map(|p| vec![p])
272 .collect_vec();
273 let outputs = outgoing_edges
274 .filter(|&(n, p)| {
275 hugr.linked_ports(n, p)
276 .any(|(n1, _)| !nodes_set.contains(&n1))
277 })
278 .collect_vec();
279 Self::try_new_with_checker(inputs, outputs, hugr, checker)
280 }
281
282 pub fn from_node(node: N, hugr: &impl HugrView<Node = N>) -> Self {
290 let nodes = vec![node];
291 let inputs = hugr
292 .node_inputs(node)
293 .filter(|&p| hugr.is_linked(node, p))
294 .map(|p| vec![(node, p)])
295 .collect_vec();
296 let outputs = hugr
297 .node_outputs(node)
298 .filter_map(|p| {
299 {
301 hugr.is_linked(node, p)
302 || HugrView::get_optype(&hugr, node)
303 .port_kind(p)
304 .is_some_and(|k| k.is_value())
305 }
306 .then_some((node, p))
307 })
308 .collect_vec();
309
310 let state_order_at_input = hugr
311 .get_optype(node)
312 .other_output_port()
313 .is_some_and(|p| hugr.is_linked(node, p));
314 let state_order_at_output = hugr
315 .get_optype(node)
316 .other_input_port()
317 .is_some_and(|p| hugr.is_linked(node, p));
318 if state_order_at_input || state_order_at_output {
319 unimplemented!("Order edges in {node:?} not supported");
320 }
321
322 Self {
323 nodes,
324 inputs,
325 outputs,
326 }
327 }
328
329 #[must_use]
331 pub fn nodes(&self) -> &[N] {
332 &self.nodes
333 }
334
335 #[must_use]
337 pub fn node_count(&self) -> usize {
338 self.nodes.len()
339 }
340
341 #[must_use]
343 pub fn incoming_ports(&self) -> &IncomingPorts<N> {
344 &self.inputs
345 }
346
347 #[must_use]
349 pub fn outgoing_ports(&self) -> &OutgoingPorts<N> {
350 &self.outputs
351 }
352
353 pub fn signature(&self, hugr: &impl HugrView<Node = N>) -> Signature {
355 let input = self
356 .inputs
357 .iter()
358 .map(|part| {
359 let &(n, p) = part.iter().next().expect("is non-empty");
360 let sig = hugr.signature(n).expect("must have dataflow signature");
361 sig.port_type(p).cloned().expect("must be dataflow edge")
362 })
363 .collect_vec();
364 let output = self
365 .outputs
366 .iter()
367 .map(|&(n, p)| {
368 let sig = hugr.signature(n).expect("must have dataflow signature");
369 sig.port_type(p).cloned().expect("must be dataflow edge")
370 })
371 .collect_vec();
372 Signature::new(input, output)
373 }
374
375 pub fn get_parent(&self, hugr: &impl HugrView<Node = N>) -> N {
377 hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
378 }
379
380 pub(crate) fn map_nodes<N2: HugrNode>(
385 &self,
386 node_map: impl Fn(N) -> N2,
387 ) -> SiblingSubgraph<N2> {
388 let nodes = self.nodes.iter().map(|&n| node_map(n)).collect_vec();
389 let inputs = self
390 .inputs
391 .iter()
392 .map(|part| part.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
393 .collect_vec();
394 let outputs = self
395 .outputs
396 .iter()
397 .map(|&(n, p)| (node_map(n), p))
398 .collect_vec();
399 SiblingSubgraph {
400 nodes,
401 inputs,
402 outputs,
403 }
404 }
405
406 pub fn create_simple_replacement(
422 &self,
423 hugr: &impl HugrView<Node = N>,
424 replacement: Hugr,
425 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
426 let rep_root = replacement.entrypoint();
427 let dfg_optype = replacement.get_optype(rep_root);
428 if !OpTag::DataflowParent.is_superset(dfg_optype.tag()) {
429 return Err(InvalidReplacement::InvalidDataflowGraph {
430 node: rep_root,
431 op: Box::new(dfg_optype.clone()),
432 });
433 }
434 let [rep_input, rep_output] = replacement
435 .get_io(rep_root)
436 .expect("DFG root in the replacement does not have input and output nodes.");
437
438 let state_order_at_input = replacement
441 .get_optype(rep_input)
442 .other_output_port()
443 .is_some_and(|p| replacement.is_linked(rep_input, p));
444 let state_order_at_output = replacement
445 .get_optype(rep_output)
446 .other_input_port()
447 .is_some_and(|p| replacement.is_linked(rep_output, p));
448 if state_order_at_input || state_order_at_output {
449 unimplemented!("Found state order edges in replacement graph");
450 }
451
452 SimpleReplacement::try_new(self.clone(), hugr, replacement)
453 }
454
455 pub fn extract_subgraph(
460 &self,
461 hugr: &impl HugrView<Node = N>,
462 name: impl Into<String>,
463 ) -> Hugr {
464 let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
465 let mut extracted = mem::take(builder.hugr_mut());
468 let node_map = extracted.insert_subgraph(extracted.entrypoint(), hugr, self);
469
470 let [inp, out] = extracted.get_io(extracted.entrypoint()).unwrap();
472 let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
473 let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
474 let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
475
476 for (inp_port, repl_ports) in inputs {
477 for (repl_node, repl_port) in repl_ports {
478 connections.push((inp, inp_port, node_map[repl_node], *repl_port));
479 }
480 }
481 for (out_port, (repl_node, repl_port)) in outputs {
482 connections.push((node_map[repl_node], *repl_port, out, out_port));
483 }
484
485 for (src, src_port, dst, dst_port) in connections {
486 extracted.connect(src, src_port, dst, dst_port);
487 }
488
489 extracted
490 }
491
492 pub fn set_outgoing_ports(
501 &mut self,
502 ports: OutgoingPorts<N>,
503 host: &impl HugrView<Node = N>,
504 ) -> Result<(), InvalidOutputPorts<N>> {
505 let old_boundary: HashSet<_> = iter_outgoing(&self.outputs).collect();
506
507 if let Some((node, port)) =
509 iter_outgoing(&ports).find(|(n, p)| !old_boundary.contains(&(*n, *p)))
510 {
511 return Err(InvalidOutputPorts::UnknownOutput { port, node });
512 }
513
514 if !has_unique_linear_ports(host, &ports) {
516 return Err(InvalidOutputPorts::NonUniqueLinear);
517 }
518
519 self.outputs = ports;
520 Ok(())
521 }
522}
523
524fn iter_incoming<N: HugrNode>(
526 inputs: &IncomingPorts<N>,
527) -> impl Iterator<Item = (N, IncomingPort)> + '_ {
528 inputs.iter().flat_map(|part| part.iter().copied())
529}
530
531fn iter_outgoing<N: HugrNode>(
533 outputs: &OutgoingPorts<N>,
534) -> impl Iterator<Item = (N, OutgoingPort)> + '_ {
535 outputs.iter().copied()
536}
537
538fn iter_io<'a, N: HugrNode>(
540 inputs: &'a IncomingPorts<N>,
541 outputs: &'a OutgoingPorts<N>,
542) -> impl Iterator<Item = (N, Port)> + 'a {
543 iter_incoming(inputs)
544 .map(|(n, p)| (n, Port::from(p)))
545 .chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
546}
547
548fn pick_parent<'a, N: HugrNode>(
558 hugr: &impl HugrView<Node = N>,
559 inputs: &'a IncomingPorts<N>,
560 outputs: &'a OutgoingPorts<N>,
561) -> Result<N, InvalidSubgraph<N>> {
562 let Some(node) = iter_incoming(inputs)
564 .map(|(n, _)| n)
565 .chain(iter_outgoing(outputs).map(|(n, _)| n))
566 .next()
567 else {
568 return Err(InvalidSubgraph::EmptySubgraph);
569 };
570 let Some(parent) = hugr.get_parent(node) else {
571 return Err(InvalidSubgraph::OrphanNode { orphan: node });
572 };
573
574 Ok(parent)
575}
576
577fn make_boundary<'a, H: HugrView>(
578 region: &impl LinkView,
579 node_map: &H::RegionPortgraphNodes,
580 inputs: &'a IncomingPorts<H::Node>,
581 outputs: &'a OutgoingPorts<H::Node>,
582) -> Boundary {
583 let to_pg_index = |n: H::Node, p: Port| {
584 region
585 .port_index(node_map.to_portgraph(n), p.pg_offset())
586 .unwrap()
587 };
588 Boundary::new(
589 iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
590 iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
591 )
592}
593
594type CheckerRegion<'g, Base> =
595 portgraph::view::FlatRegion<'g, <Base as HugrInternals>::RegionPortgraph<'g>>;
596
597pub struct TopoConvexChecker<'g, Base: 'g + HugrView> {
602 base: &'g Base,
604 region_parent: Base::Node,
606 checker: OnceCell<(
608 portgraph::algorithms::TopoConvexChecker<CheckerRegion<'g, Base>>,
609 Base::RegionPortgraphNodes,
610 )>,
611}
612
613impl<'g, Base: HugrView> TopoConvexChecker<'g, Base> {
614 pub fn new(base: &'g Base, region_parent: Base::Node) -> Self {
616 Self {
617 base,
618 region_parent,
619 checker: OnceCell::new(),
620 }
621 }
622
623 fn init_checker(
625 &self,
626 ) -> &(
627 portgraph::algorithms::TopoConvexChecker<CheckerRegion<'g, Base>>,
628 Base::RegionPortgraphNodes,
629 ) {
630 self.checker.get_or_init(|| {
631 let (region, node_map) = self.base.region_portgraph(self.region_parent);
632 let checker = portgraph::algorithms::TopoConvexChecker::new(region);
633 (checker, node_map)
634 })
635 }
636
637 fn get_checker(
639 &self,
640 ) -> &portgraph::algorithms::TopoConvexChecker<
641 portgraph::view::FlatRegion<'g, Base::RegionPortgraph<'g>>,
642 > {
643 &self.init_checker().0
644 }
645
646 fn region_portgraph(&self) -> (CheckerRegion<'g, Base>, &Base::RegionPortgraphNodes) {
648 let (checker, node_map) = self.init_checker();
649 (checker.graph(), node_map)
650 }
651
652 #[expect(dead_code)]
654 fn get_node_map(&self) -> &Base::RegionPortgraphNodes {
655 &self.init_checker().1
656 }
657}
658
659impl<Base: HugrView> ConvexChecker for TopoConvexChecker<'_, Base> {
660 fn is_convex(
661 &self,
662 nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
663 inputs: impl IntoIterator<Item = portgraph::PortIndex>,
664 outputs: impl IntoIterator<Item = portgraph::PortIndex>,
665 ) -> bool {
666 let mut nodes = nodes.into_iter().multipeek();
667 if nodes.peek().is_none() || nodes.peek().is_none() {
670 return true;
671 }
672 self.get_checker().is_convex(nodes, inputs, outputs)
673 }
674}
675
676fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(
680 hugr: &H,
681 ports: &[(H::Node, P)],
682) -> Option<Type> {
683 let &(n, p) = ports.first()?;
684 let edge_t = hugr.signature(n)?.port_type(p)?.clone();
685 ports
686 .iter()
687 .all(|&(n, p)| {
688 hugr.signature(n)
689 .is_some_and(|s| s.port_type(p) == Some(&edge_t))
690 })
691 .then_some(edge_t)
692}
693
694fn validate_subgraph<H: HugrView>(
701 hugr: &H,
702 nodes: &[H::Node],
703 inputs: &IncomingPorts<H::Node>,
704 outputs: &OutgoingPorts<H::Node>,
705) -> Result<(), InvalidSubgraph<H::Node>> {
706 let node_set = nodes.iter().copied().collect::<HashSet<_>>();
708
709 if nodes.is_empty() {
711 return Err(InvalidSubgraph::EmptySubgraph);
712 }
713 if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() {
715 let first_node = nodes[0];
716 let first_parent = hugr
717 .get_parent(first_node)
718 .ok_or(InvalidSubgraph::OrphanNode { orphan: first_node })?;
719 let other_node = *nodes
720 .iter()
721 .skip(1)
722 .find(|&&n| hugr.get_parent(n) != Some(first_parent))
723 .unwrap();
724 let other_parent = hugr
725 .get_parent(other_node)
726 .ok_or(InvalidSubgraph::OrphanNode { orphan: other_node })?;
727 return Err(InvalidSubgraph::NoSharedParent {
728 first_node,
729 first_parent,
730 other_node,
731 other_parent,
732 });
733 }
734
735 if iter_io(inputs, outputs).any(|(n, p)| is_order_edge(hugr, n, p)) {
737 unimplemented!("Connected order edges not supported at the boundary")
738 }
739
740 let boundary_ports = iter_io(inputs, outputs).collect_vec();
741 if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
743 Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
744 }
745 if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
747 hugr.linked_ports(n, p)
748 .all(|(n1, _)| node_set.contains(&n1))
749 }) {
750 Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
751 }
752
753 if nodes.iter().any(|&n| {
756 hugr.node_inputs(n).any(|p| {
757 hugr.linked_ports(n, p).any(|(n1, _)| {
758 !node_set.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
759 })
760 })
761 }) {
762 return Err(InvalidSubgraph::NotConvex);
763 }
764 if nodes.iter().any(|&n| {
767 hugr.node_outputs(n).any(|p| {
768 hugr.linked_ports(n, p)
769 .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
770 })
771 }) {
772 return Err(InvalidSubgraph::NotConvex);
773 }
774
775 if !inputs.iter().flatten().all_unique() {
777 return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
778 }
779
780 for inp in inputs {
784 let &(in_node, in_port) = inp.first().ok_or(InvalidSubgraphBoundary::EmptyPartition)?;
785 let exp_output_node_port = hugr
786 .single_linked_output(in_node, in_port)
787 .expect("valid dfg wire");
788 if let Some(output_node_port) = inp
789 .iter()
790 .map(|&(in_node, in_port)| {
791 hugr.single_linked_output(in_node, in_port)
792 .expect("valid dfg wire")
793 })
794 .find(|&p| p != exp_output_node_port)
795 {
796 return Err(InvalidSubgraphBoundary::MismatchedOutputPort(
797 (in_node, in_port),
798 exp_output_node_port,
799 output_node_port,
800 )
801 .into());
802 }
803 }
804
805 if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
807 let Some(edge_t) = get_edge_type(hugr, ports) else {
808 return true;
809 };
810 let require_copy = ports.len() > 1;
811 require_copy && !edge_t.copyable()
812 }) {
813 Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
814 }
815
816 Ok(())
817}
818
819fn get_input_output_ports<H: HugrView>(
820 hugr: &H,
821) -> (IncomingPorts<H::Node>, OutgoingPorts<H::Node>) {
822 let [inp, out] = hugr
823 .get_io(HugrView::entrypoint(&hugr))
824 .expect("invalid DFG");
825 if has_other_edge(hugr, inp, Direction::Outgoing) {
826 unimplemented!("Non-dataflow output not supported at input node")
827 }
828 let dfg_inputs = HugrView::get_optype(&hugr, inp)
829 .as_input()
830 .unwrap()
831 .signature()
832 .output_ports();
833 if has_other_edge(hugr, out, Direction::Incoming) {
834 unimplemented!("Non-dataflow input not supported at output node")
835 }
836 let dfg_outputs = HugrView::get_optype(&hugr, out)
837 .as_output()
838 .unwrap()
839 .signature()
840 .input_ports();
841
842 let inputs = dfg_inputs
845 .into_iter()
846 .map(|p| {
847 hugr.linked_inputs(inp, p)
848 .filter(|&(n, _)| n != out)
849 .collect_vec()
850 })
851 .filter(|v| !v.is_empty())
852 .collect();
853 let outputs = dfg_outputs
856 .into_iter()
857 .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
858 .collect();
859 (inputs, outputs)
860}
861
862fn is_order_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
864 let op = hugr.get_optype(node);
865 op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port)
866}
867
868fn has_other_edge<H: HugrView>(hugr: &H, node: H::Node, dir: Direction) -> bool {
870 let op = hugr.get_optype(node);
871 op.other_port_kind(dir).is_some() && hugr.is_linked(node, op.other_port(dir).unwrap())
872}
873
874#[derive(Debug, Clone, PartialEq, Error)]
876#[non_exhaustive]
877pub enum InvalidReplacement {
878 #[error("The root of the replacement {node} is a {}, but only dataflow parents are supported.", op.name())]
880 InvalidDataflowGraph {
881 node: Node,
883 op: Box<OpType>,
885 },
886 #[error(
888 "Replacement graph type mismatch. Expected {expected}, got {}.",
889 actual.clone().map_or("none".to_string(), |t| t.to_string()))
890 ]
891 InvalidSignature {
892 expected: Box<Signature>,
894 actual: Option<Box<Signature>>,
896 },
897 #[error("SiblingSubgraph is not convex.")]
899 NonConvexSubgraph,
900}
901
902#[derive(Debug, Clone, PartialEq, Eq, Error)]
904#[non_exhaustive]
905pub enum InvalidSubgraph<N: HugrNode = Node> {
906 #[error("The subgraph is not convex.")]
908 NotConvex,
909 #[error(
911 "Not a sibling subgraph. {first_node} has parent {first_parent}, but {other_node} has parent {other_parent}."
912 )]
913 NoSharedParent {
914 first_node: N,
916 first_parent: N,
918 other_node: N,
920 other_parent: N,
922 },
923 #[error("Not a sibling subgraph. {orphan} has no parent")]
925 OrphanNode {
926 orphan: N,
928 },
929 #[error("Empty subgraphs are not supported.")]
931 EmptySubgraph,
932 #[error("Invalid boundary port.")]
934 InvalidBoundary(#[from] InvalidSubgraphBoundary<N>),
935 #[error("SiblingSubgraphs may only be defined on dataflow regions.")]
937 NonDataflowRegion,
938}
939
940#[derive(Debug, Clone, PartialEq, Eq, Error)]
942#[non_exhaustive]
943pub enum InvalidSubgraphBoundary<N: HugrNode = Node> {
944 #[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
946 PortNodeNotInSet(N, Port),
947 #[error(
949 "(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph."
950 )]
951 DisconnectedBoundaryPort(N, Port),
952 #[error("A port in the input boundary is used multiple times.")]
954 NonUniqueInput,
955 #[error("A partition in the input boundary is empty.")]
957 EmptyPartition,
958 #[error("expected port {0:?} to be linked to {1:?}, but is linked to {2:?}.")]
961 MismatchedOutputPort((N, IncomingPort), (N, OutgoingPort), (N, OutgoingPort)),
962 #[error("The partition {0} in the input boundary has ports with different types.")]
964 MismatchedTypes(usize),
965}
966
967#[derive(Debug, Clone, PartialEq, Eq, Error)]
969#[error("Invalid output ports: {0:?}")]
970pub enum InvalidOutputPorts<N: HugrNode = Node> {
971 #[error("{port} in {node} was not part of the original boundary.")]
973 UnknownOutput {
974 port: OutgoingPort,
976 node: N,
978 },
979 #[error("Linear ports must appear exactly once.")]
981 NonUniqueLinear,
982}
983
984fn has_unique_linear_ports<H: HugrView>(host: &H, ports: &OutgoingPorts<H::Node>) -> bool {
986 let linear_ports: Vec<_> = ports
987 .iter()
988 .filter(|&&(n, p)| {
989 host.get_optype(n)
990 .port_kind(p)
991 .is_some_and(|pk| pk.is_linear())
992 })
993 .collect();
994 let unique_ports: HashSet<_> = linear_ports.iter().collect();
995 linear_ports.len() == unique_ports.len()
996}
997
998#[cfg(test)]
999mod tests {
1000 use cool_asserts::assert_matches;
1001
1002 use crate::builder::inout_sig;
1003 use crate::hugr::Patch;
1004 use crate::ops::Const;
1005 use crate::std_extensions::arithmetic::float_types::ConstF64;
1006 use crate::std_extensions::logic::LogicOp;
1007 use crate::type_row;
1008 use crate::utils::test_quantum_extension::{cx_gate, rz_f64};
1009 use crate::{
1010 builder::{
1011 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
1012 ModuleBuilder,
1013 },
1014 extension::prelude::{bool_t, qb_t},
1015 ops::handle::{DfgID, FuncID, NodeHandle},
1016 std_extensions::logic::test::and_op,
1017 };
1018
1019 use super::*;
1020
1021 impl<N: HugrNode> SiblingSubgraph<N> {
1022 fn from_sibling_graph(
1027 hugr: &impl HugrView<Node = N>,
1028 parent: N,
1029 ) -> Result<Self, InvalidSubgraph<N>> {
1030 let nodes = hugr.children(parent).collect_vec();
1031 if nodes.is_empty() {
1032 Err(InvalidSubgraph::EmptySubgraph)
1033 } else {
1034 Ok(Self {
1035 nodes,
1036 inputs: Vec::new(),
1037 outputs: Vec::new(),
1038 })
1039 }
1040 }
1041 }
1042
1043 fn build_hugr() -> Result<(Hugr, Node), BuildError> {
1046 let mut mod_builder = ModuleBuilder::new();
1047 let func = mod_builder.declare(
1048 "test",
1049 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).into(),
1050 )?;
1051 let func_id = {
1052 let mut dfg = mod_builder.define_declaration(&func)?;
1053 let [w0, w1, w2] = dfg.input_wires_arr();
1054 let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
1055 let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
1056 let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
1057 dfg.finish_with_outputs([w0, w1, w2])?
1058 };
1059 let hugr = mod_builder
1060 .finish_hugr()
1061 .map_err(|e| -> BuildError { e.into() })?;
1062 Ok((hugr, func_id.node()))
1063 }
1064
1065 fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
1067 let mut mod_builder = ModuleBuilder::new();
1068 let func = mod_builder.declare("test", Signature::new_endo(vec![bool_t()]).into())?;
1069 let func_id = {
1070 let mut dfg = mod_builder.define_declaration(&func)?;
1071 let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?;
1072 let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?;
1073 let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?;
1074 dfg.finish_with_outputs(outs3.outputs())?
1075 };
1076 let hugr = mod_builder
1077 .finish_hugr()
1078 .map_err(|e| -> BuildError { e.into() })?;
1079 Ok((hugr, func_id.node()))
1080 }
1081
1082 fn build_multiport_hugr() -> Result<(Hugr, Node), BuildError> {
1084 let mut mod_builder = ModuleBuilder::new();
1085 let func = mod_builder.declare(
1086 "test",
1087 Signature::new(bool_t(), vec![bool_t(), bool_t()]).into(),
1088 )?;
1089 let func_id = {
1090 let mut dfg = mod_builder.define_declaration(&func)?;
1091 let [b0] = dfg.input_wires_arr();
1092 let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr();
1093 let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr();
1094 dfg.finish_with_outputs([b1, b2])?
1095 };
1096 let hugr = mod_builder
1097 .finish_hugr()
1098 .map_err(|e| -> BuildError { e.into() })?;
1099 Ok((hugr, func_id.node()))
1100 }
1101
1102 fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
1104 let mut mod_builder = ModuleBuilder::new();
1105 let func = mod_builder.declare("test", Signature::new_endo(bool_t()).into())?;
1106 let func_id = {
1107 let mut dfg = mod_builder.define_declaration(&func)?;
1108 let in_wire = dfg.input_wires().exactly_one().unwrap();
1109 let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?;
1110 dfg.finish_with_outputs(outs.outputs())?
1111 };
1112 let hugr = mod_builder
1113 .finish_hugr()
1114 .map_err(|e| -> BuildError { e.into() })?;
1115 Ok((hugr, func_id.node()))
1116 }
1117
1118 #[test]
1119 fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
1120 let (mut hugr, func_root) = build_hugr().unwrap();
1121 let func = hugr.with_entrypoint(func_root);
1122 let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
1123
1124 let empty_dfg = {
1125 let builder =
1126 DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap();
1127 let inputs = builder.input_wires();
1128 builder.finish_hugr_with_outputs(inputs).unwrap()
1129 };
1130
1131 let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();
1132
1133 assert_eq!(rep.subgraph().nodes().len(), 4);
1134
1135 assert_eq!(hugr.num_nodes(), 8); hugr.apply_patch(rep).unwrap();
1137 assert_eq!(hugr.num_nodes(), 4); Ok(())
1140 }
1141
1142 #[test]
1143 fn test_signature() -> Result<(), InvalidSubgraph> {
1144 let (hugr, dfg) = build_hugr().unwrap();
1145 let func = hugr.with_entrypoint(dfg);
1146 let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
1147 assert_eq!(
1148 sub.signature(&func),
1149 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
1150 );
1151 Ok(())
1152 }
1153
1154 #[test]
1155 fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> {
1156 let (hugr, dfg) = build_hugr().unwrap();
1157 let func = hugr.with_entrypoint(dfg);
1158 let sub = SiblingSubgraph::from_sibling_graph(&hugr, dfg)?;
1159
1160 let empty_dfg = {
1161 let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
1162 let inputs = builder.input_wires();
1163 builder.finish_hugr_with_outputs(inputs).unwrap()
1164 };
1165
1166 assert_matches!(
1167 sub.create_simple_replacement(&func, empty_dfg).unwrap_err(),
1168 InvalidReplacement::InvalidSignature { .. }
1169 );
1170 Ok(())
1171 }
1172
1173 #[test]
1174 fn convex_subgraph() {
1175 let (hugr, func_root) = build_hugr().unwrap();
1176 let func = hugr.with_entrypoint(func_root);
1177 assert_eq!(
1178 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)
1179 .unwrap()
1180 .nodes()
1181 .len(),
1182 4
1183 );
1184 }
1185
1186 #[test]
1187 fn convex_subgraph_2() {
1188 let (hugr, func_root) = build_hugr().unwrap();
1189 let [inp, out] = hugr.get_io(func_root).unwrap();
1190 let func = hugr.with_entrypoint(func_root);
1191 SiblingSubgraph::try_new(
1193 hugr.node_outputs(inp)
1194 .take(2)
1195 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
1196 .filter(|ps| !ps.is_empty())
1197 .collect(),
1198 hugr.node_inputs(out)
1199 .take(2)
1200 .filter_map(|p| hugr.single_linked_output(out, p))
1201 .collect(),
1202 &func,
1203 )
1204 .unwrap();
1205 }
1206
1207 #[test]
1208 fn degen_boundary() {
1209 let (hugr, func_root) = build_hugr().unwrap();
1210 let func = hugr.with_entrypoint(func_root);
1211 let [inp, _] = hugr.get_io(func_root).unwrap();
1212 let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
1213 assert_matches!(
1215 SiblingSubgraph::try_new(
1216 vec![
1217 hugr.linked_ports(inp, first_cx_edge)
1218 .map(|(n, p)| (n, p.as_incoming().unwrap()))
1219 .collect()
1220 ],
1221 vec![(inp, first_cx_edge)],
1222 &func,
1223 ),
1224 Err(InvalidSubgraph::InvalidBoundary(
1225 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1226 ))
1227 );
1228 }
1229
1230 #[test]
1231 fn non_convex_subgraph() {
1232 let (hugr, func_root) = build_3not_hugr().unwrap();
1233 let func = hugr.with_entrypoint(func_root);
1234 let [inp, _out] = hugr.get_io(func_root).unwrap();
1235 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1236 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
1237 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
1238 let not1_inp = hugr.node_inputs(not1).next().unwrap();
1239 let not1_out = hugr.node_outputs(not1).next().unwrap();
1240 let not3_inp = hugr.node_inputs(not3).next().unwrap();
1241 let not3_out = hugr.node_outputs(not3).next().unwrap();
1242 assert_matches!(
1243 SiblingSubgraph::try_new(
1244 vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
1245 vec![(not1, not1_out), (not3, not3_out)],
1246 &func
1247 ),
1248 Err(InvalidSubgraph::NotConvex)
1249 );
1250 }
1251
1252 #[test]
1255 fn convex_multiports() {
1256 let (hugr, func_root) = build_multiport_hugr().unwrap();
1257 let [inp, out] = hugr.get_io(func_root).unwrap();
1258 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1259 let not2 = hugr
1260 .output_neighbours(not1)
1261 .filter(|&n| n != out)
1262 .exactly_one()
1263 .ok()
1264 .unwrap();
1265
1266 let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
1267 assert_eq!(subgraph.nodes(), [not1, not2]);
1268 }
1269
1270 #[test]
1271 fn invalid_boundary() {
1272 let (hugr, func_root) = build_hugr().unwrap();
1273 let func = hugr.with_entrypoint(func_root);
1274 let [inp, out] = hugr.get_io(func_root).unwrap();
1275 let cx_edges_in = hugr.node_outputs(inp);
1276 let cx_edges_out = hugr.node_inputs(out);
1277 assert_matches!(
1279 SiblingSubgraph::try_new(
1280 cx_edges_out.map(|p| vec![(out, p)]).collect(),
1281 cx_edges_in.map(|p| (inp, p)).collect(),
1282 &func,
1283 ),
1284 Err(InvalidSubgraph::InvalidBoundary(
1285 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1286 ))
1287 );
1288 }
1289
1290 #[test]
1291 fn preserve_signature() {
1292 let (hugr, func_root) = build_hugr_classical().unwrap();
1293 let func_graph = hugr.with_entrypoint(func_root);
1294 let func =
1295 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
1296 let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
1297 assert_eq!(func_defn.signature(), &func.signature(&func_graph).into());
1298 }
1299
1300 #[test]
1301 fn extract_subgraph() {
1302 let (hugr, func_root) = build_hugr().unwrap();
1303 let func_graph = hugr.with_entrypoint(func_root);
1304 let subgraph =
1305 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
1306 let extracted = subgraph.extract_subgraph(&hugr, "region");
1307
1308 extracted.validate().unwrap();
1309 }
1310
1311 #[test]
1312 fn edge_both_output_and_copy() {
1313 let one_bit = vec![bool_t()];
1315 let two_bit = vec![bool_t(), bool_t()];
1316
1317 let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap();
1318 let inw = builder.input_wires().exactly_one().unwrap();
1319 let outw1 = builder
1320 .add_dataflow_op(LogicOp::Not, [inw])
1321 .unwrap()
1322 .out_wire(0);
1323 let outw2 = builder
1324 .add_dataflow_op(and_op(), [inw, outw1])
1325 .unwrap()
1326 .outputs();
1327 let outw = [outw1].into_iter().chain(outw2);
1328 let h = builder.finish_hugr_with_outputs(outw).unwrap();
1329 let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&h).unwrap();
1330 assert_eq!(subg.nodes().len(), 2);
1331 }
1332
1333 #[test]
1334 fn test_unconnected() {
1335 let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap();
1337 let inw = b.input_wires().exactly_one().unwrap();
1338 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1339 let mut h = b.finish_hugr_with_outputs([]).unwrap();
1341
1342 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1343
1344 assert_eq!(subg.nodes().len(), 1);
1345 let replacement = {
1347 let mut rep_b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap();
1348 let inw = rep_b.input_wires().exactly_one().unwrap();
1349
1350 let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1351
1352 rep_b.finish_hugr_with_outputs(not_n.outputs()).unwrap()
1353 };
1354 let rep = subg.create_simple_replacement(&h, replacement).unwrap();
1355 rep.apply(&mut h).unwrap();
1356 }
1357
1358 #[test]
1360 fn single_node_subgraph() {
1361 let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap();
1363 let inw = b.input_wires().exactly_one().unwrap();
1364 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1365 let h = b.finish_hugr_with_outputs([]).unwrap();
1367
1368 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1371 assert_eq!(subg.nodes().len(), 1);
1372 assert_eq!(
1373 subg.signature(&h).io(),
1374 Signature::new(vec![bool_t()], vec![bool_t()]).io()
1375 );
1376
1377 let subg = SiblingSubgraph::try_from_nodes([not_n.node()], &h).unwrap();
1380 assert_eq!(subg.nodes().len(), 1);
1381 assert_eq!(
1382 subg.signature(&h).io(),
1383 Signature::new(vec![bool_t()], vec![]).io()
1384 );
1385 }
1386
1387 #[test]
1388 fn test_set_outgoing_ports() {
1389 let (hugr, func_root) = build_3not_hugr().unwrap();
1390 let [inp, out] = hugr.get_io(func_root).unwrap();
1391 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1392 let not1_out = hugr.node_outputs(not1).next().unwrap();
1393
1394 let mut subgraph = SiblingSubgraph::from_node(not1, &hugr);
1396
1397 assert_eq!(subgraph.outgoing_ports().len(), 1);
1399
1400 let new_outputs = vec![(not1, not1_out), (not1, not1_out)];
1402 assert!(subgraph.set_outgoing_ports(new_outputs, &hugr).is_ok());
1403
1404 assert_eq!(subgraph.outgoing_ports().len(), 2);
1406
1407 let invalid_outputs = vec![(not1, not1_out), (out, 2.into())];
1409 assert!(matches!(
1410 subgraph.set_outgoing_ports(invalid_outputs, &hugr),
1411 Err(InvalidOutputPorts::UnknownOutput { .. })
1412 ));
1413
1414 assert_eq!(subgraph.outgoing_ports().len(), 2);
1416 }
1417
1418 #[test]
1419 fn test_set_outgoing_ports_linear() {
1420 let (hugr, func_root) = build_hugr().unwrap();
1421 let [inp, _out] = hugr.get_io(func_root).unwrap();
1422 let rz = hugr.output_neighbours(inp).nth(2).unwrap();
1423 let rz_out = hugr.node_outputs(rz).next().unwrap();
1424
1425 let mut subgraph = SiblingSubgraph::from_node(rz, &hugr);
1427
1428 assert_eq!(subgraph.outgoing_ports().len(), 1);
1430
1431 let new_outputs = vec![(rz, rz_out), (rz, rz_out)];
1434 assert!(matches!(
1435 subgraph.set_outgoing_ports(new_outputs, &hugr),
1436 Err(InvalidOutputPorts::NonUniqueLinear)
1437 ));
1438
1439 assert_eq!(subgraph.outgoing_ports().len(), 1);
1441 }
1442}