hugr_core/hugr/patch/
consts.rs

1//! Rewrite operations involving Const and `LoadConst` operations
2
3use std::iter;
4
5use crate::{HugrView, Node, core::HugrNode, hugr::HugrMut};
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::{PatchHugrMut, PatchVerification};
10
11/// Remove a [`crate::ops::LoadConstant`] node with no consumers.
12#[derive(Debug, Clone)]
13pub struct RemoveLoadConstant<N = Node>(pub N);
14
15/// Error from an [`RemoveConst`] or [`RemoveLoadConstant`] operation.
16#[derive(Debug, Clone, Error, PartialEq, Eq)]
17#[non_exhaustive]
18pub enum RemoveError<N = Node> {
19    /// Invalid node.
20    #[error("Node is invalid (either not in HUGR or not correct operation).")]
21    InvalidNode(N),
22    /// Node in use.
23    #[error("Node: {0} has non-zero outgoing connections.")]
24    ValueUsed(N),
25}
26
27impl<N: HugrNode> PatchVerification for RemoveLoadConstant<N> {
28    type Error = RemoveError<N>;
29    type Node = N;
30
31    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
32        let node = self.0;
33
34        if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
35            return Err(RemoveError::InvalidNode(node));
36        }
37        let (p, _) = h
38            .out_value_types(node)
39            .exactly_one()
40            .ok()
41            .expect("LoadConstant has only one output.");
42        if h.linked_inputs(node, p).next().is_some() {
43            return Err(RemoveError::ValueUsed(node));
44        }
45
46        Ok(())
47    }
48
49    fn invalidated_nodes(
50        &self,
51        _: &impl HugrView<Node = Self::Node>,
52    ) -> impl Iterator<Item = Self::Node> {
53        iter::once(self.0)
54    }
55}
56
57impl<N: HugrNode> PatchHugrMut for RemoveLoadConstant<N> {
58    /// The [`Const`](crate::ops::Const) node the [`LoadConstant`](crate::ops::LoadConstant) was
59    /// connected to.
60    type Outcome = N;
61
62    const UNCHANGED_ON_FAILURE: bool = true;
63    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
64        self.verify(h)?;
65        let node = self.0;
66        let source = h
67            .input_neighbours(node)
68            .exactly_one()
69            .ok()
70            .expect("Validation should check a Const is connected to LoadConstant.");
71        h.remove_node(node);
72
73        Ok(source)
74    }
75}
76
77/// Remove a [`crate::ops::Const`] node with no outputs.
78#[derive(Debug, Clone)]
79pub struct RemoveConst<N = Node>(pub N);
80
81impl<N: HugrNode> PatchVerification for RemoveConst<N> {
82    type Node = N;
83    type Error = RemoveError<N>;
84
85    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
86        let node = self.0;
87
88        if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
89            return Err(RemoveError::InvalidNode(node));
90        }
91
92        if h.output_neighbours(node).next().is_some() {
93            return Err(RemoveError::ValueUsed(node));
94        }
95
96        Ok(())
97    }
98
99    fn invalidated_nodes(
100        &self,
101        _: &impl HugrView<Node = Self::Node>,
102    ) -> impl Iterator<Item = Self::Node> {
103        iter::once(self.0)
104    }
105}
106
107impl<N: HugrNode> PatchHugrMut for RemoveConst<N> {
108    // The parent of the Const node.
109    type Outcome = N;
110
111    const UNCHANGED_ON_FAILURE: bool = true;
112
113    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
114        self.verify(h)?;
115        let node = self.0;
116        let parent = h
117            .get_parent(node)
118            .expect("Const node without a parent shouldn't happen.");
119        h.remove_node(node);
120
121        Ok(parent)
122    }
123}
124
125#[cfg(test)]
126mod test {
127    use super::*;
128
129    use crate::{
130        builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
131        extension::prelude::ConstUsize,
132        ops::{Value, handle::NodeHandle},
133        type_row,
134        types::Signature,
135    };
136    #[test]
137    fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
138        let mut build = ModuleBuilder::new();
139        let con_node = build.add_constant(Value::extension(ConstUsize::new(2)));
140
141        let mut dfg_build = build.define_function("main", Signature::new_endo(type_row![]))?;
142        let load_1 = dfg_build.load_const(&con_node);
143        let load_2 = dfg_build.load_const(&con_node);
144        let tup = dfg_build.make_tuple([load_1, load_2])?;
145        dfg_build.finish_sub_container()?;
146
147        let mut h = build.finish_hugr()?;
148        // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple
149        assert_eq!(h.num_nodes(), 8);
150        let tup_node = tup.node();
151        // can't remove invalid node
152        assert_eq!(
153            h.apply_patch(RemoveConst(tup_node)),
154            Err(RemoveError::InvalidNode(tup_node))
155        );
156
157        assert_eq!(
158            h.apply_patch(RemoveLoadConstant(tup_node)),
159            Err(RemoveError::InvalidNode(tup_node))
160        );
161        let load_1_node = load_1.node();
162        let load_2_node = load_2.node();
163        let con_node = con_node.node();
164
165        let remove_1 = RemoveLoadConstant(load_1_node);
166        assert_eq!(
167            remove_1.invalidated_nodes(&h).exactly_one().ok(),
168            Some(load_1_node)
169        );
170
171        let remove_2 = RemoveLoadConstant(load_2_node);
172
173        let remove_con = RemoveConst(con_node);
174        assert_eq!(
175            remove_con.invalidated_nodes(&h).exactly_one().ok(),
176            Some(con_node)
177        );
178
179        // can't remove nodes in use
180        assert_eq!(
181            h.apply_patch(remove_1.clone()),
182            Err(RemoveError::ValueUsed(load_1_node))
183        );
184
185        // remove the use
186        h.remove_node(tup_node);
187
188        // remove first load
189        let reported_con_node = h.apply_patch(remove_1)?;
190        assert_eq!(reported_con_node, con_node);
191
192        // still can't remove const, in use by second load
193        assert_eq!(
194            h.apply_patch(remove_con.clone()),
195            Err(RemoveError::ValueUsed(con_node))
196        );
197
198        // remove second use
199        let reported_con_node = h.apply_patch(remove_2)?;
200        assert_eq!(reported_con_node, con_node);
201        // remove const
202        assert_eq!(h.apply_patch(remove_con)?, h.entrypoint());
203
204        assert_eq!(h.num_nodes(), 4);
205        assert!(h.validate().is_ok());
206        Ok(())
207    }
208}