1use std::cell::OnceCell;
8use std::collections::HashSet;
9use std::mem;
10
11use itertools::Itertools;
12use portgraph::LinkView;
13use portgraph::PortView;
14use portgraph::algorithms::CreateConvexChecker;
15use portgraph::algorithms::convex::{LineIndex, LineIntervals, Position};
16use portgraph::boundary::Boundary;
17use rustc_hash::FxHashSet;
18use thiserror::Error;
19
20use crate::builder::{Container, FunctionBuilder};
21use crate::core::HugrNode;
22use crate::hugr::internal::{HugrInternals, PortgraphNodeMap};
23use crate::hugr::{HugrMut, HugrView};
24use crate::ops::dataflow::DataflowOpTrait;
25use crate::ops::handle::{ContainerHandle, DataflowOpID};
26use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
27use crate::types::{Signature, Type};
28use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
29
30use super::root_checked::RootCheckable;
31
32#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
57pub struct SiblingSubgraph<N = Node> {
58 nodes: Vec<N>,
60 inputs: IncomingPorts<N>,
75 outputs: OutgoingPorts<N>,
83 function_calls: IncomingPorts<N>,
88}
89
90pub type IncomingPorts<N = Node> = Vec<Vec<(N, IncomingPort)>>;
97pub type OutgoingPorts<N = Node> = Vec<(N, OutgoingPort)>;
99
100impl<N: HugrNode> SiblingSubgraph<N> {
101 pub fn try_new_dataflow_subgraph<'h, H, Root>(
113 dfg_graph: impl RootCheckable<&'h H, Root>,
114 ) -> Result<Self, InvalidSubgraph<N>>
115 where
116 H: 'h + Clone + HugrView<Node = N>,
117 Root: ContainerHandle<N, ChildrenHandle = DataflowOpID>,
118 {
119 let Ok(dfg_graph) = dfg_graph.try_into_checked() else {
120 return Err(InvalidSubgraph::NonDataflowRegion);
121 };
122 let dfg_graph = dfg_graph.into_hugr();
123
124 let parent = HugrView::entrypoint(&dfg_graph);
125 let nodes = dfg_graph.children(parent).skip(2).collect_vec();
126
127 if nodes.is_empty() {
128 return Err(InvalidSubgraph::EmptySubgraph);
129 }
130
131 let (inputs, outputs) = get_input_output_ports(dfg_graph)?;
132 let non_local = get_non_local_edges(&nodes, &dfg_graph);
133 let function_calls = group_into_function_calls(non_local, &dfg_graph)?;
134
135 validate_subgraph(dfg_graph, &nodes, &inputs, &outputs, &function_calls)?;
136
137 Ok(Self {
138 nodes,
139 inputs,
140 outputs,
141 function_calls,
142 })
143 }
144
145 pub fn try_new(
193 inputs: IncomingPorts<N>,
194 outputs: OutgoingPorts<N>,
195 hugr: &impl HugrView<Node = N>,
196 ) -> Result<Self, InvalidSubgraph<N>> {
197 let parent = pick_parent(hugr, &inputs, &outputs)?;
198 let checker = TopoConvexChecker::new(hugr, parent);
199 Self::try_new_with_checker(inputs, outputs, hugr, &checker)
200 }
201
202 pub fn new_unchecked(
219 inputs: IncomingPorts<N>,
220 outputs: OutgoingPorts<N>,
221 function_calls: IncomingPorts<N>,
222 nodes: Vec<N>,
223 ) -> Self {
224 Self {
225 nodes,
226 inputs,
227 outputs,
228 function_calls,
229 }
230 }
231
232 pub fn try_new_with_checker<H: HugrView<Node = N>>(
242 mut inputs: IncomingPorts<N>,
243 outputs: OutgoingPorts<N>,
244 hugr: &H,
245 checker: &TopoConvexChecker<H>,
246 ) -> Result<Self, InvalidSubgraph<N>> {
247 let (subpg, node_map) = make_pg_subgraph(hugr, &inputs, &outputs);
248 let nodes = subpg
249 .nodes_iter()
250 .map(|index| node_map.from_portgraph(index))
251 .collect_vec();
252
253 let function_calls = drain_function_calls(&mut inputs, hugr);
254
255 validate_subgraph(hugr, &nodes, &inputs, &outputs, &function_calls)?;
256
257 if nodes.len() > 1 && !subpg.is_convex_with_checker(checker) {
258 return Err(InvalidSubgraph::NotConvex);
259 }
260
261 Ok(Self {
262 nodes,
263 inputs,
264 outputs,
265 function_calls,
266 })
267 }
268
269 pub fn try_from_nodes(
285 nodes: impl Into<Vec<N>>,
286 hugr: &impl HugrView<Node = N>,
287 ) -> Result<Self, InvalidSubgraph<N>> {
288 let nodes = nodes.into();
289 let Some(node) = nodes.first() else {
290 return Err(InvalidSubgraph::EmptySubgraph);
291 };
292 let parent = hugr
293 .get_parent(*node)
294 .ok_or(InvalidSubgraph::OrphanNode { orphan: *node })?;
295
296 let checker = TopoConvexChecker::new(hugr, parent);
297 Self::try_from_nodes_with_checker(nodes, hugr, &checker)
298 }
299
300 pub fn try_from_nodes_with_checker<H: HugrView<Node = N>>(
310 nodes: impl Into<Vec<N>>,
311 hugr: &H,
312 checker: &TopoConvexChecker<H>,
313 ) -> Result<Self, InvalidSubgraph<N>> {
314 let mut nodes: Vec<N> = nodes.into();
315 let num_nodes = nodes.len();
316
317 if nodes.is_empty() {
318 return Err(InvalidSubgraph::EmptySubgraph);
319 }
320
321 let (inputs, outputs) = get_boundary_from_nodes(hugr, &mut nodes);
322
323 if inputs.is_empty() && outputs.is_empty() {
326 return Ok(Self {
327 nodes,
328 inputs,
329 outputs,
330 function_calls: vec![],
331 });
332 }
333
334 let mut subgraph = Self::try_new_with_checker(inputs, outputs, hugr, checker)?;
335
336 if subgraph.node_count() < num_nodes {
339 subgraph.nodes = nodes;
340 }
341
342 Ok(subgraph)
343 }
344
345 pub fn try_from_nodes_with_intervals(
359 nodes: impl Into<Vec<N>>,
360 intervals: &LineIntervals,
361 line_checker: &LineConvexChecker<impl HugrView<Node = N>>,
362 ) -> Result<Self, InvalidSubgraph<N>> {
363 if !line_checker.get_checker().is_convex_by_intervals(intervals) {
364 return Err(InvalidSubgraph::NotConvex);
365 }
366
367 let nodes: Vec<N> = nodes.into();
368 let hugr = line_checker.hugr();
369
370 if nodes.is_empty() {
371 return Err(InvalidSubgraph::EmptySubgraph);
372 }
373
374 let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
375 let incoming_edges = nodes
376 .iter()
377 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
378 let outgoing_edges = nodes
379 .iter()
380 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
381 let mut inputs = incoming_edges
382 .filter(|&(n, p)| {
383 if !hugr.is_linked(n, p) {
384 return false;
385 }
386 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
387 !nodes_set.contains(&out_n)
388 })
389 .map(|p| vec![p])
391 .collect_vec();
392 let outputs = outgoing_edges
393 .filter(|&(n, p)| {
394 hugr.linked_ports(n, p)
395 .any(|(n1, _)| !nodes_set.contains(&n1))
396 })
397 .collect_vec();
398 let function_calls = drain_function_calls(&mut inputs, hugr);
399
400 Ok(Self {
401 nodes,
402 inputs,
403 outputs,
404 function_calls,
405 })
406 }
407
408 pub fn from_node(node: N, hugr: &impl HugrView<Node = N>) -> Self {
416 let nodes = vec![node];
417 let mut inputs = hugr
418 .node_inputs(node)
419 .filter(|&p| hugr.is_linked(node, p))
420 .map(|p| vec![(node, p)])
421 .collect_vec();
422 let outputs = hugr
423 .node_outputs(node)
424 .filter_map(|p| {
425 {
427 hugr.is_linked(node, p)
428 || HugrView::get_optype(&hugr, node)
429 .port_kind(p)
430 .is_some_and(|k| k.is_value())
431 }
432 .then_some((node, p))
433 })
434 .collect_vec();
435 let function_calls = drain_function_calls(&mut inputs, hugr);
436
437 let state_order_at_input = hugr
438 .get_optype(node)
439 .other_output_port()
440 .is_some_and(|p| hugr.is_linked(node, p));
441 let state_order_at_output = hugr
442 .get_optype(node)
443 .other_input_port()
444 .is_some_and(|p| hugr.is_linked(node, p));
445 if state_order_at_input || state_order_at_output {
446 unimplemented!("Order edges in {node:?} not supported");
447 }
448
449 Self {
450 nodes,
451 inputs,
452 outputs,
453 function_calls,
454 }
455 }
456
457 pub fn validate<'h, H: HugrView<Node = N>>(
468 &self,
469 hugr: &'h H,
470 mode: ValidationMode<'_, 'h, H>,
471 ) -> Result<(), InvalidSubgraph<N>> {
472 let mut exp_nodes = {
473 let (subpg, node_map) = make_pg_subgraph(hugr, &self.inputs, &self.outputs);
474 subpg
475 .nodes_iter()
476 .map(|n| node_map.from_portgraph(n))
477 .collect_vec()
478 };
479 let mut nodes = self.nodes.clone();
480
481 exp_nodes.sort_unstable();
482 nodes.sort_unstable();
483
484 if exp_nodes != nodes {
485 return Err(InvalidSubgraph::InvalidNodeSet);
486 }
487
488 validate_subgraph(
489 hugr,
490 &self.nodes,
491 &self.inputs,
492 &self.outputs,
493 &self.function_calls,
494 )?;
495
496 let checker;
497 let checker_ref = match mode {
498 ValidationMode::WithChecker(c) => Some(c),
499 ValidationMode::CheckConvexity => {
500 checker = ConvexChecker::new(hugr, self.get_parent(hugr));
501 Some(&checker)
502 }
503 ValidationMode::SkipConvexity => None,
504 };
505 if let Some(checker) = checker_ref {
506 let (subpg, _) = make_pg_subgraph(hugr, &self.inputs, &self.outputs);
507 if !subpg.is_convex_with_checker(&checker.init_checker().0) {
508 return Err(InvalidSubgraph::NotConvex);
509 }
510 }
511 Ok(())
512 }
513
514 #[must_use]
516 pub fn nodes(&self) -> &[N] {
517 &self.nodes
518 }
519
520 #[must_use]
522 pub fn node_count(&self) -> usize {
523 self.nodes.len()
524 }
525
526 #[must_use]
528 pub fn incoming_ports(&self) -> &IncomingPorts<N> {
529 &self.inputs
530 }
531
532 #[must_use]
534 pub fn outgoing_ports(&self) -> &OutgoingPorts<N> {
535 &self.outputs
536 }
537
538 #[must_use]
541 pub fn function_calls(&self) -> &IncomingPorts<N> {
542 &self.function_calls
543 }
544
545 pub fn signature(&self, hugr: &impl HugrView<Node = N>) -> Signature {
547 let input = self
548 .inputs
549 .iter()
550 .map(|part| {
551 let &(n, p) = part.iter().next().expect("is non-empty");
552 let sig = hugr.signature(n).expect("must have dataflow signature");
553 sig.port_type(p).cloned().expect("must be dataflow edge")
554 })
555 .collect_vec();
556 let output = self
557 .outputs
558 .iter()
559 .map(|&(n, p)| {
560 let sig = hugr.signature(n).expect("must have dataflow signature");
561 sig.port_type(p).cloned().expect("must be dataflow edge")
562 })
563 .collect_vec();
564 Signature::new(input, output)
565 }
566
567 pub fn get_parent(&self, hugr: &impl HugrView<Node = N>) -> N {
569 hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
570 }
571
572 pub(crate) fn map_nodes<N2: HugrNode>(
577 &self,
578 node_map: impl Fn(N) -> N2,
579 ) -> SiblingSubgraph<N2> {
580 let nodes = self.nodes.iter().map(|&n| node_map(n)).collect_vec();
581 let inputs = self
582 .inputs
583 .iter()
584 .map(|part| part.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
585 .collect_vec();
586 let outputs = self
587 .outputs
588 .iter()
589 .map(|&(n, p)| (node_map(n), p))
590 .collect_vec();
591 let function_calls = self
592 .function_calls
593 .iter()
594 .map(|calls| calls.iter().map(|&(n, p)| (node_map(n), p)).collect_vec())
595 .collect_vec();
596 SiblingSubgraph {
597 nodes,
598 inputs,
599 outputs,
600 function_calls,
601 }
602 }
603
604 pub fn create_simple_replacement(
620 &self,
621 hugr: &impl HugrView<Node = N>,
622 replacement: Hugr,
623 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
624 let rep_root = replacement.entrypoint();
625 let dfg_optype = replacement.get_optype(rep_root);
626 if !OpTag::DataflowParent.is_superset(dfg_optype.tag()) {
627 return Err(InvalidReplacement::InvalidDataflowGraph {
628 node: rep_root,
629 op: Box::new(dfg_optype.clone()),
630 });
631 }
632 let [rep_input, rep_output] = replacement
633 .get_io(rep_root)
634 .expect("DFG root in the replacement does not have input and output nodes.");
635
636 let state_order_at_input = replacement
639 .get_optype(rep_input)
640 .other_output_port()
641 .is_some_and(|p| replacement.is_linked(rep_input, p));
642 let state_order_at_output = replacement
643 .get_optype(rep_output)
644 .other_input_port()
645 .is_some_and(|p| replacement.is_linked(rep_output, p));
646 if state_order_at_input || state_order_at_output {
647 unimplemented!("Found state order edges in replacement graph");
648 }
649
650 SimpleReplacement::try_new(self.clone(), hugr, replacement)
651 }
652
653 pub fn extract_subgraph(
658 &self,
659 hugr: &impl HugrView<Node = N>,
660 name: impl Into<String>,
661 ) -> Hugr {
662 let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
663 let mut extracted = mem::take(builder.hugr_mut());
666 let node_map = extracted.insert_subgraph(extracted.entrypoint(), hugr, self);
667
668 let [inp, out] = extracted.get_io(extracted.entrypoint()).unwrap();
670 let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
671 let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
672 let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
673
674 for (inp_port, repl_ports) in inputs {
675 for (repl_node, repl_port) in repl_ports {
676 connections.push((inp, inp_port, node_map[repl_node], *repl_port));
677 }
678 }
679 for (out_port, (repl_node, repl_port)) in outputs {
680 connections.push((node_map[repl_node], *repl_port, out, out_port));
681 }
682
683 for (src, src_port, dst, dst_port) in connections {
684 extracted.connect(src, src_port, dst, dst_port);
685 }
686
687 extracted
688 }
689
690 pub fn set_outgoing_ports(
700 &mut self,
701 ports: OutgoingPorts<N>,
702 host: &impl HugrView<Node = N>,
703 ) -> Result<(), InvalidOutputPorts<N>> {
704 let old_boundary: HashSet<_> = iter_outgoing(&self.outputs).collect();
705
706 if let Some((node, port)) =
708 iter_outgoing(&ports).find(|(n, p)| !old_boundary.contains(&(*n, *p)))
709 {
710 return Err(InvalidOutputPorts::UnknownOutput { port, node });
711 }
712
713 if !has_unique_linear_ports(host, &ports) {
715 return Err(InvalidOutputPorts::NonUniqueLinear);
716 }
717
718 self.outputs = ports;
719 Ok(())
720 }
721}
722
723#[derive(Default)]
725pub enum ValidationMode<'t, 'h, H: HugrView> {
726 WithChecker(&'t TopoConvexChecker<'h, H>),
728 #[default]
730 CheckConvexity,
731 SkipConvexity,
733}
734
735fn make_pg_subgraph<'h, H: HugrView>(
736 hugr: &'h H,
737 inputs: &IncomingPorts<H::Node>,
738 outputs: &OutgoingPorts<H::Node>,
739) -> (
740 portgraph::view::Subgraph<CheckerRegion<'h, H>>,
741 H::RegionPortgraphNodes,
742) {
743 let mut io_nodes = inputs
746 .iter()
747 .flat_map(|inps| inps.iter().map(|(n, _)| *n))
748 .chain(outputs.iter().map(|(n, _)| *n));
749 let hugr_region = io_nodes
750 .next()
751 .and_then(|n| hugr.get_parent(n))
752 .unwrap_or(hugr.entrypoint());
753
754 let (region, node_map) = hugr.region_portgraph(hugr_region);
755
756 let boundary = make_boundary::<H>(®ion, &node_map, inputs, outputs);
759 (
760 portgraph::view::Subgraph::new_subgraph(region, boundary),
761 node_map,
762 )
763}
764
765fn get_boundary_from_nodes<N: HugrNode>(
769 hugr: &impl HugrView<Node = N>,
770 nodes: &mut Vec<N>,
771) -> (IncomingPorts<N>, OutgoingPorts<N>) {
772 let mut nodes_set = FxHashSet::default();
775 nodes.retain(|&n| nodes_set.insert(n));
776
777 let incoming_edges = nodes
778 .iter()
779 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
780 let outgoing_edges = nodes
781 .iter()
782 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
783
784 let inputs = incoming_edges
785 .filter(|&(n, p)| {
786 if !hugr.is_linked(n, p) {
787 return false;
788 }
789 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
790 !nodes_set.contains(&out_n)
791 })
792 .map(|p| vec![p])
794 .collect_vec();
795 let outputs = outgoing_edges
796 .filter(|&(n, p)| {
797 hugr.linked_ports(n, p)
798 .any(|(n1, _)| !nodes_set.contains(&n1))
799 })
800 .collect_vec();
801
802 (inputs, outputs)
803}
804
805fn drain_function_calls<N: HugrNode, H: HugrView<Node = N>>(
807 inputs: &mut IncomingPorts<N>,
808 hugr: &H,
809) -> IncomingPorts<N> {
810 let mut function_calls = Vec::new();
811 inputs.retain_mut(|calls| {
812 let Some(&(n, p)) = calls.first() else {
813 return true;
814 };
815 let op = hugr.get_optype(n);
816 if op.static_input_port() == Some(p)
817 && op
818 .static_input()
819 .expect("static input exists")
820 .is_function()
821 {
822 function_calls.extend(mem::take(calls));
823 false
824 } else {
825 true
826 }
827 });
828
829 group_into_function_calls(function_calls.into_iter().map(|(n, p)| (n, p.into())), hugr)
830 .expect("valid function calls")
831}
832
833fn group_into_function_calls<N: HugrNode>(
838 ports: impl IntoIterator<Item = (N, Port)>,
839 hugr: &impl HugrView<Node = N>,
840) -> Result<Vec<Vec<(N, IncomingPort)>>, InvalidSubgraph<N>> {
841 let incoming_ports: Vec<_> = ports
842 .into_iter()
843 .map(|(n, p)| {
844 let p = p
845 .as_incoming()
846 .map_err(|_| InvalidSubgraph::UnsupportedEdgeKind(n, p))?;
847 let op = hugr.get_optype(n);
848 if Some(p) != op.static_input_port() {
849 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
850 }
851 if !op
852 .static_input()
853 .expect("static input exists")
854 .is_function()
855 {
856 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
857 }
858 Ok::<_, InvalidSubgraph<N>>((n, p))
859 })
860 .try_collect()?;
861 let grouped_non_local = incoming_ports
862 .into_iter()
863 .into_group_map_by(|&(n, p)| hugr.single_linked_output(n, p).expect("valid dfg wire"));
864 Ok(grouped_non_local
865 .into_iter()
866 .sorted_unstable_by(|(n1, _), (n2, _)| n1.cmp(n2))
867 .map(|(_, v)| v)
868 .collect())
869}
870
871fn get_non_local_edges<'a, N: HugrNode>(
875 nodes: &'a [N],
876 hugr: &'a impl HugrView<Node = N>,
877) -> impl Iterator<Item = (N, Port)> + 'a {
878 let parent = hugr.get_parent(nodes[0]);
879 let is_non_local = move |n, p| {
880 hugr.linked_ports(n, p)
881 .any(|(n, _)| hugr.get_parent(n) != parent)
882 };
883 nodes
884 .iter()
885 .flat_map(move |&n| hugr.all_node_ports(n).map(move |p| (n, p)))
886 .filter(move |&(n, p)| is_non_local(n, p))
887}
888
889fn iter_incoming<N: HugrNode>(
891 inputs: &IncomingPorts<N>,
892) -> impl Iterator<Item = (N, IncomingPort)> + '_ {
893 inputs.iter().flat_map(|part| part.iter().copied())
894}
895
896fn iter_outgoing<N: HugrNode>(
898 outputs: &OutgoingPorts<N>,
899) -> impl Iterator<Item = (N, OutgoingPort)> + '_ {
900 outputs.iter().copied()
901}
902
903fn iter_io<'a, N: HugrNode>(
905 inputs: &'a IncomingPorts<N>,
906 outputs: &'a OutgoingPorts<N>,
907) -> impl Iterator<Item = (N, Port)> + 'a {
908 iter_incoming(inputs)
909 .map(|(n, p)| (n, Port::from(p)))
910 .chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
911}
912
913fn pick_parent<'a, N: HugrNode>(
923 hugr: &impl HugrView<Node = N>,
924 inputs: &'a IncomingPorts<N>,
925 outputs: &'a OutgoingPorts<N>,
926) -> Result<N, InvalidSubgraph<N>> {
927 let Some(node) = iter_incoming(inputs)
929 .map(|(n, _)| n)
930 .chain(iter_outgoing(outputs).map(|(n, _)| n))
931 .next()
932 else {
933 return Err(InvalidSubgraph::EmptySubgraph);
934 };
935 let Some(parent) = hugr.get_parent(node) else {
936 return Err(InvalidSubgraph::OrphanNode { orphan: node });
937 };
938
939 Ok(parent)
940}
941
942fn make_boundary<'a, H: HugrView>(
943 region: &impl LinkView<PortOffsetBase = u32>,
944 node_map: &H::RegionPortgraphNodes,
945 inputs: &'a IncomingPorts<H::Node>,
946 outputs: &'a OutgoingPorts<H::Node>,
947) -> Boundary {
948 let to_pg_index = |n: H::Node, p: Port| {
949 region
950 .port_index(node_map.to_portgraph(n), p.pg_offset())
951 .unwrap()
952 };
953 Boundary::new(
954 iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
955 iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
956 )
957}
958
959type CheckerRegion<'g, Base> =
960 portgraph::view::FlatRegion<'g, <Base as HugrInternals>::RegionPortgraph<'g>>;
961
962pub type TopoConvexChecker<'g, Base> =
969 ConvexChecker<'g, Base, portgraph::algorithms::TopoConvexChecker<CheckerRegion<'g, Base>>>;
970
971pub type LineConvexChecker<'g, Base> =
979 ConvexChecker<'g, Base, portgraph::algorithms::LineConvexChecker<CheckerRegion<'g, Base>>>;
980
981pub struct ConvexChecker<'g, Base: HugrView, Checker> {
990 base: &'g Base,
992 region_parent: Base::Node,
994 checker: OnceCell<(Checker, Base::RegionPortgraphNodes)>,
997}
998
999impl<'g, Base: HugrView, Checker: Clone> Clone for ConvexChecker<'g, Base, Checker> {
1000 fn clone(&self) -> Self {
1001 Self {
1002 base: self.base,
1003 region_parent: self.region_parent,
1004 checker: self.checker.clone(),
1005 }
1006 }
1007}
1008
1009impl<'g, Base: HugrView, Checker> ConvexChecker<'g, Base, Checker> {
1010 pub fn new(base: &'g Base, region_parent: Base::Node) -> Self {
1012 Self {
1013 base,
1014 region_parent,
1015 checker: OnceCell::new(),
1016 }
1017 }
1018
1019 #[inline(always)]
1021 pub fn from_entrypoint(base: &'g Base) -> Self {
1022 let region_parent = base.entrypoint();
1023 Self::new(base, region_parent)
1024 }
1025
1026 pub fn hugr(&self) -> &'g Base {
1028 self.base
1029 }
1030}
1031
1032impl<'g, Base, Checker> ConvexChecker<'g, Base, Checker>
1033where
1034 Base: HugrView,
1035 Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
1036{
1037 fn init_checker(&self) -> &(Checker, Base::RegionPortgraphNodes) {
1039 self.checker.get_or_init(|| {
1040 let (region, node_map) = self.base.region_portgraph(self.region_parent);
1041 let checker = Checker::new_convex_checker(region);
1042 (checker, node_map)
1043 })
1044 }
1045
1046 #[expect(dead_code)]
1048 fn get_node_map(&self) -> &Base::RegionPortgraphNodes {
1049 &self.init_checker().1
1050 }
1051
1052 fn get_checker(&self) -> &Checker {
1054 &self.init_checker().0
1055 }
1056}
1057
1058impl<'g, Base, Checker> portgraph::algorithms::ConvexChecker for ConvexChecker<'g, Base, Checker>
1059where
1060 Base: HugrView,
1061 Checker: CreateConvexChecker<CheckerRegion<'g, Base>>,
1062{
1063 fn is_convex(
1064 &self,
1065 nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
1066 inputs: impl IntoIterator<Item = portgraph::PortIndex>,
1067 outputs: impl IntoIterator<Item = portgraph::PortIndex>,
1068 ) -> bool {
1069 let mut nodes = nodes.into_iter().multipeek();
1070 if nodes.peek().is_none() || nodes.peek().is_none() {
1073 return true;
1074 }
1075 self.get_checker().is_convex(nodes, inputs, outputs)
1076 }
1077}
1078
1079impl<'g, Base: HugrView> LineConvexChecker<'g, Base> {
1080 pub fn get_intervals_from_nodes(
1082 &self,
1083 nodes: impl IntoIterator<Item = Base::Node>,
1084 ) -> Option<LineIntervals> {
1085 let (checker, node_map) = self.init_checker();
1086 let nodes = nodes
1087 .into_iter()
1088 .map(|n| node_map.to_portgraph(n))
1089 .collect_vec();
1090 checker.get_intervals_from_nodes(nodes)
1091 }
1092
1093 pub fn get_intervals_from_boundary_ports(
1099 &self,
1100 ports: impl IntoIterator<Item = (Base::Node, Port)>,
1101 ) -> Option<LineIntervals> {
1102 let (checker, node_map) = self.init_checker();
1103 let ports = ports
1104 .into_iter()
1105 .map(|(n, p)| {
1106 let node = node_map.to_portgraph(n);
1107 checker
1108 .graph()
1109 .port_index(node, p.pg_offset())
1110 .expect("valid port")
1111 })
1112 .collect_vec();
1113 checker.get_intervals_from_boundary_ports(ports)
1114 }
1115
1116 pub fn nodes_in_intervals<'a>(
1118 &'a self,
1119 intervals: &'a LineIntervals,
1120 ) -> impl Iterator<Item = Base::Node> + 'a {
1121 let (checker, node_map) = self.init_checker();
1122 checker
1123 .nodes_in_intervals(intervals)
1124 .map(|pg_node| node_map.from_portgraph(pg_node))
1125 }
1126
1127 pub fn lines_at_port(&self, node: Base::Node, port: impl Into<Port>) -> &[LineIndex] {
1129 let (checker, node_map) = self.init_checker();
1130 let port = checker
1131 .graph()
1132 .port_index(node_map.to_portgraph(node), port.into().pg_offset())
1133 .expect("valid port");
1134 checker.lines_at_port(port)
1135 }
1136
1137 pub fn try_extend_intervals(&self, intervals: &mut LineIntervals, node: Base::Node) -> bool {
1145 let (checker, node_map) = self.init_checker();
1146 let node = node_map.to_portgraph(node);
1147 checker.try_extend_intervals(intervals, node)
1148 }
1149
1150 pub fn get_position(&self, node: Base::Node) -> Position {
1152 let (checker, node_map) = self.init_checker();
1153 let node = node_map.to_portgraph(node);
1154 checker.get_position(node)
1155 }
1156}
1157
1158fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(
1162 hugr: &H,
1163 ports: &[(H::Node, P)],
1164) -> Option<Type> {
1165 let &(n, p) = ports.first()?;
1166 let edge_t = hugr.signature(n)?.port_type(p)?.clone();
1167 ports
1168 .iter()
1169 .all(|&(n, p)| {
1170 hugr.signature(n)
1171 .is_some_and(|s| s.port_type(p) == Some(&edge_t))
1172 })
1173 .then_some(edge_t)
1174}
1175
1176fn validate_subgraph<H: HugrView>(
1183 hugr: &H,
1184 nodes: &[H::Node],
1185 inputs: &IncomingPorts<H::Node>,
1186 outputs: &OutgoingPorts<H::Node>,
1187 function_calls: &IncomingPorts<H::Node>,
1188) -> Result<(), InvalidSubgraph<H::Node>> {
1189 let node_set = nodes.iter().copied().collect::<HashSet<_>>();
1191
1192 if nodes.is_empty() {
1194 return Err(InvalidSubgraph::EmptySubgraph);
1195 }
1196 if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() {
1198 let first_node = nodes[0];
1199 let first_parent = hugr
1200 .get_parent(first_node)
1201 .ok_or(InvalidSubgraph::OrphanNode { orphan: first_node })?;
1202 let other_node = *nodes
1203 .iter()
1204 .skip(1)
1205 .find(|&&n| hugr.get_parent(n) != Some(first_parent))
1206 .unwrap();
1207 let other_parent = hugr
1208 .get_parent(other_node)
1209 .ok_or(InvalidSubgraph::OrphanNode { orphan: other_node })?;
1210 return Err(InvalidSubgraph::NoSharedParent {
1211 first_node,
1212 first_parent,
1213 other_node,
1214 other_parent,
1215 });
1216 }
1217
1218 if let Some((n, p)) = iter_io(inputs, outputs).find(|&(n, p)| is_non_value_edge(hugr, n, p)) {
1220 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p));
1221 }
1222
1223 let boundary_ports = iter_io(inputs, outputs).collect_vec();
1224 if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
1226 Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
1227 }
1228 if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
1230 hugr.linked_ports(n, p)
1231 .all(|(n1, _)| node_set.contains(&n1))
1232 }) {
1233 Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
1234 }
1235
1236 let mut must_be_inputs = nodes
1239 .iter()
1240 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)))
1241 .filter(|&(n, p)| {
1242 hugr.linked_ports(n, p)
1243 .any(|(n1, _)| !node_set.contains(&n1))
1244 });
1245 if !must_be_inputs.all(|(n, p)| {
1246 let mut all_inputs = inputs.iter().chain(function_calls);
1247 all_inputs.any(|nps| nps.contains(&(n, p)))
1248 }) {
1249 return Err(InvalidSubgraph::NotConvex);
1250 }
1251 if nodes.iter().any(|&n| {
1254 hugr.node_outputs(n).any(|p| {
1255 hugr.linked_ports(n, p)
1256 .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
1257 })
1258 }) {
1259 return Err(InvalidSubgraph::NotConvex);
1260 }
1261
1262 if !inputs.iter().flatten().all_unique() {
1264 return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
1265 }
1266
1267 for inp in inputs {
1271 let &(in_node, in_port) = inp.first().ok_or(InvalidSubgraphBoundary::EmptyPartition)?;
1272 let exp_output_node_port = hugr
1273 .single_linked_output(in_node, in_port)
1274 .expect("valid dfg wire");
1275 if let Some(output_node_port) = inp
1276 .iter()
1277 .map(|&(in_node, in_port)| {
1278 hugr.single_linked_output(in_node, in_port)
1279 .expect("valid dfg wire")
1280 })
1281 .find(|&p| p != exp_output_node_port)
1282 {
1283 return Err(InvalidSubgraphBoundary::MismatchedOutputPort(
1284 (in_node, in_port),
1285 exp_output_node_port,
1286 output_node_port,
1287 )
1288 .into());
1289 }
1290 }
1291
1292 if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
1295 let Some(edge_t) = get_edge_type(hugr, ports) else {
1296 return true;
1297 };
1298 let require_copy = ports.len() > 1;
1299 require_copy && !edge_t.copyable()
1300 }) {
1301 Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
1302 }
1303
1304 for calls in function_calls {
1307 if !calls
1308 .iter()
1309 .map(|&(n, p)| hugr.single_linked_output(n, p))
1310 .all_equal()
1311 {
1312 let (n, p) = calls[0];
1313 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
1314 }
1315 for &(n, p) in calls {
1316 let op = hugr.get_optype(n);
1317 if op.static_input_port() != Some(p) {
1318 return Err(InvalidSubgraph::UnsupportedEdgeKind(n, p.into()));
1319 }
1320 }
1321 }
1322
1323 Ok(())
1324}
1325
1326#[allow(clippy::type_complexity)]
1327fn get_input_output_ports<H: HugrView>(
1328 hugr: &H,
1329) -> Result<(IncomingPorts<H::Node>, OutgoingPorts<H::Node>), InvalidSubgraph<H::Node>> {
1330 let [inp, out] = hugr
1331 .get_io(HugrView::entrypoint(&hugr))
1332 .expect("invalid DFG");
1333 if let Some(p) = hugr
1334 .node_outputs(inp)
1335 .find(|&p| is_non_value_edge(hugr, inp, p.into()))
1336 {
1337 return Err(InvalidSubgraph::UnsupportedEdgeKind(inp, p.into()));
1338 }
1339 if let Some(p) = hugr
1340 .node_inputs(out)
1341 .find(|&p| is_non_value_edge(hugr, out, p.into()))
1342 {
1343 return Err(InvalidSubgraph::UnsupportedEdgeKind(out, p.into()));
1344 }
1345
1346 let dfg_inputs = HugrView::get_optype(&hugr, inp)
1347 .as_input()
1348 .unwrap()
1349 .signature()
1350 .output_ports();
1351 let dfg_outputs = HugrView::get_optype(&hugr, out)
1352 .as_output()
1353 .unwrap()
1354 .signature()
1355 .input_ports();
1356
1357 let inputs = dfg_inputs
1360 .into_iter()
1361 .map(|p| {
1362 hugr.linked_inputs(inp, p)
1363 .filter(|&(n, _)| n != out)
1364 .collect_vec()
1365 })
1366 .filter(|v| !v.is_empty())
1367 .collect();
1368 let outputs = dfg_outputs
1371 .into_iter()
1372 .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
1373 .collect();
1374 Ok((inputs, outputs))
1375}
1376
1377fn is_non_value_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
1379 let op = hugr.get_optype(node);
1380 let is_other = op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port);
1381 let is_static = op.static_port(port.direction()) == Some(port) && hugr.is_linked(node, port);
1382 is_other || is_static
1383}
1384
1385impl<'a, 'c, G: HugrView, Checker: Clone> From<&'a ConvexChecker<'c, G, Checker>>
1386 for std::borrow::Cow<'a, ConvexChecker<'c, G, Checker>>
1387{
1388 fn from(value: &'a ConvexChecker<'c, G, Checker>) -> Self {
1389 Self::Borrowed(value)
1390 }
1391}
1392
1393#[derive(Debug, Clone, PartialEq, Error)]
1395#[non_exhaustive]
1396pub enum InvalidReplacement {
1397 #[error("The root of the replacement {node} is a {}, but only dataflow parents are supported.", op.name())]
1399 InvalidDataflowGraph {
1400 node: Node,
1402 op: Box<OpType>,
1404 },
1405 #[error(
1407 "Replacement graph type mismatch. Expected {expected}, got {}.",
1408 actual.clone().map_or("none".to_string(), |t| t.to_string()))
1409 ]
1410 InvalidSignature {
1411 expected: Box<Signature>,
1413 actual: Option<Box<Signature>>,
1415 },
1416 #[error("SiblingSubgraph is not convex.")]
1418 NonConvexSubgraph,
1419}
1420
1421#[derive(Debug, Clone, PartialEq, Eq, Error)]
1423#[non_exhaustive]
1424pub enum InvalidSubgraph<N: HugrNode = Node> {
1425 #[error("The subgraph is not convex.")]
1427 NotConvex,
1428 #[error(
1430 "Not a sibling subgraph. {first_node} has parent {first_parent}, but {other_node} has parent {other_parent}."
1431 )]
1432 NoSharedParent {
1433 first_node: N,
1435 first_parent: N,
1437 other_node: N,
1439 other_parent: N,
1441 },
1442 #[error("Not a sibling subgraph. {orphan} has no parent")]
1444 OrphanNode {
1445 orphan: N,
1447 },
1448 #[error("Empty subgraphs are not supported.")]
1450 EmptySubgraph,
1451 #[error("Invalid boundary port.")]
1453 InvalidBoundary(#[from] InvalidSubgraphBoundary<N>),
1454 #[error("SiblingSubgraphs may only be defined on dataflow regions.")]
1456 NonDataflowRegion,
1457 #[error("The subgraphs induced by the nodes and the boundary do not match.")]
1459 InvalidNodeSet,
1460 #[error("Unsupported edge kind at ({_0}, {_1:?}).")]
1462 UnsupportedEdgeKind(N, Port),
1463}
1464
1465#[derive(Debug, Clone, PartialEq, Eq, Error)]
1467#[non_exhaustive]
1468pub enum InvalidSubgraphBoundary<N: HugrNode = Node> {
1469 #[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
1471 PortNodeNotInSet(N, Port),
1472 #[error(
1474 "(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph."
1475 )]
1476 DisconnectedBoundaryPort(N, Port),
1477 #[error("A port in the input boundary is used multiple times.")]
1479 NonUniqueInput,
1480 #[error("A partition in the input boundary is empty.")]
1482 EmptyPartition,
1483 #[error("expected port {0:?} to be linked to {1:?}, but is linked to {2:?}.")]
1486 MismatchedOutputPort((N, IncomingPort), (N, OutgoingPort), (N, OutgoingPort)),
1487 #[error("The partition {0} in the input boundary has ports with different types.")]
1489 MismatchedTypes(usize),
1490}
1491
1492#[derive(Debug, Clone, PartialEq, Eq, Error)]
1494#[error("Invalid output ports: {0:?}")]
1495pub enum InvalidOutputPorts<N: HugrNode = Node> {
1496 #[error("{port} in {node} was not part of the original boundary.")]
1498 UnknownOutput {
1499 port: OutgoingPort,
1501 node: N,
1503 },
1504 #[error("Linear ports must appear exactly once.")]
1506 NonUniqueLinear,
1507}
1508
1509fn has_unique_linear_ports<H: HugrView>(host: &H, ports: &OutgoingPorts<H::Node>) -> bool {
1511 let linear_ports: Vec<_> = ports
1512 .iter()
1513 .filter(|&&(n, p)| {
1514 host.get_optype(n)
1515 .port_kind(p)
1516 .is_some_and(|pk| pk.is_linear())
1517 })
1518 .collect();
1519 let unique_ports: HashSet<_> = linear_ports.iter().collect();
1520 linear_ports.len() == unique_ports.len()
1521}
1522
1523#[cfg(test)]
1524mod tests {
1525 use std::collections::BTreeSet;
1526
1527 use cool_asserts::assert_matches;
1528 use rstest::{fixture, rstest};
1529
1530 use crate::builder::{endo_sig, inout_sig};
1531 use crate::extension::prelude::{MakeTuple, UnpackTuple};
1532 use crate::hugr::Patch;
1533 use crate::ops::Const;
1534 use crate::ops::handle::DataflowParentID;
1535 use crate::std_extensions::arithmetic::float_types::ConstF64;
1536 use crate::std_extensions::logic::LogicOp;
1537 use crate::type_row;
1538 use crate::utils::test_quantum_extension::{cx_gate, rz_f64};
1539 use crate::{
1540 builder::{
1541 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
1542 ModuleBuilder,
1543 },
1544 extension::prelude::{bool_t, qb_t},
1545 ops::handle::{DfgID, FuncID, NodeHandle},
1546 std_extensions::logic::test::and_op,
1547 };
1548
1549 use super::*;
1550
1551 impl<N: HugrNode> SiblingSubgraph<N> {
1552 fn from_sibling_graph(
1557 hugr: &impl HugrView<Node = N>,
1558 parent: N,
1559 ) -> Result<Self, InvalidSubgraph<N>> {
1560 let nodes = hugr.children(parent).collect_vec();
1561 if nodes.is_empty() {
1562 Err(InvalidSubgraph::EmptySubgraph)
1563 } else {
1564 Ok(Self {
1565 nodes,
1566 inputs: Vec::new(),
1567 outputs: Vec::new(),
1568 function_calls: Vec::new(),
1569 })
1570 }
1571 }
1572 }
1573
1574 fn build_hugr() -> Result<(Hugr, Node), BuildError> {
1578 let mut mod_builder = ModuleBuilder::new();
1579 let func = mod_builder.declare(
1580 "test",
1581 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).into(),
1582 )?;
1583 let func_id = {
1584 let mut dfg = mod_builder.define_declaration(&func)?;
1585 let [w0, w1, w2] = dfg.input_wires_arr();
1586 let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
1587 let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
1588 let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
1589 dfg.finish_with_outputs([w0, w1, w2])?
1590 };
1591 let hugr = mod_builder
1592 .finish_hugr()
1593 .map_err(|e| -> BuildError { e.into() })?;
1594 Ok((hugr, func_id.node()))
1595 }
1596
1597 fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
1599 let mut mod_builder = ModuleBuilder::new();
1600 let func = mod_builder.declare("test", Signature::new_endo(vec![bool_t()]).into())?;
1601 let func_id = {
1602 let mut dfg = mod_builder.define_declaration(&func)?;
1603 let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?;
1604 let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?;
1605 let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?;
1606 dfg.finish_with_outputs(outs3.outputs())?
1607 };
1608 let hugr = mod_builder
1609 .finish_hugr()
1610 .map_err(|e| -> BuildError { e.into() })?;
1611 Ok((hugr, func_id.node()))
1612 }
1613
1614 fn build_multiport_hugr() -> Result<(Hugr, Node), BuildError> {
1616 let mut mod_builder = ModuleBuilder::new();
1617 let func = mod_builder.declare(
1618 "test",
1619 Signature::new(bool_t(), vec![bool_t(), bool_t()]).into(),
1620 )?;
1621 let func_id = {
1622 let mut dfg = mod_builder.define_declaration(&func)?;
1623 let [b0] = dfg.input_wires_arr();
1624 let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr();
1625 let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr();
1626 dfg.finish_with_outputs([b1, b2])?
1627 };
1628 let hugr = mod_builder
1629 .finish_hugr()
1630 .map_err(|e| -> BuildError { e.into() })?;
1631 Ok((hugr, func_id.node()))
1632 }
1633
1634 fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
1636 let mut mod_builder = ModuleBuilder::new();
1637 let func = mod_builder.declare("test", Signature::new_endo(bool_t()).into())?;
1638 let func_id = {
1639 let mut dfg = mod_builder.define_declaration(&func)?;
1640 let in_wire = dfg.input_wires().exactly_one().unwrap();
1641 let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?;
1642 dfg.finish_with_outputs(outs.outputs())?
1643 };
1644 let hugr = mod_builder
1645 .finish_hugr()
1646 .map_err(|e| -> BuildError { e.into() })?;
1647 Ok((hugr, func_id.node()))
1648 }
1649
1650 #[test]
1651 fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
1652 let (mut hugr, func_root) = build_hugr().unwrap();
1653 let func = hugr.with_entrypoint(func_root);
1654 let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
1655 assert!(sub.validate(&func, Default::default()).is_ok());
1656
1657 let empty_dfg = {
1658 let builder =
1659 DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap();
1660 let inputs = builder.input_wires();
1661 builder.finish_hugr_with_outputs(inputs).unwrap()
1662 };
1663
1664 let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();
1665
1666 assert_eq!(rep.subgraph().nodes().len(), 4);
1667
1668 assert_eq!(hugr.num_nodes(), 8); hugr.apply_patch(rep).unwrap();
1670 assert_eq!(hugr.num_nodes(), 4); Ok(())
1673 }
1674
1675 #[test]
1677 fn construct_load_const_subgraph() -> Result<(), InvalidSubgraph> {
1678 let (hugr, func_root) = build_hugr().unwrap();
1679
1680 let const_node = hugr
1681 .children(func_root)
1682 .find(|&n| hugr.get_optype(n).is_const())
1683 .unwrap();
1684 let load_const_node = hugr
1685 .children(func_root)
1686 .find(|&n| hugr.get_optype(n).is_load_constant())
1687 .unwrap();
1688 let nodes: BTreeSet<_> = BTreeSet::from_iter([const_node, load_const_node]);
1689
1690 let sub = SiblingSubgraph::try_from_nodes(vec![const_node, load_const_node], &hugr)?;
1691
1692 let subgraph_nodes: BTreeSet<_> = sub.nodes().iter().copied().collect();
1693 assert_eq!(subgraph_nodes, nodes);
1694
1695 Ok(())
1696 }
1697
1698 #[test]
1699 fn test_signature() -> Result<(), InvalidSubgraph> {
1700 let (hugr, dfg) = build_hugr().unwrap();
1701 let func = hugr.with_entrypoint(dfg);
1702 let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)?;
1703 assert!(sub.validate(&func, Default::default()).is_ok());
1704 assert_eq!(
1705 sub.signature(&func),
1706 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
1707 );
1708 Ok(())
1709 }
1710
1711 #[test]
1712 fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> {
1713 let (hugr, dfg) = build_hugr().unwrap();
1714 let func = hugr.with_entrypoint(dfg);
1715 let sub = SiblingSubgraph::from_sibling_graph(&hugr, dfg)?;
1716
1717 let empty_dfg = {
1718 let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
1719 let inputs = builder.input_wires();
1720 builder.finish_hugr_with_outputs(inputs).unwrap()
1721 };
1722
1723 assert_matches!(
1724 sub.create_simple_replacement(&func, empty_dfg).unwrap_err(),
1725 InvalidReplacement::InvalidSignature { .. }
1726 );
1727 Ok(())
1728 }
1729
1730 #[test]
1731 fn convex_subgraph() {
1732 let (hugr, func_root) = build_hugr().unwrap();
1733 let func = hugr.with_entrypoint(func_root);
1734 assert_eq!(
1735 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func)
1736 .unwrap()
1737 .nodes()
1738 .len(),
1739 4
1740 );
1741 }
1742
1743 #[test]
1744 fn convex_subgraph_2() {
1745 let (hugr, func_root) = build_hugr().unwrap();
1746 let [inp, out] = hugr.get_io(func_root).unwrap();
1747 let func = hugr.with_entrypoint(func_root);
1748 SiblingSubgraph::try_new(
1750 hugr.node_outputs(inp)
1751 .take(2)
1752 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
1753 .filter(|ps| !ps.is_empty())
1754 .collect(),
1755 hugr.node_inputs(out)
1756 .take(2)
1757 .filter_map(|p| hugr.single_linked_output(out, p))
1758 .collect(),
1759 &func,
1760 )
1761 .unwrap();
1762 }
1763
1764 #[test]
1765 fn degen_boundary() {
1766 let (hugr, func_root) = build_hugr().unwrap();
1767 let func = hugr.with_entrypoint(func_root);
1768 let [inp, _] = hugr.get_io(func_root).unwrap();
1769 let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
1770 assert_matches!(
1772 SiblingSubgraph::try_new(
1773 vec![
1774 hugr.linked_ports(inp, first_cx_edge)
1775 .map(|(n, p)| (n, p.as_incoming().unwrap()))
1776 .collect()
1777 ],
1778 vec![(inp, first_cx_edge)],
1779 &func,
1780 ),
1781 Err(InvalidSubgraph::InvalidBoundary(
1782 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1783 ))
1784 );
1785 }
1786
1787 #[test]
1788 fn non_convex_subgraph() {
1789 let (hugr, func_root) = build_3not_hugr().unwrap();
1790 let func = hugr.with_entrypoint(func_root);
1791 let [inp, _out] = hugr.get_io(func_root).unwrap();
1792 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1793 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
1794 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
1795 let not1_inp = hugr.node_inputs(not1).next().unwrap();
1796 let not1_out = hugr.node_outputs(not1).next().unwrap();
1797 let not3_inp = hugr.node_inputs(not3).next().unwrap();
1798 let not3_out = hugr.node_outputs(not3).next().unwrap();
1799 assert_matches!(
1800 SiblingSubgraph::try_new(
1801 vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
1802 vec![(not1, not1_out), (not3, not3_out)],
1803 &func
1804 ),
1805 Err(InvalidSubgraph::NotConvex)
1806 );
1807 }
1808
1809 #[test]
1812 fn convex_multiports() {
1813 let (hugr, func_root) = build_multiport_hugr().unwrap();
1814 let [inp, out] = hugr.get_io(func_root).unwrap();
1815 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1816 let not2 = hugr
1817 .output_neighbours(not1)
1818 .filter(|&n| n != out)
1819 .exactly_one()
1820 .ok()
1821 .unwrap();
1822
1823 let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
1824 assert_eq!(subgraph.nodes(), [not1, not2]);
1825 }
1826
1827 #[test]
1828 fn invalid_boundary() {
1829 let (hugr, func_root) = build_hugr().unwrap();
1830 let func = hugr.with_entrypoint(func_root);
1831 let [inp, out] = hugr.get_io(func_root).unwrap();
1832 let cx_edges_in = hugr.node_outputs(inp);
1833 let cx_edges_out = hugr.node_inputs(out);
1834 assert_matches!(
1836 SiblingSubgraph::try_new(
1837 cx_edges_out.map(|p| vec![(out, p)]).collect(),
1838 cx_edges_in.map(|p| (inp, p)).collect(),
1839 &func,
1840 ),
1841 Err(InvalidSubgraph::InvalidBoundary(
1842 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1843 ))
1844 );
1845 }
1846
1847 #[test]
1848 fn preserve_signature() {
1849 let (hugr, func_root) = build_hugr_classical().unwrap();
1850 let func_graph = hugr.with_entrypoint(func_root);
1851 let func =
1852 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
1853 let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
1854 assert_eq!(func_defn.signature(), &func.signature(&func_graph).into());
1855 }
1856
1857 #[test]
1858 fn extract_subgraph() {
1859 let (hugr, func_root) = build_hugr().unwrap();
1860 let func_graph = hugr.with_entrypoint(func_root);
1861 let subgraph =
1862 SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID<true>>(&func_graph).unwrap();
1863 let extracted = subgraph.extract_subgraph(&hugr, "region");
1864
1865 extracted.validate().unwrap();
1866 }
1867
1868 #[test]
1869 fn edge_both_output_and_copy() {
1870 let one_bit = vec![bool_t()];
1872 let two_bit = vec![bool_t(), bool_t()];
1873
1874 let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap();
1875 let inw = builder.input_wires().exactly_one().unwrap();
1876 let outw1 = builder
1877 .add_dataflow_op(LogicOp::Not, [inw])
1878 .unwrap()
1879 .out_wire(0);
1880 let outw2 = builder
1881 .add_dataflow_op(and_op(), [inw, outw1])
1882 .unwrap()
1883 .outputs();
1884 let outw = [outw1].into_iter().chain(outw2);
1885 let h = builder.finish_hugr_with_outputs(outw).unwrap();
1886 let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&h).unwrap();
1887 assert_eq!(subg.nodes().len(), 2);
1888 }
1889
1890 #[test]
1891 fn test_unconnected() {
1892 let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap();
1894 let inw = b.input_wires().exactly_one().unwrap();
1895 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1896 let mut h = b.finish_hugr_with_outputs([]).unwrap();
1898
1899 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1900
1901 assert_eq!(subg.nodes().len(), 1);
1902 let replacement = {
1904 let mut rep_b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap();
1905 let inw = rep_b.input_wires().exactly_one().unwrap();
1906
1907 let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1908
1909 rep_b.finish_hugr_with_outputs(not_n.outputs()).unwrap()
1910 };
1911 let rep = subg.create_simple_replacement(&h, replacement).unwrap();
1912 rep.apply(&mut h).unwrap();
1913 }
1914
1915 #[test]
1918 fn single_node_subgraph() {
1919 let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap();
1921 let inw = b.input_wires().exactly_one().unwrap();
1922 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1923 let h = b.finish_hugr_with_outputs([]).unwrap();
1925
1926 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1929 assert_eq!(subg.nodes().len(), 1);
1930 assert_eq!(
1931 subg.signature(&h).io(),
1932 Signature::new(vec![bool_t()], vec![bool_t()]).io()
1933 );
1934
1935 let subg = SiblingSubgraph::try_from_nodes([not_n.node()], &h).unwrap();
1939 assert_eq!(subg.nodes().len(), 1);
1940 assert_eq!(
1941 subg.signature(&h).io(),
1942 Signature::new(vec![bool_t()], vec![]).io()
1943 );
1944 }
1945
1946 #[test]
1949 fn singleton_disconnected_subgraph() {
1950 let op = MakeTuple::new(type_row![]);
1952
1953 let mut b = DFGBuilder::new(Signature::new_endo(type_row![])).unwrap();
1954 let _mk_tuple_1 = b.add_dataflow_op(op.clone(), []).unwrap();
1955 let mk_tuple_2 = b.add_dataflow_op(op.clone(), []).unwrap();
1956 let _mk_tuple_3 = b.add_dataflow_op(op, []).unwrap();
1957 let h = b.finish_hugr_with_outputs([]).unwrap();
1959
1960 let subg = SiblingSubgraph::from_node(mk_tuple_2.node(), &h);
1963 assert_eq!(subg.nodes().len(), 1);
1964 assert_eq!(
1965 subg.signature(&h).io(),
1966 Signature::new(type_row![], vec![Type::new_tuple(type_row![])]).io()
1967 );
1968
1969 let subg = SiblingSubgraph::try_from_nodes([mk_tuple_2.node()], &h).unwrap();
1973 assert_eq!(subg.nodes().len(), 1);
1974 assert_eq!(
1975 subg.signature(&h).io(),
1976 Signature::new_endo(type_row![]).io()
1977 );
1978 }
1979
1980 #[test]
1982 fn partially_connected_subgraph() {
1983 let tuple_op = MakeTuple::new(type_row![]);
1985 let untuple_op = UnpackTuple::new(type_row![]);
1986 let tuple_t = Type::new_tuple(type_row![]);
1987
1988 let mut b = DFGBuilder::new(Signature::new(type_row![], vec![tuple_t.clone()])).unwrap();
1989 let mk_tuple_1 = b.add_dataflow_op(tuple_op.clone(), []).unwrap();
1990 let untuple_1 = b
1991 .add_dataflow_op(untuple_op.clone(), [mk_tuple_1.out_wire(0)])
1992 .unwrap();
1993 let mk_tuple_2 = b.add_dataflow_op(tuple_op.clone(), []).unwrap();
1994 let _mk_tuple_3 = b.add_dataflow_op(tuple_op, []).unwrap();
1995 let h = b
1997 .finish_hugr_with_outputs([mk_tuple_2.out_wire(0)])
1998 .unwrap();
1999
2000 let subgraph_nodes = [mk_tuple_1.node(), mk_tuple_2.node(), untuple_1.node()];
2001
2002 let subg = SiblingSubgraph::try_from_nodes(subgraph_nodes, &h).unwrap();
2004 assert_eq!(subg.nodes().len(), 3);
2005 assert_eq!(
2006 subg.signature(&h).io(),
2007 Signature::new(type_row![], vec![tuple_t]).io()
2008 );
2009 }
2010
2011 #[test]
2012 fn test_set_outgoing_ports() {
2013 let (hugr, func_root) = build_3not_hugr().unwrap();
2014 let [inp, out] = hugr.get_io(func_root).unwrap();
2015 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
2016 let not1_out = hugr.node_outputs(not1).next().unwrap();
2017
2018 let mut subgraph = SiblingSubgraph::from_node(not1, &hugr);
2020
2021 assert_eq!(subgraph.outgoing_ports().len(), 1);
2023
2024 let new_outputs = vec![(not1, not1_out), (not1, not1_out)];
2026 assert!(subgraph.set_outgoing_ports(new_outputs, &hugr).is_ok());
2027
2028 assert_eq!(subgraph.outgoing_ports().len(), 2);
2030
2031 let invalid_outputs = vec![(not1, not1_out), (out, 2.into())];
2033 assert!(matches!(
2034 subgraph.set_outgoing_ports(invalid_outputs, &hugr),
2035 Err(InvalidOutputPorts::UnknownOutput { .. })
2036 ));
2037
2038 assert_eq!(subgraph.outgoing_ports().len(), 2);
2040 }
2041
2042 #[test]
2043 fn test_set_outgoing_ports_linear() {
2044 let (hugr, func_root) = build_hugr().unwrap();
2045 let [inp, _out] = hugr.get_io(func_root).unwrap();
2046 let rz = hugr.output_neighbours(inp).nth(2).unwrap();
2047 let rz_out = hugr.node_outputs(rz).next().unwrap();
2048
2049 let mut subgraph = SiblingSubgraph::from_node(rz, &hugr);
2051
2052 assert_eq!(subgraph.outgoing_ports().len(), 1);
2054
2055 let new_outputs = vec![(rz, rz_out), (rz, rz_out)];
2058 assert!(matches!(
2059 subgraph.set_outgoing_ports(new_outputs, &hugr),
2060 Err(InvalidOutputPorts::NonUniqueLinear)
2061 ));
2062
2063 assert_eq!(subgraph.outgoing_ports().len(), 1);
2065 }
2066
2067 #[test]
2068 fn test_try_from_nodes_with_intervals() {
2069 let (hugr, func_root) = build_3not_hugr().unwrap();
2070 let line_checker = LineConvexChecker::new(&hugr, func_root);
2071 let [inp, _out] = hugr.get_io(func_root).unwrap();
2072 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
2073 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
2074
2075 let intervals = line_checker.get_intervals_from_nodes([not1, not2]).unwrap();
2076 let subgraph =
2077 SiblingSubgraph::try_from_nodes_with_intervals([not1, not2], &intervals, &line_checker)
2078 .unwrap();
2079 let exp_subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
2080
2081 assert_eq!(subgraph, exp_subgraph);
2082 assert_eq!(
2083 line_checker.nodes_in_intervals(&intervals).collect_vec(),
2084 [not1, not2]
2085 );
2086
2087 let intervals2 = line_checker
2088 .get_intervals_from_boundary_ports([
2089 (not1, IncomingPort::from(0).into()),
2090 (not2, OutgoingPort::from(0).into()),
2091 ])
2092 .unwrap();
2093 let subgraph2 = SiblingSubgraph::try_from_nodes_with_intervals(
2094 [not1, not2],
2095 &intervals2,
2096 &line_checker,
2097 )
2098 .unwrap();
2099 assert_eq!(subgraph2, exp_subgraph);
2100 }
2101
2102 #[test]
2103 fn test_validate() {
2104 let (hugr, func_root) = build_3not_hugr().unwrap();
2105 let func = hugr.with_entrypoint(func_root);
2106 let checker = TopoConvexChecker::new(&func, func_root);
2107 let [inp, _out] = hugr.get_io(func_root).unwrap();
2108 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
2109 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
2110 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
2111
2112 let sub = SiblingSubgraph::new_unchecked(
2114 vec![vec![(not1, 0.into())]],
2115 vec![(not2, 0.into())],
2116 vec![],
2117 vec![not1, not2],
2118 );
2119 assert_eq!(sub.validate(&func, ValidationMode::SkipConvexity), Ok(()));
2120 assert_eq!(sub.validate(&func, ValidationMode::CheckConvexity), Ok(()));
2121 assert_eq!(
2122 sub.validate(&func, ValidationMode::WithChecker(&checker)),
2123 Ok(())
2124 );
2125
2126 let sub = SiblingSubgraph::new_unchecked(
2128 vec![vec![(not1, 0.into())], vec![(not3, 0.into())]],
2129 vec![(not1, 0.into()), (not3, 0.into())],
2130 vec![],
2131 vec![not1, not3],
2132 );
2133 assert_eq!(sub.validate(&func, ValidationMode::SkipConvexity), Ok(()));
2134 assert_eq!(
2135 sub.validate(&func, ValidationMode::CheckConvexity),
2136 Err(InvalidSubgraph::NotConvex)
2137 );
2138 assert_eq!(
2139 sub.validate(&func, ValidationMode::WithChecker(&checker)),
2140 Err(InvalidSubgraph::NotConvex)
2141 );
2142
2143 let sub = SiblingSubgraph::new_unchecked(
2145 vec![vec![(not1, 0.into())]],
2146 vec![(not1, 0.into()), (not3, 0.into())],
2147 vec![],
2148 vec![not1, not3],
2149 );
2150 assert_eq!(
2151 sub.validate(&func, ValidationMode::SkipConvexity),
2152 Err(InvalidSubgraph::InvalidNodeSet)
2153 );
2154 }
2155
2156 #[fixture]
2157 pub(crate) fn hugr_call_subgraph() -> Hugr {
2158 let mut builder = ModuleBuilder::new();
2159 let decl_node = builder.declare("test", endo_sig(bool_t()).into()).unwrap();
2160 let mut main = builder.define_function("main", endo_sig(bool_t())).unwrap();
2161 let [bool] = main.input_wires_arr();
2162
2163 let [bool] = main
2164 .add_dataflow_op(LogicOp::Not, [bool])
2165 .unwrap()
2166 .outputs_arr();
2167
2168 let [bool] = main.call(&decl_node, &[], [bool]).unwrap().outputs_arr();
2170 let [bool] = main.call(&decl_node, &[], [bool]).unwrap().outputs_arr();
2171
2172 let main_def = main.finish_with_outputs([bool]).unwrap();
2173
2174 let mut hugr = builder.finish_hugr().unwrap();
2175 hugr.set_entrypoint(main_def.node());
2176 hugr
2177 }
2178
2179 #[rstest]
2180 fn test_call_subgraph_from_dfg(hugr_call_subgraph: Hugr) {
2181 let subg =
2182 SiblingSubgraph::try_new_dataflow_subgraph::<_, DataflowParentID>(&hugr_call_subgraph)
2183 .unwrap();
2184
2185 assert_eq!(subg.function_calls.len(), 1);
2186 assert_eq!(subg.function_calls[0].len(), 2);
2187 }
2188
2189 #[rstest]
2190 fn test_call_subgraph_from_nodes(hugr_call_subgraph: Hugr) {
2191 let call_nodes = hugr_call_subgraph
2192 .children(hugr_call_subgraph.entrypoint())
2193 .filter(|&n| hugr_call_subgraph.get_optype(n).is_call())
2194 .collect_vec();
2195
2196 let subg =
2197 SiblingSubgraph::try_from_nodes(call_nodes.clone(), &hugr_call_subgraph).unwrap();
2198 assert_eq!(subg.function_calls.len(), 1);
2199 assert_eq!(subg.function_calls[0].len(), 2);
2200
2201 let subg =
2202 SiblingSubgraph::try_from_nodes(call_nodes[0..1].to_owned(), &hugr_call_subgraph)
2203 .unwrap();
2204 assert_eq!(subg.function_calls.len(), 1);
2205 assert_eq!(subg.function_calls[0].len(), 1);
2206 }
2207
2208 #[rstest]
2209 fn test_call_subgraph_from_boundary(hugr_call_subgraph: Hugr) {
2210 let call_nodes = hugr_call_subgraph
2211 .children(hugr_call_subgraph.entrypoint())
2212 .filter(|&n| hugr_call_subgraph.get_optype(n).is_call())
2213 .collect_vec();
2214 let not_node = hugr_call_subgraph
2215 .children(hugr_call_subgraph.entrypoint())
2216 .filter(|&n| hugr_call_subgraph.get_optype(n) == &LogicOp::Not.into())
2217 .exactly_one()
2218 .ok()
2219 .unwrap();
2220
2221 let subg = SiblingSubgraph::try_new(
2222 vec![
2223 vec![(not_node, IncomingPort::from(0))],
2224 call_nodes
2225 .iter()
2226 .map(|&n| (n, IncomingPort::from(1)))
2227 .collect_vec(),
2228 ],
2229 vec![(call_nodes[1], OutgoingPort::from(0))],
2230 &hugr_call_subgraph,
2231 )
2232 .unwrap();
2233
2234 assert_eq!(subg.function_calls.len(), 1);
2235 assert_eq!(subg.function_calls[0].len(), 2);
2236 }
2237}