1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
//! Implementation of the `InsertIdentity` operation.

use std::iter;

use crate::hugr::{HugrMut, Node};
use crate::ops::{Noop, OpTag, OpTrait};
use crate::types::EdgeKind;
use crate::{HugrView, IncomingPort};

use super::Rewrite;

use thiserror::Error;

/// Specification of a identity-insertion operation.
#[derive(Debug, Clone)]
pub struct IdentityInsertion {
    /// The node following the identity to be inserted.
    pub post_node: Node,
    /// The port following the identity to be inserted.
    pub post_port: IncomingPort,
}

impl IdentityInsertion {
    /// Create a new [`IdentityInsertion`] specification.
    pub fn new(post_node: Node, post_port: IncomingPort) -> Self {
        Self {
            post_node,
            post_port,
        }
    }
}

/// Error from an [`IdentityInsertion`] operation.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[non_exhaustive]
pub enum IdentityInsertionError {
    /// Invalid parent node.
    #[error("Parent node is invalid.")]
    InvalidParentNode,
    /// Invalid node.
    #[error("Node is invalid.")]
    InvalidNode(),
    /// Invalid port kind.
    #[error("post_port has invalid kind {0:?}. Must be Value.")]
    InvalidPortKind(Option<EdgeKind>),
}

impl Rewrite for IdentityInsertion {
    type Error = IdentityInsertionError;
    /// The inserted node.
    type ApplyResult = Node;
    const UNCHANGED_ON_FAILURE: bool = true;
    fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> {
        /*
        Assumptions:
        1. Value kind inputs can only have one connection.
        2. Node exists.
        Conditions:
        1. post_port is Value kind.
        2. post_port is connected to a sibling of post_node.
        3. post_port is input.
         */

        unimplemented!()
    }
    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, IdentityInsertionError> {
        let kind = h.get_optype(self.post_node).port_kind(self.post_port);
        let Some(EdgeKind::Value(ty)) = kind else {
            return Err(IdentityInsertionError::InvalidPortKind(kind));
        };

        let (pre_node, pre_port) = h
            .single_linked_output(self.post_node, self.post_port)
            .expect("Value kind input can only have one connection.");

        h.disconnect(self.post_node, self.post_port);
        let parent = h
            .get_parent(self.post_node)
            .ok_or(IdentityInsertionError::InvalidParentNode)?;
        if !OpTag::DataflowParent.is_superset(h.get_optype(parent).tag()) {
            return Err(IdentityInsertionError::InvalidParentNode);
        }
        let new_node = h.add_node_with_parent(parent, Noop { ty });
        h.connect(pre_node, pre_port, new_node, 0);

        h.connect(new_node, 0, self.post_node, self.post_port);
        Ok(new_node)
    }

    #[inline]
    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
        iter::once(self.post_node)
    }
}

#[cfg(test)]
mod tests {
    use rstest::rstest;

    use super::super::simple_replace::test::dfg_hugr;
    use super::*;
    use crate::{
        extension::{prelude::QB_T, PRELUDE_REGISTRY},
        Hugr,
    };

    #[rstest]
    fn correct_insertion(dfg_hugr: Hugr) {
        let mut h = dfg_hugr;

        assert_eq!(h.node_count(), 6);

        let final_node = h
            .input_neighbours(h.get_io(h.root()).unwrap()[1])
            .next()
            .unwrap();

        let final_node_port = h.node_inputs(final_node).next().unwrap();

        let rw = IdentityInsertion::new(final_node, final_node_port);

        let noop_node = h.apply_rewrite(rw).unwrap();

        assert_eq!(h.node_count(), 7);

        let noop: Noop = h.get_optype(noop_node).clone().try_into().unwrap();

        assert_eq!(noop, Noop { ty: QB_T });

        h.update_validate(&PRELUDE_REGISTRY).unwrap();
    }
}