hugr_core/hugr/patch/
insert_cut.rs1use std::collections::HashMap;
4use std::iter;
5
6use crate::core::HugrNode;
7use crate::hugr::patch::inline_dfg::InlineDFG;
8use crate::hugr::{HugrMut, Node};
9use crate::ops::{DataflowOpTrait, OpTag, OpTrait, OpType};
10
11use crate::{Hugr, HugrView, IncomingPort};
12
13use super::inline_dfg::InlineDFGError;
14use super::{Patch, PatchHugrMut, PatchVerification};
15
16use itertools::Itertools;
17use thiserror::Error;
18
19pub struct InsertCut<N = Node> {
28 pub parent: N,
30 pub targets: Vec<(N, IncomingPort)>,
32 pub insertion: Hugr,
34}
35
36impl<N> InsertCut<N> {
37 pub fn new(parent: N, targets: Vec<(N, IncomingPort)>, insertion: Hugr) -> Self {
39 Self {
40 parent,
41 targets,
42 insertion,
43 }
44 }
45}
46#[derive(Debug, Clone, Error, PartialEq)]
48#[non_exhaustive]
49pub enum InsertCutError<N = Node> {
50 #[error("Parent node is invalid.")]
52 InvalidParentNode,
53 #[error("HUGR graph does not contain node: {0}.")]
55 InvalidNode(N),
56
57 #[error("Parent node is not a DFG, found root optype: {0}.")]
59 ReplaceNotDfg(OpType),
60
61 #[error("Inlining inserting DFG failed: {0}.")]
63 InlineFailed(#[from] InlineDFGError),
64
65 #[error("Incoming port has {0} connections, expected exactly 1.")]
67 InvalidIncomingPort(usize),
68
69 #[error("Target number mismatch, expected {0}, found {1}.")]
71 TargetNumberMismatch(usize, usize),
72
73 #[error("Replacement DFG must have the same number of inputs and outputs.")]
75 InputOutputMismatch,
76}
77
78impl<N: HugrNode> PatchVerification for InsertCut<N> {
79 type Error = InsertCutError<N>;
80 type Node = N;
81
82 fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
83 let insert_root = self.insertion.entrypoint_optype();
84 let Some(dfg) = insert_root.as_dfg() else {
85 return Err(InsertCutError::ReplaceNotDfg(insert_root.clone()));
86 };
87
88 let sig = dfg.signature();
89 if sig.input().len() != sig.output().len() {
90 return Err(InsertCutError::InputOutputMismatch);
91 }
92 if sig.input().len() != self.targets.len() {
93 return Err(InsertCutError::TargetNumberMismatch(
94 sig.input().len(),
95 self.targets.len(),
96 ));
97 }
98 if !h.contains_node(self.parent) {
99 return Err(InsertCutError::InvalidNode(self.parent));
100 }
101 let parent_op = h.get_optype(self.parent);
102 if !OpTag::DataflowParent.is_superset(parent_op.tag()) {
103 return Err(InsertCutError::InvalidParentNode);
104 }
105
106 for (node, port) in &self.targets {
108 if !h.contains_node(*node) {
109 return Err(InsertCutError::InvalidNode(*node));
110 }
111
112 let n_links = h.linked_outputs(*node, *port).count();
113 if n_links != 1 {
114 return Err(InsertCutError::InvalidIncomingPort(n_links));
115 }
116 }
117 Ok(())
118 }
119
120 #[inline]
121 fn invalidated_nodes(
122 &self,
123 _: &impl HugrView<Node = Self::Node>,
124 ) -> impl Iterator<Item = Self::Node> {
125 iter::once(self.parent)
126 .chain(self.targets.iter().map(|(n, _)| *n))
127 .unique()
128 }
129}
130impl PatchHugrMut for InsertCut<Node> {
131 type Outcome = HashMap<Node, Node>;
132 const UNCHANGED_ON_FAILURE: bool = false;
133
134 fn apply_hugr_mut(
135 self,
136 h: &mut impl HugrMut<Node = Node>,
137 ) -> Result<Self::Outcome, InsertCutError> {
138 let insert_res = h.insert_hugr(self.parent, self.insertion);
139 let inserted_entrypoint = insert_res.inserted_entrypoint;
140 for (i, (target, port)) in self.targets.into_iter().enumerate() {
141 let (src_n, src_p) = h
142 .single_linked_output(target, port)
143 .expect("Incoming value edge has single connection.");
144 h.disconnect(target, port);
145 h.connect(src_n, src_p, inserted_entrypoint, i);
146 h.connect(inserted_entrypoint, i, target, port);
147 }
148 let inline = InlineDFG(inserted_entrypoint.into());
149
150 inline.apply(h)?;
151 Ok(insert_res.node_map)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use rstest::rstest;
158
159 use super::*;
160 use crate::{
161 builder::{DFGBuilder, Dataflow, DataflowHugr},
162 extension::prelude::{Noop, bool_t, qb_t},
163 types::Signature,
164 };
165
166 #[rstest]
167 fn test_insert_cut() {
168 let dfg_b = DFGBuilder::new(Signature::new_endo(vec![bool_t(), qb_t()])).unwrap();
169 let inputs = dfg_b.input().outputs();
170 let mut h = dfg_b.finish_hugr_with_outputs(inputs).unwrap();
171 let [i, o] = h.get_io(h.entrypoint()).unwrap();
172
173 let mut dfg_b = DFGBuilder::new(Signature::new_endo(vec![bool_t(), qb_t()])).unwrap();
174 let [b, q] = dfg_b.input().outputs_arr();
175 let noop1 = dfg_b.add_dataflow_op(Noop::new(bool_t()), [b]).unwrap();
176 let noop2 = dfg_b.add_dataflow_op(Noop::new(qb_t()), [q]).unwrap();
177
178 let replacement = dfg_b
179 .finish_hugr_with_outputs([noop1.out_wire(0), noop2.out_wire(0)])
180 .unwrap();
181
182 let targets: Vec<_> = h.all_linked_inputs(i).collect();
183 let inserter = InsertCut::new(h.entrypoint(), targets, replacement);
184 assert_eq!(
185 inserter.invalidated_nodes(&h).collect::<Vec<Node>>(),
186 vec![h.entrypoint(), o]
187 );
188
189 inserter.verify(&h).unwrap();
190
191 assert_eq!(h.entry_descendants().count(), 3);
192 inserter.apply_hugr_mut(&mut h).unwrap();
193
194 h.validate().unwrap();
195 assert_eq!(h.entry_descendants().count(), 5);
196 }
197}