1use derive_more::{Display, Error};
4
5use crate::core::HugrNode;
6use crate::ops::{
7 Case, Conditional, DFG, DataflowOpTrait, Input, OpTrait, OpType, Output, TailLoop,
8};
9use crate::types::Signature;
10use crate::{Direction, HugrView, Node};
11
12use super::{HugrMut, PatchHugrMut, PatchVerification};
13
14#[derive(Clone, Debug, PartialEq)]
17pub struct PeelTailLoop<N = Node>(N);
18
19#[derive(Clone, Debug, Display, Error, PartialEq)]
21#[non_exhaustive]
22pub enum PeelTailLoopError<N = Node> {
23 #[display("Node to peel {node} expected to be a TailLoop but actually {op}")]
25 NotTailLoop {
26 node: N,
28 op: OpType,
30 },
31}
32
33impl<N> PeelTailLoop<N> {
34 pub fn new(node: N) -> Self {
36 Self(node)
37 }
38}
39
40impl<N: HugrNode> PatchVerification for PeelTailLoop<N> {
41 type Error = PeelTailLoopError<N>;
42 type Node = N;
43 fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
44 let opty = h.get_optype(self.0);
45 if !opty.is_tail_loop() {
46 return Err(PeelTailLoopError::NotTailLoop {
47 node: self.0,
48 op: opty.clone(),
49 });
50 }
51 Ok(())
52 }
53
54 fn invalidated_nodes(&self, h: &impl HugrView<Node = N>) -> impl Iterator<Item = N> {
55 h.get_io(self.0)
56 .into_iter()
57 .flat_map(|[_, output]| [self.0, output].into_iter())
58 }
59}
60
61impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
62 type Outcome = ();
63 fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
64 self.verify(h)?; let loop_ty = h.optype_mut(self.0);
66 let signature = loop_ty.dataflow_signature().unwrap().into_owned();
67 let OpType::TailLoop(tl) = std::mem::replace(loop_ty, DFG { signature }.into()) else {
69 panic!("Wasn't a TailLoop ?!")
70 };
71 let sum_rows = Vec::from(tl.control_variants());
72 let rest = tl.rest.clone();
73 let Signature {
74 input: loop_in,
75 output: loop_out,
76 } = tl.signature().into_owned();
77
78 let new_loop = h.add_node_after(self.0, tl); h.copy_descendants(self.0, new_loop, None);
81
82 let [_, dfg_out] = h.get_io(self.0).unwrap();
84 let cond = Conditional {
85 sum_rows,
86 other_inputs: rest,
87 outputs: loop_out.clone(),
88 };
89 let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap());
90 h.replace_op(dfg_out, cond);
92 let cond_n = dfg_out;
93 h.add_ports(cond_n, Direction::Outgoing, loop_out.len() as isize + 1);
94 let dfg_out = h.add_node_before(
95 cond_n,
96 Output {
97 types: loop_out.clone(),
98 },
99 );
100 for p in 0..loop_out.len() {
101 h.connect(cond_n, p, dfg_out, p)
102 }
103
104 let cases = case_in_rows.map(|in_row| {
106 let signature = Signature::new(in_row.clone(), loop_out.clone());
107 let n = h.add_node_with_parent(cond_n, Case { signature });
108 h.add_node_with_parent(n, Input { types: in_row });
109 let types = loop_out.clone();
110 h.add_node_with_parent(n, Output { types });
111 n
112 });
113
114 h.set_parent(new_loop, cases[TailLoop::CONTINUE_TAG]);
115 let [ctn_in, ctn_out] = h.get_io(cases[TailLoop::CONTINUE_TAG]).unwrap();
116 let [brk_in, brk_out] = h.get_io(cases[TailLoop::BREAK_TAG]).unwrap();
117 for p in 0..loop_out.len() {
118 h.connect(brk_in, p, brk_out, p);
119 h.connect(new_loop, p, ctn_out, p)
120 }
121 for p in 0..loop_in.len() {
122 h.connect(ctn_in, p, new_loop, p);
123 }
124 Ok(())
125 }
126
127 const UNCHANGED_ON_FAILURE: bool = true;
130}
131
132#[cfg(test)]
133mod test {
134 use itertools::Itertools;
135
136 use crate::builder::test::simple_dfg_hugr;
137 use crate::builder::{
138 Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
139 };
140 use crate::extension::prelude::{bool_t, usize_t};
141 use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle};
142 use crate::std_extensions::arithmetic::int_types::INT_TYPES;
143 use crate::types::{Signature, Type, TypeRow};
144 use crate::{HugrView, hugr::HugrMut};
145
146 use super::{PeelTailLoop, PeelTailLoopError};
147
148 #[test]
149 fn bad_peel() {
150 let backup = simple_dfg_hugr();
151 let op = backup.entrypoint_optype().clone();
152 assert!(!op.is_tail_loop());
153 let mut h = backup.clone();
154 let r = h.apply_patch(PeelTailLoop::new(h.entrypoint()));
155 assert_eq!(
156 r,
157 Err(PeelTailLoopError::NotTailLoop {
158 node: backup.entrypoint(),
159 op
160 })
161 );
162 assert_eq!(h, backup);
163 }
164
165 #[test]
166 fn peel_loop_incoming_edges() {
167 let i32_t = || INT_TYPES[5].clone();
168 let mut fb = FunctionBuilder::new(
169 "main",
170 Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()),
171 )
172 .unwrap();
173 let helper = fb
174 .module_root_builder()
175 .declare(
176 "helper",
177 Signature::new(
178 vec![bool_t(), usize_t(), i32_t()],
179 vec![Type::new_sum([vec![bool_t(); 2], vec![]]), usize_t()],
180 )
181 .into(),
182 )
183 .unwrap();
184 let [b, u, i] = fb.input_wires_arr();
185 let (tl, call) = {
186 let mut tlb = fb
187 .tail_loop_builder(
188 [(bool_t(), b), (bool_t(), b)],
189 [(usize_t(), u)],
190 TypeRow::new(),
191 )
192 .unwrap();
193 let [b, _, u] = tlb.input_wires_arr();
194 let c = tlb.call(&helper, &[], [b, u, i]).unwrap();
196 let [pred, other] = c.outputs_arr();
197 (tlb.finish_with_outputs(pred, [other]).unwrap(), c.node())
198 };
199 let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap();
200
201 h.apply_patch(PeelTailLoop::new(tl.node())).unwrap();
202 h.validate().unwrap();
203
204 assert_eq!(
205 h.nodes()
206 .filter(|n| h.get_optype(*n).is_tail_loop())
207 .count(),
208 1
209 );
210 use OpTag::*;
211 assert_eq!(tags(&h, call), [FnCall, Dfg, FuncDefn, ModuleRoot]);
212 let [c1, c2] = h
213 .all_linked_inputs(helper.node())
214 .map(|(n, _p)| n)
215 .collect_array()
216 .unwrap();
217 assert!([c1, c2].contains(&call));
218 let other = if call == c1 { c2 } else { c1 };
219 assert_eq!(
220 tags(&h, other),
221 [
222 FnCall,
223 TailLoop,
224 Case,
225 Conditional,
226 Dfg,
227 FuncDefn,
228 ModuleRoot
229 ]
230 );
231 }
232
233 fn tags<H: HugrView>(h: &H, n: H::Node) -> Vec<OpTag> {
234 let mut v = Vec::new();
235 let mut o = Some(n);
236 while let Some(n) = o {
237 v.push(h.get_optype(n).tag());
238 o = h.get_parent(n);
239 }
240 v
241 }
242
243 #[test]
244 fn peel_loop_order_output() {
245 let i16_t = || INT_TYPES[4].clone();
246 let mut fb =
247 FunctionBuilder::new("main", Signature::new(vec![i16_t(), bool_t()], i16_t())).unwrap();
248
249 let [i, b] = fb.input_wires_arr();
250 let tl = {
251 let mut tlb = fb
252 .tail_loop_builder([(i16_t(), i), (bool_t(), b)], [], i16_t().into())
253 .unwrap();
254 let [i, _b] = tlb.input_wires_arr();
255 let [cont] = tlb
259 .add_dataflow_op(
260 Tag::new(
261 TailLoop::BREAK_TAG,
262 tlb.loop_signature().unwrap().control_variants().into(),
263 ),
264 [i],
265 )
266 .unwrap()
267 .outputs_arr();
268 tlb.finish_with_outputs(cont, []).unwrap()
269 };
270 let [i2] = tl.outputs_arr();
271 let dfg = fb
273 .dfg_builder(Signature::new(vec![], i16_t()), [])
274 .unwrap()
275 .finish_with_outputs([i2])
276 .unwrap();
277 let mut h = fb.finish_hugr_with_outputs(dfg.outputs()).unwrap();
278 let tl = tl.node();
279
280 h.apply_patch(PeelTailLoop::new(tl)).unwrap();
281 h.validate().unwrap();
282 let [tl] = h
283 .nodes()
284 .filter(|n| h.get_optype(*n).is_tail_loop())
285 .collect_array()
286 .unwrap();
287 {
288 use OpTag::*;
289 assert_eq!(
290 tags(&h, tl),
291 [TailLoop, Case, Conditional, Dfg, FuncDefn, ModuleRoot]
292 );
293 }
294 let [out_n] = h.output_neighbours(tl).collect_array().unwrap();
295 assert!(h.get_optype(out_n).is_output());
296 assert_eq!(h.get_parent(tl), h.get_parent(out_n));
297 }
298}