hugr_core/hugr/patch/
peel_loop.rs

1//! Rewrite to peel one iteration of a [TailLoop], creating a [DFG] containing a copy of
2//! the loop body, and a [Conditional] containing the original `TailLoop` node.
3use derive_more::{Display, Error};
4
5use crate::core::HugrNode;
6use crate::ops::{
7    Case, Conditional, DFG, DataflowOpTrait, Input, OpTrait, OpType, Output, TailLoop,
8};
9use crate::types::Signature;
10use crate::{Direction, HugrView, Node};
11
12use super::{HugrMut, PatchHugrMut, PatchVerification};
13
14/// Rewrite that peels one iteration of a [TailLoop] by turning the
15/// iteration test into a [Conditional].
16#[derive(Clone, Debug, PartialEq)]
17pub struct PeelTailLoop<N = Node>(N);
18
19/// Error in performing [`PeelTailLoop`] rewrite.
20#[derive(Clone, Debug, Display, Error, PartialEq)]
21#[non_exhaustive]
22pub enum PeelTailLoopError<N = Node> {
23    /// The specified Node was not a [`TailLoop`]
24    #[display("Node to peel {node} expected to be a TailLoop but actually {op}")]
25    NotTailLoop {
26        /// The node requested to peel
27        node: N,
28        /// The actual (non-tail-loop) operation
29        op: OpType,
30    },
31}
32
33impl<N> PeelTailLoop<N> {
34    /// Create a new instance that will peel the specified [TailLoop] node
35    pub fn new(node: N) -> Self {
36        Self(node)
37    }
38}
39
40impl<N: HugrNode> PatchVerification for PeelTailLoop<N> {
41    type Error = PeelTailLoopError<N>;
42    type Node = N;
43    fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
44        let opty = h.get_optype(self.0);
45        if !opty.is_tail_loop() {
46            return Err(PeelTailLoopError::NotTailLoop {
47                node: self.0,
48                op: opty.clone(),
49            });
50        }
51        Ok(())
52    }
53
54    fn invalidated_nodes(&self, h: &impl HugrView<Node = N>) -> impl Iterator<Item = N> {
55        h.get_io(self.0)
56            .into_iter()
57            .flat_map(|[_, output]| [self.0, output].into_iter())
58    }
59}
60
61impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
62    type Outcome = ();
63    fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
64        self.verify(h)?; // Now we know we have a TailLoop!
65        let loop_ty = h.optype_mut(self.0);
66        let signature = loop_ty.dataflow_signature().unwrap().into_owned();
67        // Replace the TailLoop with a DFG - this maintains all external connections
68        let OpType::TailLoop(tl) = std::mem::replace(loop_ty, DFG { signature }.into()) else {
69            panic!("Wasn't a TailLoop ?!")
70        };
71        let sum_rows = Vec::from(tl.control_variants());
72        let rest = tl.rest.clone();
73        let Signature {
74            input: loop_in,
75            output: loop_out,
76        } = tl.signature().into_owned();
77
78        // Copy the DFG (ex-TailLoop) children into a new TailLoop *before* we add any more
79        let new_loop = h.add_node_after(self.0, tl); // Temporary parent
80        h.copy_descendants(self.0, new_loop, None);
81
82        // Add conditional inside DFG.
83        let [_, dfg_out] = h.get_io(self.0).unwrap();
84        let cond = Conditional {
85            sum_rows,
86            other_inputs: rest,
87            outputs: loop_out.clone(),
88        };
89        let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap());
90        // This preserves all edges from the end of the loop body to the conditional:
91        h.replace_op(dfg_out, cond);
92        let cond_n = dfg_out;
93        h.add_ports(cond_n, Direction::Outgoing, loop_out.len() as isize + 1);
94        let dfg_out = h.add_node_before(
95            cond_n,
96            Output {
97                types: loop_out.clone(),
98            },
99        );
100        for p in 0..loop_out.len() {
101            h.connect(cond_n, p, dfg_out, p)
102        }
103
104        // Now wire up the internals of the Conditional
105        let cases = case_in_rows.map(|in_row| {
106            let signature = Signature::new(in_row.clone(), loop_out.clone());
107            let n = h.add_node_with_parent(cond_n, Case { signature });
108            h.add_node_with_parent(n, Input { types: in_row });
109            let types = loop_out.clone();
110            h.add_node_with_parent(n, Output { types });
111            n
112        });
113
114        h.set_parent(new_loop, cases[TailLoop::CONTINUE_TAG]);
115        let [ctn_in, ctn_out] = h.get_io(cases[TailLoop::CONTINUE_TAG]).unwrap();
116        let [brk_in, brk_out] = h.get_io(cases[TailLoop::BREAK_TAG]).unwrap();
117        for p in 0..loop_out.len() {
118            h.connect(brk_in, p, brk_out, p);
119            h.connect(new_loop, p, ctn_out, p)
120        }
121        for p in 0..loop_in.len() {
122            h.connect(ctn_in, p, new_loop, p);
123        }
124        Ok(())
125    }
126
127    /// Failure only occurs if the node is not a [TailLoop].
128    /// (Any later failure means an invalid Hugr and `panic`.)
129    const UNCHANGED_ON_FAILURE: bool = true;
130}
131
132#[cfg(test)]
133mod test {
134    use itertools::Itertools;
135
136    use crate::builder::test::simple_dfg_hugr;
137    use crate::builder::{
138        Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
139    };
140    use crate::extension::prelude::{bool_t, usize_t};
141    use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle};
142    use crate::std_extensions::arithmetic::int_types::INT_TYPES;
143    use crate::types::{Signature, Type, TypeRow};
144    use crate::{HugrView, hugr::HugrMut};
145
146    use super::{PeelTailLoop, PeelTailLoopError};
147
148    #[test]
149    fn bad_peel() {
150        let backup = simple_dfg_hugr();
151        let op = backup.entrypoint_optype().clone();
152        assert!(!op.is_tail_loop());
153        let mut h = backup.clone();
154        let r = h.apply_patch(PeelTailLoop::new(h.entrypoint()));
155        assert_eq!(
156            r,
157            Err(PeelTailLoopError::NotTailLoop {
158                node: backup.entrypoint(),
159                op
160            })
161        );
162        assert_eq!(h, backup);
163    }
164
165    #[test]
166    fn peel_loop_incoming_edges() {
167        let i32_t = || INT_TYPES[5].clone();
168        let mut fb = FunctionBuilder::new(
169            "main",
170            Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()),
171        )
172        .unwrap();
173        let helper = fb
174            .module_root_builder()
175            .declare(
176                "helper",
177                Signature::new(
178                    vec![bool_t(), usize_t(), i32_t()],
179                    vec![Type::new_sum([vec![bool_t(); 2], vec![]]), usize_t()],
180                )
181                .into(),
182            )
183            .unwrap();
184        let [b, u, i] = fb.input_wires_arr();
185        let (tl, call) = {
186            let mut tlb = fb
187                .tail_loop_builder(
188                    [(bool_t(), b), (bool_t(), b)],
189                    [(usize_t(), u)],
190                    TypeRow::new(),
191                )
192                .unwrap();
193            let [b, _, u] = tlb.input_wires_arr();
194            // Static edge from FuncDecl, and 'ext' edge from function Input:
195            let c = tlb.call(&helper, &[], [b, u, i]).unwrap();
196            let [pred, other] = c.outputs_arr();
197            (tlb.finish_with_outputs(pred, [other]).unwrap(), c.node())
198        };
199        let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap();
200
201        h.apply_patch(PeelTailLoop::new(tl.node())).unwrap();
202        h.validate().unwrap();
203
204        assert_eq!(
205            h.nodes()
206                .filter(|n| h.get_optype(*n).is_tail_loop())
207                .count(),
208            1
209        );
210        use OpTag::*;
211        assert_eq!(tags(&h, call), [FnCall, Dfg, FuncDefn, ModuleRoot]);
212        let [c1, c2] = h
213            .all_linked_inputs(helper.node())
214            .map(|(n, _p)| n)
215            .collect_array()
216            .unwrap();
217        assert!([c1, c2].contains(&call));
218        let other = if call == c1 { c2 } else { c1 };
219        assert_eq!(
220            tags(&h, other),
221            [
222                FnCall,
223                TailLoop,
224                Case,
225                Conditional,
226                Dfg,
227                FuncDefn,
228                ModuleRoot
229            ]
230        );
231    }
232
233    fn tags<H: HugrView>(h: &H, n: H::Node) -> Vec<OpTag> {
234        let mut v = Vec::new();
235        let mut o = Some(n);
236        while let Some(n) = o {
237            v.push(h.get_optype(n).tag());
238            o = h.get_parent(n);
239        }
240        v
241    }
242
243    #[test]
244    fn peel_loop_order_output() {
245        let i16_t = || INT_TYPES[4].clone();
246        let mut fb =
247            FunctionBuilder::new("main", Signature::new(vec![i16_t(), bool_t()], i16_t())).unwrap();
248
249        let [i, b] = fb.input_wires_arr();
250        let tl = {
251            let mut tlb = fb
252                .tail_loop_builder([(i16_t(), i), (bool_t(), b)], [], i16_t().into())
253                .unwrap();
254            let [i, _b] = tlb.input_wires_arr();
255            // This loop only goes round once. However, we do not expect this to affect
256            // peeling: *dataflow analysis* can tell us that the conditional will always
257            // take one Case (that does not contain the TailLoop), we do not do that here.
258            let [cont] = tlb
259                .add_dataflow_op(
260                    Tag::new(
261                        TailLoop::BREAK_TAG,
262                        tlb.loop_signature().unwrap().control_variants().into(),
263                    ),
264                    [i],
265                )
266                .unwrap()
267                .outputs_arr();
268            tlb.finish_with_outputs(cont, []).unwrap()
269        };
270        let [i2] = tl.outputs_arr();
271        // Create a DFG (no inputs, one output) that reads the result of the TailLoop via an 'ext` edge
272        let dfg = fb
273            .dfg_builder(Signature::new(vec![], i16_t()), [])
274            .unwrap()
275            .finish_with_outputs([i2])
276            .unwrap();
277        let mut h = fb.finish_hugr_with_outputs(dfg.outputs()).unwrap();
278        let tl = tl.node();
279
280        h.apply_patch(PeelTailLoop::new(tl)).unwrap();
281        h.validate().unwrap();
282        let [tl] = h
283            .nodes()
284            .filter(|n| h.get_optype(*n).is_tail_loop())
285            .collect_array()
286            .unwrap();
287        {
288            use OpTag::*;
289            assert_eq!(
290                tags(&h, tl),
291                [TailLoop, Case, Conditional, Dfg, FuncDefn, ModuleRoot]
292            );
293        }
294        let [out_n] = h.output_neighbours(tl).collect_array().unwrap();
295        assert!(h.get_optype(out_n).is_output());
296        assert_eq!(h.get_parent(tl), h.get_parent(out_n));
297    }
298}