hugr_core/hugr/patch/
consts.rs1use std::iter;
4
5use crate::{HugrView, Node, core::HugrNode, hugr::HugrMut};
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::{PatchHugrMut, PatchVerification};
10
11#[derive(Debug, Clone)]
13pub struct RemoveLoadConstant<N = Node>(pub N);
14
15#[derive(Debug, Clone, Error, PartialEq, Eq)]
17#[non_exhaustive]
18pub enum RemoveError<N = Node> {
19 #[error("Node is invalid (either not in HUGR or not correct operation).")]
21 InvalidNode(N),
22 #[error("Node: {0} has non-zero outgoing connections.")]
24 ValueUsed(N),
25}
26
27impl<N: HugrNode> PatchVerification for RemoveLoadConstant<N> {
28 type Error = RemoveError<N>;
29 type Node = N;
30
31 fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
32 let node = self.0;
33
34 if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
35 return Err(RemoveError::InvalidNode(node));
36 }
37 let (p, _) = h
38 .out_value_types(node)
39 .exactly_one()
40 .ok()
41 .expect("LoadConstant has only one output.");
42 if h.linked_inputs(node, p).next().is_some() {
43 return Err(RemoveError::ValueUsed(node));
44 }
45
46 Ok(())
47 }
48
49 fn invalidated_nodes(
50 &self,
51 _: &impl HugrView<Node = Self::Node>,
52 ) -> impl Iterator<Item = Self::Node> {
53 iter::once(self.0)
54 }
55}
56
57impl<N: HugrNode> PatchHugrMut for RemoveLoadConstant<N> {
58 type Outcome = N;
61
62 const UNCHANGED_ON_FAILURE: bool = true;
63 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
64 self.verify(h)?;
65 let node = self.0;
66 let source = h
67 .input_neighbours(node)
68 .exactly_one()
69 .ok()
70 .expect("Validation should check a Const is connected to LoadConstant.");
71 h.remove_node(node);
72
73 Ok(source)
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct RemoveConst<N = Node>(pub N);
80
81impl<N: HugrNode> PatchVerification for RemoveConst<N> {
82 type Node = N;
83 type Error = RemoveError<N>;
84
85 fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
86 let node = self.0;
87
88 if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
89 return Err(RemoveError::InvalidNode(node));
90 }
91
92 if h.output_neighbours(node).next().is_some() {
93 return Err(RemoveError::ValueUsed(node));
94 }
95
96 Ok(())
97 }
98
99 fn invalidated_nodes(
100 &self,
101 _: &impl HugrView<Node = Self::Node>,
102 ) -> impl Iterator<Item = Self::Node> {
103 iter::once(self.0)
104 }
105}
106
107impl<N: HugrNode> PatchHugrMut for RemoveConst<N> {
108 type Outcome = N;
110
111 const UNCHANGED_ON_FAILURE: bool = true;
112
113 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
114 self.verify(h)?;
115 let node = self.0;
116 let parent = h
117 .get_parent(node)
118 .expect("Const node without a parent shouldn't happen.");
119 h.remove_node(node);
120
121 Ok(parent)
122 }
123}
124
125#[cfg(test)]
126mod test {
127 use super::*;
128
129 use crate::{
130 builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
131 extension::prelude::ConstUsize,
132 ops::{Value, handle::NodeHandle},
133 type_row,
134 types::Signature,
135 };
136 #[test]
137 fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
138 let mut build = ModuleBuilder::new();
139 let con_node = build.add_constant(Value::extension(ConstUsize::new(2)));
140
141 let mut dfg_build = build.define_function("main", Signature::new_endo(type_row![]))?;
142 let load_1 = dfg_build.load_const(&con_node);
143 let load_2 = dfg_build.load_const(&con_node);
144 let tup = dfg_build.make_tuple([load_1, load_2])?;
145 dfg_build.finish_sub_container()?;
146
147 let mut h = build.finish_hugr()?;
148 assert_eq!(h.num_nodes(), 8);
150 let tup_node = tup.node();
151 assert_eq!(
153 h.apply_patch(RemoveConst(tup_node)),
154 Err(RemoveError::InvalidNode(tup_node))
155 );
156
157 assert_eq!(
158 h.apply_patch(RemoveLoadConstant(tup_node)),
159 Err(RemoveError::InvalidNode(tup_node))
160 );
161 let load_1_node = load_1.node();
162 let load_2_node = load_2.node();
163 let con_node = con_node.node();
164
165 let remove_1 = RemoveLoadConstant(load_1_node);
166 assert_eq!(
167 remove_1.invalidated_nodes(&h).exactly_one().ok(),
168 Some(load_1_node)
169 );
170
171 let remove_2 = RemoveLoadConstant(load_2_node);
172
173 let remove_con = RemoveConst(con_node);
174 assert_eq!(
175 remove_con.invalidated_nodes(&h).exactly_one().ok(),
176 Some(con_node)
177 );
178
179 assert_eq!(
181 h.apply_patch(remove_1.clone()),
182 Err(RemoveError::ValueUsed(load_1_node))
183 );
184
185 h.remove_node(tup_node);
187
188 let reported_con_node = h.apply_patch(remove_1)?;
190 assert_eq!(reported_con_node, con_node);
191
192 assert_eq!(
194 h.apply_patch(remove_con.clone()),
195 Err(RemoveError::ValueUsed(con_node))
196 );
197
198 let reported_con_node = h.apply_patch(remove_2)?;
200 assert_eq!(reported_con_node, con_node);
201 assert_eq!(h.apply_patch(remove_con)?, h.entrypoint());
203
204 assert_eq!(h.num_nodes(), 4);
205 assert!(h.validate().is_ok());
206 Ok(())
207 }
208}