hugr_core/hugr/rewrite/
insert_identity.rs

1//! Implementation of the `InsertIdentity` operation.
2
3use std::iter;
4
5use crate::extension::prelude::Noop;
6use crate::hugr::{HugrMut, Node};
7use crate::ops::{OpTag, OpTrait};
8
9use crate::types::EdgeKind;
10use crate::{HugrView, IncomingPort};
11
12use super::Rewrite;
13
14use thiserror::Error;
15
16/// Specification of a identity-insertion operation.
17#[derive(Debug, Clone)]
18pub struct IdentityInsertion {
19    /// The node following the identity to be inserted.
20    pub post_node: Node,
21    /// The port following the identity to be inserted.
22    pub post_port: IncomingPort,
23}
24
25impl IdentityInsertion {
26    /// Create a new [`IdentityInsertion`] specification.
27    pub fn new(post_node: Node, post_port: IncomingPort) -> Self {
28        Self {
29            post_node,
30            post_port,
31        }
32    }
33}
34
35/// Error from an [`IdentityInsertion`] operation.
36#[derive(Debug, Clone, Error, PartialEq, Eq)]
37#[non_exhaustive]
38pub enum IdentityInsertionError {
39    /// Invalid parent node.
40    #[error("Parent node is invalid.")]
41    InvalidParentNode,
42    /// Invalid node.
43    #[error("Node is invalid.")]
44    InvalidNode(),
45    /// Invalid port kind.
46    #[error("post_port has invalid kind {}. Must be Value.", _0.as_ref().map_or("None".to_string(), ToString::to_string))]
47    InvalidPortKind(Option<EdgeKind>),
48}
49
50impl Rewrite for IdentityInsertion {
51    type Error = IdentityInsertionError;
52    /// The inserted node.
53    type ApplyResult = Node;
54    const UNCHANGED_ON_FAILURE: bool = true;
55    fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
56        /*
57        Assumptions:
58        1. Value kind inputs can only have one connection.
59        2. Node exists.
60        Conditions:
61        1. post_port is Value kind.
62        2. post_port is connected to a sibling of post_node.
63        3. post_port is input.
64         */
65
66        unimplemented!()
67    }
68    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, IdentityInsertionError> {
69        let kind = h.get_optype(self.post_node).port_kind(self.post_port);
70        let Some(EdgeKind::Value(ty)) = kind else {
71            return Err(IdentityInsertionError::InvalidPortKind(kind));
72        };
73
74        let (pre_node, pre_port) = h
75            .single_linked_output(self.post_node, self.post_port)
76            .expect("Value kind input can only have one connection.");
77
78        h.disconnect(self.post_node, self.post_port);
79        let parent = h
80            .get_parent(self.post_node)
81            .ok_or(IdentityInsertionError::InvalidParentNode)?;
82        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
83            return Err(IdentityInsertionError::InvalidParentNode);
84        }
85        let new_node = h.add_node_with_parent(parent, Noop(ty));
86        h.connect(pre_node, pre_port, new_node, 0);
87
88        h.connect(new_node, 0, self.post_node, self.post_port);
89        Ok(new_node)
90    }
91
92    #[inline]
93    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
94        iter::once(self.post_node)
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use rstest::rstest;
101
102    use super::super::simple_replace::test::dfg_hugr;
103    use super::*;
104    use crate::{extension::prelude::qb_t, Hugr};
105
106    #[rstest]
107    fn correct_insertion(dfg_hugr: Hugr) {
108        let mut h = dfg_hugr;
109
110        assert_eq!(h.node_count(), 6);
111
112        let final_node = h
113            .input_neighbours(h.get_io(h.root()).unwrap()[1])
114            .next()
115            .unwrap();
116
117        let final_node_port = h.node_inputs(final_node).next().unwrap();
118
119        let rw = IdentityInsertion::new(final_node, final_node_port);
120
121        let noop_node = h.apply_rewrite(rw).unwrap();
122
123        assert_eq!(h.node_count(), 7);
124
125        let noop: Noop = h.get_optype(noop_node).cast().unwrap();
126
127        assert_eq!(noop, Noop(qb_t()));
128
129        h.validate().unwrap();
130    }
131}