1use std::cell::OnceCell;
13use std::collections::HashSet;
14use std::mem;
15
16use itertools::Itertools;
17use portgraph::algorithms::ConvexChecker;
18use portgraph::boundary::Boundary;
19use portgraph::{view::Subgraph, Direction, PortView};
20use thiserror::Error;
21
22use crate::builder::{Container, FunctionBuilder};
23use crate::extension::ExtensionSet;
24use crate::hugr::{HugrMut, HugrView, RootTagged};
25use crate::ops::dataflow::DataflowOpTrait;
26use crate::ops::handle::{ContainerHandle, DataflowOpID};
27use crate::ops::{NamedOp, OpTag, OpTrait, OpType};
28use crate::types::{Signature, Type};
29use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement};
30
31#[derive(Clone, Debug)]
58pub struct SiblingSubgraph {
59 nodes: Vec<Node>,
61 inputs: Vec<Vec<(Node, IncomingPort)>>,
66 outputs: Vec<(Node, OutgoingPort)>,
71}
72
73pub type IncomingPorts = Vec<Vec<(Node, IncomingPort)>>;
80pub type OutgoingPorts = Vec<(Node, OutgoingPort)>;
82
83impl SiblingSubgraph {
84 pub fn try_new_dataflow_subgraph<H, Root>(dfg_graph: &H) -> Result<Self, InvalidSubgraph>
98 where
99 H: Clone + RootTagged<RootHandle = Root>,
100 Root: ContainerHandle<ChildrenHandle = DataflowOpID>,
101 {
102 let parent = dfg_graph.root();
103 let nodes = dfg_graph.children(parent).skip(2).collect_vec();
104 let (inputs, outputs) = get_input_output_ports(dfg_graph);
105
106 validate_subgraph(dfg_graph, &nodes, &inputs, &outputs)?;
107
108 if nodes.is_empty() {
109 Err(InvalidSubgraph::EmptySubgraph)
110 } else {
111 Ok(Self {
112 nodes,
113 inputs,
114 outputs,
115 })
116 }
117 }
118
119 pub fn try_new(
159 incoming: IncomingPorts,
160 outgoing: OutgoingPorts,
161 hugr: &impl HugrView,
162 ) -> Result<Self, InvalidSubgraph> {
163 let checker = TopoConvexChecker::new(hugr);
164 Self::try_new_with_checker(incoming, outgoing, hugr, &checker)
165 }
166
167 pub fn try_new_with_checker(
176 inputs: IncomingPorts,
177 outputs: OutgoingPorts,
178 hugr: &impl HugrView,
179 checker: &impl ConvexChecker,
180 ) -> Result<Self, InvalidSubgraph> {
181 let pg = hugr.portgraph();
182
183 let subpg = Subgraph::new_subgraph(pg.clone(), make_boundary(hugr, &inputs, &outputs));
185 let nodes = subpg.nodes_iter().map_into().collect_vec();
186 validate_subgraph(hugr, &nodes, &inputs, &outputs)?;
187
188 if !subpg.is_convex_with_checker(checker) {
189 return Err(InvalidSubgraph::NotConvex);
190 }
191
192 Ok(Self {
193 nodes,
194 inputs,
195 outputs,
196 })
197 }
198
199 pub fn try_from_nodes(
215 nodes: impl Into<Vec<Node>>,
216 hugr: &impl HugrView,
217 ) -> Result<Self, InvalidSubgraph> {
218 let checker = TopoConvexChecker::new(hugr);
219 Self::try_from_nodes_with_checker(nodes, hugr, &checker)
220 }
221
222 pub fn try_from_nodes_with_checker<'c, 'h: 'c, H: HugrView>(
231 nodes: impl Into<Vec<Node>>,
232 hugr: &'h H,
233 checker: &impl ConvexChecker,
234 ) -> Result<Self, InvalidSubgraph> {
235 let nodes = nodes.into();
236
237 match nodes.as_slice() {
239 [] => return Err(InvalidSubgraph::EmptySubgraph),
240 [node] => return Ok(Self::from_node(*node, hugr)),
241 _ => {}
242 };
243
244 let nodes_set = nodes.iter().copied().collect::<HashSet<_>>();
245 let incoming_edges = nodes
246 .iter()
247 .flat_map(|&n| hugr.node_inputs(n).map(move |p| (n, p)));
248 let outgoing_edges = nodes
249 .iter()
250 .flat_map(|&n| hugr.node_outputs(n).map(move |p| (n, p)));
251 let inputs = incoming_edges
252 .filter(|&(n, p)| {
253 if !hugr.is_linked(n, p) {
254 return false;
255 }
256 let (out_n, _) = hugr.single_linked_output(n, p).unwrap();
257 !nodes_set.contains(&out_n)
258 })
259 .map(|p| vec![p])
261 .collect_vec();
262 let outputs = outgoing_edges
263 .filter(|&(n, p)| {
264 hugr.linked_ports(n, p)
265 .any(|(n1, _)| !nodes_set.contains(&n1))
266 })
267 .collect_vec();
268 Self::try_new_with_checker(inputs, outputs, hugr, checker)
269 }
270
271 pub fn from_node(node: Node, hugr: &impl HugrView) -> Self {
275 let nodes = vec![node];
276 let inputs = hugr
277 .node_inputs(node)
278 .filter(|&p| hugr.is_linked(node, p))
279 .map(|p| vec![(node, p)])
280 .collect_vec();
281 let outputs = hugr
282 .node_outputs(node)
283 .filter_map(|p| {
284 {
286 hugr.is_linked(node, p)
287 || hugr
288 .get_optype(node)
289 .port_kind(p)
290 .is_some_and(|k| k.is_value())
291 }
292 .then_some((node, p))
293 })
294 .collect_vec();
295
296 Self {
297 nodes,
298 inputs,
299 outputs,
300 }
301 }
302
303 pub fn nodes(&self) -> &[Node] {
305 &self.nodes
306 }
307
308 pub fn node_count(&self) -> usize {
310 self.nodes.len()
311 }
312
313 pub fn incoming_ports(&self) -> &IncomingPorts {
315 &self.inputs
316 }
317
318 pub fn outgoing_ports(&self) -> &OutgoingPorts {
320 &self.outputs
321 }
322
323 pub fn signature(&self, hugr: &impl HugrView) -> Signature {
325 let input = self
326 .inputs
327 .iter()
328 .map(|part| {
329 let &(n, p) = part.iter().next().expect("is non-empty");
330 let sig = hugr.signature(n).expect("must have dataflow signature");
331 sig.port_type(p).cloned().expect("must be dataflow edge")
332 })
333 .collect_vec();
334 let output = self
335 .outputs
336 .iter()
337 .map(|&(n, p)| {
338 let sig = hugr.signature(n).expect("must have dataflow signature");
339 sig.port_type(p).cloned().expect("must be dataflow edge")
340 })
341 .collect_vec();
342 Signature::new(input, output).with_extension_delta(ExtensionSet::union_over(
343 self.nodes
344 .iter()
345 .map(|n| hugr.get_optype(*n).extension_delta()),
346 ))
347 }
348
349 pub fn get_parent(&self, hugr: &impl HugrView) -> Node {
351 hugr.get_parent(self.nodes[0]).expect("invalid subgraph")
352 }
353
354 pub fn create_simple_replacement(
370 &self,
371 hugr: &impl HugrView,
372 replacement: Hugr,
373 ) -> Result<SimpleReplacement, InvalidReplacement> {
374 let rep_root = replacement.root();
375 let dfg_optype = replacement.get_optype(rep_root);
376 if !OpTag::Dfg.is_superset(dfg_optype.tag()) {
377 return Err(InvalidReplacement::InvalidDataflowGraph {
378 node: rep_root,
379 op: dfg_optype.clone(),
380 });
381 }
382 let [rep_input, rep_output] = replacement
383 .get_io(rep_root)
384 .expect("DFG root in the replacement does not have input and output nodes.");
385
386 let current_signature = self.signature(hugr);
387 let new_signature = dfg_optype.dataflow_signature();
388 if new_signature.as_ref().map(|s| &s.input) != Some(¤t_signature.input)
389 || new_signature.as_ref().map(|s| &s.output) != Some(¤t_signature.output)
390 {
391 return Err(InvalidReplacement::InvalidSignature {
392 expected: self.signature(hugr),
393 actual: dfg_optype.dataflow_signature().map(|s| s.into_owned()),
394 });
395 }
396
397 let rep_inputs = replacement.node_outputs(rep_input).map(|p| (rep_input, p));
400 let rep_outputs = replacement.node_inputs(rep_output).map(|p| (rep_output, p));
401 let (rep_inputs, in_order_ports): (Vec<_>, Vec<_>) = rep_inputs.partition(|&(n, p)| {
402 replacement
403 .signature(n)
404 .is_some_and(|s| s.port_type(p).is_some())
405 });
406 let (rep_outputs, out_order_ports): (Vec<_>, Vec<_>) = rep_outputs.partition(|&(n, p)| {
407 replacement
408 .signature(n)
409 .is_some_and(|s| s.port_type(p).is_some())
410 });
411
412 if iter_io(&vec![out_order_ports], &in_order_ports)
413 .any(|(n, p)| is_order_edge(&replacement, n, p))
414 {
415 unimplemented!("Found state order edges in replacement graph");
416 }
417
418 let nu_inp = rep_inputs
419 .into_iter()
420 .zip_eq(&self.inputs)
421 .flat_map(|((rep_source_n, rep_source_p), self_targets)| {
422 replacement
423 .linked_inputs(rep_source_n, rep_source_p)
424 .flat_map(move |rep_target| {
425 self_targets
426 .iter()
427 .map(move |&self_target| (rep_target, self_target))
428 })
429 })
430 .collect();
431 let nu_out = self
432 .outputs
433 .iter()
434 .zip_eq(rep_outputs)
435 .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| {
436 hugr.linked_inputs(self_source_n, self_source_p)
437 .map(move |self_target| (self_target, rep_target_p))
438 })
439 .collect();
440
441 Ok(SimpleReplacement::new(
442 self.clone(),
443 replacement,
444 nu_inp,
445 nu_out,
446 ))
447 }
448
449 pub fn extract_subgraph(&self, hugr: &impl HugrView, name: impl Into<String>) -> Hugr {
454 let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap();
455 let mut extracted = mem::take(builder.hugr_mut());
458 let node_map = extracted.insert_subgraph(extracted.root(), hugr, self);
459
460 let [inp, out] = extracted.get_io(extracted.root()).unwrap();
462 let inputs = extracted.node_outputs(inp).zip(self.inputs.iter());
463 let outputs = extracted.node_inputs(out).zip(self.outputs.iter());
464 let mut connections = Vec::with_capacity(inputs.size_hint().0 + outputs.size_hint().0);
465
466 for (inp_port, repl_ports) in inputs {
467 for (repl_node, repl_port) in repl_ports {
468 connections.push((inp, inp_port, node_map[repl_node], *repl_port));
469 }
470 }
471 for (out_port, (repl_node, repl_port)) in outputs {
472 connections.push((node_map[repl_node], *repl_port, out, out_port));
473 }
474
475 for (src, src_port, dst, dst_port) in connections {
476 extracted.connect(src, src_port, dst, dst_port);
477 }
478
479 extracted
480 }
481}
482
483fn iter_incoming(inputs: &IncomingPorts) -> impl Iterator<Item = (Node, IncomingPort)> + '_ {
485 inputs.iter().flat_map(|part| part.iter().copied())
486}
487
488fn iter_outgoing(outputs: &OutgoingPorts) -> impl Iterator<Item = (Node, OutgoingPort)> + '_ {
490 outputs.iter().copied()
491}
492
493fn iter_io<'a>(
495 inputs: &'a IncomingPorts,
496 outputs: &'a OutgoingPorts,
497) -> impl Iterator<Item = (Node, Port)> + 'a {
498 iter_incoming(inputs)
499 .map(|(n, p)| (n, Port::from(p)))
500 .chain(iter_outgoing(outputs).map(|(n, p)| (n, Port::from(p))))
501}
502
503fn make_boundary<'a>(
504 hugr: &impl HugrView,
505 inputs: &'a IncomingPorts,
506 outputs: &'a OutgoingPorts,
507) -> Boundary {
508 let to_pg_index = |n: Node, p: Port| {
509 hugr.portgraph()
510 .port_index(n.pg_index(), p.pg_offset())
511 .unwrap()
512 };
513 Boundary::new(
514 iter_incoming(inputs).map(|(n, p)| to_pg_index(n, p.into())),
515 iter_outgoing(outputs).map(|(n, p)| to_pg_index(n, p.into())),
516 )
517}
518
519pub struct TopoConvexChecker<'g, Base: 'g + HugrView> {
524 base: &'g Base,
525 checker: OnceCell<portgraph::algorithms::TopoConvexChecker<Base::Portgraph<'g>>>,
526}
527
528impl<'g, Base: HugrView> TopoConvexChecker<'g, Base> {
529 pub fn new(base: &'g Base) -> Self {
531 Self {
532 base,
533 checker: OnceCell::new(),
534 }
535 }
536
537 fn get_checker(&self) -> &portgraph::algorithms::TopoConvexChecker<Base::Portgraph<'g>> {
539 self.checker
540 .get_or_init(|| portgraph::algorithms::TopoConvexChecker::new(self.base.portgraph()))
541 }
542}
543
544impl<Base: HugrView> ConvexChecker for TopoConvexChecker<'_, Base> {
545 fn is_convex(
546 &self,
547 nodes: impl IntoIterator<Item = portgraph::NodeIndex>,
548 inputs: impl IntoIterator<Item = portgraph::PortIndex>,
549 outputs: impl IntoIterator<Item = portgraph::PortIndex>,
550 ) -> bool {
551 let mut nodes = nodes.into_iter().multipeek();
552 if nodes.peek().is_none() || nodes.peek().is_none() {
555 return true;
556 };
557 self.get_checker().is_convex(nodes, inputs, outputs)
558 }
559}
560
561fn get_edge_type<H: HugrView, P: Into<Port> + Copy>(hugr: &H, ports: &[(Node, P)]) -> Option<Type> {
565 let &(n, p) = ports.first()?;
566 let edge_t = hugr.signature(n)?.port_type(p)?.clone();
567 ports
568 .iter()
569 .all(|&(n, p)| {
570 hugr.signature(n)
571 .is_some_and(|s| s.port_type(p) == Some(&edge_t))
572 })
573 .then_some(edge_t)
574}
575
576fn validate_subgraph<H: HugrView>(
583 hugr: &H,
584 nodes: &[Node],
585 inputs: &IncomingPorts,
586 outputs: &OutgoingPorts,
587) -> Result<(), InvalidSubgraph> {
588 let node_set = nodes.iter().copied().collect::<HashSet<_>>();
590
591 if nodes.is_empty() {
593 return Err(InvalidSubgraph::EmptySubgraph);
594 }
595 if !nodes.iter().map(|&n| hugr.get_parent(n)).all_equal() {
597 let first_node = nodes[0];
598 let first_parent = hugr.get_parent(first_node);
599 let other_node = *nodes
600 .iter()
601 .skip(1)
602 .find(|&&n| hugr.get_parent(n) != first_parent)
603 .unwrap();
604 let other_parent = hugr.get_parent(other_node);
605 return Err(InvalidSubgraph::NoSharedParent {
606 first_node,
607 first_parent,
608 other_node,
609 other_parent,
610 });
611 }
612
613 if iter_io(inputs, outputs).any(|(n, p)| is_order_edge(hugr, n, p)) {
615 unimplemented!("Connected order edges not supported at the boundary")
616 }
617
618 let boundary_ports = iter_io(inputs, outputs).collect_vec();
619 if let Some(&(n, p)) = boundary_ports.iter().find(|(n, _)| !node_set.contains(n)) {
621 Err(InvalidSubgraphBoundary::PortNodeNotInSet(n, p))?;
622 };
623 if let Some(&(n, p)) = boundary_ports.iter().find(|&&(n, p)| {
625 hugr.linked_ports(n, p)
626 .all(|(n1, _)| node_set.contains(&n1))
627 }) {
628 Err(InvalidSubgraphBoundary::DisconnectedBoundaryPort(n, p))?;
629 };
630
631 if nodes.iter().any(|&n| {
634 hugr.node_inputs(n).any(|p| {
635 hugr.linked_ports(n, p).any(|(n1, _)| {
636 !node_set.contains(&n1) && !inputs.iter().any(|nps| nps.contains(&(n, p)))
637 })
638 })
639 }) {
640 return Err(InvalidSubgraph::NotConvex);
641 }
642 if nodes.iter().any(|&n| {
645 hugr.node_outputs(n).any(|p| {
646 hugr.linked_ports(n, p)
647 .any(|(n1, _)| !node_set.contains(&n1) && !outputs.contains(&(n, p)))
648 })
649 }) {
650 return Err(InvalidSubgraph::NotConvex);
651 }
652
653 if !inputs.iter().flatten().all_unique() {
655 return Err(InvalidSubgraphBoundary::NonUniqueInput.into());
656 }
657
658 if inputs.iter().any(|p| p.is_empty()) {
660 return Err(InvalidSubgraphBoundary::EmptyPartition.into());
661 }
662
663 if let Some((i, _)) = inputs.iter().enumerate().find(|(_, ports)| {
665 let Some(edge_t) = get_edge_type(hugr, ports) else {
666 return true;
667 };
668 let require_copy = ports.len() > 1;
669 require_copy && !edge_t.copyable()
670 }) {
671 Err(InvalidSubgraphBoundary::MismatchedTypes(i))?;
672 };
673
674 Ok(())
675}
676
677fn get_input_output_ports<H: HugrView>(hugr: &H) -> (IncomingPorts, OutgoingPorts) {
678 let [inp, out] = hugr.get_io(hugr.root()).expect("invalid DFG");
679 if has_other_edge(hugr, inp, Direction::Outgoing) {
680 unimplemented!("Non-dataflow output not supported at input node")
681 }
682 let dfg_inputs = hugr
683 .get_optype(inp)
684 .as_input()
685 .unwrap()
686 .signature()
687 .output_ports();
688 if has_other_edge(hugr, out, Direction::Incoming) {
689 unimplemented!("Non-dataflow input not supported at output node")
690 }
691 let dfg_outputs = hugr
692 .get_optype(out)
693 .as_output()
694 .unwrap()
695 .signature()
696 .input_ports();
697
698 let inputs = dfg_inputs
701 .into_iter()
702 .map(|p| {
703 hugr.linked_inputs(inp, p)
704 .filter(|&(n, _)| n != out)
705 .collect_vec()
706 })
707 .filter(|v| !v.is_empty())
708 .collect();
709 let outputs = dfg_outputs
712 .into_iter()
713 .filter_map(|p| hugr.linked_outputs(out, p).find(|&(n, _)| n != inp))
714 .collect();
715 (inputs, outputs)
716}
717
718fn is_order_edge<H: HugrView>(hugr: &H, node: Node, port: Port) -> bool {
720 let op = hugr.get_optype(node);
721 op.other_port(port.direction()) == Some(port) && hugr.is_linked(node, port)
722}
723
724fn has_other_edge<H: HugrView>(hugr: &H, node: Node, dir: Direction) -> bool {
726 let op = hugr.get_optype(node);
727 op.other_port_kind(dir).is_some() && hugr.is_linked(node, op.other_port(dir).unwrap())
728}
729
730#[derive(Debug, Clone, PartialEq, Error)]
732#[non_exhaustive]
733pub enum InvalidReplacement {
734 #[error("The root of the replacement {node} is a {}, but only OpType::DFGs are supported.", op.name())]
736 InvalidDataflowGraph {
737 node: Node,
739 op: OpType,
741 },
742 #[error(
744 "Replacement graph type mismatch. Expected {expected}, got {}.",
745 actual.clone().map_or("none".to_string(), |t| t.to_string()))
746 ]
747 InvalidSignature {
748 expected: Signature,
750 actual: Option<Signature>,
752 },
753 #[error("SiblingSubgraph is not convex.")]
755 NonConvexSubgraph,
756}
757
758#[derive(Debug, Clone, PartialEq, Eq, Error)]
760#[non_exhaustive]
761pub enum InvalidSubgraph {
762 #[error("The subgraph is not convex.")]
764 NotConvex,
765 #[error(
767 "Not a sibling subgraph. {first_node} has parent {}, but {other_node} has parent {}.",
768 first_parent.map_or("None".to_string(), |n| n.to_string()),
769 other_parent.map_or("None".to_string(), |n| n.to_string())
770 )]
771 NoSharedParent {
772 first_node: Node,
774 first_parent: Option<Node>,
776 other_node: Node,
778 other_parent: Option<Node>,
780 },
781 #[error("Empty subgraphs are not supported.")]
783 EmptySubgraph,
784 #[error("Invalid boundary port.")]
786 InvalidBoundary(#[from] InvalidSubgraphBoundary),
787}
788
789#[derive(Debug, Clone, PartialEq, Eq, Error)]
791#[non_exhaustive]
792pub enum InvalidSubgraphBoundary {
793 #[error("(node {0}, port {1}) is in the boundary, but node {0} is not in the set.")]
795 PortNodeNotInSet(Node, Port),
796 #[error("(node {0}, port {1}) is in the boundary, but the port is not connected to a node outside the subgraph.")]
798 DisconnectedBoundaryPort(Node, Port),
799 #[error("A port in the input boundary is used multiple times.")]
801 NonUniqueInput,
802 #[error("A partition in the input boundary is empty.")]
804 EmptyPartition,
805 #[error("The partition {0} in the input boundary has ports with different types.")]
807 MismatchedTypes(usize),
808}
809
810#[cfg(test)]
811mod tests {
812 use cool_asserts::assert_matches;
813
814 use crate::builder::inout_sig;
815 use crate::hugr::Rewrite;
816 use crate::ops::Const;
817 use crate::std_extensions::arithmetic::float_types::{self, ConstF64};
818 use crate::std_extensions::logic::{self, LogicOp};
819 use crate::type_row;
820 use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64};
821 use crate::{
822 builder::{
823 BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
824 ModuleBuilder,
825 },
826 extension::prelude::{bool_t, qb_t},
827 hugr::views::{HierarchyView, SiblingGraph},
828 ops::handle::{DfgID, FuncID, NodeHandle},
829 std_extensions::logic::test::and_op,
830 };
831
832 use super::*;
833
834 impl SiblingSubgraph {
835 fn from_sibling_graph(sibling_graph: &impl HugrView) -> Result<Self, InvalidSubgraph> {
844 let root = sibling_graph.root();
845 let nodes = sibling_graph.children(root).collect_vec();
846 if nodes.is_empty() {
847 Err(InvalidSubgraph::EmptySubgraph)
848 } else {
849 Ok(Self {
850 nodes,
851 inputs: Vec::new(),
852 outputs: Vec::new(),
853 })
854 }
855 }
856 }
857
858 fn build_hugr() -> Result<(Hugr, Node), BuildError> {
861 let mut mod_builder = ModuleBuilder::new();
862 let func = mod_builder.declare(
863 "test",
864 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])
865 .with_extension_delta(ExtensionSet::from_iter([
866 test_quantum_extension::EXTENSION_ID,
867 float_types::EXTENSION_ID,
868 ]))
869 .into(),
870 )?;
871 let func_id = {
872 let mut dfg = mod_builder.define_declaration(&func)?;
873 let [w0, w1, w2] = dfg.input_wires_arr();
874 let [w0, w1] = dfg.add_dataflow_op(cx_gate(), [w0, w1])?.outputs_arr();
875 let c = dfg.add_load_const(Const::new(ConstF64::new(0.5).into()));
876 let [w2] = dfg.add_dataflow_op(rz_f64(), [w2, c])?.outputs_arr();
877 dfg.finish_with_outputs([w0, w1, w2])?
878 };
879 let hugr = mod_builder
880 .finish_hugr()
881 .map_err(|e| -> BuildError { e.into() })?;
882 Ok((hugr, func_id.node()))
883 }
884
885 fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> {
887 let mut mod_builder = ModuleBuilder::new();
888 let func = mod_builder.declare(
889 "test",
890 Signature::new_endo(vec![bool_t()])
891 .with_extension_delta(logic::EXTENSION_ID)
892 .into(),
893 )?;
894 let func_id = {
895 let mut dfg = mod_builder.define_declaration(&func)?;
896 let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?;
897 let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?;
898 let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?;
899 dfg.finish_with_outputs(outs3.outputs())?
900 };
901 let hugr = mod_builder
902 .finish_hugr()
903 .map_err(|e| -> BuildError { e.into() })?;
904 Ok((hugr, func_id.node()))
905 }
906
907 fn build_multiport_hugr() -> Result<(Hugr, Node), BuildError> {
909 let mut mod_builder = ModuleBuilder::new();
910 let func = mod_builder.declare(
911 "test",
912 Signature::new(bool_t(), vec![bool_t(), bool_t()])
913 .with_extension_delta(logic::EXTENSION_ID)
914 .into(),
915 )?;
916 let func_id = {
917 let mut dfg = mod_builder.define_declaration(&func)?;
918 let [b0] = dfg.input_wires_arr();
919 let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr();
920 let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr();
921 dfg.finish_with_outputs([b1, b2])?
922 };
923 let hugr = mod_builder
924 .finish_hugr()
925 .map_err(|e| -> BuildError { e.into() })?;
926 Ok((hugr, func_id.node()))
927 }
928
929 fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> {
931 let mut mod_builder = ModuleBuilder::new();
932 let func = mod_builder.declare(
933 "test",
934 Signature::new_endo(bool_t())
935 .with_extension_delta(logic::EXTENSION_ID)
936 .into(),
937 )?;
938 let func_id = {
939 let mut dfg = mod_builder.define_declaration(&func)?;
940 let in_wire = dfg.input_wires().exactly_one().unwrap();
941 let outs = dfg.add_dataflow_op(and_op(), [in_wire, in_wire])?;
942 dfg.finish_with_outputs(outs.outputs())?
943 };
944 let hugr = mod_builder
945 .finish_hugr()
946 .map_err(|e| -> BuildError { e.into() })?;
947 Ok((hugr, func_id.node()))
948 }
949
950 #[test]
951 fn construct_subgraph() -> Result<(), InvalidSubgraph> {
952 let (hugr, func_root) = build_hugr().unwrap();
953 let sibling_graph: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
954 let from_root = SiblingSubgraph::from_sibling_graph(&sibling_graph)?;
955 let region: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
956 let from_region = SiblingSubgraph::from_sibling_graph(®ion)?;
957 assert_eq!(
958 from_root.get_parent(&sibling_graph),
959 from_region.get_parent(&sibling_graph)
960 );
961 assert_eq!(
962 from_root.signature(&sibling_graph),
963 from_region.signature(&sibling_graph)
964 );
965 Ok(())
966 }
967
968 #[test]
969 fn construct_simple_replacement() -> Result<(), InvalidSubgraph> {
970 let (mut hugr, func_root) = build_hugr().unwrap();
971 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root).unwrap();
972 let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
973
974 let empty_dfg = {
975 let builder =
976 DFGBuilder::new(Signature::new_endo(vec![qb_t(), qb_t(), qb_t()])).unwrap();
977 let inputs = builder.input_wires();
978 builder.finish_hugr_with_outputs(inputs).unwrap()
979 };
980
981 let rep = sub.create_simple_replacement(&func, empty_dfg).unwrap();
982
983 assert_eq!(rep.subgraph().nodes().len(), 4);
984
985 assert_eq!(hugr.node_count(), 8); hugr.apply_rewrite(rep).unwrap();
987 assert_eq!(hugr.node_count(), 4); Ok(())
990 }
991
992 #[test]
993 fn test_signature() -> Result<(), InvalidSubgraph> {
994 let (hugr, dfg) = build_hugr().unwrap();
995 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, dfg).unwrap();
996 let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?;
997 assert_eq!(
998 sub.signature(&func),
999 Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta(
1000 ExtensionSet::from_iter([
1001 test_quantum_extension::EXTENSION_ID,
1002 float_types::EXTENSION_ID,
1003 ])
1004 )
1005 );
1006 Ok(())
1007 }
1008
1009 #[test]
1010 fn construct_simple_replacement_invalid_signature() -> Result<(), InvalidSubgraph> {
1011 let (hugr, dfg) = build_hugr().unwrap();
1012 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, dfg).unwrap();
1013 let sub = SiblingSubgraph::from_sibling_graph(&func)?;
1014
1015 let empty_dfg = {
1016 let builder = DFGBuilder::new(Signature::new_endo(vec![qb_t()])).unwrap();
1017 let inputs = builder.input_wires();
1018 builder.finish_hugr_with_outputs(inputs).unwrap()
1019 };
1020
1021 assert_matches!(
1022 sub.create_simple_replacement(&func, empty_dfg).unwrap_err(),
1023 InvalidReplacement::InvalidSignature { .. }
1024 );
1025 Ok(())
1026 }
1027
1028 #[test]
1029 fn convex_subgraph() {
1030 let (hugr, func_root) = build_hugr().unwrap();
1031 let func: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1032 assert_eq!(
1033 SiblingSubgraph::try_new_dataflow_subgraph(&func)
1034 .unwrap()
1035 .nodes()
1036 .len(),
1037 4
1038 )
1039 }
1040
1041 #[test]
1042 fn convex_subgraph_2() {
1043 let (hugr, func_root) = build_hugr().unwrap();
1044 let [inp, out] = hugr.get_io(func_root).unwrap();
1045 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1046 SiblingSubgraph::try_new(
1048 hugr.node_outputs(inp)
1049 .take(2)
1050 .map(|p| hugr.linked_inputs(inp, p).collect_vec())
1051 .filter(|ps| !ps.is_empty())
1052 .collect(),
1053 hugr.node_inputs(out)
1054 .take(2)
1055 .filter_map(|p| hugr.single_linked_output(out, p))
1056 .collect(),
1057 &func,
1058 )
1059 .unwrap();
1060 }
1061
1062 #[test]
1063 fn degen_boundary() {
1064 let (hugr, func_root) = build_hugr().unwrap();
1065 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1066 let [inp, _] = hugr.get_io(func_root).unwrap();
1067 let first_cx_edge = hugr.node_outputs(inp).next().unwrap();
1068 assert_matches!(
1070 SiblingSubgraph::try_new(
1071 vec![hugr
1072 .linked_ports(inp, first_cx_edge)
1073 .map(|(n, p)| (n, p.as_incoming().unwrap()))
1074 .collect()],
1075 vec![(inp, first_cx_edge)],
1076 &func,
1077 ),
1078 Err(InvalidSubgraph::InvalidBoundary(
1079 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1080 ))
1081 );
1082 }
1083
1084 #[test]
1085 fn non_convex_subgraph() {
1086 let (hugr, func_root) = build_3not_hugr().unwrap();
1087 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1088 let [inp, _out] = hugr.get_io(func_root).unwrap();
1089 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1090 let not2 = hugr.output_neighbours(not1).exactly_one().ok().unwrap();
1091 let not3 = hugr.output_neighbours(not2).exactly_one().ok().unwrap();
1092 let not1_inp = hugr.node_inputs(not1).next().unwrap();
1093 let not1_out = hugr.node_outputs(not1).next().unwrap();
1094 let not3_inp = hugr.node_inputs(not3).next().unwrap();
1095 let not3_out = hugr.node_outputs(not3).next().unwrap();
1096 assert_matches!(
1097 SiblingSubgraph::try_new(
1098 vec![vec![(not1, not1_inp)], vec![(not3, not3_inp)]],
1099 vec![(not1, not1_out), (not3, not3_out)],
1100 &func
1101 ),
1102 Err(InvalidSubgraph::NotConvex)
1103 );
1104 }
1105
1106 #[test]
1109 fn convex_multiports() {
1110 let (hugr, func_root) = build_multiport_hugr().unwrap();
1111 let [inp, out] = hugr.get_io(func_root).unwrap();
1112 let not1 = hugr.output_neighbours(inp).exactly_one().ok().unwrap();
1113 let not2 = hugr
1114 .output_neighbours(not1)
1115 .filter(|&n| n != out)
1116 .exactly_one()
1117 .ok()
1118 .unwrap();
1119
1120 let subgraph = SiblingSubgraph::try_from_nodes([not1, not2], &hugr).unwrap();
1121 assert_eq!(subgraph.nodes(), [not1, not2]);
1122 }
1123
1124 #[test]
1125 fn invalid_boundary() {
1126 let (hugr, func_root) = build_hugr().unwrap();
1127 let func: SiblingGraph<'_> = SiblingGraph::try_new(&hugr, func_root).unwrap();
1128 let [inp, out] = hugr.get_io(func_root).unwrap();
1129 let cx_edges_in = hugr.node_outputs(inp);
1130 let cx_edges_out = hugr.node_inputs(out);
1131 assert_matches!(
1133 SiblingSubgraph::try_new(
1134 cx_edges_out.map(|p| vec![(out, p)]).collect(),
1135 cx_edges_in.map(|p| (inp, p)).collect(),
1136 &func,
1137 ),
1138 Err(InvalidSubgraph::InvalidBoundary(
1139 InvalidSubgraphBoundary::DisconnectedBoundaryPort(_, _)
1140 ))
1141 );
1142 }
1143
1144 #[test]
1145 fn preserve_signature() {
1146 let (hugr, func_root) = build_hugr_classical().unwrap();
1147 let func_graph: SiblingGraph<'_, FuncID<true>> =
1148 SiblingGraph::try_new(&hugr, func_root).unwrap();
1149 let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
1150 let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap();
1151 assert_eq!(func_defn.signature, func.signature(&func_graph).into());
1152 }
1153
1154 #[test]
1155 fn extract_subgraph() {
1156 let (hugr, func_root) = build_hugr().unwrap();
1157 let func_graph: SiblingGraph<'_, FuncID<true>> =
1158 SiblingGraph::try_new(&hugr, func_root).unwrap();
1159 let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap();
1160 let extracted = subgraph.extract_subgraph(&hugr, "region");
1161
1162 extracted.validate().unwrap();
1163 }
1164
1165 #[test]
1166 fn edge_both_output_and_copy() {
1167 let one_bit = vec![bool_t()];
1169 let two_bit = vec![bool_t(), bool_t()];
1170
1171 let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap();
1172 let inw = builder.input_wires().exactly_one().unwrap();
1173 let outw1 = builder
1174 .add_dataflow_op(LogicOp::Not, [inw])
1175 .unwrap()
1176 .out_wire(0);
1177 let outw2 = builder
1178 .add_dataflow_op(and_op(), [inw, outw1])
1179 .unwrap()
1180 .outputs();
1181 let outw = [outw1].into_iter().chain(outw2);
1182 let h = builder.finish_hugr_with_outputs(outw).unwrap();
1183 let view = SiblingGraph::<DfgID>::try_new(&h, h.root()).unwrap();
1184 let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap();
1185 assert_eq!(subg.nodes().len(), 2);
1186 }
1187
1188 #[test]
1189 fn test_unconnected() {
1190 let mut b = DFGBuilder::new(
1192 Signature::new(bool_t(), type_row![])
1193 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1195 )
1196 .unwrap();
1197 let inw = b.input_wires().exactly_one().unwrap();
1198 let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1199 let mut h = b.finish_hugr_with_outputs([]).unwrap();
1201
1202 let subg = SiblingSubgraph::from_node(not_n.node(), &h);
1203
1204 assert_eq!(subg.nodes().len(), 1);
1205 let replacement = {
1207 let mut rep_b = DFGBuilder::new(
1208 Signature::new_endo(bool_t())
1209 .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID),
1210 )
1211 .unwrap();
1212 let inw = rep_b.input_wires().exactly_one().unwrap();
1213
1214 let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap();
1215
1216 rep_b.finish_hugr_with_outputs(not_n.outputs()).unwrap()
1217 };
1218 let rep = subg.create_simple_replacement(&h, replacement).unwrap();
1219 rep.apply(&mut h).unwrap();
1220 }
1221}