hugr_core/hugr/patch/
insert_identity.rs

1//! Implementation of the `InsertIdentity` operation.
2
3use std::iter;
4
5use crate::core::HugrNode;
6use crate::extension::prelude::Noop;
7use crate::hugr::{HugrMut, Node};
8use crate::ops::{OpTag, OpTrait};
9
10use crate::types::EdgeKind;
11use crate::{HugrView, IncomingPort};
12
13use super::{PatchHugrMut, PatchVerification};
14
15use thiserror::Error;
16
17/// Specification of a identity-insertion operation.
18#[derive(Debug, Clone)]
19pub struct IdentityInsertion<N = Node> {
20    /// The node following the identity to be inserted.
21    pub post_node: N,
22    /// The port following the identity to be inserted.
23    pub post_port: IncomingPort,
24}
25
26impl<N> IdentityInsertion<N> {
27    /// Create a new [`IdentityInsertion`] specification.
28    pub fn new(post_node: N, post_port: IncomingPort) -> Self {
29        Self {
30            post_node,
31            post_port,
32        }
33    }
34}
35
36/// Error from an [`IdentityInsertion`] operation.
37#[derive(Debug, Clone, Error, PartialEq, Eq)]
38#[non_exhaustive]
39pub enum IdentityInsertionError {
40    /// Invalid parent node.
41    #[error("Parent node is invalid.")]
42    InvalidParentNode,
43    /// Invalid node.
44    #[error("Node is invalid.")]
45    InvalidNode(),
46    /// Invalid port kind.
47    #[error("post_port has invalid kind {}. Must be Value.", _0.as_ref().map_or("None".to_string(), ToString::to_string))]
48    InvalidPortKind(Option<EdgeKind>),
49}
50
51impl<N: HugrNode> PatchVerification for IdentityInsertion<N> {
52    type Error = IdentityInsertionError;
53    type Node = N;
54
55    fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
56        /*
57        Assumptions:
58        1. Value kind inputs can only have one connection.
59        2. Node exists.
60        Conditions:
61        1. post_port is Value kind.
62        2. post_port is connected to a sibling of post_node.
63        3. post_port is input.
64         */
65
66        unimplemented!()
67    }
68
69    #[inline]
70    fn invalidated_nodes(
71        &self,
72        _: &impl HugrView<Node = Self::Node>,
73    ) -> impl Iterator<Item = Self::Node> {
74        iter::once(self.post_node)
75    }
76}
77
78impl<N: HugrNode> PatchHugrMut for IdentityInsertion<N> {
79    /// The inserted node.
80    type Outcome = N;
81
82    const UNCHANGED_ON_FAILURE: bool = true;
83
84    fn apply_hugr_mut(
85        self,
86        h: &mut impl HugrMut<Node = N>,
87    ) -> Result<Self::Outcome, IdentityInsertionError> {
88        let kind = h.get_optype(self.post_node).port_kind(self.post_port);
89        let Some(EdgeKind::Value(ty)) = kind else {
90            return Err(IdentityInsertionError::InvalidPortKind(kind));
91        };
92
93        let (pre_node, pre_port) = h
94            .single_linked_output(self.post_node, self.post_port)
95            .expect("Value kind input can only have one connection.");
96
97        h.disconnect(self.post_node, self.post_port);
98        let parent = h
99            .get_parent(self.post_node)
100            .ok_or(IdentityInsertionError::InvalidParentNode)?;
101        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
102            return Err(IdentityInsertionError::InvalidParentNode);
103        }
104        let new_node = h.add_node_with_parent(parent, Noop(ty));
105        h.connect(pre_node, pre_port, new_node, 0);
106
107        h.connect(new_node, 0, self.post_node, self.post_port);
108        Ok(new_node)
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use rstest::rstest;
115
116    use super::super::simple_replace::test::dfg_hugr;
117    use super::*;
118    use crate::{Hugr, extension::prelude::qb_t};
119
120    #[rstest]
121    fn correct_insertion(dfg_hugr: Hugr) {
122        let mut h = dfg_hugr;
123
124        assert_eq!(h.entry_descendants().count(), 6);
125
126        let final_node = h
127            .input_neighbours(h.get_io(h.entrypoint()).unwrap()[1])
128            .next()
129            .unwrap();
130
131        let final_node_port = h.node_inputs(final_node).next().unwrap();
132
133        let rw = IdentityInsertion::new(final_node, final_node_port);
134
135        let noop_node = h.apply_patch(rw).unwrap();
136
137        assert_eq!(h.entry_descendants().count(), 7);
138
139        let noop: Noop = h.get_optype(noop_node).cast().unwrap();
140
141        assert_eq!(noop, Noop(qb_t()));
142
143        h.validate().unwrap();
144    }
145}