hugr_core/hugr/rewrite/
consts.rs1use std::iter;
4
5use crate::{hugr::HugrMut, HugrView, Node};
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::Rewrite;
10
11#[derive(Debug, Clone)]
13pub struct RemoveLoadConstant<N = Node>(pub N);
14
15#[derive(Debug, Clone, Error, PartialEq, Eq)]
17#[non_exhaustive]
18pub enum RemoveError {
19 #[error("Node is invalid (either not in HUGR or not correct operation).")]
21 InvalidNode(Node),
22 #[error("Node: {0} has non-zero outgoing connections.")]
24 ValueUsed(Node),
25}
26
27impl Rewrite for RemoveLoadConstant {
28 type Error = RemoveError;
29
30 type ApplyResult = Node;
32
33 const UNCHANGED_ON_FAILURE: bool = true;
34
35 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
36 let node = self.0;
37
38 if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
39 return Err(RemoveError::InvalidNode(node));
40 }
41 let (p, _) = h
42 .out_value_types(node)
43 .exactly_one()
44 .ok()
45 .expect("LoadConstant has only one output.");
46 if h.linked_inputs(node, p).next().is_some() {
47 return Err(RemoveError::ValueUsed(node));
48 }
49
50 Ok(())
51 }
52
53 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
54 self.verify(h)?;
55 let node = self.0;
56 let source = h
57 .input_neighbours(node)
58 .exactly_one()
59 .ok()
60 .expect("Validation should check a Const is connected to LoadConstant.");
61 h.remove_node(node);
62
63 Ok(source)
64 }
65
66 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
67 iter::once(self.0)
68 }
69}
70
71#[derive(Debug, Clone)]
73pub struct RemoveConst(pub Node);
74
75impl Rewrite for RemoveConst {
76 type Error = RemoveError;
77
78 type ApplyResult = Node;
80
81 const UNCHANGED_ON_FAILURE: bool = true;
82
83 fn verify(&self, h: &impl HugrView<Node = Node>) -> Result<(), Self::Error> {
84 let node = self.0;
85
86 if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
87 return Err(RemoveError::InvalidNode(node));
88 }
89
90 if h.output_neighbours(node).next().is_some() {
91 return Err(RemoveError::ValueUsed(node));
92 }
93
94 Ok(())
95 }
96
97 fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
98 self.verify(h)?;
99 let node = self.0;
100 let parent = h
101 .get_parent(node)
102 .expect("Const node without a parent shouldn't happen.");
103 h.remove_node(node);
104
105 Ok(parent)
106 }
107
108 fn invalidation_set(&self) -> impl Iterator<Item = Node> {
109 iter::once(self.0)
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use super::*;
116
117 use crate::extension::prelude::PRELUDE_ID;
118 use crate::{
119 builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
120 extension::prelude::ConstUsize,
121 ops::{handle::NodeHandle, Value},
122 type_row,
123 types::Signature,
124 };
125 #[test]
126 fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
127 let mut build = ModuleBuilder::new();
128 let con_node = build.add_constant(Value::extension(ConstUsize::new(2)));
129
130 let mut dfg_build = build.define_function(
131 "main",
132 Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID.clone()),
133 )?;
134 let load_1 = dfg_build.load_const(&con_node);
135 let load_2 = dfg_build.load_const(&con_node);
136 let tup = dfg_build.make_tuple([load_1, load_2])?;
137 dfg_build.finish_sub_container()?;
138
139 let mut h = build.finish_hugr()?;
140 assert_eq!(h.node_count(), 8);
142 let tup_node = tup.node();
143 assert_eq!(
145 h.apply_rewrite(RemoveConst(tup_node)),
146 Err(RemoveError::InvalidNode(tup_node))
147 );
148
149 assert_eq!(
150 h.apply_rewrite(RemoveLoadConstant(tup_node)),
151 Err(RemoveError::InvalidNode(tup_node))
152 );
153 let load_1_node = load_1.node();
154 let load_2_node = load_2.node();
155 let con_node = con_node.node();
156
157 let remove_1 = RemoveLoadConstant(load_1_node);
158 assert_eq!(
159 remove_1.invalidation_set().exactly_one().ok(),
160 Some(load_1_node)
161 );
162
163 let remove_2 = RemoveLoadConstant(load_2_node);
164
165 let remove_con = RemoveConst(con_node);
166 assert_eq!(
167 remove_con.invalidation_set().exactly_one().ok(),
168 Some(con_node)
169 );
170
171 assert_eq!(
173 h.apply_rewrite(remove_1.clone()),
174 Err(RemoveError::ValueUsed(load_1_node))
175 );
176
177 h.remove_node(tup_node);
179
180 let reported_con_node = h.apply_rewrite(remove_1)?;
182 assert_eq!(reported_con_node, con_node);
183
184 assert_eq!(
186 h.apply_rewrite(remove_con.clone()),
187 Err(RemoveError::ValueUsed(con_node))
188 );
189
190 let reported_con_node = h.apply_rewrite(remove_2)?;
192 assert_eq!(reported_con_node, con_node);
193 assert_eq!(h.apply_rewrite(remove_con)?, h.root());
195
196 assert_eq!(h.node_count(), 4);
197 assert!(h.validate().is_ok());
198 Ok(())
199 }
200}