calyx_opt/passes/
compile_invoke.rs

1use crate::traversal::{
2    self, Action, ConstructVisitor, Named, VisResult, Visitor,
3};
4use calyx_ir::structure;
5use calyx_ir::{self as ir, Attributes, LibrarySignatures};
6use calyx_utils::{CalyxResult, Error};
7use ir::{Assignment, RRC, WRC};
8use itertools::Itertools;
9use std::collections::HashMap;
10use std::rc::Rc;
11
12use super::dump_ports;
13
14// given `cell_ref` returns the `go` port of the cell (if it only has one `go` port),
15// or an error otherwise
16fn get_go_port(cell_ref: ir::RRC<ir::Cell>) -> CalyxResult<ir::RRC<ir::Port>> {
17    let cell = cell_ref.borrow();
18    let name = cell.name();
19
20    // Get the go port
21    match cell.find_unique_with_attr(ir::NumAttr::Go) {
22        Ok(Some(p)) => Ok(p),
23        Ok(None) => Err(Error::malformed_control(format!(
24            "Invoked component `{name}` does not define a @go signal. Cannot compile the invoke",
25        ))),
26        Err(_) => {
27            Err(Error::malformed_control(format!(
28                "Invoked component `{name}` defines multiple @go signals. Cannot compile the invoke",
29            )))
30        }
31    }
32}
33
34// given inputs and outputs (of the invoke), and the `enable_assignments` (e.g., invoked_component.go = 1'd1)
35// and a cell, builds the assignments for the corresponding group
36fn build_assignments<T>(
37    inputs: &mut Vec<(ir::Id, ir::RRC<ir::Port>)>,
38    outputs: &mut Vec<(ir::Id, ir::RRC<ir::Port>)>,
39    builder: &mut ir::Builder,
40    cell: &ir::Cell,
41) -> Vec<ir::Assignment<T>> {
42    inputs
43        .drain(..)
44        .map(|(inp, p)| {
45            builder.build_assignment(cell.get(inp), p, ir::Guard::True)
46        })
47        .chain(outputs.drain(..).map(|(out, p)| {
48            builder.build_assignment(p, cell.get(out), ir::Guard::True)
49        }))
50        .collect()
51}
52
53#[derive(Default)]
54/// Map for storing added ports for each ref cell
55/// level of Hashmap represents:
56/// HashMap<-component name-, Hashmap<(-ref cell name-,-port name-), port>>;
57struct RefPortMap(HashMap<ir::Id, HashMap<ir::Canonical, RRC<ir::Port>>>);
58
59impl RefPortMap {
60    fn insert(
61        &mut self,
62        comp_name: ir::Id,
63        ports: HashMap<ir::Canonical, RRC<ir::Port>>,
64    ) {
65        self.0.insert(comp_name, ports);
66    }
67
68    fn get(
69        &self,
70        comp_name: &ir::Id,
71    ) -> Option<&HashMap<ir::Canonical, RRC<ir::Port>>> {
72        self.0.get(comp_name)
73    }
74
75    /// Get all of the newly added ports associated with a component that had
76    /// ref cells
77    fn get_ports(&self, comp_name: &ir::Id) -> Option<Vec<RRC<ir::Port>>> {
78        self.0
79            .get(comp_name)
80            .map(|map| map.values().cloned().collect())
81    }
82}
83
84/// Compiles [`ir::Invoke`] statements into an [`ir::Enable`] that runs the
85/// invoked component.
86pub struct CompileInvoke {
87    /// Mapping from component to the canonical port name of ref cell o
88    port_names: RefPortMap,
89    /// Mapping from the ports of cells that were removed to the new port on the
90    /// component signature.
91    removed: HashMap<ir::Canonical, ir::RRC<ir::Port>>,
92    /// Ref cells in the component. We hold onto these so that our references don't get invalidated
93    ref_cells: Vec<ir::RRC<ir::Cell>>,
94}
95
96impl ConstructVisitor for CompileInvoke {
97    fn from(_ctx: &ir::Context) -> CalyxResult<Self>
98    where
99        Self: Sized,
100    {
101        Ok(CompileInvoke {
102            port_names: RefPortMap::default(),
103            removed: HashMap::new(),
104            ref_cells: Vec::new(),
105        })
106    }
107
108    fn clear_data(&mut self) {
109        self.removed.clear();
110        self.ref_cells.clear()
111    }
112}
113
114impl Named for CompileInvoke {
115    fn name() -> &'static str {
116        "compile-invoke"
117    }
118
119    fn description() -> &'static str {
120        "Rewrites invoke statements to group enables"
121    }
122}
123
124impl CompileInvoke {
125    /// Given `ref_cells` of an invoke, returns `(inputs, outputs)` where
126    /// inputs are the corresponding inputs to the `invoke` and
127    /// outputs are the corresponding outputs to the `invoke`.
128    ///
129    /// Since this pass eliminates all ref cells in post order, we expect that
130    /// invoked component already had all of its ref cells removed.
131    fn ref_cells_to_ports<T>(
132        &mut self,
133        inv_cell: RRC<ir::Cell>,
134        ref_cells: impl Iterator<Item = (ir::Id, ir::RRC<ir::Cell>)>,
135    ) -> Vec<ir::Assignment<T>> {
136        let inv_comp = inv_cell.borrow().type_name().unwrap();
137        let mut assigns = Vec::new();
138        for (ref_cell_name, cell) in ref_cells {
139            log::debug!(
140                "Removing ref cell `{}` with {} ports",
141                ref_cell_name,
142                cell.borrow().ports.len()
143            );
144
145            // Mapping from canonical names of the ports of the ref cell to the
146            // new port defined on the signature of the component
147            let Some(comp_ports) = self.port_names.get(&inv_comp) else {
148                unreachable!("component `{}` invoked but not already visited by the pass", inv_comp)
149            };
150
151            // The type of the cell is the same as the ref cell so we can
152            // iterate over its ports and generate bindings for the ref cell.
153            for pr in &cell.borrow().ports {
154                let port = pr.borrow();
155                if port.has_attribute(ir::BoolAttr::Clk)
156                    || port.has_attribute(ir::BoolAttr::Reset)
157                {
158                    continue;
159                }
160
161                let canon = ir::Canonical::new(ref_cell_name, port.name);
162                let Some(comp_port) = comp_ports.get(&canon) else {
163                    unreachable!("port `{}` not found in the signature of {}. Known ports are: {}",
164                        canon,
165                        inv_comp,
166                        comp_ports.keys().map(|c| c.port.as_ref()).collect_vec().join(", ")
167                    )
168                };
169                // Get the port on the new cell with the same name as ref_port
170                let ref_port = inv_cell.borrow().get(comp_port.borrow().name);
171                log::debug!("Port `{}` -> `{}`", canon, ref_port.borrow().name);
172
173                let old_port = pr.borrow().canonical();
174                // If the port has been removed already, get the new port from the component's signature
175                let arg_port = if let Some(sig_pr) = self.removed.get(&old_port)
176                {
177                    log::debug!(
178                        "Port `{}` has been removed. Using `{}`",
179                        old_port,
180                        sig_pr.borrow().name
181                    );
182                    Rc::clone(sig_pr)
183                } else {
184                    Rc::clone(pr)
185                };
186
187                match port.direction {
188                    ir::Direction::Output => {
189                        log::debug!(
190                            "constructing: {} = {}",
191                            ref_port.borrow().canonical(),
192                            arg_port.borrow().canonical()
193                        );
194                        assigns.push(ir::Assignment::new(
195                            ref_port.clone(),
196                            arg_port,
197                        ));
198                    }
199                    ir::Direction::Input => {
200                        log::debug!(
201                            "constructing: {} = {}",
202                            arg_port.borrow().canonical(),
203                            ref_port.borrow().canonical(),
204                        );
205                        assigns.push(ir::Assignment::new(
206                            arg_port,
207                            ref_port.clone(),
208                        ));
209                    }
210                    _ => {
211                        unreachable!("Cell should have inout ports");
212                    }
213                }
214            }
215        }
216        assigns
217    }
218}
219
220impl Visitor for CompileInvoke {
221    fn iteration_order() -> crate::traversal::Order
222    where
223        Self: Sized,
224    {
225        traversal::Order::Post
226    }
227
228    fn start(
229        &mut self,
230        comp: &mut ir::Component,
231        _sigs: &LibrarySignatures,
232        _comps: &[ir::Component],
233    ) -> VisResult {
234        log::debug!("Visiting `{}`", comp.name);
235        // For all subcomponents that had a `ref` cell in them, we need to
236        // update their cell to have the new ports added from inlining the
237        // signatures of all the ref cells.
238        for cell in comp.cells.iter() {
239            let mut new_ports: Vec<RRC<ir::Port>> = Vec::new();
240            if let Some(name) = cell.borrow().type_name() {
241                if let Some(vec) = self.port_names.get_ports(&name) {
242                    log::debug!(
243                        "Updating ports of cell `{}' (type `{name}')",
244                        cell.borrow().name()
245                    );
246                    for p in vec.iter() {
247                        let new_port = ir::rrc(ir::Port {
248                            name: p.borrow().name,
249                            width: p.borrow().width,
250                            direction: p.borrow().direction.reverse(),
251                            parent: ir::PortParent::Cell(WRC::from(cell)),
252                            attributes: Attributes::default(),
253                        });
254                        new_ports.push(new_port);
255                    }
256                }
257            }
258            cell.borrow_mut().ports.extend(new_ports);
259        }
260
261        let dump_ports::DumpResults { cells, rewrites } =
262            dump_ports::dump_ports_to_signature(
263                comp,
264                |cell| cell.borrow().is_reference(),
265                true,
266            );
267
268        // Hold onto the cells so they don't get dropped.
269        self.ref_cells = cells;
270        self.removed = rewrites;
271
272        Ok(Action::Continue)
273    }
274
275    fn invoke(
276        &mut self,
277        s: &mut ir::Invoke,
278        comp: &mut ir::Component,
279        ctx: &LibrarySignatures,
280        _comps: &[ir::Component],
281    ) -> VisResult {
282        let mut builder = ir::Builder::new(comp, ctx);
283        let invoke_group = builder.add_group("invoke");
284        // Assigns representing the ref cell connections
285        invoke_group.borrow_mut().assignments.extend(
286            self.ref_cells_to_ports(Rc::clone(&s.comp), s.ref_cells.drain(..)),
287        );
288
289        // comp.go = 1'd1;
290        // invoke[done] = comp.done;
291        structure!(builder;
292            let one = constant(1, 1);
293        );
294        let cell = s.comp.borrow();
295        let go_port = get_go_port(Rc::clone(&s.comp))?;
296        let done_port = cell.find_unique_with_attr(ir::NumAttr::Done)?.unwrap();
297
298        // Build assignemnts
299        let go_assign = builder.build_assignment(
300            go_port,
301            one.borrow().get("out"),
302            ir::Guard::True,
303        );
304        let done_assign = builder.build_assignment(
305            invoke_group.borrow().get("done"),
306            done_port,
307            ir::Guard::True,
308        );
309
310        invoke_group
311            .borrow_mut()
312            .assignments
313            .extend(vec![go_assign, done_assign]);
314
315        // Generate argument assignments
316        let cell = &*s.comp.borrow();
317        let assigns = build_assignments(
318            &mut s.inputs,
319            &mut s.outputs,
320            &mut builder,
321            cell,
322        );
323        invoke_group.borrow_mut().assignments.extend(assigns);
324
325        // Add assignments from the attached combinational group
326        if let Some(cgr) = &s.comb_group {
327            let cg = &*cgr.borrow();
328            invoke_group
329                .borrow_mut()
330                .assignments
331                .extend(cg.assignments.iter().cloned())
332        }
333
334        // Copy "promotable" annotation from the `invoke` statement if present
335        if let Some(time) = s.attributes.get(ir::NumAttr::Promotable) {
336            invoke_group
337                .borrow_mut()
338                .attributes
339                .insert(ir::NumAttr::Promotable, time);
340        }
341
342        let mut en = ir::Enable {
343            group: invoke_group,
344            attributes: Attributes::default(),
345        };
346        if let Some(time) = s.attributes.get(ir::NumAttr::Promotable) {
347            en.attributes.insert(ir::NumAttr::Promotable, time);
348        }
349
350        Ok(Action::change(ir::Control::Enable(en)))
351    }
352
353    fn static_invoke(
354        &mut self,
355        s: &mut ir::StaticInvoke,
356        comp: &mut ir::Component,
357        ctx: &LibrarySignatures,
358        _comps: &[ir::Component],
359    ) -> VisResult {
360        let mut builder = ir::Builder::new(comp, ctx);
361
362        let invoke_group = builder.add_static_group("static_invoke", s.latency);
363
364        invoke_group.borrow_mut().assignments.extend(
365            self.ref_cells_to_ports(Rc::clone(&s.comp), s.ref_cells.drain(..)),
366        );
367
368        // comp.go = 1'd1;
369        structure!(builder;
370            let one = constant(1, 1);
371        );
372
373        // Get the go port
374        let go_port = get_go_port(Rc::clone(&s.comp))?;
375
376        // Checks whether compe is a static<n> component or an @interval(n) component.
377        let go_guard = if s
378            .comp
379            .borrow()
380            .ports
381            .iter()
382            .any(|port| port.borrow().attributes.has(ir::NumAttr::Interval))
383        {
384            // For @interval(n) components, we do not guard the comp.go
385            // We trigger the go signal for the entire interval.
386            ir::Guard::True
387        } else {
388            // For static<n> components, we guard the comp.go with %[0:1]
389            ir::Guard::Info(ir::StaticTiming::new((0, 1)))
390        };
391
392        // Build assignemnts
393        let go_assign: ir::Assignment<ir::StaticTiming> = builder
394            .build_assignment(go_port, one.borrow().get("out"), go_guard);
395        invoke_group.borrow_mut().assignments.push(go_assign);
396
397        // Generate argument assignments
398        let cell = &*s.comp.borrow();
399        let assigns = build_assignments(
400            &mut s.inputs,
401            &mut s.outputs,
402            &mut builder,
403            cell,
404        );
405        invoke_group.borrow_mut().assignments.extend(assigns);
406
407        if let Some(cgr) = &s.comb_group {
408            let cg = &*cgr.borrow();
409            invoke_group.borrow_mut().assignments.extend(
410                cg.assignments
411                    .iter()
412                    .cloned()
413                    .map(Assignment::from)
414                    .collect_vec(),
415            );
416        }
417
418        let en = ir::StaticEnable {
419            group: invoke_group,
420            attributes: Attributes::default(),
421        };
422
423        Ok(Action::StaticChange(Box::new(ir::StaticControl::Enable(
424            en,
425        ))))
426    }
427
428    fn finish(
429        &mut self,
430        comp: &mut ir::Component,
431        _sigs: &LibrarySignatures,
432        _comps: &[ir::Component],
433    ) -> VisResult {
434        let port_map = std::mem::take(&mut self.removed);
435        // Add the newly added port to the global port map
436        // Rewrite all of the ref cell ports
437        let rw = ir::Rewriter {
438            port_map,
439            ..Default::default()
440        };
441        rw.rewrite(comp);
442        self.port_names.insert(comp.name, rw.port_map);
443        Ok(Action::Continue)
444    }
445}