hugr_core/hugr/patch/
consts.rs

1//! Rewrite operations involving Const and `LoadConst` operations
2
3use std::iter;
4
5use crate::{HugrView, Node, core::HugrNode, hugr::HugrMut};
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::{PatchHugrMut, PatchVerification};
10
11/// Remove a [`crate::ops::LoadConstant`] node with no consumers.
12#[derive(Debug, Clone)]
13pub struct RemoveLoadConstant<N = Node>(pub N);
14
15/// Error from an [`RemoveConst`] or [`RemoveLoadConstant`] operation.
16#[derive(Debug, Clone, Error, PartialEq, Eq)]
17#[non_exhaustive]
18pub enum RemoveError<N = Node> {
19    /// Invalid node.
20    #[error("Node is invalid (either not in HUGR or not correct operation).")]
21    InvalidNode(N),
22    /// Node in use.
23    #[error("Node: {0} has non-zero outgoing connections.")]
24    ValueUsed(N),
25}
26
27impl<N: HugrNode> PatchVerification for RemoveLoadConstant<N> {
28    type Error = RemoveError<N>;
29    type Node = N;
30
31    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
32        let node = self.0;
33
34        if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) {
35            return Err(RemoveError::InvalidNode(node));
36        }
37        let (p, _) = h
38            .out_value_types(node)
39            .exactly_one()
40            .ok()
41            .expect("LoadConstant has only one output.");
42        if h.linked_inputs(node, p).next().is_some() {
43            return Err(RemoveError::ValueUsed(node));
44        }
45
46        Ok(())
47    }
48
49    fn invalidation_set(&self) -> impl Iterator<Item = N> {
50        iter::once(self.0)
51    }
52}
53
54impl<N: HugrNode> PatchHugrMut for RemoveLoadConstant<N> {
55    /// The [`Const`](crate::ops::Const) node the [`LoadConstant`](crate::ops::LoadConstant) was
56    /// connected to.
57    type Outcome = N;
58
59    const UNCHANGED_ON_FAILURE: bool = true;
60    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
61        self.verify(h)?;
62        let node = self.0;
63        let source = h
64            .input_neighbours(node)
65            .exactly_one()
66            .ok()
67            .expect("Validation should check a Const is connected to LoadConstant.");
68        h.remove_node(node);
69
70        Ok(source)
71    }
72}
73
74/// Remove a [`crate::ops::Const`] node with no outputs.
75#[derive(Debug, Clone)]
76pub struct RemoveConst<N = Node>(pub N);
77
78impl<N: HugrNode> PatchVerification for RemoveConst<N> {
79    type Node = N;
80    type Error = RemoveError<N>;
81
82    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
83        let node = self.0;
84
85        if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) {
86            return Err(RemoveError::InvalidNode(node));
87        }
88
89        if h.output_neighbours(node).next().is_some() {
90            return Err(RemoveError::ValueUsed(node));
91        }
92
93        Ok(())
94    }
95
96    fn invalidation_set(&self) -> impl Iterator<Item = N> {
97        iter::once(self.0)
98    }
99}
100
101impl<N: HugrNode> PatchHugrMut for RemoveConst<N> {
102    // The parent of the Const node.
103    type Outcome = N;
104
105    const UNCHANGED_ON_FAILURE: bool = true;
106
107    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<Self::Outcome, Self::Error> {
108        self.verify(h)?;
109        let node = self.0;
110        let parent = h
111            .get_parent(node)
112            .expect("Const node without a parent shouldn't happen.");
113        h.remove_node(node);
114
115        Ok(parent)
116    }
117}
118
119#[cfg(test)]
120mod test {
121    use super::*;
122
123    use crate::{
124        builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer},
125        extension::prelude::ConstUsize,
126        ops::{Value, handle::NodeHandle},
127        type_row,
128        types::Signature,
129    };
130    #[test]
131    fn test_const_remove() -> Result<(), Box<dyn std::error::Error>> {
132        let mut build = ModuleBuilder::new();
133        let con_node = build.add_constant(Value::extension(ConstUsize::new(2)));
134
135        let mut dfg_build = build.define_function("main", Signature::new_endo(type_row![]))?;
136        let load_1 = dfg_build.load_const(&con_node);
137        let load_2 = dfg_build.load_const(&con_node);
138        let tup = dfg_build.make_tuple([load_1, load_2])?;
139        dfg_build.finish_sub_container()?;
140
141        let mut h = build.finish_hugr()?;
142        // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple
143        assert_eq!(h.num_nodes(), 8);
144        let tup_node = tup.node();
145        // can't remove invalid node
146        assert_eq!(
147            h.apply_patch(RemoveConst(tup_node)),
148            Err(RemoveError::InvalidNode(tup_node))
149        );
150
151        assert_eq!(
152            h.apply_patch(RemoveLoadConstant(tup_node)),
153            Err(RemoveError::InvalidNode(tup_node))
154        );
155        let load_1_node = load_1.node();
156        let load_2_node = load_2.node();
157        let con_node = con_node.node();
158
159        let remove_1 = RemoveLoadConstant(load_1_node);
160        assert_eq!(
161            remove_1.invalidation_set().exactly_one().ok(),
162            Some(load_1_node)
163        );
164
165        let remove_2 = RemoveLoadConstant(load_2_node);
166
167        let remove_con = RemoveConst(con_node);
168        assert_eq!(
169            remove_con.invalidation_set().exactly_one().ok(),
170            Some(con_node)
171        );
172
173        // can't remove nodes in use
174        assert_eq!(
175            h.apply_patch(remove_1.clone()),
176            Err(RemoveError::ValueUsed(load_1_node))
177        );
178
179        // remove the use
180        h.remove_node(tup_node);
181
182        // remove first load
183        let reported_con_node = h.apply_patch(remove_1)?;
184        assert_eq!(reported_con_node, con_node);
185
186        // still can't remove const, in use by second load
187        assert_eq!(
188            h.apply_patch(remove_con.clone()),
189            Err(RemoveError::ValueUsed(con_node))
190        );
191
192        // remove second use
193        let reported_con_node = h.apply_patch(remove_2)?;
194        assert_eq!(reported_con_node, con_node);
195        // remove const
196        assert_eq!(h.apply_patch(remove_con)?, h.entrypoint());
197
198        assert_eq!(h.num_nodes(), 4);
199        assert!(h.validate().is_ok());
200        Ok(())
201    }
202}