hugr_core/hugr/patch/
insert_identity.rs1use std::iter;
4
5use crate::core::HugrNode;
6use crate::extension::prelude::Noop;
7use crate::hugr::{HugrMut, Node};
8use crate::ops::{OpTag, OpTrait};
9
10use crate::types::EdgeKind;
11use crate::{HugrView, IncomingPort};
12
13use super::{PatchHugrMut, PatchVerification};
14
15use thiserror::Error;
16
17#[derive(Debug, Clone)]
19pub struct IdentityInsertion<N = Node> {
20 pub post_node: N,
22 pub post_port: IncomingPort,
24}
25
26impl<N> IdentityInsertion<N> {
27 pub fn new(post_node: N, post_port: IncomingPort) -> Self {
29 Self {
30 post_node,
31 post_port,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Error, PartialEq, Eq)]
38#[non_exhaustive]
39pub enum IdentityInsertionError {
40 #[error("Parent node is invalid.")]
42 InvalidParentNode,
43 #[error("Node is invalid.")]
45 InvalidNode(),
46 #[error("post_port has invalid kind {}. Must be Value.", _0.as_ref().map_or("None".to_string(), ToString::to_string))]
48 InvalidPortKind(Option<EdgeKind>),
49}
50
51impl<N: HugrNode> PatchVerification for IdentityInsertion<N> {
52 type Error = IdentityInsertionError;
53 type Node = N;
54
55 fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
56 unimplemented!()
67 }
68
69 #[inline]
70 fn invalidation_set(&self) -> impl Iterator<Item = N> {
71 iter::once(self.post_node)
72 }
73}
74
75impl<N: HugrNode> PatchHugrMut for IdentityInsertion<N> {
76 type Outcome = N;
78
79 const UNCHANGED_ON_FAILURE: bool = true;
80
81 fn apply_hugr_mut(
82 self,
83 h: &mut impl HugrMut<Node = N>,
84 ) -> Result<Self::Outcome, IdentityInsertionError> {
85 let kind = h.get_optype(self.post_node).port_kind(self.post_port);
86 let Some(EdgeKind::Value(ty)) = kind else {
87 return Err(IdentityInsertionError::InvalidPortKind(kind));
88 };
89
90 let (pre_node, pre_port) = h
91 .single_linked_output(self.post_node, self.post_port)
92 .expect("Value kind input can only have one connection.");
93
94 h.disconnect(self.post_node, self.post_port);
95 let parent = h
96 .get_parent(self.post_node)
97 .ok_or(IdentityInsertionError::InvalidParentNode)?;
98 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
99 return Err(IdentityInsertionError::InvalidParentNode);
100 }
101 let new_node = h.add_node_with_parent(parent, Noop(ty));
102 h.connect(pre_node, pre_port, new_node, 0);
103
104 h.connect(new_node, 0, self.post_node, self.post_port);
105 Ok(new_node)
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use rstest::rstest;
112
113 use super::super::simple_replace::test::dfg_hugr;
114 use super::*;
115 use crate::{Hugr, extension::prelude::qb_t};
116
117 #[rstest]
118 fn correct_insertion(dfg_hugr: Hugr) {
119 let mut h = dfg_hugr;
120
121 assert_eq!(h.entry_descendants().count(), 6);
122
123 let final_node = h
124 .input_neighbours(h.get_io(h.entrypoint()).unwrap()[1])
125 .next()
126 .unwrap();
127
128 let final_node_port = h.node_inputs(final_node).next().unwrap();
129
130 let rw = IdentityInsertion::new(final_node, final_node_port);
131
132 let noop_node = h.apply_patch(rw).unwrap();
133
134 assert_eq!(h.entry_descendants().count(), 7);
135
136 let noop: Noop = h.get_optype(noop_node).cast().unwrap();
137
138 assert_eq!(noop, Noop(qb_t()));
139
140 h.validate().unwrap();
141 }
142}