hugr_core/hugr/rewrite/
consts.rs

1//! Rewrite operations involving Const and LoadConst operations
2
3use std::iter;
4
5use crate::{hugr::HugrMut, HugrView, Node};
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::Rewrite;
10
11/// Remove a [`crate::ops::LoadConstant`] node with no consumers.
12#[derive(Debug, Clone)]
13pub struct RemoveLoadConstant<N = Node>(pub N);
14
15/// Error from an [`RemoveConst`] or [`RemoveLoadConstant`] operation.
16#[derive(Debug, Clone, Error, PartialEq, Eq)]
17#[non_exhaustive]
18pub enum RemoveError {
19    /// Invalid node.
20    #[error("Node is invalid (either not in HUGR or not correct operation).")]
21    InvalidNode(Node),
22    /// Node in use.
23    #[error("Node: {0} has non-zero outgoing connections.")]
24    ValueUsed(Node),
25}
26
27impl Rewrite for RemoveLoadConstant {
28    type Error = RemoveError;
29
30    // The Const node the LoadConstant was connected to.
31    type ApplyResult = Node;
32
33    const UNCHANGED_ON_FAILURE: bool = true;
34
35    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
36        let node = self.0;
37
38        if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
39            return Err(RemoveError::InvalidNode(node));
40        }
41        let (p, _) = h
42            .out_value_types(node)
43            .exactly_one()
44            .ok()
45            .expect("LoadConstant has only one output.");
46        if h.linked_inputs(node, p).next().is_some() {
47            return Err(RemoveError::ValueUsed(node));
48        }
49
50        Ok(())
51    }
52
53    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
54        self.verify(h)?;
55        let node = self.0;
56        let source = h
57            .input_neighbours(node)
58            .exactly_one()
59            .ok()
60            .expect("Validation should check a Const is connected to LoadConstant.");
61        h.remove_node(node);
62
63        Ok(source)
64    }
65
66    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
67        iter::once(self.0)
68    }
69}
70
71/// Remove a [`crate::ops::Const`] node with no outputs.
72#[derive(Debug, Clone)]
73pub struct RemoveConst(pub Node);
74
75impl Rewrite for RemoveConst {
76    type Error = RemoveError;
77
78    // The parent of the Const node.
79    type ApplyResult = Node;
80
81    const UNCHANGED_ON_FAILURE: bool = true;
82
83    fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
84        let node = self.0;
85
86        if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
87            return Err(RemoveError::InvalidNode(node));
88        }
89
90        if h.output_neighbours(node).next().is_some() {
91            return Err(RemoveError::ValueUsed(node));
92        }
93
94        Ok(())
95    }
96
97    fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
98        self.verify(h)?;
99        let node = self.0;
100        let parent = h
101            .get_parent(node)
102            .expect("Const node without a parent shouldn't happen.");
103        h.remove_node(node);
104
105        Ok(parent)
106    }
107
108    fn invalidation_set(&self) -> impl Iterator<Item = Node> {
109        iter::once(self.0)
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116
117    use crate::extension::prelude::PRELUDE_ID;
118    use crate::{
119        builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
120        extension::prelude::ConstUsize,
121        ops::{handle::NodeHandle, Value},
122        type_row,
123        types::Signature,
124    };
125    #[test]
126    fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
127        let mut build = ModuleBuilder::new();
128        let con_node = build.add_constant(Value::extension(ConstUsize::new(2)));
129
130        let mut dfg_build = build.define_function(
131            "main",
132            Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID.clone()),
133        )?;
134        let load_1 = dfg_build.load_const(&con_node);
135        let load_2 = dfg_build.load_const(&con_node);
136        let tup = dfg_build.make_tuple([load_1, load_2])?;
137        dfg_build.finish_sub_container()?;
138
139        let mut h = build.finish_hugr()?;
140        // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple
141        assert_eq!(h.node_count(), 8);
142        let tup_node = tup.node();
143        // can't remove invalid node
144        assert_eq!(
145            h.apply_rewrite(RemoveConst(tup_node)),
146            Err(RemoveError::InvalidNode(tup_node))
147        );
148
149        assert_eq!(
150            h.apply_rewrite(RemoveLoadConstant(tup_node)),
151            Err(RemoveError::InvalidNode(tup_node))
152        );
153        let load_1_node = load_1.node();
154        let load_2_node = load_2.node();
155        let con_node = con_node.node();
156
157        let remove_1 = RemoveLoadConstant(load_1_node);
158        assert_eq!(
159            remove_1.invalidation_set().exactly_one().ok(),
160            Some(load_1_node)
161        );
162
163        let remove_2 = RemoveLoadConstant(load_2_node);
164
165        let remove_con = RemoveConst(con_node);
166        assert_eq!(
167            remove_con.invalidation_set().exactly_one().ok(),
168            Some(con_node)
169        );
170
171        // can't remove nodes in use
172        assert_eq!(
173            h.apply_rewrite(remove_1.clone()),
174            Err(RemoveError::ValueUsed(load_1_node))
175        );
176
177        // remove the use
178        h.remove_node(tup_node);
179
180        // remove first load
181        let reported_con_node = h.apply_rewrite(remove_1)?;
182        assert_eq!(reported_con_node, con_node);
183
184        // still can't remove const, in use by second load
185        assert_eq!(
186            h.apply_rewrite(remove_con.clone()),
187            Err(RemoveError::ValueUsed(con_node))
188        );
189
190        // remove second use
191        let reported_con_node = h.apply_rewrite(remove_2)?;
192        assert_eq!(reported_con_node, con_node);
193        // remove const
194        assert_eq!(h.apply_rewrite(remove_con)?, h.root());
195
196        assert_eq!(h.node_count(), 4);
197        assert!(h.validate().is_ok());
198        Ok(())
199    }
200}