1use std::cell::OnceCell;
8use std::collections::HashSet;
9use std::mem;
10
11use itertools::Itertools;
12use portgraph::LinkView;
13use portgraph::PortView;
14use portgraph::algorithms::CreateConvexChecker;
15use portgraph::algorithms::convex::{LineIndex, LineIntervals, Position};
16use portgraph::boundary::Boundary;
17use rustc_hash::FxHashSet;
18use thiserror::Error;
19
20use crate::builder::{Container, FunctionBuilder};
21use crate::core::HugrNode;
22use crate::hugr::internal::{HugrInternals, PortgraphNodeMap};
23use crate::hugr::{HugrMut, HugrView};
24use crate::ops::dataflow::DataflowOpTrait;
25use crate::ops::handle::{ContainerHandle, DataflowOpID};
26use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
27use crate::types::{Signature, Type};
28use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
29
30use super::root_checked::RootCheckable;
31
32#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub struct SiblingSubgraph<N = Node> {
58 nodes: Vec<N>,
60 inputs: IncomingPorts<N>,
75 outputs: OutgoingPorts<N>,
83 function_calls: IncomingPorts<N>,
88}
89
90pub type IncomingPorts<N = Node> = Vec<Vec<(N, IncomingPort)>>;
97pub type OutgoingPorts<N = Node> = Vec<(N, OutgoingPort)>;
99
100impl<N: HugrNode> SiblingSubgraph<N> {
101 pub fn try_new_dataflow_subgraph<'h, H, Root>(
113 dfg_graph: impl RootCheckable<&'h H, Root>,
114 ) -> Result<Self, InvalidSubgraph<N>>
115 where
116 H: 'h + Clone + HugrView<Node = N>,
117 Root: ContainerHandle<N, ChildrenHandle = DataflowOpID>,
118 {
119 let Ok(dfg_graph) = dfg_graph.try_into_checked() else {
120 return Err(InvalidSubgraph::NonDataflowRegion);
121 };
122 let dfg_graph = dfg_graph.into_hugr();
123
124 let parent = HugrView::entrypoint(&dfg_graph);
125 let nodes = dfg_graph.children(parent).skip(2).collect_vec();
126
127 if nodes.is_empty() {
128 return Err(InvalidSubgraph::EmptySubgraph);
129 }
130
131 let (inputs, outputs) = get_input_output_ports(dfg_graph)?;
132 let non_local = get_non_local_edges(&nodes, &dfg_graph);
133 let function_calls = group_into_function_calls(non_local, &dfg_graph)?;
134
135 validate_subgraph(dfg_graph, &nodes, &inputs, &outputs, &function_calls)?;
136
137 Ok(Self {
138 nodes,
139 inputs,
140 outputs,
141 function_calls,
142 })
143 }
144
145 pub fn try_new(
193 inputs: IncomingPorts<N>,
194 outputs: OutgoingPorts<N>,
195 hugr: &impl HugrView<Node = N>,
196 ) -> Result<Self, InvalidSubgraph<N>> {
197 let parent = pick_parent(hugr, &inputs, &outputs)?;
198 let checker = TopoConvexChecker::new(hugr, parent);
199 Self::try_new_with_checker(inputs, outputs, hugr, &checker)
200 }
201
202 pub fn new_unchecked(
219 inputs: IncomingPorts<N>,
220 outputs: OutgoingPorts<N>,
221 function_calls: IncomingPorts<N>,
222 nodes: Vec<N>,
223 ) -> Self {
224 Self {
225 nodes,
226 inputs,
227 outputs,
228 function_calls,
229 }
230 }
231
232 pub fn try_new_with_checker<H: HugrView<Node = N>>(
242 mut inputs: IncomingPorts<N>,
243 outputs: OutgoingPorts<N>,
244 hugr: &H,
245 checker: &TopoConvexChecker<H>,
246 ) -> Result<Self, InvalidSubgraph<N>> {
247 let (subpg, node_map) = make_pg_subgraph(hugr, &inputs, &outputs);
248 let nodes = subpg
249 .nodes_iter()
250 .map(|index| node_map.from_portgraph(index))
251 .collect_vec();
252
253 let function_calls = drain_function_calls(&mut inputs, hugr);
254
255 validate_subgraph(hugr, &nodes, &inputs, &outputs, &function_calls)?;
256
257 if nodes.len() > 1 && !subpg.is_convex_with_checker(checker) {
258 return Err(InvalidSubgraph::NotConvex);
259 }
260
261 Ok(Self {
262 nodes,
263 inputs,
264 outputs,
265 function_calls,
266 })
267 }
268
269 pub fn try_from_nodes(
285 nodes: impl Into<Vec<N>>,
286 hugr: &impl HugrView<Node = N>,
287 ) -> Result<Self, InvalidSubgraph<N>> {
288 let nodes = nodes.into();
289 let Some(node) = nodes.first() else {
290 return Err(InvalidSubgraph::EmptySubgraph);
291 };
292 let parent = hugr
293 .get_parent(*node)
294 .ok_or(InvalidSubgraph::OrphanNode { orphan: *node })?;
295
296 let checker = TopoConvexChecker::new(hugr, parent);
297 Self::try_from_nodes_with_checker(nodes, hugr, &checker)
298 }
299
300 pub fn try_from_nodes_with_checker<H: HugrView<Node = N>>(
310 nodes: impl Into<Vec<N>>,
311 hugr: &H,
312 checker: &TopoConvexChecker<H>,
313 ) -> Result<Self, InvalidSubgraph<N>> {
314 let mut nodes: Vec<N> = nodes.into();
315 let num_nodes = nodes.len();
316
317 if nodes.is_empty() {
318 return Err(InvalidSubgraph::EmptySubgraph);
319 }
320
321 let (inputs, outputs) = get_boundary_from_nodes(hugr, &mut nodes);
322
323 if inputs.is_empty() && outputs.is_empty() {
326 return Ok(Self {
327 nodes,
328 inputs,
329 outputs,
330 function_calls: vec![],
331 });
332 }
333
334 let mut subgraph = Self::try_new_with_checker(inputs, outputs, hugr, checker)?;
335
336 if subgraph.node_count() < num_nodes {
339 subgraph.nodes = nodes;
340 }
341
342 Ok(subgraph)
343 }
344
345 pub fn try_from_nodes_with_intervals(
359 nodes: impl Into<Vec<N>>,
360 intervals: &LineIntervals,
361 line_checker: &LineConvexChecker<impl HugrView<Node = N>>,
362 ) -> Result<Self, InvalidSubgraph<N>> {
363 if !line_checker.get_checker().is_convex_by_intervals(intervals) {
364 return Err(InvalidSubgraph::NotConvex);
365 }
366
367 let nodes: Vec<N> = nodes.into();
368 let hugr = line_checker.hugr();
369
370 if nodes.is_empty() {
371 return Err(InvalidSubgraph::EmptySubgraph);
372 }
373
374 let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
375 let incoming_edges = nodes
376 .iter()
377 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
378 let outgoing_edges = nodes
379 .iter()
380 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
381 let mut inputs = incoming_edges
382 .filter(|&(n, p)| {
383 if !hugr.is_linked(n, p) {
384 return false;
385 }
386 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
387 !nodes_set.contains(&out_n)
388 })
389 .map(|p| vec![p])
391 .collect_vec();
392 let outputs = outgoing_edges
393 .filter(|&(n, p)| {
394 hugr.linked_ports(n, p)
395 .any(|(n1, _)| !nodes_set.contains(&n1))
396 })
397 .collect_vec();
398 let function_calls = drain_function_calls(&mut inputs, hugr);
399
400 Ok(Self {
401 nodes,
402 inputs,
403 outputs,
404 function_calls,
405 })
406 }
407
408 pub fn from_node(node: N, hugr: &impl HugrView<Node = N>) -> Self {
416 let nodes = vec![node];
417 let mut inputs = hugr
418 .node_inputs(node)
419 .filter(|&p| hugr.is_linked(node, p))
420 .map(|p| vec![(node, p)])
421 .collect_vec();
422 let outputs = hugr
423 .node_outputs(node)
424 .filter_map(|p| {
425 {
427 hugr.is_linked(node, p)
428 || HugrView::get_optype(&hugr, node)
429 .port_kind(p)
430 .is_some_and(|k| k.is_value())
431 }
432 .then_some((node, p))
433 })
434 .collect_vec();
435 let function_calls = drain_function_calls(&mut inputs, hugr);
436
437 let state_order_at_input = hugr
438 .get_optype(node)
439 .other_output_port()
440 .is_some_and(|p| hugr.is_linked(node, p));
441 let state_order_at_output = hugr
442 .get_optype(node)
443 .other_input_port()
444 .is_some_and(|p| hugr.is_linked(node, p));
445 if state_order_at_input || state_order_at_output {
446 unimplemented!("Order edges in {node:?} not supported");
447 }
448
449 Self {
450 nodes,
451 inputs,
452 outputs,
453 function_calls,
454 }
455 }
456
457 pub fn validate<'h, H: HugrView<Node = N>>(
468 &self,
469 hugr: &'h H,
470 mode: ValidationMode<'_, 'h, H>,
471 ) -> Result<(), InvalidSubgraph<N>> {
472 let mut exp_nodes = {
473 let (subpg, node_map) = make_pg_subgraph(hugr, &self.inputs, &self.outputs);
474 subpg
475 .nodes_iter()
476 .map(|n| node_map.from_portgraph(n))
477 .collect_vec()
478 };
479 let mut nodes = self.nodes.clone();
480
481 exp_nodes.sort_unstable();
482 nodes.sort_unstable();
483
484 if exp_nodes != nodes {
485 return Err(InvalidSubgraph::InvalidNodeSet);
486 }
487
488 validate_subgraph(
489 hugr,
490 &self.nodes,
491 &self.inputs,
492 &self.outputs,
493 &self.function_calls,
494 )?;
495
496 let checker;
497 let checker_ref = match mode {
498 ValidationMode::WithChecker(c) => Some(c),
499 ValidationMode::CheckConvexity => {
500 checker = ConvexChecker::new(hugr, self.get_parent(hugr));
501 Some(&checker)
502 }
503 ValidationMode::SkipConvexity => None,
504 };
505 if let Some(checker) = checker_ref {
506 let (subpg, _) = make_pg_subgraph(hugr, &self.inputs, &self.outputs);
507 if !subpg.is_convex_with_checker(&checker.init_checker().0) {
508 return Err(InvalidSubgraph::NotConvex);
509 }
510 }
511 Ok(())
512 }
513
514 #[must_use]
516 pub fn nodes(&self) -> &[N] {
517 &self.nodes
518 }
519
520 #[must_use]
522 pub fn node_count(&self) -> usize {
523 self.nodes.len()
524 }
525
526 #[must_use]
528 pub fn incoming_ports(&self) -> &IncomingPorts<N> {
529 &self.inputs
530 }
531
532 #[must_use]
534 pub fn outgoing_ports(&self) -> &OutgoingPorts<N> {
535 &self.outputs
536 }
537
538 #[must_use]
541 pub fn function_calls(&self) -> &IncomingPorts<N> {
542 &self.function_calls
543 }
544
545 pub fn signature(&self, hugr: &impl HugrView<Node = N>) -> Signature {
547 let input = self
548 .inputs
549 .iter()
550 .map(|part| {
551 let &(n, p) = part.iter().next().expect("is non-empty");
552 let sig = hugr.signature(n).expect("must have dataflow signature");
553 sig.port_type(p).cloned().expect("must be dataflow edge")
554 })
555 .collect_vec();
556 let output = self
557 .outputs
558 .iter()
559 .map(|&(n, p)| {
560 let sig = hugr.signature(n).expect("must have dataflow signature");
561 sig.port_type(p).cloned().expect("must be dataflow edge")
562 })
563 .collect_vec();
564 Signature::new(input, output)
565 }
566
567 pub fn get_parent(&self, hugr: &impl HugrView<Node = N>) -> N {
569 hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
570 }
571
572 pub(crate) fn map_nodes<N2: HugrNode>(
577 &self,
578 node_map: impl Fn(N) -> N2,
579 ) -> SiblingSubgraph<N2> {
580 let nodes = self.nodes.iter().map(|&n| node_map(n)).collect_vec();
581 let inputs = self
582 .inputs
583 .iter()
584 .map(|part| part.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
585 .collect_vec();
586 let outputs = self
587 .outputs
588 .iter()
589 .map(|&(n, p)| (node_map(n), p))
590 .collect_vec();
591 let function_calls = self
592 .function_calls
593 .iter()
594 .map(|calls| calls.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
595 .collect_vec();
596 SiblingSubgraph {
597 nodes,
598 inputs,
599 outputs,
600 function_calls,
601 }
602 }
603
604 pub fn create_simple_replacement(
620 &self,
621 hugr: &impl HugrView<Node = N>,
622 replacement: Hugr,
623 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
624 let rep_root = replacement.entrypoint();
625 let dfg_optype = replacement.get_optype(rep_root);
626 if !OpTag::DataflowParent.is_superset(dfg_optype.tag()) {
627 return Err(InvalidReplacement::InvalidDataflowGraph {
628 node: rep_root,
629 op: Box::new(dfg_optype.clone()),
630 });
631 }
632 let [rep_input, rep_output] = replacement
633 .get_io(rep_root)
634 .expect("DFG root in the replacement does not have input and output nodes.");
635
636 let state_order_at_input = replacement
639 .get_optype(rep_input)
640 .other_output_port()
641 .is_some_and(|p| replacement.is_linked(rep_input, p));
642 let state_order_at_output = replacement
643 .get_optype(rep_output)
644 .other_input_port()
645 .is_some_and(|p| replacement.is_linked(rep_output, p));
646 if state_order_at_input || state_order_at_output {
647 unimplemented!("Found state order edges in replacement graph");
648 }
649
650 SimpleReplacement::try_new(self.clone(), hugr, replacement)
651 }
652
653 pub fn extract_subgraph(
658 &self,
659 hugr: &impl HugrView<Node = N>,
660 name: impl Into<String>,
661 ) -> Hugr {
662 let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
663 let mut extracted = mem::take(builder.hugr_mut());
666 let node_map = extracted.insert_subgraph(extracted.entrypoint(), hugr, self);
667
668 let [inp, out] = extracted.get_io(extracted.entrypoint()).unwrap();
670 let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
671 let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
672 let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
673
674 for (inp_port, repl_ports) in inputs {
675 for (repl_node, repl_port) in repl_ports {
676 connections.push((inp, inp_port, node_map[repl_node], *repl_port));
677 }
678 }
679 for (out_port, (repl_node, repl_port)) in outputs {
680 connections.push((node_map[repl_node], *repl_port, out, out_port));
681 }
682
683 for (src, src_port, dst, dst_port) in connections {
684 extracted.connect(src, src_port, dst, dst_port);
685 }
686
687 extracted
688 }
689
690 pub fn set_outgoing_ports(
700 &mut self,
701 ports: OutgoingPorts<N>,
702 host: &impl HugrView<Node = N>,
703 ) -> Result<(), InvalidOutputPorts<N>> {
704 let old_boundary: HashSet<_> = iter_outgoing(&self.outputs).collect();
705
706 if let Some((node, port)) =
708 iter_outgoing(&ports).find(|(n, p)| !old_boundary.contains(&(*n, *p)))
709 {
710 return Err(InvalidOutputPorts::UnknownOutput { port, node });
711 }
712
713 if !has_unique_linear_ports(host, &ports) {
715 return Err(InvalidOutputPorts::NonUniqueLinear);
716 }
717
718 self.outputs = ports;
719 Ok(())
720 }
721}
722
723#[derive(Default)]
725pub enum ValidationMode<'t, 'h, H: HugrView> {
726 WithChecker(&'t TopoConvexChecker<'h, H>),
728 #[default]
730 CheckConvexity,
731 SkipConvexity,
733}
734
735fn make_pg_subgraph<'h, H: HugrView>(
736 hugr: &'h H,
737 inputs: &IncomingPorts<H::Node>,
738 outputs: &OutgoingPorts<H::Node>,
739) -> (
740 portgraph::view::Subgraph<CheckerRegion<'h, H>>,
741 H::RegionPortgraphNodes,
742) {
743 let (region, node_map) = hugr.region_portgraph(hugr.entrypoint());
744
745 let boundary = make_boundary::<H>(®ion, &node_map, inputs, outputs);
748 (
749 portgraph::view::Subgraph::new_subgraph(region, boundary),
750 node_map,
751 )
752}
753
754fn get_boundary_from_nodes<N: HugrNode>(
758 hugr: &impl HugrView<Node = N>,
759 nodes: &mut Vec<N>,
760) -> (IncomingPorts<N>, OutgoingPorts<N>) {
761 let mut nodes_set = FxHashSet::default();
764 nodes.retain(|&n| nodes_set.insert(n));
765
766 let incoming_edges = nodes
767 .iter()
768 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
769 let outgoing_edges = nodes
770 .iter()
771 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
772
773 let inputs = incoming_edges
774 .filter(|&(n, p)| {
775 if !hugr.is_linked(n, p) {
776 return false;
777 }
778 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
779 !nodes_set.contains(&out_n)
780 })
781 .map(|p| vec![p])
783 .collect_vec();
784 let outputs = outgoing_edges
785 .filter(|&(n, p)| {
786 hugr.linked_ports(n, p)
787 .any(|(n1, _)| !nodes_set.contains(&n1))
788 })
789 .collect_vec();
790
791 (inputs, outputs)
792}
793
794fn drain_function_calls<N: HugrNode, H: HugrView<Node = N>>(
796 inputs: &mut IncomingPorts<N>,
797 hugr: &H,
798) -> IncomingPorts<N> {
799 let mut function_calls = Vec::new();
800 inputs.retain_mut(|calls| {
801 let Some(&(n, p)) = calls.first() else {
802 return true;
803 };
804 let op = hugr.get_optype(n);
805 if op.static_input_port() == Some(p)
806 && op
807 .static_input()
808 .expect("static input exists")
809 .is_function()
810 {
811 function_calls.extend(mem::take(calls));
812 false
813 } else {
814 true
815 }
816 });
817
818 group_into_function_calls(function_calls.into_iter().map(|(n, p)| (n, p.into())), hugr)
819 .expect("valid function calls")
820}
821
822fn group_into_function_calls<N: HugrNode>(
827 ports: impl IntoIterator<Item = (N, Port)>,
828 hugr: &impl HugrView<Node = N>,
829) -> Result<Vec<Vec<(N, IncomingPort)>>, InvalidSubgraph<N>> {
830 let incoming_ports: Vec<_> = ports
831 .into_iter()
832 .map(|(n, p)| {
833 let p = p
834 .as_incoming()
835 .map_err(|_| InvalidSubgraph::UnsupportedEdgeKind(n, p))?;
836 let op = hugr.get_optype(n);
837 if Some(p) != op.static_input_port() {
838 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
839 }
840 if !op
841 .static_input()
842 .expect("static input exists")
843 .is_function()
844 {
845 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
846 }
847 Ok::<_, InvalidSubgraph<N>>((n, p))
848 })
849 .try_collect()?;
850 let grouped_non_local = incoming_ports
851 .into_iter()
852 .into_group_map_by(|&(n, p)| hugr.single_linked_output(n, p).expect("valid dfg wire"));
853 Ok(grouped_non_local
854 .into_iter()
855 .sorted_unstable_by(|(n1, _), (n2, _)| n1.cmp(n2))
856 .map(|(_, v)| v)
857 .collect())
858}
859
860fn get_non_local_edges<'a, N: HugrNode>(
864 nodes: &'a [N],
865 hugr: &'a impl HugrView<Node = N>,
866) -> impl Iterator<Item = (N, Port)> + 'a {
867 let parent = hugr.get_parent(nodes[0]);
868 let is_non_local = move |n, p| {
869 hugr.linked_ports(n, p)
870 .any(|(n, _)| hugr.get_parent(n) != parent)
871 };
872 nodes
873 .iter()
874 .flat_map(move |&n| hugr.all_node_ports(n).map(move |p| (n, p)))
875 .filter(move |&(n, p)| is_non_local(n, p))
876}
877
878fn iter_incoming<N: HugrNode>(
880 inputs: &IncomingPorts<N>,
881) -> impl Iterator<Item = (N, IncomingPort)> + '_ {
882 inputs.iter().flat_map(|part| part.iter().copied())
883}
884
885fn iter_outgoing<N: HugrNode>(
887 outputs: &OutgoingPorts<N>,
888) -> impl Iterator<Item = (N, OutgoingPort)> + '_ {
889 outputs.iter().copied()
890}
891
892fn iter_io<'a, N: HugrNode>(
894 inputs: &'a IncomingPorts<N>,
895 outputs: &'a OutgoingPorts<N>,
896) -> impl Iterator<Item = (N, Port)> + 'a {
897 iter_incoming(inputs)
898 .map(|(n, p)| (n, Port::from(p)))
899 .chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
900}
901
902fn pick_parent<'a, N: HugrNode>(
912 hugr: &impl HugrView<Node = N>,
913 inputs: &'a IncomingPorts<N>,
914 outputs: &'a OutgoingPorts<N>,
915) -> Result<N, InvalidSubgraph<N>> {
916 let Some(node) = iter_incoming(inputs)
918 .map(|(n, _)| n)
919 .chain(iter_outgoing(outputs).map(|(n, _)| n))
920 .next()
921 else {
922 return Err(InvalidSubgraph::EmptySubgraph);
923 };
924 let Some(parent) = hugr.get_parent(node) else {
925 return Err(InvalidSubgraph::OrphanNode { orphan: node });
926 };
927
928 Ok(parent)
929}
930
931fn make_boundary<'a, H: HugrView>(
932 region: &impl LinkView<PortOffsetBase = u32>,
933 node_map: &H::RegionPortgraphNodes,
934 inputs: &'a IncomingPorts<H::Node>,
935 outputs: &'a OutgoingPorts<H::Node>,
936) -> Boundary {
937 let to_pg_index = |n: H::Node, p: Port| {
938 region
939 .port_index(node_map.to_portgraph(n), p.pg_offset())
940 .unwrap()
941 };
942 Boundary::new(
943 iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
944 iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
945 )
946}
947
948type CheckerRegion<'g, Base> =
949 portgraph::view::FlatRegion<'g, <Base as HugrInternals>::RegionPortgraph<'g>>;
950
951pub type TopoConvexChecker<'g, Base> =
958 ConvexChecker<'g, Base, portgraph::algorithms::TopoConvexChecker<CheckerRegion<'g, Base>>>;
959
960pub type LineConvexChecker<'g, Base> =
968 ConvexChecker<'g, Base, portgraph::algorithms::LineConvexChecker<CheckerRegion<'g, Base>>>;
969
970pub struct ConvexChecker<'g, Base: HugrView, Checker> {
979 base: &'g Base,
981 region_parent: Base::Node,
983 checker: OnceCell<(Checker, Base::RegionPortgraphNodes)>,
986}
987
988impl<'g, Base: HugrView, Checker: Clone> Clone for ConvexChecker<'g, Base, Checker> {
989 fn clone(&self) -> Self {
990 Self {
991 base: self.base,
992 region_parent: self.region_parent,
993 checker: self.checker.clone(),
994 }
995 }
996}
997
998impl<'g, Base: HugrView, Checker> ConvexChecker<'g, Base, Checker> {
999 pub fn new(base: &'g Base, region_parent: Base::Node) -> Self {
1001 Self {
1002 base,
1003 region_parent,
1004 checker: OnceCell::new(),
1005 }
1006 }
1007
1008 #[inline(always)]
1010 pub fn from_entrypoint(base: &'g Base) -> Self {
1011 let region_parent = base.entrypoint();
1012 Self::new(base, region_parent)
1013 }
1014
1015 pub fn hugr(&self) -> &'g Base {
1017 self.base
1018 }
1019}
1020
1021impl<'g, Base, Checker> ConvexChecker<'g, Base, Checker>
1022where
1023 Base: HugrView,
1024 Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
1025{
1026 fn init_checker(&self) -> &(Checker, Base::RegionPortgraphNodes) {
1028 self.checker.get_or_init(|| {
1029 let (region, node_map) = self.base.region_portgraph(self.region_parent);
1030 let checker = Checker::new_convex_checker(region);
1031 (checker, node_map)
1032 })
1033 }
1034
1035 #[expect(dead_code)]
1037 fn get_node_map(&self) -> &Base::RegionPortgraphNodes {
1038 &self.init_checker().1
1039 }
1040
1041 fn get_checker(&self) -> &Checker {
1043 &self.init_checker().0
1044 }
1045}
1046
1047impl<'g, Base, Checker> portgraph::algorithms::ConvexChecker for ConvexChecker<'g, Base, Checker>
1048where
1049 Base: HugrView,
1050 Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
1051{
1052 fn is_convex(
1053 &self,
1054 nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
1055 inputs: impl IntoIterator<Item = portgraph::PortIndex>,
1056 outputs: impl IntoIterator<Item = portgraph::PortIndex>,
1057 ) -> bool {
1058 let mut nodes = nodes.into_iter().multipeek();
1059 if nodes.peek().is_none() || nodes.peek().is_none() {
1062 return true;
1063 }
1064 self.get_checker().is_convex(nodes, inputs, outputs)
1065 }
1066}
1067
1068impl<'g, Base: HugrView> LineConvexChecker<'g, Base> {
1069 pub fn get_intervals_from_nodes(
1071 &self,
1072 nodes: impl IntoIterator<Item = Base::Node>,
1073 ) -> Option<LineIntervals> {
1074 let (checker, node_map) = self.init_checker();
1075 let nodes = nodes
1076 .into_iter()
1077 .map(|n| node_map.to_portgraph(n))
1078 .collect_vec();
1079 checker.get_intervals_from_nodes(nodes)
1080 }
1081
1082 pub fn get_intervals_from_boundary_ports(
1088 &self,
1089 ports: impl IntoIterator<Item = (Base::Node, Port)>,
1090 ) -> Option<LineIntervals> {
1091 let (checker, node_map) = self.init_checker();
1092 let ports = ports
1093 .into_iter()
1094 .map(|(n, p)| {
1095 let node = node_map.to_portgraph(n);
1096 checker
1097 .graph()
1098 .port_index(node, p.pg_offset())
1099 .expect("valid port")
1100 })
1101 .collect_vec();
1102 checker.get_intervals_from_boundary_ports(ports)
1103 }
1104
1105 pub fn nodes_in_intervals<'a>(
1107 &'a self,
1108 intervals: &'a LineIntervals,
1109 ) -> impl Iterator<Item = Base::Node> + 'a {
1110 let (checker, node_map) = self.init_checker();
1111 checker
1112 .nodes_in_intervals(intervals)
1113 .map(|pg_node| node_map.from_portgraph(pg_node))
1114 }
1115
1116 pub fn lines_at_port(&self, node: Base::Node, port: impl Into<Port>) -> &[LineIndex] {
1118 let (checker, node_map) = self.init_checker();
1119 let port = checker
1120 .graph()
1121 .port_index(node_map.to_portgraph(node), port.into().pg_offset())
1122 .expect("valid port");
1123 checker.lines_at_port(port)
1124 }
1125
1126 pub fn try_extend_intervals(&self, intervals: &mut LineIntervals, node: Base::Node) -> bool {
1134 let (checker, node_map) = self.init_checker();
1135 let node = node_map.to_portgraph(node);
1136 checker.try_extend_intervals(intervals, node)
1137 }
1138
1139 pub fn get_position(&self, node: Base::Node) -> Position {
1141 let (checker, node_map) = self.init_checker();
1142 let node = node_map.to_portgraph(node);
1143 checker.get_position(node)
1144 }
1145}
1146
1147fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(
1151 hugr: &H,
1152 ports: &[(H::Node, P)],
1153) -> Option<Type> {
1154 let &(n, p) = ports.first()?;
1155 let edge_t = hugr.signature(n)?.port_type(p)?.clone();
1156 ports
1157 .iter()
1158 .all(|&(n, p)| {
1159 hugr.signature(n)
1160 .is_some_and(|s| s.port_type(p) == Some(&edge_t))
1161 })
1162 .then_some(edge_t)
1163}
1164
1165fn validate_subgraph<H: HugrView>(
1172 hugr: &H,
1173 nodes: &[H::Node],
1174 inputs: &IncomingPorts<H::Node>,
1175 outputs: &OutgoingPorts<H::Node>,
1176 function_calls: &IncomingPorts<H::Node>,
1177) -> Result<(), InvalidSubgraph<H::Node>> {
1178 let node_set = nodes.iter().copied().collect::<HashSet<_>>();
1180
1181 if nodes.is_empty() {
1183 return Err(InvalidSubgraph::EmptySubgraph);
1184 }
1185 if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() {
1187 let first_node = nodes[0];
1188 let first_parent = hugr
1189 .get_parent(first_node)
1190 .ok_or(InvalidSubgraph::OrphanNode { orphan: first_node })?;
1191 let other_node = *nodes
1192 .iter()
1193 .skip(1)
1194 .find(|&&n| hugr.get_parent(n) != Some(first_parent))
1195 .unwrap();
1196 let other_parent = hugr
1197 .get_parent(other_node)
1198 .ok_or(InvalidSubgraph::OrphanNode { orphan: other_node })?;
1199 return Err(InvalidSubgraph::NoSharedParent {
1200 first_node,
1201 first_parent,
1202 other_node,
1203 other_parent,
1204 });
1205 }
1206
1207 if let Some((n, p)) = iter_io(inputs, outputs).find(|&(n, p)| is_non_value_edge(hugr, n, p)) {
1209 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p));
1210 }
1211
1212 let boundary_ports = iter_io(inputs, outputs).collect_vec();
1213 if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
1215 Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
1216 }
1217 if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
1219 hugr.linked_ports(n, p)
1220 .all(|(n1, _)| node_set.contains(&n1))
1221 }) {
1222 Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
1223 }
1224
1225 let mut must_be_inputs = nodes
1228 .iter()
1229 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)))
1230 .filter(|&(n, p)| {
1231 hugr.linked_ports(n, p)
1232 .any(|(n1, _)| !node_set.contains(&n1))
1233 });
1234 if !must_be_inputs.all(|(n, p)| {
1235 let mut all_inputs = inputs.iter().chain(function_calls);
1236 all_inputs.any(|nps| nps.contains(&(n, p)))
1237 }) {
1238 return Err(InvalidSubgraph::NotConvex);
1239 }
1240 if nodes.iter().any(|&n| {
1243 hugr.node_outputs(n).any(|p| {
1244 hugr.linked_ports(n, p)
1245 .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
1246 })
1247 }) {
1248 return Err(InvalidSubgraph::NotConvex);
1249 }
1250
1251 if !inputs.iter().flatten().all_unique() {
1253 return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
1254 }
1255
1256 for inp in inputs {
1260 let &(in_node, in_port) = inp.first().ok_or(InvalidSubgraphBoundary::EmptyPartition)?;
1261 let exp_output_node_port = hugr
1262 .single_linked_output(in_node, in_port)
1263 .expect("valid dfg wire");
1264 if let Some(output_node_port) = inp
1265 .iter()
1266 .map(|&(in_node, in_port)| {
1267 hugr.single_linked_output(in_node, in_port)
1268 .expect("valid dfg wire")
1269 })
1270 .find(|&p| p != exp_output_node_port)
1271 {
1272 return Err(InvalidSubgraphBoundary::MismatchedOutputPort(
1273 (in_node, in_port),
1274 exp_output_node_port,
1275 output_node_port,
1276 )
1277 .into());
1278 }
1279 }
1280
1281 if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
1284 let Some(edge_t) = get_edge_type(hugr, ports) else {
1285 return true;
1286 };
1287 let require_copy = ports.len() > 1;
1288 require_copy && !edge_t.copyable()
1289 }) {
1290 Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
1291 }
1292
1293 for calls in function_calls {
1296 if !calls
1297 .iter()
1298 .map(|&(n, p)| hugr.single_linked_output(n, p))
1299 .all_equal()
1300 {
1301 let (n, p) = calls[0];
1302 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
1303 }
1304 for &(n, p) in calls {
1305 let op = hugr.get_optype(n);
1306 if op.static_input_port() != Some(p) {
1307 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
1308 }
1309 }
1310 }
1311
1312 Ok(())
1313}
1314
1315#[allow(clippy::type_complexity)]
1316fn get_input_output_ports<H: HugrView>(
1317 hugr: &H,
1318) -> Result<(IncomingPorts<H::Node>, OutgoingPorts<H::Node>), InvalidSubgraph<H::Node>> {
1319 let [inp, out] = hugr
1320 .get_io(HugrView::entrypoint(&hugr))
1321 .expect("invalid DFG");
1322 if let Some(p) = hugr
1323 .node_outputs(inp)
1324 .find(|&p| is_non_value_edge(hugr, inp, p.into()))
1325 {
1326 return Err(InvalidSubgraph::UnsupportedEdgeKind(inp, p.into()));
1327 }
1328 if let Some(p) = hugr
1329 .node_inputs(out)
1330 .find(|&p| is_non_value_edge(hugr, out, p.into()))
1331 {
1332 return Err(InvalidSubgraph::UnsupportedEdgeKind(out, p.into()));
1333 }
1334
1335 let dfg_inputs = HugrView::get_optype(&hugr, inp)
1336 .as_input()
1337 .unwrap()
1338 .signature()
1339 .output_ports();
1340 let dfg_outputs = HugrView::get_optype(&hugr, out)
1341 .as_output()
1342 .unwrap()
1343 .signature()
1344 .input_ports();
1345
1346 let inputs = dfg_inputs
1349 .into_iter()
1350 .map(|p| {
1351 hugr.linked_inputs(inp, p)
1352 .filter(|&(n, _)| n != out)
1353 .collect_vec()
1354 })
1355 .filter(|v| !v.is_empty())
1356 .collect();
1357 let outputs = dfg_outputs
1360 .into_iter()
1361 .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
1362 .collect();
1363 Ok((inputs, outputs))
1364}
1365
1366fn is_non_value_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
1368 let op = hugr.get_optype(node);
1369 let is_other = op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port);
1370 let is_static = op.static_port(port.direction()) == Some(port) && hugr.is_linked(node, port);
1371 is_other || is_static
1372}
1373
1374impl<'a, 'c, G: HugrView, Checker: Clone> From<&'a ConvexChecker<'c, G, Checker>>
1375 for std::borrow::Cow<'a, ConvexChecker<'c, G, Checker>>
1376{
1377 fn from(value: &'a ConvexChecker<'c, G, Checker>) -> Self {
1378 Self::Borrowed(value)
1379 }
1380}
1381
1382#[derive(Debug, Clone, PartialEq, Error)]
1384#[non_exhaustive]
1385pub enum InvalidReplacement {
1386 #[error("The root of the replacement {node} is a {}, but only dataflow parents are supported.", op.name())]
1388 InvalidDataflowGraph {
1389 node: Node,
1391 op: Box<OpType>,
1393 },
1394 #[error(
1396 "Replacement graph type mismatch. Expected {expected}, got {}.",
1397 actual.clone().map_or("none".to_string(), |t| t.to_string()))
1398 ]
1399 InvalidSignature {
1400 expected: Box<Signature>,
1402 actual: Option<Box<Signature>>,
1404 },
1405 #[error("SiblingSubgraph is not convex.")]
1407 NonConvexSubgraph,
1408}
1409
1410#[derive(Debug, Clone, PartialEq, Eq, Error)]
1412#[non_exhaustive]
1413pub enum InvalidSubgraph<N: HugrNode = Node> {
1414 #[error("The subgraph is not convex.")]
1416 NotConvex,
1417 #[error(
1419 "Not a sibling subgraph. {first_node} has parent {first_parent}, but {other_node} has parent {other_parent}."
1420 )]
1421 NoSharedParent {
1422 first_node: N,
1424 first_parent: N,
1426 other_node: N,
1428 other_parent: N,
1430 },
1431 #[error("Not a sibling subgraph. {orphan} has no parent")]
1433 OrphanNode {
1434 orphan: N,
1436 },
1437 #[error("Empty subgraphs are not supported.")]
1439 EmptySubgraph,
1440 #[error("Invalid boundary port.")]
1442 InvalidBoundary(#[from] InvalidSubgraphBoundary<N>),
1443 #[error("SiblingSubgraphs may only be defined on dataflow regions.")]
1445 NonDataflowRegion,
1446 #[error("The subgraphs induced by the nodes and the boundary do not match.")]
1448 InvalidNodeSet,
1449 #[error("Unsupported edge kind at ({_0}, {_1:?}).")]
1451 UnsupportedEdgeKind(N, Port),
1452}
1453
1454#[derive(Debug, Clone, PartialEq, Eq, Error)]
1456#[non_exhaustive]
1457pub enum InvalidSubgraphBoundary<N: HugrNode = Node> {
1458 #[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
1460 PortNodeNotInSet(N, Port),
1461 #[error(
1463 "(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph."
1464 )]
1465 DisconnectedBoundaryPort(N, Port),
1466 #[error("A port in the input boundary is used multiple times.")]
1468 NonUniqueInput,
1469 #[error("A partition in the input boundary is empty.")]
1471 EmptyPartition,
1472 #[error("expected port {0:?} to be linked to {1:?}, but is linked to {2:?}.")]
1475 MismatchedOutputPort((N, IncomingPort), (N, OutgoingPort), (N, OutgoingPort)),
1476 #[error("The partition {0} in the input boundary has ports with different types.")]
1478 MismatchedTypes(usize),
1479}
1480
1481#[derive(Debug, Clone, PartialEq, Eq, Error)]
1483#[error("Invalid output ports: {0:?}")]
1484pub enum InvalidOutputPorts<N: HugrNode = Node> {
1485 #[error("{port} in {node} was not part of the original boundary.")]
1487 UnknownOutput {
1488 port: OutgoingPort,
1490 node: N,
1492 },
1493 #[error("Linear ports must appear exactly once.")]
1495 NonUniqueLinear,
1496}
1497
1498fn has_unique_linear_ports<H: HugrView>(host: &H, ports: &OutgoingPorts<H::Node>) -> bool {
1500 let linear_ports: Vec<_> = ports
1501 .iter()
1502 .filter(|&&(n, p)| {
1503 host.get_optype(n)
1504 .port_kind(p)
1505 .is_some_and(|pk| pk.is_linear())
1506 })
1507 .collect();
1508 let unique_ports: HashSet<_> = linear_ports.iter().collect();
1509 linear_ports.len() == unique_ports.len()
1510}
1511
1512#[cfg(test)]
1513mod tests {
1514 use cool_asserts::assert_matches;
1515 use rstest::{fixture, rstest};
1516
1517 use crate::builder::{endo_sig, inout_sig};
1518 use crate::extension::prelude::{MakeTuple, UnpackTuple};
1519 use crate::hugr::Patch;
1520 use crate::ops::Const;
1521 use crate::ops::handle::DataflowParentID;
1522 use crate::std_extensions::arithmetic::float_types::ConstF64;
1523 use crate::std_extensions::logic::LogicOp;
1524 use crate::type_row;
1525 use crate::utils::test_quantum_extension::{cx_gate, rz_f64};
1526 use crate::{
1527 builder::{
1528 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
1529 ModuleBuilder,
1530 },
1531 extension::prelude::{bool_t, qb_t},
1532 ops::handle::{DfgID, FuncID, NodeHandle},
1533 std_extensions::logic::test::and_op,
1534 };
1535
1536 use super::*;
1537
1538 impl<N: HugrNode> SiblingSubgraph<N> {
1539 fn from_sibling_graph(
1544 hugr: &impl HugrView<Node = N>,
1545 parent: N,
1546 ) -> Result<Self, InvalidSubgraph<N>> {
1547 let nodes = hugr.children(parent).collect_vec();
1548 if nodes.is_empty() {
1549 Err(InvalidSubgraph::EmptySubgraph)
1550 } else {
1551 Ok(Self {
1552 nodes,
1553 inputs: Vec::new(),
1554 outputs: Vec::new(),
1555 function_calls: Vec::new(),
1556 })
1557 }
1558 }
1559 }
1560
1561 fn build_hugr() -> Result<(Hugr, Node), BuildError> {
1565 let mut mod_builder = ModuleBuilder::new();
1566 let func = mod_builder.declare(
1567 "test",
1568 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).into(),
1569 )?;
1570 let func_id = {
1571 let mut dfg = mod_builder.define_declaration(&func)?;
1572 let [w0, w1, w2] = dfg.input_wires_arr();
1573 let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
1574 let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
1575 let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
1576 dfg.finish_with_outputs([w0, w1, w2])?
1577 };
1578 let hugr = mod_builder
1579 .finish_hugr()
1580 .map_err(|e| -> BuildError { e.into() })?;
1581 Ok((hugr, func_id.node()))
1582 }
1583
1584 fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
1586 let mut mod_builder = ModuleBuilder::new();
1587 let func = mod_builder.declare("test", Signature::new_endo(vec![bool_t()]).into())?;
1588 let func_id = {
1589 let mut dfg = mod_builder.define_declaration(&func)?;
1590 let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?;
1591 let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?;
1592 let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?;
1593 dfg.finish_with_outputs(outs3.outputs())?
1594 };
1595 let hugr = mod_builder
1596 .finish_hugr()
1597 .map_err(|e| -> BuildError { e.into() })?;
1598 Ok((hugr, func_id.node()))
1599 }
1600
1601 fn build_multiport_hugr() -> Result<(Hugr, Node), BuildError> {
1603 let mut mod_builder = ModuleBuilder::new();
1604 let func = mod_builder.declare(
1605 "test",
1606 Signature::new(bool_t(), vec![bool_t(), bool_t()]).into(),
1607 )?;
1608 let func_id = {
1609 let mut dfg = mod_builder.define_declaration(&func)?;
1610 let [b0] = dfg.input_wires_arr();
1611 let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr();
1612 let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr();
1613 dfg.finish_with_outputs([b1, b2])?
1614 };
1615 let hugr = mod_builder
1616 .finish_hugr()
1617 .map_err(|e| -> BuildError { e.into() })?;
1618 Ok((hugr, func_id.node()))
1619 }
1620
1621 fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
1623 let mut mod_builder = ModuleBuilder::new();
1624 let func = mod_builder.declare("test", Signature::new_endo(bool_t()).into())?;
1625 let func_id = {
1626 let mut dfg = mod_builder.define_declaration(&func)?;
1627 let in_wire = dfg.input_wires().exactly_one().unwrap();
1628 let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?;
1629 dfg.finish_with_outputs(outs.outputs())?
1630 };
1631 let hugr = mod_builder
1632 .finish_hugr()
1633 .map_err(|e| -> BuildError { e.into() })?;
1634 Ok((hugr, func_id.node()))
1635 }
1636
1637 #[test]
1638 fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
1639 let (mut hugr, func_root) = build_hugr().unwrap();
1640 let func = hugr.with_entrypoint(func_root);
1641 let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
1642 assert!(sub.validate(&func, Default::default()).is_ok());
1643
1644 let empty_dfg = {
1645 let builder =
1646 DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap();
1647 let inputs = builder.input_wires();
1648 builder.finish_hugr_with_outputs(inputs).unwrap()
1649 };
1650
1651 let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();
1652
1653 assert_eq!(rep.subgraph().nodes().len(), 4);
1654
1655 assert_eq!(hugr.num_nodes(), 8); hugr.apply_patch(rep).unwrap();
1657 assert_eq!(hugr.num_nodes(), 4); Ok(())
1660 }
1661
1662 #[test]
1663 fn test_signature() -> Result<(), InvalidSubgraph> {
1664 let (hugr, dfg) = build_hugr().unwrap();
1665 let func = hugr.with_entrypoint(dfg);
1666 let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
1667 assert!(sub.validate(&func, Default::default()).is_ok());
1668 assert_eq!(
1669 sub.signature(&func),
1670 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
1671 );
1672 Ok(())
1673 }
1674
1675 #[test]
1676 fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> {
1677 let (hugr, dfg) = build_hugr().unwrap();
1678 let func = hugr.with_entrypoint(dfg);
1679 let sub = SiblingSubgraph::from_sibling_graph(&hugr, dfg)?;
1680
1681 let empty_dfg = {
1682 let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
1683 let inputs = builder.input_wires();
1684 builder.finish_hugr_with_outputs(inputs).unwrap()
1685 };
1686
1687 assert_matches!(
1688 sub.create_simple_replacement(&func, empty_dfg).unwrap_err(),
1689 InvalidReplacement::InvalidSignature { .. }
1690 );
1691 Ok(())
1692 }
1693
1694 #[test]
1695 fn convex_subgraph() {
1696 let (hugr, func_root) = build_hugr().unwrap();
1697 let func = hugr.with_entrypoint(func_root);
1698 assert_eq!(
1699 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)
1700 .unwrap()
1701 .nodes()
1702 .len(),
1703 4
1704 );
1705 }
1706
1707 #[test]
1708 fn convex_subgraph_2() {
1709 let (hugr, func_root) = build_hugr().unwrap();
1710 let [inp, out] = hugr.get_io(func_root).unwrap();
1711 let func = hugr.with_entrypoint(func_root);
1712 SiblingSubgraph::try_new(
1714 hugr.node_outputs(inp)
1715 .take(2)
1716 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
1717 .filter(|ps| !ps.is_empty())
1718 .collect(),
1719 hugr.node_inputs(out)
1720 .take(2)
1721 .filter_map(|p| hugr.single_linked_output(out, p))
1722 .collect(),
1723 &func,
1724 )
1725 .unwrap();
1726 }
1727
1728 #[test]
1729 fn degen_boundary() {
1730 let (hugr, func_root) = build_hugr().unwrap();
1731 let func = hugr.with_entrypoint(func_root);
1732 let [inp, _] = hugr.get_io(func_root).unwrap();
1733 let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
1734 assert_matches!(
1736 SiblingSubgraph::try_new(
1737 vec![
1738 hugr.linked_ports(inp, first_cx_edge)
1739 .map(|(n, p)| (n, p.as_incoming().unwrap()))
1740 .collect()
1741 ],
1742 vec![(inp, first_cx_edge)],
1743 &func,
1744 ),
1745 Err(InvalidSubgraph::InvalidBoundary(
1746 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1747 ))
1748 );
1749 }
1750
1751 #[test]
1752 fn non_convex_subgraph() {
1753 let (hugr, func_root) = build_3not_hugr().unwrap();
1754 let func = hugr.with_entrypoint(func_root);
1755 let [inp, _out] = hugr.get_io(func_root).unwrap();
1756 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1757 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
1758 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
1759 let not1_inp = hugr.node_inputs(not1).next().unwrap();
1760 let not1_out = hugr.node_outputs(not1).next().unwrap();
1761 let not3_inp = hugr.node_inputs(not3).next().unwrap();
1762 let not3_out = hugr.node_outputs(not3).next().unwrap();
1763 assert_matches!(
1764 SiblingSubgraph::try_new(
1765 vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
1766 vec![(not1, not1_out), (not3, not3_out)],
1767 &func
1768 ),
1769 Err(InvalidSubgraph::NotConvex)
1770 );
1771 }
1772
1773 #[test]
1776 fn convex_multiports() {
1777 let (hugr, func_root) = build_multiport_hugr().unwrap();
1778 let [inp, out] = hugr.get_io(func_root).unwrap();
1779 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1780 let not2 = hugr
1781 .output_neighbours(not1)
1782 .filter(|&n| n != out)
1783 .exactly_one()
1784 .ok()
1785 .unwrap();
1786
1787 let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
1788 assert_eq!(subgraph.nodes(), [not1, not2]);
1789 }
1790
1791 #[test]
1792 fn invalid_boundary() {
1793 let (hugr, func_root) = build_hugr().unwrap();
1794 let func = hugr.with_entrypoint(func_root);
1795 let [inp, out] = hugr.get_io(func_root).unwrap();
1796 let cx_edges_in = hugr.node_outputs(inp);
1797 let cx_edges_out = hugr.node_inputs(out);
1798 assert_matches!(
1800 SiblingSubgraph::try_new(
1801 cx_edges_out.map(|p| vec![(out, p)]).collect(),
1802 cx_edges_in.map(|p| (inp, p)).collect(),
1803 &func,
1804 ),
1805 Err(InvalidSubgraph::InvalidBoundary(
1806 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1807 ))
1808 );
1809 }
1810
1811 #[test]
1812 fn preserve_signature() {
1813 let (hugr, func_root) = build_hugr_classical().unwrap();
1814 let func_graph = hugr.with_entrypoint(func_root);
1815 let func =
1816 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
1817 let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
1818 assert_eq!(func_defn.signature(), &func.signature(&func_graph).into());
1819 }
1820
1821 #[test]
1822 fn extract_subgraph() {
1823 let (hugr, func_root) = build_hugr().unwrap();
1824 let func_graph = hugr.with_entrypoint(func_root);
1825 let subgraph =
1826 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
1827 let extracted = subgraph.extract_subgraph(&hugr, "region");
1828
1829 extracted.validate().unwrap();
1830 }
1831
1832 #[test]
1833 fn edge_both_output_and_copy() {
1834 let one_bit = vec![bool_t()];
1836 let two_bit = vec![bool_t(), bool_t()];
1837
1838 let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap();
1839 let inw = builder.input_wires().exactly_one().unwrap();
1840 let outw1 = builder
1841 .add_dataflow_op(LogicOp::Not, [inw])
1842 .unwrap()
1843 .out_wire(0);
1844 let outw2 = builder
1845 .add_dataflow_op(and_op(), [inw, outw1])
1846 .unwrap()
1847 .outputs();
1848 let outw = [outw1].into_iter().chain(outw2);
1849 let h = builder.finish_hugr_with_outputs(outw).unwrap();
1850 let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&h).unwrap();
1851 assert_eq!(subg.nodes().len(), 2);
1852 }
1853
1854 #[test]
1855 fn test_unconnected() {
1856 let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap();
1858 let inw = b.input_wires().exactly_one().unwrap();
1859 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1860 let mut h = b.finish_hugr_with_outputs([]).unwrap();
1862
1863 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1864
1865 assert_eq!(subg.nodes().len(), 1);
1866 let replacement = {
1868 let mut rep_b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap();
1869 let inw = rep_b.input_wires().exactly_one().unwrap();
1870
1871 let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1872
1873 rep_b.finish_hugr_with_outputs(not_n.outputs()).unwrap()
1874 };
1875 let rep = subg.create_simple_replacement(&h, replacement).unwrap();
1876 rep.apply(&mut h).unwrap();
1877 }
1878
1879 #[test]
1882 fn single_node_subgraph() {
1883 let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap();
1885 let inw = b.input_wires().exactly_one().unwrap();
1886 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1887 let h = b.finish_hugr_with_outputs([]).unwrap();
1889
1890 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1893 assert_eq!(subg.nodes().len(), 1);
1894 assert_eq!(
1895 subg.signature(&h).io(),
1896 Signature::new(vec![bool_t()], vec![bool_t()]).io()
1897 );
1898
1899 let subg = SiblingSubgraph::try_from_nodes([not_n.node()], &h).unwrap();
1903 assert_eq!(subg.nodes().len(), 1);
1904 assert_eq!(
1905 subg.signature(&h).io(),
1906 Signature::new(vec![bool_t()], vec![]).io()
1907 );
1908 }
1909
1910 #[test]
1913 fn singleton_disconnected_subgraph() {
1914 let op = MakeTuple::new(type_row![]);
1916
1917 let mut b = DFGBuilder::new(Signature::new_endo(type_row![])).unwrap();
1918 let _mk_tuple_1 = b.add_dataflow_op(op.clone(), []).unwrap();
1919 let mk_tuple_2 = b.add_dataflow_op(op.clone(), []).unwrap();
1920 let _mk_tuple_3 = b.add_dataflow_op(op, []).unwrap();
1921 let h = b.finish_hugr_with_outputs([]).unwrap();
1923
1924 let subg = SiblingSubgraph::from_node(mk_tuple_2.node(), &h);
1927 assert_eq!(subg.nodes().len(), 1);
1928 assert_eq!(
1929 subg.signature(&h).io(),
1930 Signature::new(type_row![], vec![Type::new_tuple(type_row![])]).io()
1931 );
1932
1933 let subg = SiblingSubgraph::try_from_nodes([mk_tuple_2.node()], &h).unwrap();
1937 assert_eq!(subg.nodes().len(), 1);
1938 assert_eq!(
1939 subg.signature(&h).io(),
1940 Signature::new_endo(type_row![]).io()
1941 );
1942 }
1943
1944 #[test]
1946 fn partially_connected_subgraph() {
1947 let tuple_op = MakeTuple::new(type_row![]);
1949 let untuple_op = UnpackTuple::new(type_row![]);
1950 let tuple_t = Type::new_tuple(type_row![]);
1951
1952 let mut b = DFGBuilder::new(Signature::new(type_row![], vec![tuple_t.clone()])).unwrap();
1953 let mk_tuple_1 = b.add_dataflow_op(tuple_op.clone(), []).unwrap();
1954 let untuple_1 = b
1955 .add_dataflow_op(untuple_op.clone(), [mk_tuple_1.out_wire(0)])
1956 .unwrap();
1957 let mk_tuple_2 = b.add_dataflow_op(tuple_op.clone(), []).unwrap();
1958 let _mk_tuple_3 = b.add_dataflow_op(tuple_op, []).unwrap();
1959 let h = b
1961 .finish_hugr_with_outputs([mk_tuple_2.out_wire(0)])
1962 .unwrap();
1963
1964 let subgraph_nodes = [mk_tuple_1.node(), mk_tuple_2.node(), untuple_1.node()];
1965
1966 let subg = SiblingSubgraph::try_from_nodes(subgraph_nodes, &h).unwrap();
1968 assert_eq!(subg.nodes().len(), 3);
1969 assert_eq!(
1970 subg.signature(&h).io(),
1971 Signature::new(type_row![], vec![tuple_t]).io()
1972 );
1973 }
1974
1975 #[test]
1976 fn test_set_outgoing_ports() {
1977 let (hugr, func_root) = build_3not_hugr().unwrap();
1978 let [inp, out] = hugr.get_io(func_root).unwrap();
1979 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1980 let not1_out = hugr.node_outputs(not1).next().unwrap();
1981
1982 let mut subgraph = SiblingSubgraph::from_node(not1, &hugr);
1984
1985 assert_eq!(subgraph.outgoing_ports().len(), 1);
1987
1988 let new_outputs = vec![(not1, not1_out), (not1, not1_out)];
1990 assert!(subgraph.set_outgoing_ports(new_outputs, &hugr).is_ok());
1991
1992 assert_eq!(subgraph.outgoing_ports().len(), 2);
1994
1995 let invalid_outputs = vec![(not1, not1_out), (out, 2.into())];
1997 assert!(matches!(
1998 subgraph.set_outgoing_ports(invalid_outputs, &hugr),
1999 Err(InvalidOutputPorts::UnknownOutput { .. })
2000 ));
2001
2002 assert_eq!(subgraph.outgoing_ports().len(), 2);
2004 }
2005
2006 #[test]
2007 fn test_set_outgoing_ports_linear() {
2008 let (hugr, func_root) = build_hugr().unwrap();
2009 let [inp, _out] = hugr.get_io(func_root).unwrap();
2010 let rz = hugr.output_neighbours(inp).nth(2).unwrap();
2011 let rz_out = hugr.node_outputs(rz).next().unwrap();
2012
2013 let mut subgraph = SiblingSubgraph::from_node(rz, &hugr);
2015
2016 assert_eq!(subgraph.outgoing_ports().len(), 1);
2018
2019 let new_outputs = vec![(rz, rz_out), (rz, rz_out)];
2022 assert!(matches!(
2023 subgraph.set_outgoing_ports(new_outputs, &hugr),
2024 Err(InvalidOutputPorts::NonUniqueLinear)
2025 ));
2026
2027 assert_eq!(subgraph.outgoing_ports().len(), 1);
2029 }
2030
2031 #[test]
2032 fn test_try_from_nodes_with_intervals() {
2033 let (hugr, func_root) = build_3not_hugr().unwrap();
2034 let line_checker = LineConvexChecker::new(&hugr, func_root);
2035 let [inp, _out] = hugr.get_io(func_root).unwrap();
2036 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
2037 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
2038
2039 let intervals = line_checker.get_intervals_from_nodes([not1, not2]).unwrap();
2040 let subgraph =
2041 SiblingSubgraph::try_from_nodes_with_intervals([not1, not2], &intervals, &line_checker)
2042 .unwrap();
2043 let exp_subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
2044
2045 assert_eq!(subgraph, exp_subgraph);
2046 assert_eq!(
2047 line_checker.nodes_in_intervals(&intervals).collect_vec(),
2048 [not1, not2]
2049 );
2050
2051 let intervals2 = line_checker
2052 .get_intervals_from_boundary_ports([
2053 (not1, IncomingPort::from(0).into()),
2054 (not2, OutgoingPort::from(0).into()),
2055 ])
2056 .unwrap();
2057 let subgraph2 = SiblingSubgraph::try_from_nodes_with_intervals(
2058 [not1, not2],
2059 &intervals2,
2060 &line_checker,
2061 )
2062 .unwrap();
2063 assert_eq!(subgraph2, exp_subgraph);
2064 }
2065
2066 #[test]
2067 fn test_validate() {
2068 let (hugr, func_root) = build_3not_hugr().unwrap();
2069 let func = hugr.with_entrypoint(func_root);
2070 let checker = TopoConvexChecker::new(&func, func_root);
2071 let [inp, _out] = hugr.get_io(func_root).unwrap();
2072 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
2073 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
2074 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
2075
2076 let sub = SiblingSubgraph::new_unchecked(
2078 vec![vec![(not1, 0.into())]],
2079 vec![(not2, 0.into())],
2080 vec![],
2081 vec![not1, not2],
2082 );
2083 assert_eq!(sub.validate(&func, ValidationMode::SkipConvexity), Ok(()));
2084 assert_eq!(sub.validate(&func, ValidationMode::CheckConvexity), Ok(()));
2085 assert_eq!(
2086 sub.validate(&func, ValidationMode::WithChecker(&checker)),
2087 Ok(())
2088 );
2089
2090 let sub = SiblingSubgraph::new_unchecked(
2092 vec![vec![(not1, 0.into())], vec![(not3, 0.into())]],
2093 vec![(not1, 0.into()), (not3, 0.into())],
2094 vec![],
2095 vec![not1, not3],
2096 );
2097 assert_eq!(sub.validate(&func, ValidationMode::SkipConvexity), Ok(()));
2098 assert_eq!(
2099 sub.validate(&func, ValidationMode::CheckConvexity),
2100 Err(InvalidSubgraph::NotConvex)
2101 );
2102 assert_eq!(
2103 sub.validate(&func, ValidationMode::WithChecker(&checker)),
2104 Err(InvalidSubgraph::NotConvex)
2105 );
2106
2107 let sub = SiblingSubgraph::new_unchecked(
2109 vec![vec![(not1, 0.into())]],
2110 vec![(not1, 0.into()), (not3, 0.into())],
2111 vec![],
2112 vec![not1, not3],
2113 );
2114 assert_eq!(
2115 sub.validate(&func, ValidationMode::SkipConvexity),
2116 Err(InvalidSubgraph::InvalidNodeSet)
2117 );
2118 }
2119
2120 #[fixture]
2121 pub(crate) fn hugr_call_subgraph() -> Hugr {
2122 let mut builder = ModuleBuilder::new();
2123 let decl_node = builder.declare("test", endo_sig(bool_t()).into()).unwrap();
2124 let mut main = builder.define_function("main", endo_sig(bool_t())).unwrap();
2125 let [bool] = main.input_wires_arr();
2126
2127 let [bool] = main
2128 .add_dataflow_op(LogicOp::Not, [bool])
2129 .unwrap()
2130 .outputs_arr();
2131
2132 let [bool] = main.call(&decl_node, &[], [bool]).unwrap().outputs_arr();
2134 let [bool] = main.call(&decl_node, &[], [bool]).unwrap().outputs_arr();
2135
2136 let main_def = main.finish_with_outputs([bool]).unwrap();
2137
2138 let mut hugr = builder.finish_hugr().unwrap();
2139 hugr.set_entrypoint(main_def.node());
2140 hugr
2141 }
2142
2143 #[rstest]
2144 fn test_call_subgraph_from_dfg(hugr_call_subgraph: Hugr) {
2145 let subg =
2146 SiblingSubgraph::try_new_dataflow_subgraph::<_, DataflowParentID>(&hugr_call_subgraph)
2147 .unwrap();
2148
2149 assert_eq!(subg.function_calls.len(), 1);
2150 assert_eq!(subg.function_calls[0].len(), 2);
2151 }
2152
2153 #[rstest]
2154 fn test_call_subgraph_from_nodes(hugr_call_subgraph: Hugr) {
2155 let call_nodes = hugr_call_subgraph
2156 .children(hugr_call_subgraph.entrypoint())
2157 .filter(|&n| hugr_call_subgraph.get_optype(n).is_call())
2158 .collect_vec();
2159
2160 let subg =
2161 SiblingSubgraph::try_from_nodes(call_nodes.clone(), &hugr_call_subgraph).unwrap();
2162 assert_eq!(subg.function_calls.len(), 1);
2163 assert_eq!(subg.function_calls[0].len(), 2);
2164
2165 let subg =
2166 SiblingSubgraph::try_from_nodes(call_nodes[0..1].to_owned(), &hugr_call_subgraph)
2167 .unwrap();
2168 assert_eq!(subg.function_calls.len(), 1);
2169 assert_eq!(subg.function_calls[0].len(), 1);
2170 }
2171
2172 #[rstest]
2173 fn test_call_subgraph_from_boundary(hugr_call_subgraph: Hugr) {
2174 let call_nodes = hugr_call_subgraph
2175 .children(hugr_call_subgraph.entrypoint())
2176 .filter(|&n| hugr_call_subgraph.get_optype(n).is_call())
2177 .collect_vec();
2178 let not_node = hugr_call_subgraph
2179 .children(hugr_call_subgraph.entrypoint())
2180 .filter(|&n| hugr_call_subgraph.get_optype(n) == &LogicOp::Not.into())
2181 .exactly_one()
2182 .ok()
2183 .unwrap();
2184
2185 let subg = SiblingSubgraph::try_new(
2186 vec![
2187 vec![(not_node, IncomingPort::from(0))],
2188 call_nodes
2189 .iter()
2190 .map(|&n| (n, IncomingPort::from(1)))
2191 .collect_vec(),
2192 ],
2193 vec![(call_nodes[1], OutgoingPort::from(0))],
2194 &hugr_call_subgraph,
2195 )
2196 .unwrap();
2197
2198 assert_eq!(subg.function_calls.len(), 1);
2199 assert_eq!(subg.function_calls[0].len(), 2);
2200 }
2201}