calyx_opt/traversal/
visitor.rs

1//! Implements a visitor for `ir::Control` programs.
2//! Program passes implemented as the Visitor are directly invoked on
3//! [`ir::Context`] to compile every [`ir::Component`] using the pass.
4use super::action::{Action, VisResult};
5use super::{CompTraversal, ConstructVisitor, Named, Order};
6use calyx_ir::{
7    self as ir, Component, Context, Control, LibrarySignatures, StaticControl,
8};
9use calyx_utils::CalyxResult;
10use std::rc::Rc;
11
12/// The visiting interface for a [`ir::Control`](crate::Control) program.
13/// Contains two kinds of functions:
14/// 1. start_<node>: Called when visiting <node> top-down.
15/// 2. finish_<node>: Called when visiting <node> bottow-up.
16///
17/// A pass will usually override one or more function and rely on the default
18/// visitors to automatically visit the children.
19pub trait Visitor {
20    /// Precondition for this pass to run on the program. If this function returns
21    /// None, the pass triggers. Otherwise it aborts and logs the string as the reason.
22    fn precondition(_ctx: &ir::Context) -> Option<String>
23    where
24        Self: Sized,
25    {
26        None
27    }
28
29    #[inline(always)]
30    /// Transform the [`ir::Context`] before visiting the components.
31    fn start_context(&mut self, _ctx: &mut ir::Context) -> VisResult {
32        Ok(Action::Continue)
33    }
34
35    #[inline(always)]
36    /// Transform the [`ir::Context`] after visiting the components.
37    fn finish_context(&mut self, _ctx: &mut ir::Context) -> VisResult {
38        Ok(Action::Continue)
39    }
40
41    /// Define the iteration order in which components should be visited
42    #[inline(always)]
43    fn iteration_order() -> Order
44    where
45        Self: Sized,
46    {
47        Order::No
48    }
49
50    /// Define the traversal over a component.
51    /// Calls [Visitor::start], visits each control node, and finally calls
52    /// [Visitor::finish].
53    fn traverse_component(
54        &mut self,
55        comp: &mut ir::Component,
56        signatures: &LibrarySignatures,
57        components: &[Component],
58    ) -> CalyxResult<()>
59    where
60        Self: Sized,
61    {
62        self.start(comp, signatures, components)?
63            .and_then(|| {
64                // Create a clone of the reference to the Control
65                // program.
66                let control_ref = Rc::clone(&comp.control);
67                if let Control::Empty(_) = &*control_ref.borrow() {
68                    // Don't traverse if the control program is empty.
69                    return Ok(Action::Continue);
70                }
71                // Mutably borrow the control program and traverse.
72                control_ref
73                    .borrow_mut()
74                    .visit(self, comp, signatures, components)?;
75                Ok(Action::Continue)
76            })?
77            .and_then(|| self.finish(comp, signatures, components))?
78            .apply_change(&mut comp.control.borrow_mut());
79        Ok(())
80    }
81
82    /// Run the visitor on a given program [`ir::Context`].
83    /// The function mutably borrows the `control` program in each component and
84    /// traverses it.
85    ///
86    /// After visiting a component, it called [ConstructVisitor::clear_data] to
87    /// reset the struct.
88    ///
89    /// # Panics
90    /// Panics if the pass attempts to use the control program mutably.
91    fn do_pass(&mut self, context: &mut Context) -> CalyxResult<()>
92    where
93        Self: Sized + ConstructVisitor + Named,
94    {
95        if let Some(msg) = Self::precondition(&*context) {
96            log::info!("Skipping `{}': {msg}", Self::name());
97            return Ok(());
98        }
99
100        self.start_context(context)?;
101
102        let signatures = &context.lib;
103        let comps = std::mem::take(&mut context.components);
104
105        // Temporarily take ownership of components from context.
106        let mut po = CompTraversal::new(comps, Self::iteration_order());
107        po.apply_update(|comp, comps| {
108            self.traverse_component(comp, signatures, comps)?;
109            self.clear_data();
110            Ok(())
111        })?;
112        context.components = po.take();
113
114        self.finish_context(context)?;
115
116        Ok(())
117    }
118
119    /// Build a [Default] implementation of this pass and call [Visitor::do_pass]
120    /// using it.
121    #[inline(always)]
122    fn do_pass_default(context: &mut Context) -> CalyxResult<Self>
123    where
124        Self: ConstructVisitor + Sized + Named,
125    {
126        let mut visitor = Self::from(&*context)?;
127        visitor.do_pass(context)?;
128        Ok(visitor)
129    }
130
131    /// Executed before the traversal begins.
132    fn start(
133        &mut self,
134        _comp: &mut Component,
135        _sigs: &LibrarySignatures,
136        _comps: &[ir::Component],
137    ) -> VisResult {
138        Ok(Action::Continue)
139    }
140
141    /// Executed after the traversal ends.
142    /// This method is always invoked regardless of the [Action] returned from
143    /// the children.
144    fn finish(
145        &mut self,
146        _comp: &mut Component,
147        _sigs: &LibrarySignatures,
148        _comps: &[ir::Component],
149    ) -> VisResult {
150        Ok(Action::Continue)
151    }
152
153    /// Executed before visiting the children of a [ir::Seq] node.
154    fn start_seq(
155        &mut self,
156        _s: &mut ir::Seq,
157        _comp: &mut Component,
158        _sigs: &LibrarySignatures,
159        _comps: &[ir::Component],
160    ) -> VisResult {
161        Ok(Action::Continue)
162    }
163
164    /// Executed after visiting the children of a [ir::Seq] node.
165    fn finish_seq(
166        &mut self,
167        _s: &mut ir::Seq,
168        _comp: &mut Component,
169        _sigs: &LibrarySignatures,
170        _comps: &[ir::Component],
171    ) -> VisResult {
172        Ok(Action::Continue)
173    }
174
175    /// Executed before visiting the children of a [ir::Par] node.
176    fn start_par(
177        &mut self,
178        _s: &mut ir::Par,
179        _comp: &mut Component,
180        _sigs: &LibrarySignatures,
181        _comps: &[ir::Component],
182    ) -> VisResult {
183        Ok(Action::Continue)
184    }
185
186    /// Executed after visiting the children of a [ir::Par] node.
187    fn finish_par(
188        &mut self,
189        _s: &mut ir::Par,
190        _comp: &mut Component,
191        _sigs: &LibrarySignatures,
192        _comps: &[ir::Component],
193    ) -> VisResult {
194        Ok(Action::Continue)
195    }
196
197    /// Executed before visiting the children of a [ir::If] node.
198    fn start_if(
199        &mut self,
200        _s: &mut ir::If,
201        _comp: &mut Component,
202        _sigs: &LibrarySignatures,
203        _comps: &[ir::Component],
204    ) -> VisResult {
205        Ok(Action::Continue)
206    }
207
208    /// Executed after visiting the children of a [ir::If] node.
209    fn finish_if(
210        &mut self,
211        _s: &mut ir::If,
212        _comp: &mut Component,
213        _sigs: &LibrarySignatures,
214        _comps: &[ir::Component],
215    ) -> VisResult {
216        Ok(Action::Continue)
217    }
218
219    /// Executed before visiting the children of a [ir::While] node.
220    fn start_while(
221        &mut self,
222        _s: &mut ir::While,
223        _comp: &mut Component,
224        _sigs: &LibrarySignatures,
225        _comps: &[ir::Component],
226    ) -> VisResult {
227        Ok(Action::Continue)
228    }
229
230    /// Executed after visiting the children of a [ir::While] node.
231    fn finish_while(
232        &mut self,
233        _s: &mut ir::While,
234        _comp: &mut Component,
235        _sigs: &LibrarySignatures,
236        _comps: &[ir::Component],
237    ) -> VisResult {
238        Ok(Action::Continue)
239    }
240
241    /// Executed before visiting the children of a [ir::Repeat] node.
242    fn start_repeat(
243        &mut self,
244        _s: &mut ir::Repeat,
245        _comp: &mut Component,
246        _sigs: &LibrarySignatures,
247        _comps: &[ir::Component],
248    ) -> VisResult {
249        Ok(Action::Continue)
250    }
251
252    /// Executed after visiting the children of a [ir::Repeat] node.
253    fn finish_repeat(
254        &mut self,
255        _s: &mut ir::Repeat,
256        _comp: &mut Component,
257        _sigs: &LibrarySignatures,
258        _comps: &[ir::Component],
259    ) -> VisResult {
260        Ok(Action::Continue)
261    }
262
263    /// Executed before visiting the contents of an [ir::StaticControl] node.
264    fn start_static_control(
265        &mut self,
266        _s: &mut ir::StaticControl,
267        _comp: &mut Component,
268        _sigs: &LibrarySignatures,
269        _comps: &[ir::Component],
270    ) -> VisResult {
271        Ok(Action::Continue)
272    }
273
274    /// Executed after visiting the conetnts of an [ir::StaticControl] node.
275    fn finish_static_control(
276        &mut self,
277        _s: &mut ir::StaticControl,
278        _comp: &mut Component,
279        _sigs: &LibrarySignatures,
280        _comps: &[ir::Component],
281    ) -> VisResult {
282        Ok(Action::Continue)
283    }
284
285    /// Executed at an [ir::Enable] node.
286    fn enable(
287        &mut self,
288        _s: &mut ir::Enable,
289        _comp: &mut Component,
290        _sigs: &LibrarySignatures,
291        _comps: &[ir::Component],
292    ) -> VisResult {
293        Ok(Action::Continue)
294    }
295
296    /// Executed at an [ir::StaticEnable] node.
297    fn static_enable(
298        &mut self,
299        _s: &mut ir::StaticEnable,
300        _comp: &mut Component,
301        _sigs: &LibrarySignatures,
302        _comps: &[ir::Component],
303    ) -> VisResult {
304        Ok(Action::Continue)
305    }
306
307    /// Executed before visiting the children of a [ir::StaticIf] node.
308    fn start_static_if(
309        &mut self,
310        _s: &mut ir::StaticIf,
311        _comp: &mut Component,
312        _sigs: &LibrarySignatures,
313        _comps: &[ir::Component],
314    ) -> VisResult {
315        Ok(Action::Continue)
316    }
317
318    /// Executed after visiting the children of a [ir::StaticIf] node.
319    fn finish_static_if(
320        &mut self,
321        _s: &mut ir::StaticIf,
322        _comp: &mut Component,
323        _sigs: &LibrarySignatures,
324        _comps: &[ir::Component],
325    ) -> VisResult {
326        Ok(Action::Continue)
327    }
328
329    /// Executed before visiting the children of a [ir::StaticRepeat] node.
330    fn start_static_repeat(
331        &mut self,
332        _s: &mut ir::StaticRepeat,
333        _comp: &mut Component,
334        _sigs: &LibrarySignatures,
335        _comps: &[ir::Component],
336    ) -> VisResult {
337        Ok(Action::Continue)
338    }
339
340    /// Executed after visiting the children of a [ir::StaticRepeat] node.
341    fn finish_static_repeat(
342        &mut self,
343        _s: &mut ir::StaticRepeat,
344        _comp: &mut Component,
345        _sigs: &LibrarySignatures,
346        _comps: &[ir::Component],
347    ) -> VisResult {
348        Ok(Action::Continue)
349    }
350
351    // Executed before visiting the children of a [ir::StaticSeq] node.
352    fn start_static_seq(
353        &mut self,
354        _s: &mut ir::StaticSeq,
355        _comp: &mut Component,
356        _sigs: &LibrarySignatures,
357        _comps: &[ir::Component],
358    ) -> VisResult {
359        Ok(Action::Continue)
360    }
361
362    // Executed after visiting the children of a [ir::StaticSeq] node.
363    fn finish_static_seq(
364        &mut self,
365        _s: &mut ir::StaticSeq,
366        _comp: &mut Component,
367        _sigs: &LibrarySignatures,
368        _comps: &[ir::Component],
369    ) -> VisResult {
370        Ok(Action::Continue)
371    }
372
373    // Executed before visiting the children of a [ir::StaticPar] node.
374    fn start_static_par(
375        &mut self,
376        _s: &mut ir::StaticPar,
377        _comp: &mut Component,
378        _sigs: &LibrarySignatures,
379        _comps: &[ir::Component],
380    ) -> VisResult {
381        Ok(Action::Continue)
382    }
383
384    // Executed after visiting the children of a [ir::StaticPar] node.
385    fn finish_static_par(
386        &mut self,
387        _s: &mut ir::StaticPar,
388        _comp: &mut Component,
389        _sigs: &LibrarySignatures,
390        _comps: &[ir::Component],
391    ) -> VisResult {
392        Ok(Action::Continue)
393    }
394
395    /// Executed at an [ir::Invoke] node.
396    fn invoke(
397        &mut self,
398        _s: &mut ir::Invoke,
399        _comp: &mut Component,
400        _sigs: &LibrarySignatures,
401        _comps: &[ir::Component],
402    ) -> VisResult {
403        Ok(Action::Continue)
404    }
405
406    /// Executed at a [ir::StaticInvoke] node.
407    fn static_invoke(
408        &mut self,
409        _s: &mut ir::StaticInvoke,
410        _comp: &mut Component,
411        _sigs: &LibrarySignatures,
412        _comps: &[ir::Component],
413    ) -> VisResult {
414        Ok(Action::Continue)
415    }
416
417    /// Executed at an [ir::Empty] node.
418    fn empty(
419        &mut self,
420        _s: &mut ir::Empty,
421        _comp: &mut Component,
422        _sigs: &LibrarySignatures,
423        _comps: &[ir::Component],
424    ) -> VisResult {
425        Ok(Action::Continue)
426    }
427}
428
429/// Describes types that can be visited by things implementing [Visitor].
430/// This performs a recursive walk of the tree.
431///
432/// It calls `Visitor::start_*` on the way down, and `Visitor::finish_*` on
433/// the way up.
434pub trait Visitable {
435    /// Perform the traversal.
436    fn visit(
437        &mut self,
438        visitor: &mut dyn Visitor,
439        component: &mut Component,
440        signatures: &LibrarySignatures,
441        components: &[ir::Component],
442    ) -> VisResult;
443}
444
445impl Visitable for Control {
446    fn visit(
447        &mut self,
448        visitor: &mut dyn Visitor,
449        component: &mut Component,
450        sigs: &LibrarySignatures,
451        comps: &[ir::Component],
452    ) -> VisResult {
453        let res = match self {
454            Control::Seq(ctrl) => visitor
455                .start_seq(ctrl, component, sigs, comps)?
456                .and_then(|| ctrl.stmts.visit(visitor, component, sigs, comps))?
457                .pop()
458                .and_then(|| {
459                    visitor.finish_seq(ctrl, component, sigs, comps)
460                })?,
461            Control::Par(ctrl) => visitor
462                .start_par(ctrl, component, sigs, comps)?
463                .and_then(|| ctrl.stmts.visit(visitor, component, sigs, comps))?
464                .pop()
465                .and_then(|| {
466                    visitor.finish_par(ctrl, component, sigs, comps)
467                })?,
468            Control::If(ctrl) => visitor
469                .start_if(ctrl, component, sigs, comps)?
470                .and_then(|| {
471                    ctrl.tbranch.visit(visitor, component, sigs, comps)
472                })?
473                .and_then(|| {
474                    ctrl.fbranch.visit(visitor, component, sigs, comps)
475                })?
476                .pop()
477                .and_then(|| visitor.finish_if(ctrl, component, sigs, comps))?,
478            Control::While(ctrl) => visitor
479                .start_while(ctrl, component, sigs, comps)?
480                .and_then(|| ctrl.body.visit(visitor, component, sigs, comps))?
481                .pop()
482                .and_then(|| {
483                    visitor.finish_while(ctrl, component, sigs, comps)
484                })?,
485            Control::Repeat(ctrl) => visitor
486                .start_repeat(ctrl, component, sigs, comps)?
487                .and_then(|| ctrl.body.visit(visitor, component, sigs, comps))?
488                .pop()
489                .and_then(|| {
490                    visitor.finish_repeat(ctrl, component, sigs, comps)
491                })?,
492            Control::Enable(ctrl) => {
493                visitor.enable(ctrl, component, sigs, comps)?
494            }
495            Control::Static(sctrl) => visitor
496                .start_static_control(sctrl, component, sigs, comps)?
497                .and_then(|| sctrl.visit(visitor, component, sigs, comps))?
498                .pop()
499                .and_then(|| {
500                    visitor.finish_static_control(sctrl, component, sigs, comps)
501                })?,
502            Control::Empty(ctrl) => {
503                visitor.empty(ctrl, component, sigs, comps)?
504            }
505            Control::Invoke(data) => {
506                visitor.invoke(data, component, sigs, comps)?
507            }
508        };
509        Ok(res.apply_change(self))
510    }
511}
512
513impl Visitable for StaticControl {
514    fn visit(
515        &mut self,
516        visitor: &mut dyn Visitor,
517        component: &mut Component,
518        signatures: &LibrarySignatures,
519        components: &[ir::Component],
520    ) -> VisResult {
521        let res = match self {
522            StaticControl::Empty(ctrl) => {
523                visitor.empty(ctrl, component, signatures, components)?
524            }
525            StaticControl::Enable(ctrl) => visitor
526                .static_enable(ctrl, component, signatures, components)?,
527            StaticControl::Repeat(ctrl) => visitor
528                .start_static_repeat(ctrl, component, signatures, components)?
529                .and_then(|| {
530                    ctrl.body.visit(visitor, component, signatures, components)
531                })?
532                .pop()
533                .and_then(|| {
534                    visitor.finish_static_repeat(
535                        ctrl, component, signatures, components,
536                    )
537                })?,
538            StaticControl::Seq(ctrl) => visitor
539                .start_static_seq(ctrl, component, signatures, components)?
540                .and_then(|| {
541                    ctrl.stmts.visit(visitor, component, signatures, components)
542                })?
543                .pop()
544                .and_then(|| {
545                    visitor.finish_static_seq(
546                        ctrl, component, signatures, components,
547                    )
548                })?,
549            StaticControl::Par(ctrl) => visitor
550                .start_static_par(ctrl, component, signatures, components)?
551                .and_then(|| {
552                    ctrl.stmts.visit(visitor, component, signatures, components)
553                })?
554                .pop()
555                .and_then(|| {
556                    visitor.finish_static_par(
557                        ctrl, component, signatures, components,
558                    )
559                })?,
560            StaticControl::If(sctrl) => visitor
561                .start_static_if(sctrl, component, signatures, components)?
562                .and_then(|| {
563                    sctrl
564                        .tbranch
565                        .visit(visitor, component, signatures, components)
566                })?
567                .and_then(|| {
568                    sctrl
569                        .fbranch
570                        .visit(visitor, component, signatures, components)
571                })?
572                .pop()
573                .and_then(|| {
574                    visitor.finish_static_if(
575                        sctrl, component, signatures, components,
576                    )
577                })?,
578            StaticControl::Invoke(sin) => {
579                visitor.static_invoke(sin, component, signatures, components)?
580            }
581        };
582        Ok(res.apply_static_change(self))
583    }
584}
585
586/// Blanket implementation for Vectors of Visitables
587impl<V: Visitable> Visitable for Vec<V> {
588    fn visit(
589        &mut self,
590        visitor: &mut dyn Visitor,
591        component: &mut Component,
592        sigs: &LibrarySignatures,
593        components: &[ir::Component],
594    ) -> VisResult {
595        for t in self {
596            let res = t.visit(visitor, component, sigs, components)?;
597            match res {
598                Action::Continue
599                | Action::SkipChildren
600                | Action::Change(_)
601                | Action::StaticChange(_) => {
602                    continue;
603                }
604                Action::Stop => return Ok(Action::Stop),
605            };
606        }
607        Ok(Action::Continue)
608    }
609}