calyx_opt/passes/
simplify_with_control.rs

1use crate::analysis;
2use crate::traversal::{Action, Named, VisResult, Visitor};
3use calyx_ir::{self as ir, structure, GetAttributes, LibrarySignatures, RRC};
4use calyx_utils::{CalyxResult, Error};
5use std::collections::HashMap;
6use std::rc::Rc;
7
8#[derive(Default)]
9/// Transforms combinational groups into normal groups by registering the values
10/// read from the ports of cells used within the combinational group.
11///
12/// It also transforms (if,while)-with into semantically equivalent control programs that
13/// first enable a group that calculates and registers the ports defined by the combinational group
14/// execute the respective cond group and then execute the control operator.
15///
16/// # Example
17/// ```
18/// group comb_cond<"static"=0> {
19///     lt.right = 32'd10;
20///     lt.left = 32'd1;
21///     eq.right = r.out;
22///     eq.left = x.out;
23///     comb_cond[done] = 1'd1;
24/// }
25/// control {
26///     if lt.out with comb_cond {
27///         ...
28///     }
29///     while eq.out with comb_cond {
30///         ...
31///     }
32/// }
33/// ```
34/// into:
35/// ```
36/// group comb_cond<"static"=1> {
37///     lt.right = 32'd10;
38///     lt.left = 32'd1;
39///     eq.right = r.out;
40///     eq.left = x.out;
41///     lt_reg.in = lt.out
42///     lt_reg.write_en = 1'd1;
43///     eq_reg.in = eq.out;
44///     eq_reg.write_en = 1'd1;
45///     comb_cond[done] = lt_reg.done & eq_reg.done ? 1'd1;
46/// }
47/// control {
48///     seq {
49///       comb_cond;
50///       if lt_reg.out {
51///           ...
52///       }
53///     }
54///     seq {
55///       comb_cond;
56///       while eq_reg.out {
57///           ...
58///           comb_cond;
59///       }
60///     }
61/// }
62/// ```
63pub struct SimplifyWithControl {
64    // Mapping from (group_name, (cell_name, port_name)) -> (port, static_group).
65    port_rewrite: HashMap<PortInGroup, (RRC<ir::Port>, RRC<ir::StaticGroup>)>,
66}
67
68/// Represents (group_name, (cell_name, port_name))
69type PortInGroup = (ir::Id, ir::Canonical);
70
71impl Named for SimplifyWithControl {
72    fn name() -> &'static str {
73        "simplify-with-control"
74    }
75
76    fn description() -> &'static str {
77        "Transforms if-with and while-with to if and while"
78    }
79}
80
81impl Visitor for SimplifyWithControl {
82    fn start(
83        &mut self,
84        comp: &mut ir::Component,
85        sigs: &LibrarySignatures,
86        _comps: &[ir::Component],
87    ) -> VisResult {
88        let mut used_ports =
89            analysis::ControlPorts::<false>::from(&*comp.control.borrow());
90
91        // Early return if there are no combinational groups
92        if comp.comb_groups.is_empty() {
93            return Ok(Action::Stop);
94        }
95
96        // Detach the combinational groups from the component
97        let comb_groups = std::mem::take(&mut comp.comb_groups);
98        let mut builder = ir::Builder::new(comp, sigs);
99
100        // Groups generated by transforming combinational groups
101        let groups = comb_groups
102            .iter()
103            .map(|cg_ref| {
104                let name = cg_ref.borrow().name();
105                // Register the ports read by the combinational group's usages.
106                let used_ports = used_ports.remove(&name).ok_or_else(|| {
107                    Error::malformed_structure(format!(
108                        "values from combinational group `{}` never used",
109                        name
110                    ))
111                })?;
112
113                // Group generated to replace this comb group.
114                let group_ref = builder.add_static_group(name, 1);
115                let mut group = group_ref.borrow_mut();
116                // Attach assignmens from comb group
117                group.assignments = cg_ref
118                    .borrow_mut()
119                    .assignments
120                    .clone()
121                    .into_iter()
122                    .map(|x| x.into())
123                    .collect();
124
125                // Registers to save value for the group
126                let mut save_regs = Vec::with_capacity(used_ports.len());
127                for port in used_ports {
128                    // Register to save port value
129                    structure!(builder;
130                        let comb_reg = prim std_reg(port.borrow().width);
131                        let signal_on = constant(1, 1);
132                    );
133                    let write = builder.build_assignment(
134                        comb_reg.borrow().get("in"),
135                        Rc::clone(&port),
136                        ir::Guard::True,
137                    );
138                    let en = builder.build_assignment(
139                        comb_reg.borrow().get("write_en"),
140                        signal_on.borrow().get("out"),
141                        ir::Guard::True,
142                    );
143                    group.assignments.push(write);
144                    group.assignments.push(en);
145
146                    // Define mapping from this port to the register's output
147                    // value.
148                    self.port_rewrite.insert(
149                        (name, port.borrow().canonical().clone()),
150                        (
151                            Rc::clone(&comb_reg.borrow().get("out")),
152                            Rc::clone(&group_ref),
153                        ),
154                    );
155
156                    save_regs.push(comb_reg);
157                }
158
159                // No need for a done condition
160                drop(group);
161
162                Ok(group_ref)
163            })
164            .collect::<CalyxResult<Vec<_>>>()?;
165
166        for group in groups {
167            comp.get_static_groups_mut().add(group)
168        }
169
170        // Restore the combinational groups
171        comp.comb_groups = comb_groups;
172
173        Ok(Action::Continue)
174    }
175
176    fn finish_while(
177        &mut self,
178        s: &mut ir::While,
179        _comp: &mut ir::Component,
180        _sigs: &LibrarySignatures,
181        _comps: &[ir::Component],
182    ) -> VisResult {
183        if s.cond.is_none() {
184            return Ok(Action::Continue);
185        }
186
187        // Construct a new `while` statement
188        let key = (
189            s.cond.as_ref().unwrap().borrow().name(),
190            s.port.borrow().canonical(),
191        );
192        let (port_ref, cond_ref) = self.port_rewrite.get(&key).unwrap();
193        let cond_in_body = ir::Control::static_enable(Rc::clone(cond_ref));
194        let body = std::mem::replace(s.body.as_mut(), ir::Control::empty());
195        let new_body = ir::Control::seq(vec![body, cond_in_body]);
196        let mut while_ =
197            ir::Control::while_(Rc::clone(port_ref), None, Box::new(new_body));
198        let attrs = while_.get_mut_attributes();
199        *attrs = std::mem::take(&mut s.attributes);
200        let cond_before_body = ir::Control::static_enable(Rc::clone(cond_ref));
201        Ok(Action::change(ir::Control::seq(vec![
202            cond_before_body,
203            while_,
204        ])))
205    }
206
207    /// Transforms a `if-with` into a `seq-if` which first runs the cond group
208    /// and then the branch.
209    fn finish_if(
210        &mut self,
211        s: &mut ir::If,
212        _comp: &mut ir::Component,
213        _sigs: &LibrarySignatures,
214        _comps: &[ir::Component],
215    ) -> VisResult {
216        if s.cond.is_none() {
217            return Ok(Action::Continue);
218        }
219        // Construct a new `if` statement
220        let key = (
221            s.cond.as_ref().unwrap().borrow().name(),
222            s.port.borrow().canonical(),
223        );
224        let (port_ref, cond_ref) =
225            self.port_rewrite.get(&key).unwrap_or_else(|| {
226                panic!(
227                    "{}: Port `{}` in group `{}` doesn't have a rewrite",
228                    Self::name(),
229                    key.1,
230                    key.0
231                )
232            });
233        let tbranch =
234            std::mem::replace(s.tbranch.as_mut(), ir::Control::empty());
235        let fbranch =
236            std::mem::replace(s.fbranch.as_mut(), ir::Control::empty());
237        let mut if_ = ir::Control::if_(
238            Rc::clone(port_ref),
239            None,
240            Box::new(tbranch),
241            Box::new(fbranch),
242        );
243        let attrs = if_.get_mut_attributes();
244        *attrs = std::mem::take(&mut s.attributes);
245        let cond = ir::Control::static_enable(Rc::clone(cond_ref));
246        Ok(Action::change(ir::Control::seq(vec![cond, if_])))
247    }
248
249    fn finish(
250        &mut self,
251        comp: &mut ir::Component,
252        _sigs: &LibrarySignatures,
253        _comps: &[ir::Component],
254    ) -> VisResult {
255        if comp.is_static() {
256            let msg =
257                format!("Static Component {} has combinational groups which is not supported", comp.name);
258            return Err(Error::pass_assumption(Self::name(), msg)
259                .with_pos(&comp.attributes));
260        }
261        Ok(Action::Continue)
262    }
263}