hugr_core/hugr/rewrite/
insert_identity.rs1use std::iter;
4
5use crate::extension::prelude::Noop;
6use crate::hugr::{HugrMut, Node};
7use crate::ops::{OpTag, OpTrait};
8
9use crate::types::EdgeKind;
10use crate::{HugrView, IncomingPort};
11
12use super::Rewrite;
13
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
18pub struct IdentityInsertion {
19 pub post_node: Node,
21 pub post_port: IncomingPort,
23}
24
25impl IdentityInsertion {
26 pub fn new(post_node: Node, post_port: IncomingPort) -> Self {
28 Self {
29 post_node,
30 post_port,
31 }
32 }
33}
34
35#[derive(Debug, Clone, Error, PartialEq, Eq)]
37#[non_exhaustive]
38pub enum IdentityInsertionError {
39 #[error("Parent node is invalid.")]
41 InvalidParentNode,
42 #[error("Node is invalid.")]
44 InvalidNode(),
45 #[error("post_port has invalid kind {}. Must be Value.", _0.as_ref().map_or("None".to_string(), ToString::to_string))]
47 InvalidPortKind(Option<EdgeKind>),
48}
49
50impl Rewrite for IdentityInsertion {
51 type Error = IdentityInsertionError;
52 type ApplyResult = Node;
54 const UNCHANGED_ON_FAILURE: bool = true;
55 fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
56 unimplemented!()
67 }
68 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, IdentityInsertionError> {
69 let kind = h.get_optype(self.post_node).port_kind(self.post_port);
70 let Some(EdgeKind::Value(ty)) = kind else {
71 return Err(IdentityInsertionError::InvalidPortKind(kind));
72 };
73
74 let (pre_node, pre_port) = h
75 .single_linked_output(self.post_node, self.post_port)
76 .expect("Value kind input can only have one connection.");
77
78 h.disconnect(self.post_node, self.post_port);
79 let parent = h
80 .get_parent(self.post_node)
81 .ok_or(IdentityInsertionError::InvalidParentNode)?;
82 if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
83 return Err(IdentityInsertionError::InvalidParentNode);
84 }
85 let new_node = h.add_node_with_parent(parent, Noop(ty));
86 h.connect(pre_node, pre_port, new_node, 0);
87
88 h.connect(new_node, 0, self.post_node, self.post_port);
89 Ok(new_node)
90 }
91
92 #[inline]
93 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
94 iter::once(self.post_node)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use rstest::rstest;
101
102 use super::super::simple_replace::test::dfg_hugr;
103 use super::*;
104 use crate::{extension::prelude::qb_t, Hugr};
105
106 #[rstest]
107 fn correct_insertion(dfg_hugr: Hugr) {
108 let mut h = dfg_hugr;
109
110 assert_eq!(h.node_count(), 6);
111
112 let final_node = h
113 .input_neighbours(h.get_io(h.root()).unwrap()[1])
114 .next()
115 .unwrap();
116
117 let final_node_port = h.node_inputs(final_node).next().unwrap();
118
119 let rw = IdentityInsertion::new(final_node, final_node_port);
120
121 let noop_node = h.apply_rewrite(rw).unwrap();
122
123 assert_eq!(h.node_count(), 7);
124
125 let noop: Noop = h.get_optype(noop_node).cast().unwrap();
126
127 assert_eq!(noop, Noop(qb_t()));
128
129 h.validate().unwrap();
130 }
131}