hugr_core/hugr/patch/
insert_cut.rs

1//! Patch for inserting a sub-HUGR as a "cut" across existing edges.
2
3use std::collections::HashMap;
4use std::iter;
5
6use crate::core::HugrNode;
7use crate::hugr::patch::inline_dfg::InlineDFG;
8use crate::hugr::{HugrMut, Node};
9use crate::ops::{DataflowOpTrait, OpTag, OpTrait, OpType};
10
11use crate::{Hugr, HugrView, IncomingPort};
12
13use super::inline_dfg::InlineDFGError;
14use super::{Patch, PatchHugrMut, PatchVerification};
15
16use itertools::Itertools;
17use thiserror::Error;
18
19/// Implementation of the `InsertCut` operation.
20///
21/// The `InsertCut` operation allows inserting a HUGR sub-graph as a "cut" between existing nodes in a dataflow graph.
22/// It effectively intercepts connections between nodes by inserting the new Hugr in between them.
23///
24/// This patch operation works by:
25/// 1. Inserting a new HUGR as a child of the specified parent
26/// 2. Redirecting existing connections through the newly inserted HUGR.
27pub struct InsertCut<N = Node> {
28    /// The parent node to insert the new HUGR under.
29    pub parent: N,
30    /// The targets of the existing edges.
31    pub targets: Vec<(N, IncomingPort)>,
32    /// The HUGR to insert, must have  DFG root.
33    pub insertion: Hugr,
34}
35
36impl<N> InsertCut<N> {
37    /// Create a new [`InsertCut`] specification.
38    pub fn new(parent: N, targets: Vec<(N, IncomingPort)>, insertion: Hugr) -> Self {
39        Self {
40            parent,
41            targets,
42            insertion,
43        }
44    }
45}
46/// Error from an [`InsertCut`] operation.
47#[derive(Debug, Clone, Error, PartialEq)]
48#[non_exhaustive]
49pub enum InsertCutError<N = Node> {
50    /// Invalid parent node.
51    #[error("Parent node is invalid.")]
52    InvalidParentNode,
53    /// Invalid node.
54    #[error("HUGR graph does not contain node: {0}.")]
55    InvalidNode(N),
56
57    /// Replacement HUGR not a DFG.
58    #[error("Parent node is not a DFG, found root optype: {0}.")]
59    ReplaceNotDfg(OpType),
60
61    /// Inline error.
62    #[error("Inlining inserting DFG failed: {0}.")]
63    InlineFailed(#[from] InlineDFGError),
64
65    /// Port connection error.
66    #[error("Incoming port has {0} connections, expected exactly 1.")]
67    InvalidIncomingPort(usize),
68
69    /// Target number mismatch.
70    #[error("Target number mismatch, expected {0}, found {1}.")]
71    TargetNumberMismatch(usize, usize),
72
73    /// Input/Output mismatch.
74    #[error("Replacement DFG must have the same number of inputs and outputs.")]
75    InputOutputMismatch,
76}
77
78impl<N: HugrNode> PatchVerification for InsertCut<N> {
79    type Error = InsertCutError<N>;
80    type Node = N;
81
82    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
83        let insert_root = self.insertion.entrypoint_optype();
84        let Some(dfg) = insert_root.as_dfg() else {
85            return Err(InsertCutError::ReplaceNotDfg(insert_root.clone()));
86        };
87
88        let sig = dfg.signature();
89        if sig.input().len() != sig.output().len() {
90            return Err(InsertCutError::InputOutputMismatch);
91        }
92        if sig.input().len() != self.targets.len() {
93            return Err(InsertCutError::TargetNumberMismatch(
94                sig.input().len(),
95                self.targets.len(),
96            ));
97        }
98        if !h.contains_node(self.parent) {
99            return Err(InsertCutError::InvalidNode(self.parent));
100        }
101        let parent_op = h.get_optype(self.parent);
102        if !OpTag::DataflowParent.is_superset(parent_op.tag()) {
103            return Err(InsertCutError::InvalidParentNode);
104        }
105
106        // Verify that each target node exists and each target port is valid
107        for (node, port) in &self.targets {
108            if !h.contains_node(*node) {
109                return Err(InsertCutError::InvalidNode(*node));
110            }
111
112            let n_links = h.linked_outputs(*node, *port).count();
113            if n_links != 1 {
114                return Err(InsertCutError::InvalidIncomingPort(n_links));
115            }
116        }
117        Ok(())
118    }
119
120    #[inline]
121    fn invalidated_nodes(
122        &self,
123        _: &impl HugrView<Node = Self::Node>,
124    ) -> impl Iterator<Item = Self::Node> {
125        iter::once(self.parent)
126            .chain(self.targets.iter().map(|(n, _)| *n))
127            .unique()
128    }
129}
130impl PatchHugrMut for InsertCut<Node> {
131    type Outcome = HashMap<Node, Node>;
132    const UNCHANGED_ON_FAILURE: bool = false;
133
134    fn apply_hugr_mut(
135        self,
136        h: &mut impl HugrMut<Node = Node>,
137    ) -> Result<Self::Outcome, InsertCutError> {
138        let insert_res = h.insert_hugr(self.parent, self.insertion);
139        let inserted_entrypoint = insert_res.inserted_entrypoint;
140        for (i, (target, port)) in self.targets.into_iter().enumerate() {
141            let (src_n, src_p) = h
142                .single_linked_output(target, port)
143                .expect("Incoming value edge has single connection.");
144            h.disconnect(target, port);
145            h.connect(src_n, src_p, inserted_entrypoint, i);
146            h.connect(inserted_entrypoint, i, target, port);
147        }
148        let inline = InlineDFG(inserted_entrypoint.into());
149
150        inline.apply(h)?;
151        Ok(insert_res.node_map)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use rstest::rstest;
158
159    use super::*;
160    use crate::{
161        builder::{DFGBuilder, Dataflow, DataflowHugr},
162        extension::prelude::{Noop, bool_t, qb_t},
163        types::Signature,
164    };
165
166    #[rstest]
167    fn test_insert_cut() {
168        let dfg_b = DFGBuilder::new(Signature::new_endo(vec![bool_t(), qb_t()])).unwrap();
169        let inputs = dfg_b.input().outputs();
170        let mut h = dfg_b.finish_hugr_with_outputs(inputs).unwrap();
171        let [i, o] = h.get_io(h.entrypoint()).unwrap();
172
173        let mut dfg_b = DFGBuilder::new(Signature::new_endo(vec![bool_t(), qb_t()])).unwrap();
174        let [b, q] = dfg_b.input().outputs_arr();
175        let noop1 = dfg_b.add_dataflow_op(Noop::new(bool_t()), [b]).unwrap();
176        let noop2 = dfg_b.add_dataflow_op(Noop::new(qb_t()), [q]).unwrap();
177
178        let replacement = dfg_b
179            .finish_hugr_with_outputs([noop1.out_wire(0), noop2.out_wire(0)])
180            .unwrap();
181
182        let targets: Vec<_> = h.all_linked_inputs(i).collect();
183        let inserter = InsertCut::new(h.entrypoint(), targets, replacement);
184        assert_eq!(
185            inserter.invalidated_nodes(&h).collect::<Vec<Node>>(),
186            vec![h.entrypoint(), o]
187        );
188
189        inserter.verify(&h).unwrap();
190
191        assert_eq!(h.entry_descendants().count(), 3);
192        inserter.apply_hugr_mut(&mut h).unwrap();
193
194        h.validate().unwrap();
195        assert_eq!(h.entry_descendants().count(), 5);
196    }
197}