calyx_opt/passes/
top_down_compile_control.rs

1use super::math_utilities::get_bit_width_from;
2use crate::passes;
3use crate::traversal::{
4    Action, ConstructVisitor, Named, ParseVal, PassOpt, VisResult, Visitor,
5};
6use calyx_ir::{self as ir, GetAttributes, LibrarySignatures, Printer, RRC};
7use calyx_ir::{build_assignments, guard, structure};
8use calyx_utils::CalyxResult;
9use calyx_utils::Error;
10use ir::Nothing;
11use itertools::Itertools;
12use petgraph::graph::DiGraph;
13use std::collections::HashMap;
14use std::io::Write;
15use std::rc::Rc;
16
17const NODE_ID: ir::Attribute =
18    ir::Attribute::Internal(ir::InternalAttr::NODE_ID);
19
20/// Computes the exit edges of a given [ir::Control] program.
21///
22/// ## Example
23/// In the following Calyx program:
24/// ```
25/// while comb_reg.out {
26///   seq {
27///     @NODE_ID(4) incr;
28///     @NODE_ID(5) cond0;
29///   }
30/// }
31/// ```
32/// The exit edge is is `[(5, cond0[done])]` indicating that the state 5 exits when the guard
33/// `cond0[done]` is true.
34///
35/// Multiple exit points are created when conditions are used:
36/// ```
37/// while comb_reg.out {
38///   @NODE_ID(7) incr;
39///   if comb_reg2.out {
40///     @NODE_ID(8) tru;
41///   } else {
42///     @NODE_ID(9) fal;
43///   }
44/// }
45/// ```
46/// The exit set is `[(8, tru[done] & !comb_reg.out), (9, fal & !comb_reg.out)]`.
47fn control_exits(con: &ir::Control, exits: &mut Vec<PredEdge>) {
48    match con {
49        ir::Control::Empty(_) => {}
50        ir::Control::Enable(ir::Enable { group, attributes }) => {
51            let cur_state = attributes.get(NODE_ID).unwrap();
52            exits.push((cur_state, guard!(group["done"])))
53        }
54        ir::Control::Seq(ir::Seq { stmts, .. }) => {
55            if let Some(stmt) = stmts.last() { control_exits(stmt, exits) }
56        }
57        ir::Control::If(ir::If {
58            tbranch, fbranch, ..
59        }) => {
60            control_exits(
61                tbranch, exits,
62            );
63            control_exits(
64                fbranch, exits,
65            )
66        }
67        ir::Control::While(ir::While { body, port, .. }) => {
68            let mut loop_exits = vec![];
69            control_exits(body, &mut loop_exits);
70            // Loop exits only happen when the loop guard is false
71            exits.extend(loop_exits.into_iter().map(|(s, g)| {
72                (s, g & !ir::Guard::from(port.clone()))
73            }));
74        },
75        ir::Control::Repeat(_) => unreachable!("`repeat` statements should have been compiled away. Run `{}` before this pass.", passes::CompileRepeat::name()),
76        ir::Control::Invoke(_) => unreachable!("`invoke` statements should have been compiled away. Run `{}` before this pass.", passes::CompileInvoke::name()),
77        ir::Control::Par(_) => unreachable!(),
78        ir::Control::Static(_) => unreachable!(" static control should have been compiled away. Run the static compilation passes before this pass")
79    }
80}
81
82/// Adds the @NODE_ID attribute to [ir::Enable] and [ir::Par].
83/// Each [ir::Enable] gets a unique label within the context of a child of
84/// a [ir::Par] node.
85/// Furthermore, if an if/while/seq statement is labeled with a `new_fsm` attribute,
86/// then it will get its own unique label. Within that if/while/seq, each enable
87/// will get its own unique label within the context of that if/while/seq (see
88/// example for clarification).
89///
90/// ## Example:
91/// ```
92/// seq { A; B; par { C; D; }; E; @new_fsm seq {F; G; H}}
93/// ```
94/// gets the labels:
95/// ```
96/// seq {
97///   @NODE_ID(1) A; @NODE_ID(2) B;
98///   @NODE_ID(3) par {
99///     @NODE_ID(0) C;
100///     @NODE_ID(0) D;
101///   }
102///   @NODE_ID(4) E;
103///   @NODE_ID(5) seq{
104///     @NODE_ID(0) F;
105///     @NODE_ID(1) G;
106///     @NODE_ID(2) H;
107///   }
108/// }
109/// ```
110///
111/// These identifiers are used by the compilation methods [calculate_states_recur]
112/// and [control_exits].
113fn compute_unique_ids(con: &mut ir::Control, cur_state: u64) -> u64 {
114    match con {
115        ir::Control::Enable(ir::Enable { attributes, .. }) => {
116            attributes.insert(NODE_ID, cur_state);
117            cur_state + 1
118        }
119        ir::Control::Par(ir::Par { stmts, attributes }) => {
120            attributes.insert(NODE_ID, cur_state);
121            stmts.iter_mut().for_each(|stmt| {
122                compute_unique_ids(stmt, 0);
123            });
124            cur_state + 1
125        }
126        ir::Control::Seq(ir::Seq { stmts, attributes }) => {
127            let new_fsm = attributes.has(ir::BoolAttr::NewFSM);
128            // if new_fsm is true, then insert attribute at the seq, and then
129            // start over counting states from 0
130            let mut cur = if new_fsm{
131                attributes.insert(NODE_ID, cur_state);
132                0
133            } else {
134                cur_state
135            };
136            stmts.iter_mut().for_each(|stmt| {
137                cur = compute_unique_ids(stmt, cur);
138            });
139            // If new_fsm is true then we want to return cur_state + 1, since this
140            // seq should really only take up 1 "state" on the "outer" fsm
141            if new_fsm{
142                cur_state + 1
143            } else {
144                cur
145            }
146        }
147        ir::Control::If(ir::If {
148            tbranch, fbranch, attributes, ..
149        }) => {
150            let new_fsm = attributes.has(ir::BoolAttr::NewFSM);
151            // if new_fsm is true, then we want to add an attribute to this
152            // control statement
153            if new_fsm {
154                attributes.insert(NODE_ID, cur_state);
155            }
156            // If the program starts with a branch then branches can't get
157            // the initial state.
158            // Also, if new_fsm is true, we want to start with state 1 as well:
159            // we can't start at 0 for the reason mentioned above
160            let cur = if new_fsm || cur_state == 0 {
161                1
162            } else {
163                cur_state
164            };
165            let tru_nxt = compute_unique_ids(
166                tbranch, cur
167            );
168            let false_nxt = compute_unique_ids(
169                fbranch, tru_nxt
170            );
171            // If new_fsm is true then we want to return cur_state + 1, since this
172            // if stmt should really only take up 1 "state" on the "outer" fsm
173            if new_fsm {
174                cur_state + 1
175            } else {
176                false_nxt
177            }
178        }
179        ir::Control::While(ir::While { body, attributes, .. }) => {
180            let new_fsm = attributes.has(ir::BoolAttr::NewFSM);
181            // if new_fsm is true, then we want to add an attribute to this
182            // control statement
183            if new_fsm{
184                attributes.insert(NODE_ID, cur_state);
185            }
186            // If the program starts with a branch then branches can't get
187            // the initial state.
188            // Also, if new_fsm is true, we want to start with state 1 as well:
189            // we can't start at 0 for the reason mentioned above
190            let cur = if new_fsm || cur_state == 0 {
191                1
192            } else {
193                cur_state
194            };
195            let body_nxt = compute_unique_ids(body, cur);
196            // If new_fsm is true then we want to return cur_state + 1, since this
197            // while loop should really only take up 1 "state" on the "outer" fsm
198            if new_fsm{
199                cur_state + 1
200            } else {
201                body_nxt
202            }
203        }
204        ir::Control::Empty(_) => cur_state,
205        ir::Control::Repeat(_) => unreachable!("`repeat` statements should have been compiled away. Run `{}` before this pass.", passes::CompileRepeat::name()),
206        ir::Control::Invoke(_) => unreachable!("`invoke` statements should have been compiled away. Run `{}` before this pass.", passes::CompileInvoke::name()),
207        ir::Control::Static(_) => unreachable!("static control should have been compiled away. Run the static compilation passes before this pass")
208    }
209}
210
211/// Represents the dyanmic execution schedule of a control program.
212struct Schedule<'b, 'a: 'b> {
213    /// Assigments that should be enabled in a given state.
214    pub enables: HashMap<u64, Vec<ir::Assignment<Nothing>>>,
215    /// Transition from one state to another when the guard is true.
216    pub transitions: Vec<(u64, u64, ir::Guard<Nothing>)>,
217    /// The component builder. The reference has a shorter lifetime than the builder itself
218    /// to allow multiple schedules to use the same builder.
219    pub builder: &'b mut ir::Builder<'a>,
220}
221
222impl<'b, 'a> From<&'b mut ir::Builder<'a>> for Schedule<'b, 'a> {
223    fn from(builder: &'b mut ir::Builder<'a>) -> Self {
224        Schedule {
225            enables: HashMap::new(),
226            transitions: Vec::new(),
227            builder,
228        }
229    }
230}
231
232impl<'b, 'a> Schedule<'b, 'a> {
233    /// Validate that all states are reachable in the transition graph.
234    fn validate(&self) {
235        let graph = DiGraph::<(), u32>::from_edges(
236            self.transitions
237                .iter()
238                .map(|(s, e, _)| (*s as u32, *e as u32)),
239        );
240
241        debug_assert!(
242            petgraph::algo::connected_components(&graph) == 1,
243            "State transition graph has unreachable states (graph has more than one connected component).");
244    }
245
246    /// Return the max state in the transition graph
247    fn last_state(&self) -> u64 {
248        self.transitions
249            .iter()
250            .max_by_key(|(_, s, _)| s)
251            .expect("Schedule::transition is empty!")
252            .1
253    }
254
255    /// Print out the current schedule
256    fn display(&self, group: String) {
257        let out = &mut std::io::stdout();
258        writeln!(out, "======== {} =========", group).unwrap();
259        self.enables
260            .iter()
261            .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2))
262            .for_each(|(state, assigns)| {
263                writeln!(out, "{}:", state).unwrap();
264                assigns.iter().for_each(|assign| {
265                    Printer::write_assignment(assign, 2, out).unwrap();
266                    writeln!(out).unwrap();
267                })
268            });
269        writeln!(out, "{}:\n  <end>", self.last_state()).unwrap();
270        writeln!(out, "transitions:").unwrap();
271        self.transitions
272            .iter()
273            .sorted_by(|(k1, _, _), (k2, _, _)| k1.cmp(k2))
274            .for_each(|(i, f, g)| {
275                writeln!(out, "  ({}, {}): {}", i, f, Printer::guard_str(g))
276                    .unwrap();
277            });
278    }
279
280    /// Implement a given [Schedule] and return the name of the [ir::Group] that
281    /// implements it.
282    fn realize_schedule(self, dump_fsm: bool) -> RRC<ir::Group> {
283        self.validate();
284
285        let group = self.builder.add_group("tdcc");
286        if dump_fsm {
287            self.display(format!(
288                "{}:{}",
289                self.builder.component.name,
290                group.borrow().name()
291            ));
292        }
293
294        let final_state = self.last_state();
295        let fsm_size = get_bit_width_from(
296            final_state + 1, /* represent 0..final_state */
297        );
298        structure!(self.builder;
299            let fsm = prim std_reg(fsm_size);
300            let signal_on = constant(1, 1);
301            let last_state = constant(final_state, fsm_size);
302            let first_state = constant(0, fsm_size);
303        );
304
305        // Enable assignments
306        group.borrow_mut().assignments.extend(
307            self.enables
308                .into_iter()
309                .sorted_by(|(k1, _), (k2, _)| k1.cmp(k2))
310                .flat_map(|(state, mut assigns)| {
311                    let state_const =
312                        self.builder.add_constant(state, fsm_size);
313                    let state_guard = guard!(fsm["out"] == state_const["out"]);
314                    assigns.iter_mut().for_each(|asgn| {
315                        asgn.guard.update(|g| g.and(state_guard.clone()))
316                    });
317                    assigns
318                }),
319        );
320
321        // Transition assignments
322        group.borrow_mut().assignments.extend(
323            self.transitions.into_iter().flat_map(|(s, e, guard)| {
324                structure!(self.builder;
325                    let end_const = constant(e, fsm_size);
326                    let start_const = constant(s, fsm_size);
327                );
328                let ec_borrow = end_const.borrow();
329                let trans_guard =
330                    guard!((fsm["out"] == start_const["out"]) & guard);
331
332                vec![
333                    self.builder.build_assignment(
334                        fsm.borrow().get("in"),
335                        ec_borrow.get("out"),
336                        trans_guard.clone(),
337                    ),
338                    self.builder.build_assignment(
339                        fsm.borrow().get("write_en"),
340                        signal_on.borrow().get("out"),
341                        trans_guard,
342                    ),
343                ]
344            }),
345        );
346
347        // Done condition for group
348        let last_guard = guard!(fsm["out"] == last_state["out"]);
349        let done_assign = self.builder.build_assignment(
350            group.borrow().get("done"),
351            signal_on.borrow().get("out"),
352            last_guard.clone(),
353        );
354        group.borrow_mut().assignments.push(done_assign);
355
356        // Cleanup: Add a transition from last state to the first state.
357        let reset_fsm = build_assignments!(self.builder;
358            fsm["in"] = last_guard ? first_state["out"];
359            fsm["write_en"] = last_guard ? signal_on["out"];
360        );
361        self.builder
362            .component
363            .continuous_assignments
364            .extend(reset_fsm);
365
366        group
367    }
368}
369
370/// Represents an edge from a predeccesor to the current control node.
371/// The `u64` represents the FSM state of the predeccesor and the guard needs
372/// to be true for the predeccesor to transition to the current state.
373type PredEdge = (u64, ir::Guard<Nothing>);
374
375impl Schedule<'_, '_> {
376    /// Recursively build an dynamic finite state machine represented by a [Schedule].
377    /// Does the following, given an [ir::Control]:
378    ///     1. If needed, add transitions from predeccesors to the current state.
379    ///     2. Enable the groups in the current state
380    ///     3. Calculate [PredEdge] implied by this state
381    ///     4. Return [PredEdge] and the next state.
382    /// Another note: the functions calc_seq_recur, calc_while_recur, and calc_if_recur
383    /// are functions that `calculate_states_recur` uses for when con is a seq, while,
384    /// and if respectively. The reason why they are defined as separate functions is because we
385    /// need to call `calculate_seq_recur` (for example) directly when we are in `finish_seq`
386    /// since `finish_seq` only gives us access to a `& mut seq` type, not a `& Control`
387    /// type.
388    fn calculate_states_recur(
389        // Current schedule.
390        &mut self,
391        con: &ir::Control,
392        // The set of previous states that want to transition into cur_state
393        preds: Vec<PredEdge>,
394        // True if early_transitions are allowed
395        early_transitions: bool,
396    ) -> CalyxResult<Vec<PredEdge>> {
397        match con {
398        // See explanation of FSM states generated in [ir::TopDownCompileControl].
399        ir::Control::Enable(ir::Enable { group, attributes }) => {
400            let cur_state = attributes.get(NODE_ID).unwrap_or_else(|| panic!("Group `{}` does not have node_id information", group.borrow().name()));
401            // If there is exactly one previous transition state with a `true`
402            // guard, then merge this state into previous state.
403            // This happens when the first control statement is an enable not
404            // inside a branch.
405            let (cur_state, prev_states) = if preds.len() == 1 && preds[0].1.is_true() {
406                (preds[0].0, vec![])
407            } else {
408                (cur_state, preds)
409            };
410
411            let not_done = !guard!(group["done"]);
412            let signal_on = self.builder.add_constant(1, 1);
413
414            // Activate this group in the current state
415            let en_go = build_assignments!(self.builder;
416                group["go"] = not_done ? signal_on["out"];
417            );
418            self
419                .enables
420                .entry(cur_state)
421                .or_default()
422                .extend(en_go);
423
424            // Activate group in the cycle when previous state signals done.
425            // NOTE: We explicilty do not add `not_done` to the guard.
426            // See explanation in [ir::TopDownCompileControl] to understand
427            // why.
428            if early_transitions {
429                for (st, g) in &prev_states {
430                    let early_go = build_assignments!(self.builder;
431                        group["go"] = g ? signal_on["out"];
432                    );
433                    self.enables.entry(*st).or_default().extend(early_go);
434                }
435            }
436
437            let transitions = prev_states
438                .into_iter()
439                .map(|(st, guard)| (st, cur_state, guard));
440            self.transitions.extend(transitions);
441
442            let done_cond = guard!(group["done"]);
443            Ok(vec![(cur_state, done_cond)])
444        }
445        ir::Control::Seq(seq) => {
446            self.calc_seq_recur(seq, preds, early_transitions)
447        }
448        ir::Control::If(if_stmt) => {
449            self.calc_if_recur(if_stmt, preds, early_transitions)
450        }
451        ir::Control::While(while_stmt) => {
452            self.calc_while_recur(while_stmt, preds, early_transitions)
453        }
454        ir::Control::Par(_) => unreachable!(),
455        ir::Control::Repeat(_) => unreachable!("`repeat` statements should have been compiled away. Run `{}` before this pass.", passes::CompileRepeat::name()),
456        ir::Control::Invoke(_) => unreachable!("`invoke` statements should have been compiled away. Run `{}` before this pass.", passes::CompileInvoke::name()),
457        ir::Control::Empty(_) => unreachable!("`calculate_states_recur` should not see an `empty` control."),
458        ir::Control::Static(_) => unreachable!("static control should have been compiled away. Run the static compilation passes before this pass")
459    }
460    }
461
462    /// Builds a finite state machine for `seq` represented by a [Schedule].
463    /// At a high level, it iterates through each stmt in the seq's control, using the
464    /// previous stmt's [PredEdge] as the `preds` for the current stmt, and returns
465    /// the [PredEdge] implied by the last stmt in `seq`'s control.
466    fn calc_seq_recur(
467        &mut self,
468        seq: &ir::Seq,
469        // The set of previous states that want to transition into cur_state
470        preds: Vec<PredEdge>,
471        // True if early_transitions are allowed
472        early_transitions: bool,
473    ) -> CalyxResult<Vec<PredEdge>> {
474        let mut prev = preds;
475        for stmt in &seq.stmts {
476            prev =
477                self.calculate_states_recur(stmt, prev, early_transitions)?;
478        }
479        Ok(prev)
480    }
481
482    /// Builds a finite state machine for `if_stmt` represented by a [Schedule].
483    /// First generates the transitions into the true branch + the transitions that exist
484    /// inside the true branch. Then generates the transitions into the false branch + the transitions
485    /// that exist inside the false branch. Then calculates the transitions needed to
486    /// exit the if statmement (which include edges from both the true and false branches).
487    fn calc_if_recur(
488        &mut self,
489        if_stmt: &ir::If,
490        // The set of previous states that want to transition into cur_state
491        preds: Vec<PredEdge>,
492        // True if early_transitions are allowed
493        early_transitions: bool,
494    ) -> CalyxResult<Vec<PredEdge>> {
495        if if_stmt.cond.is_some() {
496            return Err(Error::malformed_structure(format!("{}: Found group `{}` in with position of if. This should have compiled away.", TopDownCompileControl::name(), if_stmt.cond.as_ref().unwrap().borrow().name())));
497        }
498        let port_guard: ir::Guard<Nothing> = Rc::clone(&if_stmt.port).into();
499        // Previous states transitioning into true branch need the conditional
500        // to be true.
501        let tru_transitions = preds
502            .clone()
503            .into_iter()
504            .map(|(s, g)| (s, g & port_guard.clone()))
505            .collect();
506        let tru_prev = self.calculate_states_recur(
507            &if_stmt.tbranch,
508            tru_transitions,
509            early_transitions,
510        )?;
511        // Previous states transitioning into false branch need the conditional
512        // to be false.
513        let fal_transitions = preds
514            .into_iter()
515            .map(|(s, g)| (s, g & !port_guard.clone()))
516            .collect();
517
518        let fal_prev = if let ir::Control::Empty(..) = *if_stmt.fbranch {
519            // If the false branch is empty, then all the prevs to this node will become prevs
520            // to the next node.
521            fal_transitions
522        } else {
523            self.calculate_states_recur(
524                &if_stmt.fbranch,
525                fal_transitions,
526                early_transitions,
527            )?
528        };
529
530        let prevs = tru_prev.into_iter().chain(fal_prev).collect();
531        Ok(prevs)
532    }
533
534    /// Builds a finite state machine for `while_stmt` represented by a [Schedule].
535    /// It first generates the backwards edges (i.e., edges from the end of the while
536    /// body back to the beginning of the while body), then generates the forwards
537    /// edges in the body, then generates the edges that exit the while loop.
538    fn calc_while_recur(
539        &mut self,
540        while_stmt: &ir::While,
541        // The set of previous states that want to transition into cur_state
542        preds: Vec<PredEdge>,
543        // True if early_transitions are allowed
544        early_transitions: bool,
545    ) -> CalyxResult<Vec<PredEdge>> {
546        if while_stmt.cond.is_some() {
547            return Err(Error::malformed_structure(format!("{}: Found group `{}` in with position of if. This should have compiled away.", TopDownCompileControl::name(), while_stmt.cond.as_ref().unwrap().borrow().name())));
548        }
549
550        let port_guard: ir::Guard<Nothing> = Rc::clone(&while_stmt.port).into();
551
552        // Step 1: Generate the backward edges by computing the exit nodes.
553        let mut exits = vec![];
554        control_exits(&while_stmt.body, &mut exits);
555
556        // Step 2: Generate the forward edges normally.
557        // Previous transitions into the body require the condition to be
558        // true.
559        let transitions: Vec<PredEdge> = preds
560            .clone()
561            .into_iter()
562            .chain(exits)
563            .map(|(s, g)| (s, g & port_guard.clone()))
564            .collect();
565        let prevs = self.calculate_states_recur(
566            &while_stmt.body,
567            transitions,
568            early_transitions,
569        )?;
570
571        // Step 3: The final out edges from the while come from:
572        //   - Before the body when the condition is false
573        //   - Inside the body when the condition is false
574        let not_port_guard = !port_guard;
575        let all_prevs = preds
576            .into_iter()
577            .chain(prevs)
578            .map(|(st, guard)| (st, guard & not_port_guard.clone()))
579            .collect();
580
581        Ok(all_prevs)
582    }
583
584    /// Creates a Schedule that represents `seq`, mainly relying on `calc_seq_recur()`.
585    fn calculate_states_seq(
586        &mut self,
587        seq: &ir::Seq,
588        early_transitions: bool,
589    ) -> CalyxResult<()> {
590        let first_state = (0, ir::Guard::True);
591        // We create an empty first state in case the control program starts with
592        // a branch (if, while).
593        // If the program doesn't branch, then the initial state is merged into
594        // the first group.
595        let prev =
596            self.calc_seq_recur(seq, vec![first_state], early_transitions)?;
597        self.add_nxt_transition(prev);
598        Ok(())
599    }
600
601    /// Creates a Schedule that represents `if`, mainly relying on `calc_if_recur()`.
602    fn calculate_states_if(
603        &mut self,
604        if_stmt: &ir::If,
605        early_transitions: bool,
606    ) -> CalyxResult<()> {
607        let first_state = (0, ir::Guard::True);
608        // We create an empty first state in case the control program starts with
609        // a branch (if, while).
610        // If the program doesn't branch, then the initial state is merged into
611        // the first group.
612        let prev =
613            self.calc_if_recur(if_stmt, vec![first_state], early_transitions)?;
614        self.add_nxt_transition(prev);
615        Ok(())
616    }
617
618    /// Creates a Schedule that represents `while`, mainly relying on `calc_while_recur()`.
619    fn calculate_states_while(
620        &mut self,
621        while_stmt: &ir::While,
622        early_transitions: bool,
623    ) -> CalyxResult<()> {
624        let first_state = (0, ir::Guard::True);
625        // We create an empty first state in case the control program starts with
626        // a branch (if, while).
627        // If the program doesn't branch, then the initial state is merged into
628        // the first group.
629        let prev = self.calc_while_recur(
630            while_stmt,
631            vec![first_state],
632            early_transitions,
633        )?;
634        self.add_nxt_transition(prev);
635        Ok(())
636    }
637
638    /// Given predecessors prev, creates a new "next" state and transitions from
639    /// each state in prev to the next state.
640    /// In other words, it just adds an "end" state to [Schedule] and the
641    /// appropriate transitions to that "end" state.
642    fn add_nxt_transition(&mut self, prev: Vec<PredEdge>) {
643        let nxt = prev
644            .iter()
645            .max_by(|(st1, _), (st2, _)| st1.cmp(st2))
646            .unwrap()
647            .0
648            + 1;
649        let transitions = prev.into_iter().map(|(st, guard)| (st, nxt, guard));
650        self.transitions.extend(transitions);
651    }
652
653    /// Note: the functions calculate_states_seq, calculate_states_while, and calculate_states_if
654    /// are functions that basically do what `calculate_states` would do if `calculate_states` knew (for certain)
655    /// that its input parameter would be a seq/while/if.
656    /// The reason why we need to define these as separate functions is because `finish_seq`
657    /// (for example) we only gives us access to a `& mut seq` type, not a `& Control`
658    /// type.
659    fn calculate_states(
660        &mut self,
661        con: &ir::Control,
662        early_transitions: bool,
663    ) -> CalyxResult<()> {
664        let first_state = (0, ir::Guard::True);
665        // We create an empty first state in case the control program starts with
666        // a branch (if, while).
667        // If the program doesn't branch, then the initial state is merged into
668        // the first group.
669        let prev = self.calculate_states_recur(
670            con,
671            vec![first_state],
672            early_transitions,
673        )?;
674        self.add_nxt_transition(prev);
675        Ok(())
676    }
677}
678
679/// **Core lowering pass.**
680/// Compiles away the control programs in components into purely structural code using an
681/// finite-state machine (FSM).
682///
683/// Lowering operates in two steps:
684/// 1. Compile all [ir::Par] control sub-programs into a single [ir::Enable] of a group that runs
685///    all children to completion.
686/// 2. Compile the top-level control program into a single [ir::Enable].
687///
688/// ## Compiling non-`par` programs
689/// At very high-level, the pass assigns an FSM state to each [ir::Enable] in the program and
690/// generates transitions to the state to activate the groups contained within the [ir::Enable].
691///
692/// The compilation process calculates all predeccesors of the [ir::Enable] while walking over the
693/// control program. A predeccesor is any enable statement that can directly "jump" to the current
694/// [ir::Enable]. The compilation process computes all such predeccesors and the guards that need
695/// to be true for the predeccesor to jump into this enable statement.
696///
697/// ```
698/// cond0;
699/// while lt.out {
700///   if gt.out { true } else { false }
701/// }
702/// next;
703/// ```
704/// The predeccesor sets are:
705/// ```
706/// cond0 -> []
707/// true -> [(cond0, lt.out & gt.out); (true; lt.out & gt.out); (false, lt.out & !gt.out)]
708/// false -> [(cond0, lt.out & !gt.out); (true; lt.out & gt.out); (false, lt.out & !gt.out)]
709/// next -> [(cond0, !lt.out); (true, !lt.out); (false, !lt.out)]
710/// ```
711///
712/// ### Compiling [ir::Enable]
713/// The process first takes all edges from predeccesors and transitions to the state for this
714/// enable and enables the group in this state:
715/// ```text
716/// let cur_state; // state of this enable
717/// for (state, guard) in predeccesors:
718///   transitions.insert(state, cur_state, guard)
719/// enables.insert(cur_state, group)
720/// ```
721///
722/// While this process will generate a functioning FSM, the FSM takes unnecessary cycles for FSM
723/// transitions.
724///
725/// For example:
726/// ```
727/// seq { one; two; }
728/// ```
729/// The FSM generated will look like this (where `f` is the FSM register):
730/// ```
731/// f.in = one[done] ? 1;
732/// f.in = two[done] ? 2;
733/// one[go] = !one[done] & f.out == 0;
734/// two[go] = !two[done] & f.out == 1;
735/// ```
736///
737/// The cycle-level timing for this FSM will look like:
738///     - cycle 0: (`f.out` == 0), enable one
739///     - cycle t: (`f.out` == 0), (`one[done]` == 1), disable one
740///     - cycle t+1: (`f.out` == 1), enable two
741///     - cycle t+l: (`f.out` == 1), (`two[done]` == 1), disable two
742///     - cycle t+l+1: finish
743///
744/// The transition t -> t+1 represents one where group one is done but group two hasn't started
745/// executing.
746///
747/// To address this specific problem, there is an additional enable added to run all groups within
748/// an enable *while the FSM is transitioning*.
749/// The final transition will look like this:
750/// ```
751/// f.in = one[done] ? 1;
752/// f.in = two[done] ? 2;
753/// one[go] = !one[done] & f.out == 0;
754/// two[go] = (!two[done] & f.out == 1) || (one[done] & f.out == 0);
755/// ```
756///
757/// Note that `!two[done]` isn't present in the second disjunct because all groups are guaranteed
758/// to run for at least one cycle and the second disjunct will only be true for one cycle before
759/// the first disjunct becomes true.
760///
761/// ## Compiling `par` programs
762/// We have to generate new FSM-based controller for each child of a `par` node so that each child
763/// can indepdendently make progress.
764/// If we tie the children to one top-level FSM, their transitions would become interdependent and
765/// reduce available concurrency.
766///
767/// ## Compilation guarantee
768/// At the end of this pass, the control program will have no more than one
769/// group enable in it.
770pub struct TopDownCompileControl {
771    /// Print out the FSM representation to STDOUT
772    dump_fsm: bool,
773    /// Enable early transitions
774    early_transitions: bool,
775}
776
777impl ConstructVisitor for TopDownCompileControl {
778    fn from(ctx: &ir::Context) -> CalyxResult<Self>
779    where
780        Self: Sized + Named,
781    {
782        let opts = Self::get_opts(ctx);
783
784        Ok(TopDownCompileControl {
785            dump_fsm: opts[&"dump-fsm"].bool(),
786            early_transitions: opts[&"early-transitions"].bool(),
787        })
788    }
789
790    fn clear_data(&mut self) {
791        /* All data can be transferred between components */
792    }
793}
794
795impl Named for TopDownCompileControl {
796    fn name() -> &'static str {
797        "tdcc"
798    }
799
800    fn description() -> &'static str {
801        "Top-down compilation for removing control constructs"
802    }
803
804    fn opts() -> Vec<PassOpt> {
805        vec![
806            PassOpt::new(
807                "dump-fsm",
808                "Print out the state machine implementing the schedule",
809                ParseVal::Bool(false),
810                PassOpt::parse_bool,
811            ),
812            PassOpt::new(
813                "early-transitions",
814                "Experimental: Enable early transitions for group enables",
815                ParseVal::Bool(false),
816                PassOpt::parse_bool,
817            ),
818        ]
819    }
820}
821
822impl Visitor for TopDownCompileControl {
823    fn start(
824        &mut self,
825        comp: &mut ir::Component,
826        _sigs: &LibrarySignatures,
827        _comps: &[ir::Component],
828    ) -> VisResult {
829        // Do not try to compile an enable
830        if matches!(
831            *comp.control.borrow(),
832            ir::Control::Enable(..) | ir::Control::Empty(..)
833        ) {
834            return Ok(Action::Stop);
835        }
836
837        let mut con = comp.control.borrow_mut();
838        compute_unique_ids(&mut con, 0);
839        // IRPrinter::write_control(&con, 0, &mut std::io::stderr());
840        Ok(Action::Continue)
841    }
842
843    fn finish_seq(
844        &mut self,
845        s: &mut ir::Seq,
846        comp: &mut ir::Component,
847        sigs: &LibrarySignatures,
848        _comps: &[ir::Component],
849    ) -> VisResult {
850        // only compile using new fsm if has new_fsm attribute
851        if !s.attributes.has(ir::BoolAttr::NewFSM) {
852            return Ok(Action::Continue);
853        }
854        let mut builder = ir::Builder::new(comp, sigs);
855        let mut sch = Schedule::from(&mut builder);
856        sch.calculate_states_seq(s, self.early_transitions)?;
857        // Compile schedule and return the group.
858        let seq_group = sch.realize_schedule(self.dump_fsm);
859
860        // Add NODE_ID to compiled group.
861        let mut en = ir::Control::enable(seq_group);
862        let node_id = s.attributes.get(NODE_ID).unwrap();
863        en.get_mut_attributes().insert(NODE_ID, node_id);
864
865        Ok(Action::change(en))
866    }
867
868    fn finish_if(
869        &mut self,
870        i: &mut ir::If,
871        comp: &mut ir::Component,
872        sigs: &LibrarySignatures,
873        _comps: &[ir::Component],
874    ) -> VisResult {
875        // only compile using new fsm if has new_fsm attribute
876        if !i.attributes.has(ir::BoolAttr::NewFSM) {
877            return Ok(Action::Continue);
878        }
879        let mut builder = ir::Builder::new(comp, sigs);
880        let mut sch = Schedule::from(&mut builder);
881
882        // Compile schedule and return the group.
883        sch.calculate_states_if(i, self.early_transitions)?;
884        let if_group = sch.realize_schedule(self.dump_fsm);
885
886        // Add NODE_ID to compiled group.
887        let mut en = ir::Control::enable(if_group);
888        let node_id = i.attributes.get(NODE_ID).unwrap();
889        en.get_mut_attributes().insert(NODE_ID, node_id);
890
891        Ok(Action::change(en))
892    }
893
894    fn finish_while(
895        &mut self,
896        w: &mut ir::While,
897        comp: &mut ir::Component,
898        sigs: &LibrarySignatures,
899        _comps: &[ir::Component],
900    ) -> VisResult {
901        // only compile using new fsm if has attribute
902        if !w.attributes.has(ir::BoolAttr::NewFSM) {
903            return Ok(Action::Continue);
904        }
905        let mut builder = ir::Builder::new(comp, sigs);
906        let mut sch = Schedule::from(&mut builder);
907        sch.calculate_states_while(w, self.early_transitions)?;
908
909        // Compile schedule and return the group.
910        let if_group = sch.realize_schedule(self.dump_fsm);
911
912        // Add NODE_ID to compiled group.
913        let mut en = ir::Control::enable(if_group);
914        let node_id = w.attributes.get(NODE_ID).unwrap();
915        en.get_mut_attributes().insert(NODE_ID, node_id);
916
917        Ok(Action::change(en))
918    }
919
920    /// Compile each child in `par` block separately so each child can make
921    /// progress indepdendently.
922    fn finish_par(
923        &mut self,
924        s: &mut ir::Par,
925        comp: &mut ir::Component,
926        sigs: &LibrarySignatures,
927        _comps: &[ir::Component],
928    ) -> VisResult {
929        let mut builder = ir::Builder::new(comp, sigs);
930
931        // Compilation group
932        let par_group = builder.add_group("par");
933        structure!(builder;
934            let signal_on = constant(1, 1);
935            let signal_off = constant(0, 1);
936        );
937
938        // Registers to save the done signal from each child.
939        let mut done_regs = Vec::with_capacity(s.stmts.len());
940
941        // For each child, build the enabling logic.
942        for con in &s.stmts {
943            let group = match con {
944                // Do not compile enables
945                ir::Control::Enable(ir::Enable { group, .. }) => {
946                    Rc::clone(group)
947                }
948                // Compile complex schedule and return the group.
949                _ => {
950                    let mut sch = Schedule::from(&mut builder);
951                    sch.calculate_states(con, self.early_transitions)?;
952                    sch.realize_schedule(self.dump_fsm)
953                }
954            };
955
956            // Build circuitry to enable and disable this group.
957            structure!(builder;
958                let pd = prim std_reg(1);
959            );
960            let group_go = !(guard!(pd["out"] | group["done"]));
961            let group_done = guard!(group["done"]);
962
963            // Save the done condition in a register.
964            let assigns = build_assignments!(builder;
965                group["go"] = group_go ? signal_on["out"];
966                pd["in"] = group_done ? signal_on["out"];
967                pd["write_en"] = group_done ? signal_on["out"];
968            );
969            par_group.borrow_mut().assignments.extend(assigns);
970            done_regs.push(pd)
971        }
972
973        // Done condition for this group
974        let done_guard = done_regs
975            .clone()
976            .into_iter()
977            .map(|r| guard!(r["out"]))
978            .fold(ir::Guard::True, ir::Guard::and);
979
980        // CLEANUP: Reset the registers once the group is finished.
981        let mut cleanup = done_regs
982            .into_iter()
983            .flat_map(|r| {
984                build_assignments!(builder;
985                    r["in"] = done_guard ? signal_off["out"];
986                    r["write_en"] = done_guard ? signal_on["out"];
987                )
988            })
989            .collect::<Vec<_>>();
990        builder
991            .component
992            .continuous_assignments
993            .append(&mut cleanup);
994
995        // Done conditional for this group.
996        let done = builder.build_assignment(
997            par_group.borrow().get("done"),
998            signal_on.borrow().get("out"),
999            done_guard,
1000        );
1001        par_group.borrow_mut().assignments.push(done);
1002
1003        // Add NODE_ID to compiled group.
1004        let mut en = ir::Control::enable(par_group);
1005        let node_id = s.attributes.get(NODE_ID).unwrap();
1006        en.get_mut_attributes().insert(NODE_ID, node_id);
1007
1008        Ok(Action::change(en))
1009    }
1010
1011    fn finish(
1012        &mut self,
1013        comp: &mut ir::Component,
1014        sigs: &LibrarySignatures,
1015        _comps: &[ir::Component],
1016    ) -> VisResult {
1017        let control = Rc::clone(&comp.control);
1018        // IRPrinter::write_control(&control.borrow(), 0, &mut std::io::stderr());
1019        let mut builder = ir::Builder::new(comp, sigs);
1020        let mut sch = Schedule::from(&mut builder);
1021        // Add assignments for the final states
1022        sch.calculate_states(&control.borrow(), self.early_transitions)?;
1023        let comp_group = sch.realize_schedule(self.dump_fsm);
1024
1025        Ok(Action::change(ir::Control::enable(comp_group)))
1026    }
1027}