1use std::cell::OnceCell;
13use std::collections::HashSet;
14use std::mem;
15
16use itertools::Itertools;
17use portgraph::algorithms::ConvexChecker;
18use portgraph::boundary::Boundary;
19use portgraph::{view::Subgraph, Direction, PortView};
20use thiserror::Error;
21
22use crate::builder::{Container, FunctionBuilder};
23use crate::core::HugrNode;
24use crate::extension::ExtensionSet;
25use crate::hugr::{HugrMut, HugrView, RootTagged};
26use crate::ops::dataflow::DataflowOpTrait;
27use crate::ops::handle::{ContainerHandle, DataflowOpID};
28use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
29use crate::types::{Signature, Type};
30use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
31
32#[derive(Clone, Debug)]
59pub struct SiblingSubgraph<N = Node> {
60 nodes: Vec<N>,
62 inputs: Vec<Vec<(N, IncomingPort)>>,
67 outputs: Vec<(N, OutgoingPort)>,
72}
73
74pub type IncomingPorts<N = Node> = Vec<Vec<(N, IncomingPort)>>;
81pub type OutgoingPorts<N = Node> = Vec<(N, OutgoingPort)>;
83
84impl<N: HugrNode> SiblingSubgraph<N> {
85 pub fn try_new_dataflow_subgraph<H, Root>(dfg_graph: &H) -> Result<Self, InvalidSubgraph<N>>
99 where
100 H: Clone + RootTagged<RootHandle = Root, Node = N>,
101 Root: ContainerHandle<ChildrenHandle = DataflowOpID>,
102 {
103 let parent = dfg_graph.root();
104 let nodes = dfg_graph.children(parent).skip(2).collect_vec();
105 let (inputs, outputs) = get_input_output_ports(dfg_graph);
106
107 validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?;
108
109 if nodes.is_empty() {
110 Err(InvalidSubgraph::EmptySubgraph)
111 } else {
112 Ok(Self {
113 nodes,
114 inputs,
115 outputs,
116 })
117 }
118 }
119
120 pub fn try_new(
160 incoming: IncomingPorts<N>,
161 outgoing: OutgoingPorts<N>,
162 hugr: &impl HugrView<Node = N>,
163 ) -> Result<Self, InvalidSubgraph<N>> {
164 let checker = TopoConvexChecker::new(hugr);
165 Self::try_new_with_checker(incoming, outgoing, hugr, &checker)
166 }
167
168 pub fn try_new_with_checker(
177 inputs: IncomingPorts<N>,
178 outputs: OutgoingPorts<N>,
179 hugr: &impl HugrView<Node = N>,
180 checker: &impl ConvexChecker,
181 ) -> Result<Self, InvalidSubgraph<N>> {
182 let pg = hugr.portgraph();
183
184 let subpg = Subgraph::new_subgraph(pg.clone(), make_boundary(hugr, &inputs, &outputs));
186 let nodes = subpg
187 .nodes_iter()
188 .map(|index| hugr.get_node(index))
189 .collect_vec();
190 validate_subgraph(hugr, &nodes, &inputs, &outputs)?;
191
192 if nodes.len() > 1 && !subpg.is_convex_with_checker(checker) {
193 return Err(InvalidSubgraph::NotConvex);
194 }
195
196 Ok(Self {
197 nodes,
198 inputs,
199 outputs,
200 })
201 }
202
203 pub fn try_from_nodes(
219 nodes: impl Into<Vec<N>>,
220 hugr: &impl HugrView<Node = N>,
221 ) -> Result<Self, InvalidSubgraph<N>> {
222 let checker = TopoConvexChecker::new(hugr);
223 Self::try_from_nodes_with_checker(nodes, hugr, &checker)
224 }
225
226 pub fn try_from_nodes_with_checker<'c, 'h: 'c>(
235 nodes: impl Into<Vec<N>>,
236 hugr: &'h impl HugrView<Node = N>,
237 checker: &impl ConvexChecker,
238 ) -> Result<Self, InvalidSubgraph<N>> {
239 let nodes = nodes.into();
240
241 if nodes.is_empty() {
242 return Err(InvalidSubgraph::EmptySubgraph);
243 }
244
245 let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
246 let incoming_edges = nodes
247 .iter()
248 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
249 let outgoing_edges = nodes
250 .iter()
251 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
252 let inputs = incoming_edges
253 .filter(|&(n, p)| {
254 if !hugr.is_linked(n, p) {
255 return false;
256 }
257 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
258 !nodes_set.contains(&out_n)
259 })
260 .map(|p| vec![p])
262 .collect_vec();
263 let outputs = outgoing_edges
264 .filter(|&(n, p)| {
265 hugr.linked_ports(n, p)
266 .any(|(n1, _)| !nodes_set.contains(&n1))
267 })
268 .collect_vec();
269 Self::try_new_with_checker(inputs, outputs, hugr, checker)
270 }
271
272 pub fn from_node(node: N, hugr: &impl HugrView<Node = N>) -> Self {
276 let nodes = vec![node];
277 let inputs = hugr
278 .node_inputs(node)
279 .filter(|&p| hugr.is_linked(node, p))
280 .map(|p| vec![(node, p)])
281 .collect_vec();
282 let outputs = hugr
283 .node_outputs(node)
284 .filter_map(|p| {
285 {
287 hugr.is_linked(node, p)
288 || hugr
289 .get_optype(node)
290 .port_kind(p)
291 .is_some_and(|k| k.is_value())
292 }
293 .then_some((node, p))
294 })
295 .collect_vec();
296
297 Self {
298 nodes,
299 inputs,
300 outputs,
301 }
302 }
303
304 pub fn nodes(&self) -> &[N] {
306 &self.nodes
307 }
308
309 pub fn node_count(&self) -> usize {
311 self.nodes.len()
312 }
313
314 pub fn incoming_ports(&self) -> &IncomingPorts<N> {
316 &self.inputs
317 }
318
319 pub fn outgoing_ports(&self) -> &OutgoingPorts<N> {
321 &self.outputs
322 }
323
324 pub fn signature(&self, hugr: &impl HugrView<Node = N>) -> Signature {
326 let input = self
327 .inputs
328 .iter()
329 .map(|part| {
330 let &(n, p) = part.iter().next().expect("is non-empty");
331 let sig = hugr.signature(n).expect("must have dataflow signature");
332 sig.port_type(p).cloned().expect("must be dataflow edge")
333 })
334 .collect_vec();
335 let output = self
336 .outputs
337 .iter()
338 .map(|&(n, p)| {
339 let sig = hugr.signature(n).expect("must have dataflow signature");
340 sig.port_type(p).cloned().expect("must be dataflow edge")
341 })
342 .collect_vec();
343 Signature::new(input, output).with_extension_delta(ExtensionSet::union_over(
344 self.nodes
345 .iter()
346 .map(|n| hugr.get_optype(*n).extension_delta()),
347 ))
348 }
349
350 pub fn get_parent(&self, hugr: &impl HugrView<Node = N>) -> N {
352 hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
353 }
354
355 pub fn create_simple_replacement(
371 &self,
372 hugr: &impl HugrView<Node = N>,
373 replacement: Hugr,
374 ) -> Result<SimpleReplacement<N>, InvalidReplacement> {
375 let rep_root = replacement.root();
376 let dfg_optype = replacement.get_optype(rep_root);
377 if !OpTag::Dfg.is_superset(dfg_optype.tag()) {
378 return Err(InvalidReplacement::InvalidDataflowGraph {
379 node: rep_root,
380 op: dfg_optype.clone(),
381 });
382 }
383 let [rep_input, rep_output] = replacement
384 .get_io(rep_root)
385 .expect("DFG root in the replacement does not have input and output nodes.");
386
387 let current_signature = self.signature(hugr);
388 let new_signature = dfg_optype.dataflow_signature();
389 if new_signature.as_ref().map(|s| &s.input) != Some(¤t_signature.input)
390 || new_signature.as_ref().map(|s| &s.output) != Some(¤t_signature.output)
391 {
392 return Err(InvalidReplacement::InvalidSignature {
393 expected: self.signature(hugr),
394 actual: dfg_optype.dataflow_signature().map(|s| s.into_owned()),
395 });
396 }
397
398 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p));
401 let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p));
402 let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs.partition(|&(n, p)| {
403 replacement
404 .signature(n)
405 .is_some_and(|s| s.port_type(p).is_some())
406 });
407 let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| {
408 replacement
409 .signature(n)
410 .is_some_and(|s| s.port_type(p).is_some())
411 });
412
413 if iter_io(&vec![out_order_ports], &in_order_ports)
414 .any(|(n, p)| is_order_edge(&replacement, n, p))
415 {
416 unimplemented!("Found state order edges in replacement graph");
417 }
418
419 let nu_inp = rep_inputs
420 .into_iter()
421 .zip_eq(&self.inputs)
422 .flat_map(|((rep_source_n, rep_source_p), self_targets)| {
423 replacement
424 .linked_inputs(rep_source_n, rep_source_p)
425 .flat_map(move |rep_target| {
426 self_targets
427 .iter()
428 .map(move |&self_target| (rep_target, self_target))
429 })
430 })
431 .collect();
432 let nu_out = self
433 .outputs
434 .iter()
435 .zip_eq(rep_outputs)
436 .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| {
437 hugr.linked_inputs(self_source_n, self_source_p)
438 .map(move |self_target| (self_target, rep_target_p))
439 })
440 .collect();
441
442 Ok(SimpleReplacement::new(
443 self.clone(),
444 replacement,
445 nu_inp,
446 nu_out,
447 ))
448 }
449}
450
451impl SiblingSubgraph {
452 pub fn extract_subgraph(
457 &self,
458 hugr: &impl HugrView<Node = Node>,
459 name: impl Into<String>,
460 ) -> Hugr {
461 let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
462 let mut extracted = mem::take(builder.hugr_mut());
465 let node_map = extracted.insert_subgraph(extracted.root(), hugr, self);
466
467 let [inp, out] = extracted.get_io(extracted.root()).unwrap();
469 let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
470 let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
471 let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
472
473 for (inp_port, repl_ports) in inputs {
474 for (repl_node, repl_port) in repl_ports {
475 connections.push((inp, inp_port, node_map[repl_node], *repl_port));
476 }
477 }
478 for (out_port, (repl_node, repl_port)) in outputs {
479 connections.push((node_map[repl_node], *repl_port, out, out_port));
480 }
481
482 for (src, src_port, dst, dst_port) in connections {
483 extracted.connect(src, src_port, dst, dst_port);
484 }
485
486 extracted
487 }
488}
489
490fn iter_incoming<N: HugrNode>(
492 inputs: &IncomingPorts<N>,
493) -> impl Iterator<Item = (N, IncomingPort)> + '_ {
494 inputs.iter().flat_map(|part| part.iter().copied())
495}
496
497fn iter_outgoing<N: HugrNode>(
499 outputs: &OutgoingPorts<N>,
500) -> impl Iterator<Item = (N, OutgoingPort)> + '_ {
501 outputs.iter().copied()
502}
503
504fn iter_io<'a, N: HugrNode>(
506 inputs: &'a IncomingPorts<N>,
507 outputs: &'a OutgoingPorts<N>,
508) -> impl Iterator<Item = (N, Port)> + 'a {
509 iter_incoming(inputs)
510 .map(|(n, p)| (n, Port::from(p)))
511 .chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
512}
513
514fn make_boundary<'a, N: HugrNode>(
515 hugr: &impl HugrView<Node = N>,
516 inputs: &'a IncomingPorts<N>,
517 outputs: &'a OutgoingPorts<N>,
518) -> Boundary {
519 let to_pg_index = |n: N, p: Port| {
520 hugr.portgraph()
521 .port_index(hugr.get_pg_index(n), p.pg_offset())
522 .unwrap()
523 };
524 Boundary::new(
525 iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
526 iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
527 )
528}
529
530pub struct TopoConvexChecker<'g, Base: 'g + HugrView> {
535 base: &'g Base,
536 checker: OnceCell<portgraph::algorithms::TopoConvexChecker<Base::Portgraph<'g>>>,
537}
538
539impl<'g, Base: HugrView> TopoConvexChecker<'g, Base> {
540 pub fn new(base: &'g Base) -> Self {
542 Self {
543 base,
544 checker: OnceCell::new(),
545 }
546 }
547
548 fn get_checker(&self) -> &portgraph::algorithms::TopoConvexChecker<Base::Portgraph<'g>> {
550 self.checker
551 .get_or_init(|| portgraph::algorithms::TopoConvexChecker::new(self.base.portgraph()))
552 }
553}
554
555impl<Base: HugrView> ConvexChecker for TopoConvexChecker<'_, Base> {
556 fn is_convex(
557 &self,
558 nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
559 inputs: impl IntoIterator<Item = portgraph::PortIndex>,
560 outputs: impl IntoIterator<Item = portgraph::PortIndex>,
561 ) -> bool {
562 let mut nodes = nodes.into_iter().multipeek();
563 if nodes.peek().is_none() || nodes.peek().is_none() {
566 return true;
567 };
568 self.get_checker().is_convex(nodes, inputs, outputs)
569 }
570}
571
572fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(
576 hugr: &H,
577 ports: &[(H::Node, P)],
578) -> Option<Type> {
579 let &(n, p) = ports.first()?;
580 let edge_t = hugr.signature(n)?.port_type(p)?.clone();
581 ports
582 .iter()
583 .all(|&(n, p)| {
584 hugr.signature(n)
585 .is_some_and(|s| s.port_type(p) == Some(&edge_t))
586 })
587 .then_some(edge_t)
588}
589
590fn validate_subgraph<H: HugrView>(
597 hugr: &H,
598 nodes: &[H::Node],
599 inputs: &IncomingPorts<H::Node>,
600 outputs: &OutgoingPorts<H::Node>,
601) -> Result<(), InvalidSubgraph<H::Node>> {
602 let node_set = nodes.iter().copied().collect::<HashSet<_>>();
604
605 if nodes.is_empty() {
607 return Err(InvalidSubgraph::EmptySubgraph);
608 }
609 if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() {
611 let first_node = nodes[0];
612 let first_parent = hugr.get_parent(first_node);
613 let other_node = *nodes
614 .iter()
615 .skip(1)
616 .find(|&&n| hugr.get_parent(n) != first_parent)
617 .unwrap();
618 let other_parent = hugr.get_parent(other_node);
619 return Err(InvalidSubgraph::NoSharedParent {
620 first_node,
621 first_parent,
622 other_node,
623 other_parent,
624 });
625 }
626
627 if iter_io(inputs, outputs).any(|(n, p)| is_order_edge(hugr, n, p)) {
629 unimplemented!("Connected order edges not supported at the boundary")
630 }
631
632 let boundary_ports = iter_io(inputs, outputs).collect_vec();
633 if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
635 Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
636 };
637 if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
639 hugr.linked_ports(n, p)
640 .all(|(n1, _)| node_set.contains(&n1))
641 }) {
642 Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
643 };
644
645 if nodes.iter().any(|&n| {
648 hugr.node_inputs(n).any(|p| {
649 hugr.linked_ports(n, p).any(|(n1, _)| {
650 !node_set.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
651 })
652 })
653 }) {
654 return Err(InvalidSubgraph::NotConvex);
655 }
656 if nodes.iter().any(|&n| {
659 hugr.node_outputs(n).any(|p| {
660 hugr.linked_ports(n, p)
661 .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
662 })
663 }) {
664 return Err(InvalidSubgraph::NotConvex);
665 }
666
667 if !inputs.iter().flatten().all_unique() {
669 return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
670 }
671
672 if inputs.iter().any(|p| p.is_empty()) {
674 return Err(InvalidSubgraphBoundary::EmptyPartition.into());
675 }
676
677 if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
679 let Some(edge_t) = get_edge_type(hugr, ports) else {
680 return true;
681 };
682 let require_copy = ports.len() > 1;
683 require_copy && !edge_t.copyable()
684 }) {
685 Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
686 };
687
688 Ok(())
689}
690
691fn get_input_output_ports<H: HugrView>(
692 hugr: &H,
693) -> (IncomingPorts<H::Node>, OutgoingPorts<H::Node>) {
694 let [inp, out] = hugr.get_io(hugr.root()).expect("invalid DFG");
695 if has_other_edge(hugr, inp, Direction::Outgoing) {
696 unimplemented!("Non-dataflow output not supported at input node")
697 }
698 let dfg_inputs = hugr
699 .get_optype(inp)
700 .as_input()
701 .unwrap()
702 .signature()
703 .output_ports();
704 if has_other_edge(hugr, out, Direction::Incoming) {
705 unimplemented!("Non-dataflow input not supported at output node")
706 }
707 let dfg_outputs = hugr
708 .get_optype(out)
709 .as_output()
710 .unwrap()
711 .signature()
712 .input_ports();
713
714 let inputs = dfg_inputs
717 .into_iter()
718 .map(|p| {
719 hugr.linked_inputs(inp, p)
720 .filter(|&(n, _)| n != out)
721 .collect_vec()
722 })
723 .filter(|v| !v.is_empty())
724 .collect();
725 let outputs = dfg_outputs
728 .into_iter()
729 .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
730 .collect();
731 (inputs, outputs)
732}
733
734fn is_order_edge<H: HugrView>(hugr: &H, node: H::Node, port: Port) -> bool {
736 let op = hugr.get_optype(node);
737 op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port)
738}
739
740fn has_other_edge<H: HugrView>(hugr: &H, node: H::Node, dir: Direction) -> bool {
742 let op = hugr.get_optype(node);
743 op.other_port_kind(dir).is_some() && hugr.is_linked(node, op.other_port(dir).unwrap())
744}
745
746#[derive(Debug, Clone, PartialEq, Error)]
748#[non_exhaustive]
749pub enum InvalidReplacement {
750 #[error("The root of the replacement {node} is a {}, but only OpType::DFGs are supported.", op.name())]
752 InvalidDataflowGraph {
753 node: Node,
755 op: OpType,
757 },
758 #[error(
760 "Replacement graph type mismatch. Expected {expected}, got {}.",
761 actual.clone().map_or("none".to_string(), |t| t.to_string()))
762 ]
763 InvalidSignature {
764 expected: Signature,
766 actual: Option<Signature>,
768 },
769 #[error("SiblingSubgraph is not convex.")]
771 NonConvexSubgraph,
772}
773
774#[derive(Debug, Clone, PartialEq, Eq, Error)]
776#[non_exhaustive]
777pub enum InvalidSubgraph<N: HugrNode = Node> {
778 #[error("The subgraph is not convex.")]
780 NotConvex,
781 #[error(
783 "Not a sibling subgraph. {first_node} has parent {}, but {other_node} has parent {}.",
784 first_parent.map_or("None".to_string(), |n| n.to_string()),
785 other_parent.map_or("None".to_string(), |n| n.to_string())
786 )]
787 NoSharedParent {
788 first_node: N,
790 first_parent: Option<N>,
792 other_node: N,
794 other_parent: Option<N>,
796 },
797 #[error("Empty subgraphs are not supported.")]
799 EmptySubgraph,
800 #[error("Invalid boundary port.")]
802 InvalidBoundary(#[from] InvalidSubgraphBoundary<N>),
803}
804
805#[derive(Debug, Clone, PartialEq, Eq, Error)]
807#[non_exhaustive]
808pub enum InvalidSubgraphBoundary<N: HugrNode = Node> {
809 #[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
811 PortNodeNotInSet(N, Port),
812 #[error("(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph.")]
814 DisconnectedBoundaryPort(N, Port),
815 #[error("A port in the input boundary is used multiple times.")]
817 NonUniqueInput,
818 #[error("A partition in the input boundary is empty.")]
820 EmptyPartition,
821 #[error("The partition {0} in the input boundary has ports with different types.")]
823 MismatchedTypes(usize),
824}
825
826#[cfg(test)]
827mod tests {
828 use cool_asserts::assert_matches;
829
830 use crate::builder::inout_sig;
831 use crate::hugr::Rewrite;
832 use crate::ops::Const;
833 use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
834 use crate::std_extensions::logic::{self, LogicOp};
835 use crate::type_row;
836 use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64};
837 use crate::{
838 builder::{
839 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
840 ModuleBuilder,
841 },
842 extension::prelude::{bool_t, qb_t},
843 hugr::views::{HierarchyView, SiblingGraph},
844 ops::handle::{DfgID, FuncID, NodeHandle},
845 std_extensions::logic::test::and_op,
846 };
847
848 use super::*;
849
850 impl<N: HugrNode> SiblingSubgraph<N> {
851 fn from_sibling_graph(
860 sibling_graph: &impl HugrView<Node = N>,
861 ) -> Result<Self, InvalidSubgraph<N>> {
862 let root = sibling_graph.root();
863 let nodes = sibling_graph.children(root).collect_vec();
864 if nodes.is_empty() {
865 Err(InvalidSubgraph::EmptySubgraph)
866 } else {
867 Ok(Self {
868 nodes,
869 inputs: Vec::new(),
870 outputs: Vec::new(),
871 })
872 }
873 }
874 }
875
876 fn build_hugr() -> Result<(Hugr, Node), BuildError> {
879 let mut mod_builder = ModuleBuilder::new();
880 let func = mod_builder.declare(
881 "test",
882 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
883 .with_extension_delta(ExtensionSet::from_iter([
884 test_quantum_extension::EXTENSION_ID,
885 float_types::EXTENSION_ID,
886 ]))
887 .into(),
888 )?;
889 let func_id = {
890 let mut dfg = mod_builder.define_declaration(&func)?;
891 let [w0, w1, w2] = dfg.input_wires_arr();
892 let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
893 let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
894 let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
895 dfg.finish_with_outputs([w0, w1, w2])?
896 };
897 let hugr = mod_builder
898 .finish_hugr()
899 .map_err(|e| -> BuildError { e.into() })?;
900 Ok((hugr, func_id.node()))
901 }
902
903 fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
905 let mut mod_builder = ModuleBuilder::new();
906 let func = mod_builder.declare(
907 "test",
908 Signature::new_endo(vec![bool_t()])
909 .with_extension_delta(logic::EXTENSION_ID)
910 .into(),
911 )?;
912 let func_id = {
913 let mut dfg = mod_builder.define_declaration(&func)?;
914 let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?;
915 let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?;
916 let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?;
917 dfg.finish_with_outputs(outs3.outputs())?
918 };
919 let hugr = mod_builder
920 .finish_hugr()
921 .map_err(|e| -> BuildError { e.into() })?;
922 Ok((hugr, func_id.node()))
923 }
924
925 fn build_multiport_hugr() -> Result<(Hugr, Node), BuildError> {
927 let mut mod_builder = ModuleBuilder::new();
928 let func = mod_builder.declare(
929 "test",
930 Signature::new(bool_t(), vec![bool_t(), bool_t()])
931 .with_extension_delta(logic::EXTENSION_ID)
932 .into(),
933 )?;
934 let func_id = {
935 let mut dfg = mod_builder.define_declaration(&func)?;
936 let [b0] = dfg.input_wires_arr();
937 let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr();
938 let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr();
939 dfg.finish_with_outputs([b1, b2])?
940 };
941 let hugr = mod_builder
942 .finish_hugr()
943 .map_err(|e| -> BuildError { e.into() })?;
944 Ok((hugr, func_id.node()))
945 }
946
947 fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
949 let mut mod_builder = ModuleBuilder::new();
950 let func = mod_builder.declare(
951 "test",
952 Signature::new_endo(bool_t())
953 .with_extension_delta(logic::EXTENSION_ID)
954 .into(),
955 )?;
956 let func_id = {
957 let mut dfg = mod_builder.define_declaration(&func)?;
958 let in_wire = dfg.input_wires().exactly_one().unwrap();
959 let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?;
960 dfg.finish_with_outputs(outs.outputs())?
961 };
962 let hugr = mod_builder
963 .finish_hugr()
964 .map_err(|e| -> BuildError { e.into() })?;
965 Ok((hugr, func_id.node()))
966 }
967
968 #[test]
969 fn construct_subgraph() -> Result<(), InvalidSubgraph> {
970 let (hugr, func_root) = build_hugr().unwrap();
971 let sibling_graph: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
972 let from_root = SiblingSubgraph::from_sibling_graph(&sibling_graph)?;
973 let region: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
974 let from_region = SiblingSubgraph::from_sibling_graph(®ion)?;
975 assert_eq!(
976 from_root.get_parent(&sibling_graph),
977 from_region.get_parent(&sibling_graph)
978 );
979 assert_eq!(
980 from_root.signature(&sibling_graph),
981 from_region.signature(&sibling_graph)
982 );
983 Ok(())
984 }
985
986 #[test]
987 fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
988 let (mut hugr, func_root) = build_hugr().unwrap();
989 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root).unwrap();
990 let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
991
992 let empty_dfg = {
993 let builder =
994 DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap();
995 let inputs = builder.input_wires();
996 builder.finish_hugr_with_outputs(inputs).unwrap()
997 };
998
999 let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();
1000
1001 assert_eq!(rep.subgraph().nodes().len(), 4);
1002
1003 assert_eq!(hugr.node_count(), 8); hugr.apply_rewrite(rep).unwrap();
1005 assert_eq!(hugr.node_count(), 4); Ok(())
1008 }
1009
1010 #[test]
1011 fn test_signature() -> Result<(), InvalidSubgraph> {
1012 let (hugr, dfg) = build_hugr().unwrap();
1013 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, dfg).unwrap();
1014 let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
1015 assert_eq!(
1016 sub.signature(&func),
1017 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta(
1018 ExtensionSet::from_iter([
1019 test_quantum_extension::EXTENSION_ID,
1020 float_types::EXTENSION_ID,
1021 ])
1022 )
1023 );
1024 Ok(())
1025 }
1026
1027 #[test]
1028 fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> {
1029 let (hugr, dfg) = build_hugr().unwrap();
1030 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, dfg).unwrap();
1031 let sub = SiblingSubgraph::from_sibling_graph(&func)?;
1032
1033 let empty_dfg = {
1034 let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
1035 let inputs = builder.input_wires();
1036 builder.finish_hugr_with_outputs(inputs).unwrap()
1037 };
1038
1039 assert_matches!(
1040 sub.create_simple_replacement(&func, empty_dfg).unwrap_err(),
1041 InvalidReplacement::InvalidSignature { .. }
1042 );
1043 Ok(())
1044 }
1045
1046 #[test]
1047 fn convex_subgraph() {
1048 let (hugr, func_root) = build_hugr().unwrap();
1049 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1050 assert_eq!(
1051 SiblingSubgraph::try_new_dataflow_subgraph(&func)
1052 .unwrap()
1053 .nodes()
1054 .len(),
1055 4
1056 )
1057 }
1058
1059 #[test]
1060 fn convex_subgraph_2() {
1061 let (hugr, func_root) = build_hugr().unwrap();
1062 let [inp, out] = hugr.get_io(func_root).unwrap();
1063 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1064 SiblingSubgraph::try_new(
1066 hugr.node_outputs(inp)
1067 .take(2)
1068 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
1069 .filter(|ps| !ps.is_empty())
1070 .collect(),
1071 hugr.node_inputs(out)
1072 .take(2)
1073 .filter_map(|p| hugr.single_linked_output(out, p))
1074 .collect(),
1075 &func,
1076 )
1077 .unwrap();
1078 }
1079
1080 #[test]
1081 fn degen_boundary() {
1082 let (hugr, func_root) = build_hugr().unwrap();
1083 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1084 let [inp, _] = hugr.get_io(func_root).unwrap();
1085 let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
1086 assert_matches!(
1088 SiblingSubgraph::try_new(
1089 vec![hugr
1090 .linked_ports(inp, first_cx_edge)
1091 .map(|(n, p)| (n, p.as_incoming().unwrap()))
1092 .collect()],
1093 vec![(inp, first_cx_edge)],
1094 &func,
1095 ),
1096 Err(InvalidSubgraph::InvalidBoundary(
1097 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1098 ))
1099 );
1100 }
1101
1102 #[test]
1103 fn non_convex_subgraph() {
1104 let (hugr, func_root) = build_3not_hugr().unwrap();
1105 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1106 let [inp, _out] = hugr.get_io(func_root).unwrap();
1107 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1108 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
1109 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
1110 let not1_inp = hugr.node_inputs(not1).next().unwrap();
1111 let not1_out = hugr.node_outputs(not1).next().unwrap();
1112 let not3_inp = hugr.node_inputs(not3).next().unwrap();
1113 let not3_out = hugr.node_outputs(not3).next().unwrap();
1114 assert_matches!(
1115 SiblingSubgraph::try_new(
1116 vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
1117 vec![(not1, not1_out), (not3, not3_out)],
1118 &func
1119 ),
1120 Err(InvalidSubgraph::NotConvex)
1121 );
1122 }
1123
1124 #[test]
1127 fn convex_multiports() {
1128 let (hugr, func_root) = build_multiport_hugr().unwrap();
1129 let [inp, out] = hugr.get_io(func_root).unwrap();
1130 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1131 let not2 = hugr
1132 .output_neighbours(not1)
1133 .filter(|&n| n != out)
1134 .exactly_one()
1135 .ok()
1136 .unwrap();
1137
1138 let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
1139 assert_eq!(subgraph.nodes(), [not1, not2]);
1140 }
1141
1142 #[test]
1143 fn invalid_boundary() {
1144 let (hugr, func_root) = build_hugr().unwrap();
1145 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1146 let [inp, out] = hugr.get_io(func_root).unwrap();
1147 let cx_edges_in = hugr.node_outputs(inp);
1148 let cx_edges_out = hugr.node_inputs(out);
1149 assert_matches!(
1151 SiblingSubgraph::try_new(
1152 cx_edges_out.map(|p| vec![(out, p)]).collect(),
1153 cx_edges_in.map(|p| (inp, p)).collect(),
1154 &func,
1155 ),
1156 Err(InvalidSubgraph::InvalidBoundary(
1157 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1158 ))
1159 );
1160 }
1161
1162 #[test]
1163 fn preserve_signature() {
1164 let (hugr, func_root) = build_hugr_classical().unwrap();
1165 let func_graph: SiblingGraph<'_, FuncID<true>> =
1166 SiblingGraph::try_new(&hugr, func_root).unwrap();
1167 let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
1168 let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
1169 assert_eq!(func_defn.signature, func.signature(&func_graph).into());
1170 }
1171
1172 #[test]
1173 fn extract_subgraph() {
1174 let (hugr, func_root) = build_hugr().unwrap();
1175 let func_graph: SiblingGraph<'_, FuncID<true>> =
1176 SiblingGraph::try_new(&hugr, func_root).unwrap();
1177 let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
1178 let extracted = subgraph.extract_subgraph(&hugr, "region");
1179
1180 extracted.validate().unwrap();
1181 }
1182
1183 #[test]
1184 fn edge_both_output_and_copy() {
1185 let one_bit = vec![bool_t()];
1187 let two_bit = vec![bool_t(), bool_t()];
1188
1189 let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap();
1190 let inw = builder.input_wires().exactly_one().unwrap();
1191 let outw1 = builder
1192 .add_dataflow_op(LogicOp::Not, [inw])
1193 .unwrap()
1194 .out_wire(0);
1195 let outw2 = builder
1196 .add_dataflow_op(and_op(), [inw, outw1])
1197 .unwrap()
1198 .outputs();
1199 let outw = [outw1].into_iter().chain(outw2);
1200 let h = builder.finish_hugr_with_outputs(outw).unwrap();
1201 let view = SiblingGraph::<DfgID>::try_new(&h, h.root()).unwrap();
1202 let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap();
1203 assert_eq!(subg.nodes().len(), 2);
1204 }
1205
1206 #[test]
1207 fn test_unconnected() {
1208 let mut b = DFGBuilder::new(
1210 Signature::new(bool_t(), type_row![])
1211 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1213 )
1214 .unwrap();
1215 let inw = b.input_wires().exactly_one().unwrap();
1216 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1217 let mut h = b.finish_hugr_with_outputs([]).unwrap();
1219
1220 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1221
1222 assert_eq!(subg.nodes().len(), 1);
1223 let replacement = {
1225 let mut rep_b = DFGBuilder::new(
1226 Signature::new_endo(bool_t())
1227 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1228 )
1229 .unwrap();
1230 let inw = rep_b.input_wires().exactly_one().unwrap();
1231
1232 let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1233
1234 rep_b.finish_hugr_with_outputs(not_n.outputs()).unwrap()
1235 };
1236 let rep = subg.create_simple_replacement(&h, replacement).unwrap();
1237 rep.apply(&mut h).unwrap();
1238 }
1239
1240 #[test]
1242 fn single_node_subgraph() {
1243 let mut b = DFGBuilder::new(
1245 Signature::new(bool_t(), type_row![])
1246 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1247 )
1248 .unwrap();
1249 let inw = b.input_wires().exactly_one().unwrap();
1250 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1251 let h = b.finish_hugr_with_outputs([]).unwrap();
1253
1254 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1257 assert_eq!(subg.nodes().len(), 1);
1258 assert_eq!(
1259 subg.signature(&h).io(),
1260 Signature::new(vec![bool_t()], vec![bool_t()]).io()
1261 );
1262
1263 let subg = SiblingSubgraph::try_from_nodes([not_n.node()], &h).unwrap();
1266 assert_eq!(subg.nodes().len(), 1);
1267 assert_eq!(
1268 subg.signature(&h).io(),
1269 Signature::new(vec![bool_t()], vec![]).io()
1270 );
1271 }
1272}