1use std::cell::OnceCell;
13use std::collections::HashSet;
14use std::mem;
15
16use itertools::Itertools;
17use portgraph::algorithms::ConvexChecker;
18use portgraph::boundary::Boundary;
19use portgraph::{view::Subgraph, Direction, PortView};
20use thiserror::Error;
21
22use crate::builder::{Container, FunctionBuilder};
23use crate::core::HugrNode;
24use crate::extension::ExtensionSet;
25use crate::hugr::{HugrMut, HugrView, RootTagged};
26use crate::ops::dataflow::DataflowOpTrait;
27use crate::ops::handle::{ContainerHandle, DataflowOpID};
28use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
29use crate::types::{Signature, Type};
30use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
31
32#[derive(Clone, Debug)]
59pub struct SiblingSubgraph<N = Node> {
60 nodes: Vec<N>,
62 inputs: Vec<Vec<(N, IncomingPort)>>,
67 outputs: Vec<(N, OutgoingPort)>,
72}
73
74pub type IncomingPorts<N = Node> = Vec<Vec<(N, IncomingPort)>>;
81pub type OutgoingPorts<N = Node> = Vec<(N, OutgoingPort)>;
83
84impl<N: HugrNode> SiblingSubgraph<N> {
85 pub fn try_new_dataflow_subgraph<H, Root>(dfg_graph: &H) -> Result<Self, InvalidSubgraph<N>>
99 where
100 H: Clone + RootTagged<RootHandle = Root, Node = N>,
101 Root: ContainerHandle<ChildrenHandle = DataflowOpID>,
102 {
103 let parent = dfg_graph.root();
104 let nodes = dfg_graph.children(parent).skip(2).collect_vec();
105 let (inputs, outputs) = get_input_output_ports(dfg_graph);
106
107 validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?;
108
109 if nodes.is_empty() {
110 Err(InvalidSubgraph::EmptySubgraph)
111 } else {
112 Ok(Self {
113 nodes,
114 inputs,
115 outputs,
116 })
117 }
118 }
119
120 pub fn try_new(
160 incoming: IncomingPorts<N>,
161 outgoing: OutgoingPorts<N>,
162 hugr: &impl HugrView<Node = N>,
163 ) -> Result<Self, InvalidSubgraph<N>> {
164 let checker = TopoConvexChecker::new(hugr);
165 Self::try_new_with_checker(incoming, outgoing, hugr, &checker)
166 }
167
168 pub fn try_new_with_checker(
177 inputs: IncomingPorts<N>,
178 outputs: OutgoingPorts<N>,
179 hugr: &impl HugrView<Node = N>,
180 checker: &impl ConvexChecker,
181 ) -> Result<Self, InvalidSubgraph<N>> {
182 let pg = hugr.portgraph();
183
184 let subpg = Subgraph::new_subgraph(pg.clone(), make_boundary(hugr, &inputs, &outputs));
186 let nodes = subpg
187 .nodes_iter()
188 .map(|index| hugr.get_node(index))
189 .collect_vec();
190 validate_subgraph(hugr, &nodes, &inputs, &outputs)?;
191
192 if !subpg.is_convex_with_checker(checker) {
193 return Err(InvalidSubgraph::NotConvex);
194 }
195
196 Ok(Self {
197 nodes,
198 inputs,
199 outputs,
200 })
201 }
202
203 pub fn try_from_nodes(
219 nodes: impl Into<Vec<N>>,
220 hugr: &impl HugrView<Node = N>,
221 ) -> Result<Self, InvalidSubgraph<N>> {
222 let checker = TopoConvexChecker::new(hugr);
223 Self::try_from_nodes_with_checker(nodes, hugr, &checker)
224 }
225
226 pub fn try_from_nodes_with_checker<'c, 'h: 'c>(
235 nodes: impl Into<Vec<N>>,
236 hugr: &'h impl HugrView<Node = N>,
237 checker: &impl ConvexChecker,
238 ) -> Result<Self, InvalidSubgraph<N>> {
239 let nodes = nodes.into();
240
241 match nodes.as_slice() {
243 [] => return Err(InvalidSubgraph::EmptySubgraph),
244 [node] => return Ok(Self::from_node(*node, hugr)),
245 _ => {}
246 };
247
248 let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
249 let incoming_edges = nodes
250 .iter()
251 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
252 let outgoing_edges = nodes
253 .iter()
254 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
255 let inputs = incoming_edges
256 .filter(|&(n, p)| {
257 if !hugr.is_linked(n, p) {
258 return false;
259 }
260 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
261 !nodes_set.contains(&out_n)
262 })
263 .map(|p| vec![p])
265 .collect_vec();
266 let outputs = outgoing_edges
267 .filter(|&(n, p)| {
268 hugr.linked_ports(n, p)
269 .any(|(n1, _)| !nodes_set.contains(&n1))
270 })
271 .collect_vec();
272 Self::try_new_with_checker(inputs, outputs, hugr, checker)
273 }
274
275 pub fn from_node(node: N, hugr: &impl HugrView<Node = N>) -> Self {
279 let nodes = vec![node];
280 let inputs = hugr
281 .node_inputs(node)
282 .filter(|&p| hugr.is_linked(node, p))
283 .map(|p| vec![(node, p)])
284 .collect_vec();
285 let outputs = hugr
286 .node_outputs(node)
287 .filter_map(|p| {
288 {
290 hugr.is_linked(node, p)
291 || hugr
292 .get_optype(node)
293 .port_kind(p)
294 .is_some_and(|k| k.is_value())
295 }
296 .then_some((node, p))
297 })
298 .collect_vec();
299
300 Self {
301 nodes,
302 inputs,
303 outputs,
304 }
305 }
306
307 pub fn nodes(&self) -> &[N] {
309 &self.nodes
310 }
311
312 pub fn node_count(&self) -> usize {
314 self.nodes.len()
315 }
316
317 pub fn incoming_ports(&self) -> &IncomingPorts<N> {
319 &self.inputs
320 }
321
322 pub fn outgoing_ports(&self) -> &OutgoingPorts<N> {
324 &self.outputs
325 }
326
327 pub fn signature(&self, hugr: &impl HugrView<Node = N>) -> Signature {
329 let input = self
330 .inputs
331 .iter()
332 .map(|part| {
333 let &(n, p) = part.iter().next().expect("is non-empty");
334 let sig = hugr.signature(n).expect("must have dataflow signature");
335 sig.port_type(p).cloned().expect("must be dataflow edge")
336 })
337 .collect_vec();
338 let output = self
339 .outputs
340 .iter()
341 .map(|&(n, p)| {
342 let sig = hugr.signature(n).expect("must have dataflow signature");
343 sig.port_type(p).cloned().expect("must be dataflow edge")
344 })
345 .collect_vec();
346 Signature::new(input, output).with_extension_delta(ExtensionSet::union_over(
347 self.nodes
348 .iter()
349 .map(|n| hugr.get_optype(*n).extension_delta()),
350 ))
351 }
352
353 pub fn get_parent(&self, hugr: &impl HugrView<Node = N>) -> N {
355 hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
356 }
357
358 pub fn create_simple_replacement(
374 &self,
375 hugr: &impl HugrView<Node = N>,
376 replacement: Hugr,
377 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
378 let rep_root = replacement.root();
379 let dfg_optype = replacement.get_optype(rep_root);
380 if !OpTag::Dfg.is_superset(dfg_optype.tag()) {
381 return Err(InvalidReplacement::InvalidDataflowGraph {
382 node: rep_root,
383 op: dfg_optype.clone(),
384 });
385 }
386 let [rep_input, rep_output] = replacement
387 .get_io(rep_root)
388 .expect("DFG root in the replacement does not have input and output nodes.");
389
390 let current_signature = self.signature(hugr);
391 let new_signature = dfg_optype.dataflow_signature();
392 if new_signature.as_ref().map(|s| &s.input) != Some(¤t_signature.input)
393 || new_signature.as_ref().map(|s| &s.output) != Some(¤t_signature.output)
394 {
395 return Err(InvalidReplacement::InvalidSignature {
396 expected: self.signature(hugr),
397 actual: dfg_optype.dataflow_signature().map(|s| s.into_owned()),
398 });
399 }
400
401 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p));
404 let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p));
405 let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs.partition(|&(n, p)| {
406 replacement
407 .signature(n)
408 .is_some_and(|s| s.port_type(p).is_some())
409 });
410 let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| {
411 replacement
412 .signature(n)
413 .is_some_and(|s| s.port_type(p).is_some())
414 });
415
416 if iter_io(&vec![out_order_ports], &in_order_ports)
417 .any(|(n, p)| is_order_edge(&replacement, n, p))
418 {
419 unimplemented!("Found state order edges in replacement graph");
420 }
421
422 let nu_inp = rep_inputs
423 .into_iter()
424 .zip_eq(&self.inputs)
425 .flat_map(|((rep_source_n, rep_source_p), self_targets)| {
426 replacement
427 .linked_inputs(rep_source_n, rep_source_p)
428 .flat_map(move |rep_target| {
429 self_targets
430 .iter()
431 .map(move |&self_target| (rep_target, self_target))
432 })
433 })
434 .collect();
435 let nu_out = self
436 .outputs
437 .iter()
438 .zip_eq(rep_outputs)
439 .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| {
440 hugr.linked_inputs(self_source_n, self_source_p)
441 .map(move |self_target| (self_target, rep_target_p))
442 })
443 .collect();
444
445 Ok(SimpleReplacement::new(
446 self.clone(),
447 replacement,
448 nu_inp,
449 nu_out,
450 ))
451 }
452}
453
454impl SiblingSubgraph {
455 pub fn extract_subgraph(
460 &self,
461 hugr: &impl HugrView<Node = Node>,
462 name: impl Into<String>,
463 ) -> Hugr {
464 let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
465 let mut extracted = mem::take(builder.hugr_mut());
468 let node_map = extracted.insert_subgraph(extracted.root(), hugr, self);
469
470 let [inp, out] = extracted.get_io(extracted.root()).unwrap();
472 let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
473 let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
474 let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
475
476 for (inp_port, repl_ports) in inputs {
477 for (repl_node, repl_port) in repl_ports {
478 connections.push((inp, inp_port, node_map[repl_node], *repl_port));
479 }
480 }
481 for (out_port, (repl_node, repl_port)) in outputs {
482 connections.push((node_map[repl_node], *repl_port, out, out_port));
483 }
484
485 for (src, src_port, dst, dst_port) in connections {
486 extracted.connect(src, src_port, dst, dst_port);
487 }
488
489 extracted
490 }
491}
492
493fn iter_incoming<N: HugrNode>(
495 inputs: &IncomingPorts<N>,
496) -> impl Iterator<Item = (N, IncomingPort)> + '_ {
497 inputs.iter().flat_map(|part| part.iter().copied())
498}
499
500fn iter_outgoing<N: HugrNode>(
502 outputs: &OutgoingPorts<N>,
503) -> impl Iterator<Item = (N, OutgoingPort)> + '_ {
504 outputs.iter().copied()
505}
506
507fn iter_io<'a, N: HugrNode>(
509 inputs: &'a IncomingPorts<N>,
510 outputs: &'a OutgoingPorts<N>,
511) -> impl Iterator<Item = (N, Port)> + 'a {
512 iter_incoming(inputs)
513 .map(|(n, p)| (n, Port::from(p)))
514 .chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
515}
516
517fn make_boundary<'a, N: HugrNode>(
518 hugr: &impl HugrView<Node = N>,
519 inputs: &'a IncomingPorts<N>,
520 outputs: &'a OutgoingPorts<N>,
521) -> Boundary {
522 let to_pg_index = |n: N, p: Port| {
523 hugr.portgraph()
524 .port_index(hugr.get_pg_index(n), p.pg_offset())
525 .unwrap()
526 };
527 Boundary::new(
528 iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
529 iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
530 )
531}
532
533pub struct TopoConvexChecker<'g, Base: 'g + HugrView> {
538 base: &'g Base,
539 checker: OnceCell<portgraph::algorithms::TopoConvexChecker<Base::Portgraph<'g>>>,
540}
541
542impl<'g, Base: HugrView> TopoConvexChecker<'g, Base> {
543 pub fn new(base: &'g Base) -> Self {
545 Self {
546 base,
547 checker: OnceCell::new(),
548 }
549 }
550
551 fn get_checker(&self) -> &portgraph::algorithms::TopoConvexChecker<Base::Portgraph<'g>> {
553 self.checker
554 .get_or_init(|| portgraph::algorithms::TopoConvexChecker::new(self.base.portgraph()))
555 }
556}
557
558impl<Base: HugrView> ConvexChecker for TopoConvexChecker<'_, Base> {
559 fn is_convex(
560 &self,
561 nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
562 inputs: impl IntoIterator<Item = portgraph::PortIndex>,
563 outputs: impl IntoIterator<Item = portgraph::PortIndex>,
564 ) -> bool {
565 let mut nodes = nodes.into_iter().multipeek();
566 if nodes.peek().is_none() || nodes.peek().is_none() {
569 return true;
570 };
571 self.get_checker().is_convex(nodes, inputs, outputs)
572 }
573}
574
575fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(
579 hugr: &H,
580 ports: &[(H::Node, P)],
581) -> Option<Type> {
582 let &(n, p) = ports.first()?;
583 let edge_t = hugr.signature(n)?.port_type(p)?.clone();
584 ports
585 .iter()
586 .all(|&(n, p)| {
587 hugr.signature(n)
588 .is_some_and(|s| s.port_type(p) == Some(&edge_t))
589 })
590 .then_some(edge_t)
591}
592
593fn validate_subgraph<H: HugrView>(
600 hugr: &H,
601 nodes: &[H::Node],
602 inputs: &IncomingPorts<H::Node>,
603 outputs: &OutgoingPorts<H::Node>,
604) -> Result<(), InvalidSubgraph<H::Node>> {
605 let node_set = nodes.iter().copied().collect::<HashSet<_>>();
607
608 if nodes.is_empty() {
610 return Err(InvalidSubgraph::EmptySubgraph);
611 }
612 if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() {
614 let first_node = nodes[0];
615 let first_parent = hugr.get_parent(first_node);
616 let other_node = *nodes
617 .iter()
618 .skip(1)
619 .find(|&&n| hugr.get_parent(n) != first_parent)
620 .unwrap();
621 let other_parent = hugr.get_parent(other_node);
622 return Err(InvalidSubgraph::NoSharedParent {
623 first_node,
624 first_parent,
625 other_node,
626 other_parent,
627 });
628 }
629
630 if iter_io(inputs, outputs).any(|(n, p)| is_order_edge(hugr, n, p)) {
632 unimplemented!("Connected order edges not supported at the boundary")
633 }
634
635 let boundary_ports = iter_io(inputs, outputs).collect_vec();
636 if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
638 Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
639 };
640 if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
642 hugr.linked_ports(n, p)
643 .all(|(n1, _)| node_set.contains(&n1))
644 }) {
645 Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
646 };
647
648 if nodes.iter().any(|&n| {
651 hugr.node_inputs(n).any(|p| {
652 hugr.linked_ports(n, p).any(|(n1, _)| {
653 !node_set.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
654 })
655 })
656 }) {
657 return Err(InvalidSubgraph::NotConvex);
658 }
659 if nodes.iter().any(|&n| {
662 hugr.node_outputs(n).any(|p| {
663 hugr.linked_ports(n, p)
664 .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
665 })
666 }) {
667 return Err(InvalidSubgraph::NotConvex);
668 }
669
670 if !inputs.iter().flatten().all_unique() {
672 return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
673 }
674
675 if inputs.iter().any(|p| p.is_empty()) {
677 return Err(InvalidSubgraphBoundary::EmptyPartition.into());
678 }
679
680 if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
682 let Some(edge_t) = get_edge_type(hugr, ports) else {
683 return true;
684 };
685 let require_copy = ports.len() > 1;
686 require_copy && !edge_t.copyable()
687 }) {
688 Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
689 };
690
691 Ok(())
692}
693
694fn get_input_output_ports<H: HugrView>(
695 hugr: &H,
696) -> (IncomingPorts<H::Node>, OutgoingPorts<H::Node>) {
697 let [inp, out] = hugr.get_io(hugr.root()).expect("invalid DFG");
698 if has_other_edge(hugr, inp, Direction::Outgoing) {
699 unimplemented!("Non-dataflow output not supported at input node")
700 }
701 let dfg_inputs = hugr
702 .get_optype(inp)
703 .as_input()
704 .unwrap()
705 .signature()
706 .output_ports();
707 if has_other_edge(hugr, out, Direction::Incoming) {
708 unimplemented!("Non-dataflow input not supported at output node")
709 }
710 let dfg_outputs = hugr
711 .get_optype(out)
712 .as_output()
713 .unwrap()
714 .signature()
715 .input_ports();
716
717 let inputs = dfg_inputs
720 .into_iter()
721 .map(|p| {
722 hugr.linked_inputs(inp, p)
723 .filter(|&(n, _)| n != out)
724 .collect_vec()
725 })
726 .filter(|v| !v.is_empty())
727 .collect();
728 let outputs = dfg_outputs
731 .into_iter()
732 .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
733 .collect();
734 (inputs, outputs)
735}
736
737fn is_order_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
739 let op = hugr.get_optype(node);
740 op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port)
741}
742
743fn has_other_edge<H: HugrView>(hugr: &H, node: H::Node, dir: Direction) -> bool {
745 let op = hugr.get_optype(node);
746 op.other_port_kind(dir).is_some() && hugr.is_linked(node, op.other_port(dir).unwrap())
747}
748
749#[derive(Debug, Clone, PartialEq, Error)]
751#[non_exhaustive]
752pub enum InvalidReplacement {
753 #[error("The root of the replacement {node} is a {}, but only OpType::DFGs are supported.", op.name())]
755 InvalidDataflowGraph {
756 node: Node,
758 op: OpType,
760 },
761 #[error(
763 "Replacement graph type mismatch. Expected {expected}, got {}.",
764 actual.clone().map_or("none".to_string(), |t| t.to_string()))
765 ]
766 InvalidSignature {
767 expected: Signature,
769 actual: Option<Signature>,
771 },
772 #[error("SiblingSubgraph is not convex.")]
774 NonConvexSubgraph,
775}
776
777#[derive(Debug, Clone, PartialEq, Eq, Error)]
779#[non_exhaustive]
780pub enum InvalidSubgraph<N: HugrNode = Node> {
781 #[error("The subgraph is not convex.")]
783 NotConvex,
784 #[error(
786 "Not a sibling subgraph. {first_node} has parent {}, but {other_node} has parent {}.",
787 first_parent.map_or("None".to_string(), |n| n.to_string()),
788 other_parent.map_or("None".to_string(), |n| n.to_string())
789 )]
790 NoSharedParent {
791 first_node: N,
793 first_parent: Option<N>,
795 other_node: N,
797 other_parent: Option<N>,
799 },
800 #[error("Empty subgraphs are not supported.")]
802 EmptySubgraph,
803 #[error("Invalid boundary port.")]
805 InvalidBoundary(#[from] InvalidSubgraphBoundary<N>),
806}
807
808#[derive(Debug, Clone, PartialEq, Eq, Error)]
810#[non_exhaustive]
811pub enum InvalidSubgraphBoundary<N: HugrNode = Node> {
812 #[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
814 PortNodeNotInSet(N, Port),
815 #[error("(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph.")]
817 DisconnectedBoundaryPort(N, Port),
818 #[error("A port in the input boundary is used multiple times.")]
820 NonUniqueInput,
821 #[error("A partition in the input boundary is empty.")]
823 EmptyPartition,
824 #[error("The partition {0} in the input boundary has ports with different types.")]
826 MismatchedTypes(usize),
827}
828
829#[cfg(test)]
830mod tests {
831 use cool_asserts::assert_matches;
832
833 use crate::builder::inout_sig;
834 use crate::hugr::Rewrite;
835 use crate::ops::Const;
836 use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
837 use crate::std_extensions::logic::{self, LogicOp};
838 use crate::type_row;
839 use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64};
840 use crate::{
841 builder::{
842 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
843 ModuleBuilder,
844 },
845 extension::prelude::{bool_t, qb_t},
846 hugr::views::{HierarchyView, SiblingGraph},
847 ops::handle::{DfgID, FuncID, NodeHandle},
848 std_extensions::logic::test::and_op,
849 };
850
851 use super::*;
852
853 impl<N: HugrNode> SiblingSubgraph<N> {
854 fn from_sibling_graph(
863 sibling_graph: &impl HugrView<Node = N>,
864 ) -> Result<Self, InvalidSubgraph<N>> {
865 let root = sibling_graph.root();
866 let nodes = sibling_graph.children(root).collect_vec();
867 if nodes.is_empty() {
868 Err(InvalidSubgraph::EmptySubgraph)
869 } else {
870 Ok(Self {
871 nodes,
872 inputs: Vec::new(),
873 outputs: Vec::new(),
874 })
875 }
876 }
877 }
878
879 fn build_hugr() -> Result<(Hugr, Node), BuildError> {
882 let mut mod_builder = ModuleBuilder::new();
883 let func = mod_builder.declare(
884 "test",
885 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
886 .with_extension_delta(ExtensionSet::from_iter([
887 test_quantum_extension::EXTENSION_ID,
888 float_types::EXTENSION_ID,
889 ]))
890 .into(),
891 )?;
892 let func_id = {
893 let mut dfg = mod_builder.define_declaration(&func)?;
894 let [w0, w1, w2] = dfg.input_wires_arr();
895 let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
896 let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
897 let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
898 dfg.finish_with_outputs([w0, w1, w2])?
899 };
900 let hugr = mod_builder
901 .finish_hugr()
902 .map_err(|e| -> BuildError { e.into() })?;
903 Ok((hugr, func_id.node()))
904 }
905
906 fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
908 let mut mod_builder = ModuleBuilder::new();
909 let func = mod_builder.declare(
910 "test",
911 Signature::new_endo(vec![bool_t()])
912 .with_extension_delta(logic::EXTENSION_ID)
913 .into(),
914 )?;
915 let func_id = {
916 let mut dfg = mod_builder.define_declaration(&func)?;
917 let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?;
918 let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?;
919 let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?;
920 dfg.finish_with_outputs(outs3.outputs())?
921 };
922 let hugr = mod_builder
923 .finish_hugr()
924 .map_err(|e| -> BuildError { e.into() })?;
925 Ok((hugr, func_id.node()))
926 }
927
928 fn build_multiport_hugr() -> Result<(Hugr, Node), BuildError> {
930 let mut mod_builder = ModuleBuilder::new();
931 let func = mod_builder.declare(
932 "test",
933 Signature::new(bool_t(), vec![bool_t(), bool_t()])
934 .with_extension_delta(logic::EXTENSION_ID)
935 .into(),
936 )?;
937 let func_id = {
938 let mut dfg = mod_builder.define_declaration(&func)?;
939 let [b0] = dfg.input_wires_arr();
940 let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr();
941 let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr();
942 dfg.finish_with_outputs([b1, b2])?
943 };
944 let hugr = mod_builder
945 .finish_hugr()
946 .map_err(|e| -> BuildError { e.into() })?;
947 Ok((hugr, func_id.node()))
948 }
949
950 fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
952 let mut mod_builder = ModuleBuilder::new();
953 let func = mod_builder.declare(
954 "test",
955 Signature::new_endo(bool_t())
956 .with_extension_delta(logic::EXTENSION_ID)
957 .into(),
958 )?;
959 let func_id = {
960 let mut dfg = mod_builder.define_declaration(&func)?;
961 let in_wire = dfg.input_wires().exactly_one().unwrap();
962 let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?;
963 dfg.finish_with_outputs(outs.outputs())?
964 };
965 let hugr = mod_builder
966 .finish_hugr()
967 .map_err(|e| -> BuildError { e.into() })?;
968 Ok((hugr, func_id.node()))
969 }
970
971 #[test]
972 fn construct_subgraph() -> Result<(), InvalidSubgraph> {
973 let (hugr, func_root) = build_hugr().unwrap();
974 let sibling_graph: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
975 let from_root = SiblingSubgraph::from_sibling_graph(&sibling_graph)?;
976 let region: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
977 let from_region = SiblingSubgraph::from_sibling_graph(®ion)?;
978 assert_eq!(
979 from_root.get_parent(&sibling_graph),
980 from_region.get_parent(&sibling_graph)
981 );
982 assert_eq!(
983 from_root.signature(&sibling_graph),
984 from_region.signature(&sibling_graph)
985 );
986 Ok(())
987 }
988
989 #[test]
990 fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
991 let (mut hugr, func_root) = build_hugr().unwrap();
992 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root).unwrap();
993 let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
994
995 let empty_dfg = {
996 let builder =
997 DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap();
998 let inputs = builder.input_wires();
999 builder.finish_hugr_with_outputs(inputs).unwrap()
1000 };
1001
1002 let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();
1003
1004 assert_eq!(rep.subgraph().nodes().len(), 4);
1005
1006 assert_eq!(hugr.node_count(), 8); hugr.apply_rewrite(rep).unwrap();
1008 assert_eq!(hugr.node_count(), 4); Ok(())
1011 }
1012
1013 #[test]
1014 fn test_signature() -> Result<(), InvalidSubgraph> {
1015 let (hugr, dfg) = build_hugr().unwrap();
1016 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, dfg).unwrap();
1017 let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
1018 assert_eq!(
1019 sub.signature(&func),
1020 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta(
1021 ExtensionSet::from_iter([
1022 test_quantum_extension::EXTENSION_ID,
1023 float_types::EXTENSION_ID,
1024 ])
1025 )
1026 );
1027 Ok(())
1028 }
1029
1030 #[test]
1031 fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> {
1032 let (hugr, dfg) = build_hugr().unwrap();
1033 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, dfg).unwrap();
1034 let sub = SiblingSubgraph::from_sibling_graph(&func)?;
1035
1036 let empty_dfg = {
1037 let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
1038 let inputs = builder.input_wires();
1039 builder.finish_hugr_with_outputs(inputs).unwrap()
1040 };
1041
1042 assert_matches!(
1043 sub.create_simple_replacement(&func, empty_dfg).unwrap_err(),
1044 InvalidReplacement::InvalidSignature { .. }
1045 );
1046 Ok(())
1047 }
1048
1049 #[test]
1050 fn convex_subgraph() {
1051 let (hugr, func_root) = build_hugr().unwrap();
1052 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1053 assert_eq!(
1054 SiblingSubgraph::try_new_dataflow_subgraph(&func)
1055 .unwrap()
1056 .nodes()
1057 .len(),
1058 4
1059 )
1060 }
1061
1062 #[test]
1063 fn convex_subgraph_2() {
1064 let (hugr, func_root) = build_hugr().unwrap();
1065 let [inp, out] = hugr.get_io(func_root).unwrap();
1066 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1067 SiblingSubgraph::try_new(
1069 hugr.node_outputs(inp)
1070 .take(2)
1071 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
1072 .filter(|ps| !ps.is_empty())
1073 .collect(),
1074 hugr.node_inputs(out)
1075 .take(2)
1076 .filter_map(|p| hugr.single_linked_output(out, p))
1077 .collect(),
1078 &func,
1079 )
1080 .unwrap();
1081 }
1082
1083 #[test]
1084 fn degen_boundary() {
1085 let (hugr, func_root) = build_hugr().unwrap();
1086 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1087 let [inp, _] = hugr.get_io(func_root).unwrap();
1088 let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
1089 assert_matches!(
1091 SiblingSubgraph::try_new(
1092 vec![hugr
1093 .linked_ports(inp, first_cx_edge)
1094 .map(|(n, p)| (n, p.as_incoming().unwrap()))
1095 .collect()],
1096 vec![(inp, first_cx_edge)],
1097 &func,
1098 ),
1099 Err(InvalidSubgraph::InvalidBoundary(
1100 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1101 ))
1102 );
1103 }
1104
1105 #[test]
1106 fn non_convex_subgraph() {
1107 let (hugr, func_root) = build_3not_hugr().unwrap();
1108 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1109 let [inp, _out] = hugr.get_io(func_root).unwrap();
1110 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1111 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
1112 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
1113 let not1_inp = hugr.node_inputs(not1).next().unwrap();
1114 let not1_out = hugr.node_outputs(not1).next().unwrap();
1115 let not3_inp = hugr.node_inputs(not3).next().unwrap();
1116 let not3_out = hugr.node_outputs(not3).next().unwrap();
1117 assert_matches!(
1118 SiblingSubgraph::try_new(
1119 vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
1120 vec![(not1, not1_out), (not3, not3_out)],
1121 &func
1122 ),
1123 Err(InvalidSubgraph::NotConvex)
1124 );
1125 }
1126
1127 #[test]
1130 fn convex_multiports() {
1131 let (hugr, func_root) = build_multiport_hugr().unwrap();
1132 let [inp, out] = hugr.get_io(func_root).unwrap();
1133 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1134 let not2 = hugr
1135 .output_neighbours(not1)
1136 .filter(|&n| n != out)
1137 .exactly_one()
1138 .ok()
1139 .unwrap();
1140
1141 let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
1142 assert_eq!(subgraph.nodes(), [not1, not2]);
1143 }
1144
1145 #[test]
1146 fn invalid_boundary() {
1147 let (hugr, func_root) = build_hugr().unwrap();
1148 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1149 let [inp, out] = hugr.get_io(func_root).unwrap();
1150 let cx_edges_in = hugr.node_outputs(inp);
1151 let cx_edges_out = hugr.node_inputs(out);
1152 assert_matches!(
1154 SiblingSubgraph::try_new(
1155 cx_edges_out.map(|p| vec![(out, p)]).collect(),
1156 cx_edges_in.map(|p| (inp, p)).collect(),
1157 &func,
1158 ),
1159 Err(InvalidSubgraph::InvalidBoundary(
1160 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1161 ))
1162 );
1163 }
1164
1165 #[test]
1166 fn preserve_signature() {
1167 let (hugr, func_root) = build_hugr_classical().unwrap();
1168 let func_graph: SiblingGraph<'_, FuncID<true>> =
1169 SiblingGraph::try_new(&hugr, func_root).unwrap();
1170 let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
1171 let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
1172 assert_eq!(func_defn.signature, func.signature(&func_graph).into());
1173 }
1174
1175 #[test]
1176 fn extract_subgraph() {
1177 let (hugr, func_root) = build_hugr().unwrap();
1178 let func_graph: SiblingGraph<'_, FuncID<true>> =
1179 SiblingGraph::try_new(&hugr, func_root).unwrap();
1180 let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
1181 let extracted = subgraph.extract_subgraph(&hugr, "region");
1182
1183 extracted.validate().unwrap();
1184 }
1185
1186 #[test]
1187 fn edge_both_output_and_copy() {
1188 let one_bit = vec![bool_t()];
1190 let two_bit = vec![bool_t(), bool_t()];
1191
1192 let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap();
1193 let inw = builder.input_wires().exactly_one().unwrap();
1194 let outw1 = builder
1195 .add_dataflow_op(LogicOp::Not, [inw])
1196 .unwrap()
1197 .out_wire(0);
1198 let outw2 = builder
1199 .add_dataflow_op(and_op(), [inw, outw1])
1200 .unwrap()
1201 .outputs();
1202 let outw = [outw1].into_iter().chain(outw2);
1203 let h = builder.finish_hugr_with_outputs(outw).unwrap();
1204 let view = SiblingGraph::<DfgID>::try_new(&h, h.root()).unwrap();
1205 let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap();
1206 assert_eq!(subg.nodes().len(), 2);
1207 }
1208
1209 #[test]
1210 fn test_unconnected() {
1211 let mut b = DFGBuilder::new(
1213 Signature::new(bool_t(), type_row![])
1214 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1216 )
1217 .unwrap();
1218 let inw = b.input_wires().exactly_one().unwrap();
1219 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1220 let mut h = b.finish_hugr_with_outputs([]).unwrap();
1222
1223 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1224
1225 assert_eq!(subg.nodes().len(), 1);
1226 let replacement = {
1228 let mut rep_b = DFGBuilder::new(
1229 Signature::new_endo(bool_t())
1230 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1231 )
1232 .unwrap();
1233 let inw = rep_b.input_wires().exactly_one().unwrap();
1234
1235 let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1236
1237 rep_b.finish_hugr_with_outputs(not_n.outputs()).unwrap()
1238 };
1239 let rep = subg.create_simple_replacement(&h, replacement).unwrap();
1240 rep.apply(&mut h).unwrap();
1241 }
1242}