hugr_core/hugr/patch/
insert_cut.rs

1//! Patch for inserting a sub-HUGR as a "cut" across existing edges.
2
3use std::collections::HashMap;
4use std::iter;
5
6use crate::core::HugrNode;
7use crate::hugr::patch::inline_dfg::InlineDFG;
8use crate::hugr::{HugrMut, Node};
9use crate::ops::{DataflowOpTrait, OpTag, OpTrait, OpType};
10
11use crate::{Hugr, HugrView, IncomingPort};
12
13use super::inline_dfg::InlineDFGError;
14use super::{Patch, PatchHugrMut, PatchVerification};
15
16use itertools::Itertools;
17use thiserror::Error;
18
19/// Implementation of the `InsertCut` operation.
20///
21/// The `InsertCut` operation allows inserting a HUGR sub-graph as a "cut" between existing nodes in a dataflow graph.
22/// It effectively intercepts connections between nodes by inserting the new Hugr in between them.
23///
24/// This patch operation works by:
25/// 1. Inserting a new HUGR as a child of the specified parent
26/// 2. Redirecting existing connections through the newly inserted HUGR.
27pub struct InsertCut<N = Node> {
28    /// The parent node to insert the new HUGR under.
29    pub parent: N,
30    /// The targets of the existing edges.
31    pub targets: Vec<(N, IncomingPort)>,
32    /// The HUGR to insert, must have  DFG root.
33    pub insertion: Hugr,
34}
35
36impl<N> InsertCut<N> {
37    /// Create a new [`InsertCut`] specification.
38    pub fn new(parent: N, targets: Vec<(N, IncomingPort)>, insertion: Hugr) -> Self {
39        Self {
40            parent,
41            targets,
42            insertion,
43        }
44    }
45}
46/// Error from an [`InsertCut`] operation.
47#[derive(Debug, Clone, Error, PartialEq)]
48#[non_exhaustive]
49pub enum InsertCutError<N = Node> {
50    /// Invalid parent node.
51    #[error("Parent node is invalid.")]
52    InvalidParentNode,
53    /// Invalid node.
54    #[error("HUGR graph does not contain node: {0}.")]
55    InvalidNode(N),
56
57    /// Replacement HUGR not a DFG.
58    #[error("Parent node is not a DFG, found root optype: {0}.")]
59    ReplaceNotDfg(OpType),
60
61    /// Inline error.
62    #[error("Inlining inserting DFG failed: {0}.")]
63    InlineFailed(#[from] InlineDFGError),
64
65    /// Port connection error.
66    #[error("Incoming port has {0} connections, expected exactly 1.")]
67    InvalidIncomingPort(usize),
68
69    /// Target number mismatch.
70    #[error("Target number mismatch, expected {0}, found {1}.")]
71    TargetNumberMismatch(usize, usize),
72
73    /// Input/Output mismatch.
74    #[error("Replacement DFG must have the same number of inputs and outputs.")]
75    InputOutputMismatch,
76}
77
78impl<N: HugrNode> PatchVerification for InsertCut<N> {
79    type Error = InsertCutError<N>;
80    type Node = N;
81
82    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
83        let insert_root = self.insertion.entrypoint_optype();
84        let Some(dfg) = insert_root.as_dfg() else {
85            return Err(InsertCutError::ReplaceNotDfg(insert_root.clone()));
86        };
87
88        let sig = dfg.signature();
89        if sig.input().len() != sig.output().len() {
90            return Err(InsertCutError::InputOutputMismatch);
91        }
92        if sig.input().len() != self.targets.len() {
93            return Err(InsertCutError::TargetNumberMismatch(
94                sig.input().len(),
95                self.targets.len(),
96            ));
97        }
98        if !h.contains_node(self.parent) {
99            return Err(InsertCutError::InvalidNode(self.parent));
100        }
101        let parent_op = h.get_optype(self.parent);
102        if !OpTag::DataflowParent.is_superset(parent_op.tag()) {
103            return Err(InsertCutError::InvalidParentNode);
104        }
105
106        // Verify that each target node exists and each target port is valid
107        for (node, port) in &self.targets {
108            if !h.contains_node(*node) {
109                return Err(InsertCutError::InvalidNode(*node));
110            }
111
112            let n_links = h.linked_outputs(*node, *port).count();
113            if n_links != 1 {
114                return Err(InsertCutError::InvalidIncomingPort(n_links));
115            }
116        }
117        Ok(())
118    }
119
120    #[inline]
121    fn invalidation_set(&self) -> impl Iterator<Item = N> {
122        iter::once(self.parent)
123            .chain(self.targets.iter().map(|(n, _)| *n))
124            .unique()
125    }
126}
127impl PatchHugrMut for InsertCut<Node> {
128    type Outcome = HashMap<Node, Node>;
129    const UNCHANGED_ON_FAILURE: bool = false;
130
131    fn apply_hugr_mut(
132        self,
133        h: &mut impl HugrMut<Node = Node>,
134    ) -> Result<Self::Outcome, InsertCutError> {
135        let insert_res = h.insert_hugr(self.parent, self.insertion);
136        let inserted_entrypoint = insert_res.inserted_entrypoint;
137        for (i, (target, port)) in self.targets.into_iter().enumerate() {
138            let (src_n, src_p) = h
139                .single_linked_output(target, port)
140                .expect("Incoming value edge has single connection.");
141            h.disconnect(target, port);
142            h.connect(src_n, src_p, inserted_entrypoint, i);
143            h.connect(inserted_entrypoint, i, target, port);
144        }
145        let inline = InlineDFG(inserted_entrypoint.into());
146
147        inline.apply(h)?;
148        Ok(insert_res.node_map)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use rstest::rstest;
155
156    use super::*;
157    use crate::{
158        builder::{DFGBuilder, Dataflow, DataflowHugr},
159        extension::prelude::{Noop, bool_t, qb_t},
160        types::Signature,
161    };
162
163    #[rstest]
164    fn test_insert_cut() {
165        let dfg_b = DFGBuilder::new(Signature::new_endo(vec![bool_t(), qb_t()])).unwrap();
166        let inputs = dfg_b.input().outputs();
167        let mut h = dfg_b.finish_hugr_with_outputs(inputs).unwrap();
168        let [i, o] = h.get_io(h.entrypoint()).unwrap();
169
170        let mut dfg_b = DFGBuilder::new(Signature::new_endo(vec![bool_t(), qb_t()])).unwrap();
171        let [b, q] = dfg_b.input().outputs_arr();
172        let noop1 = dfg_b.add_dataflow_op(Noop::new(bool_t()), [b]).unwrap();
173        let noop2 = dfg_b.add_dataflow_op(Noop::new(qb_t()), [q]).unwrap();
174
175        let replacement = dfg_b
176            .finish_hugr_with_outputs([noop1.out_wire(0), noop2.out_wire(0)])
177            .unwrap();
178
179        let targets: Vec<_> = h.all_linked_inputs(i).collect();
180        let inserter = InsertCut::new(h.entrypoint(), targets, replacement);
181        assert_eq!(
182            inserter.invalidation_set().collect::<Vec<Node>>(),
183            vec![h.entrypoint(), o]
184        );
185
186        inserter.verify(&h).unwrap();
187
188        assert_eq!(h.entry_descendants().count(), 3);
189        inserter.apply_hugr_mut(&mut h).unwrap();
190
191        h.validate().unwrap();
192        assert_eq!(h.entry_descendants().count(), 5);
193    }
194}