calyx_opt/analysis/
inference_analysis.rs

1use super::AssignmentAnalysis;
2use crate::analysis::{compute_static::WithStatic, GraphAnalysis};
3use calyx_ir::{self as ir, GetAttributes, RRC};
4use ir::CellType;
5use itertools::Itertools;
6use std::collections::{HashMap, HashSet};
7
8/// Struct to store information about the go-done interfaces defined by a primitive.
9/// There is no default implementation because it will almost certainly be very
10/// unhelpful: you will want to use `from_ctx`.
11#[derive(Debug)]
12pub struct GoDone {
13    ports: Vec<(ir::Id, ir::Id, u64)>,
14}
15
16impl GoDone {
17    pub fn new(ports: Vec<(ir::Id, ir::Id, u64)>) -> Self {
18        Self { ports }
19    }
20
21    /// Returns true if this is @go port
22    pub fn is_go(&self, name: &ir::Id) -> bool {
23        self.ports.iter().any(|(go, _, _)| name == go)
24    }
25
26    /// Returns true if this is a @done port
27    pub fn is_done(&self, name: &ir::Id) -> bool {
28        self.ports.iter().any(|(_, done, _)| name == done)
29    }
30
31    /// Returns the latency associated with the provided @go port if present
32    pub fn get_latency(&self, go_port: &ir::Id) -> Option<u64> {
33        self.ports.iter().find_map(|(go, _, lat)| {
34            if go == go_port {
35                Some(*lat)
36            } else {
37                None
38            }
39        })
40    }
41
42    /// Iterate over the defined ports
43    pub fn iter(&self) -> impl Iterator<Item = &(ir::Id, ir::Id, u64)> {
44        self.ports.iter()
45    }
46
47    /// Iterate over the defined ports
48    pub fn is_empty(&self) -> bool {
49        self.ports.is_empty()
50    }
51
52    /// Iterate over the defined ports
53    pub fn len(&self) -> usize {
54        self.ports.len()
55    }
56
57    /// Iterate over the defined ports
58    pub fn get_ports(&self) -> &Vec<(ir::Id, ir::Id, u64)> {
59        &self.ports
60    }
61}
62
63impl From<&ir::Primitive> for GoDone {
64    fn from(prim: &ir::Primitive) -> Self {
65        let done_ports: HashMap<_, _> = prim
66            .find_all_with_attr(ir::NumAttr::Done)
67            .map(|pd| (pd.attributes.get(ir::NumAttr::Done), pd.name()))
68            .collect();
69
70        let go_ports = prim
71            .find_all_with_attr(ir::NumAttr::Go)
72            .filter_map(|pd| {
73                // Primitives only have @interval.
74                pd.attributes.get(ir::NumAttr::Interval).and_then(|st| {
75                    done_ports
76                        .get(&pd.attributes.get(ir::NumAttr::Go))
77                        .map(|done_port| (pd.name(), *done_port, st))
78                })
79            })
80            .collect_vec();
81        GoDone::new(go_ports)
82    }
83}
84
85impl From<&ir::Cell> for GoDone {
86    fn from(cell: &ir::Cell) -> Self {
87        let done_ports: HashMap<_, _> = cell
88            .find_all_with_attr(ir::NumAttr::Done)
89            .map(|pr| {
90                let port = pr.borrow();
91                (port.attributes.get(ir::NumAttr::Done), port.name)
92            })
93            .collect();
94
95        let go_ports = cell
96            .find_all_with_attr(ir::NumAttr::Go)
97            .filter_map(|pr| {
98                let port = pr.borrow();
99                // Get static interval thru either @interval or @promotable.
100                let st = match port.attributes.get(ir::NumAttr::Interval) {
101                    Some(st) => Some(st),
102                    None => port.attributes.get(ir::NumAttr::Promotable),
103                };
104                if let Some(static_latency) = st {
105                    return done_ports
106                        .get(&port.attributes.get(ir::NumAttr::Go))
107                        .map(|done_port| {
108                            (port.name, *done_port, static_latency)
109                        });
110                }
111                None
112            })
113            .collect_vec();
114        GoDone::new(go_ports)
115    }
116}
117
118/// Default implemnetation is almost certainly not helpful.
119/// You should probably use `from_ctx` instead.
120pub struct InferenceAnalysis {
121    /// component name -> vec<(go signal, done signal, latency)>
122    pub latency_data: HashMap<ir::Id, GoDone>,
123    /// Maps static component names to their latencies, but there can only
124    /// be one go port on the component. (This is a subset of the information
125    /// given by latency_data), and is helpful for inferring invokes.
126    /// Perhaps someday we should get rid of it and only make it one field.
127    pub static_component_latencies: HashMap<ir::Id, u64>,
128
129    updated_components: HashSet<ir::Id>,
130}
131
132impl InferenceAnalysis {
133    /// Builds FixUp struct from a ctx. Looks at all primitives and component
134    /// signatures to get latency information.
135    pub fn from_ctx(ctx: &ir::Context) -> Self {
136        let mut latency_data = HashMap::new();
137        let mut static_component_latencies = HashMap::new();
138        // Construct latency_data for each primitive
139        for prim in ctx.lib.signatures() {
140            let prim_go_done = GoDone::from(prim);
141            if prim_go_done.len() == 1 {
142                static_component_latencies
143                    .insert(prim.name, prim_go_done.get_ports()[0].2);
144            }
145            latency_data.insert(prim.name, GoDone::from(prim));
146        }
147        for comp in &ctx.components {
148            let comp_sig = comp.signature.borrow();
149
150            let done_ports: HashMap<_, _> = comp_sig
151                .find_all_with_attr(ir::NumAttr::Done)
152                .map(|pd| {
153                    let pd_ref = pd.borrow();
154                    (pd_ref.attributes.get(ir::NumAttr::Done), pd_ref.name)
155                })
156                .collect();
157
158            let go_ports = comp_sig
159                .find_all_with_attr(ir::NumAttr::Go)
160                .filter_map(|pd| {
161                    let pd_ref = pd.borrow();
162                    // Get static interval thru either @interval or @promotable.
163                    let st = match pd_ref.attributes.get(ir::NumAttr::Interval)
164                    {
165                        Some(st) => Some(st),
166                        None => pd_ref.attributes.get(ir::NumAttr::Promotable),
167                    };
168                    if let Some(static_latency) = st {
169                        return done_ports
170                            .get(&pd_ref.attributes.get(ir::NumAttr::Go))
171                            .map(|done_port| {
172                                (pd_ref.name, *done_port, static_latency)
173                            });
174                    }
175                    None
176                })
177                .collect_vec();
178
179            let go_done_comp = GoDone::new(go_ports);
180
181            if go_done_comp.len() == 1 {
182                static_component_latencies
183                    .insert(comp.name, go_done_comp.get_ports()[0].2);
184            }
185            latency_data.insert(comp.name, go_done_comp);
186        }
187        InferenceAnalysis {
188            latency_data,
189            static_component_latencies,
190            updated_components: HashSet::new(),
191        }
192    }
193
194    /// Updates the component, given a component name and a new latency and GoDone object.
195    pub fn add_component(
196        &mut self,
197        (comp_name, latency, go_done): (ir::Id, u64, GoDone),
198    ) {
199        self.latency_data.insert(comp_name, go_done);
200        self.static_component_latencies.insert(comp_name, latency);
201    }
202
203    /// Updates the component, given a component name and a new latency.
204    /// Note that this expects that the component already is accounted for
205    /// in self.latency_data and self.static_component_latencies.
206    pub fn remove_component(&mut self, comp_name: ir::Id) {
207        if self.latency_data.contains_key(&comp_name) {
208            // To make inference as strong as possible, only update updated_components
209            // if we actually updated it.
210            self.updated_components.insert(comp_name);
211        }
212        self.latency_data.remove(&comp_name);
213        self.static_component_latencies.remove(&comp_name);
214    }
215
216    /// Updates the component, given a component name and a new latency.
217    /// Note that this expects that the component already is accounted for
218    /// in self.latency_data and self.static_component_latencies.
219    pub fn adjust_component(
220        &mut self,
221        (comp_name, adjusted_latency): (ir::Id, u64),
222    ) {
223        // Check whether we actually updated the component's latency.
224        let mut updated = false;
225        self.latency_data.entry(comp_name).and_modify(|go_done| {
226            for (_, _, cur_latency) in &mut go_done.ports {
227                // Updating components with latency data.
228                if *cur_latency != adjusted_latency {
229                    *cur_latency = adjusted_latency;
230                    updated = true;
231                }
232            }
233        });
234        self.static_component_latencies
235            .insert(comp_name, adjusted_latency);
236        if updated {
237            self.updated_components.insert(comp_name);
238        }
239    }
240
241    /// Return true if the edge (`src`, `dst`) meet one these criteria, and false otherwise:
242    ///   - `src` is an "out" port of a constant, and `dst` is a "go" port
243    ///   - `src` is a "done" port, and `dst` is a "go" port
244    ///   - `src` is a "done" port, and `dst` is the "done" port of a group
245    fn mem_wrt_dep_graph(&self, src: &ir::Port, dst: &ir::Port) -> bool {
246        match (&src.parent, &dst.parent) {
247            (
248                ir::PortParent::Cell(src_cell_wrf),
249                ir::PortParent::Cell(dst_cell_wrf),
250            ) => {
251                let src_rf = src_cell_wrf.upgrade();
252                let src_cell = src_rf.borrow();
253                let dst_rf = dst_cell_wrf.upgrade();
254                let dst_cell = dst_rf.borrow();
255                if let (Some(s_name), Some(d_name)) =
256                    (src_cell.type_name(), dst_cell.type_name())
257                {
258                    let data_src = self.latency_data.get(&s_name);
259                    let data_dst = self.latency_data.get(&d_name);
260                    if let (Some(dst_ports), Some(src_ports)) =
261                        (data_dst, data_src)
262                    {
263                        return src_ports.is_done(&src.name)
264                            && dst_ports.is_go(&dst.name);
265                    }
266                }
267
268                // A constant writes to a cell: to be added to the graph, the cell needs to be a "done" port.
269                if let (Some(d_name), ir::CellType::Constant { .. }) =
270                    (dst_cell.type_name(), &src_cell.prototype)
271                {
272                    if let Some(ports) = self.latency_data.get(&d_name) {
273                        return ports.is_go(&dst.name);
274                    }
275                }
276
277                false
278            }
279
280            // Something is written to a group: to be added to the graph, this needs to be a "done" port.
281            (_, ir::PortParent::Group(_)) => dst.name == "done",
282
283            // If we encounter anything else, no need to add it to the graph.
284            _ => false,
285        }
286    }
287
288    /// Return a Vec of edges (`a`, `b`), where `a` is a "go" port and `b`
289    /// is a "done" port, and `a` and `b` have the same parent cell.
290    fn find_go_done_edges(
291        &self,
292        group: &ir::Group,
293    ) -> Vec<(RRC<ir::Port>, RRC<ir::Port>)> {
294        let rw_set = group.assignments.iter().analysis().cell_uses();
295        let mut go_done_edges: Vec<(RRC<ir::Port>, RRC<ir::Port>)> = Vec::new();
296
297        for cell_ref in rw_set {
298            let cell = cell_ref.borrow();
299            if let Some(ports) =
300                cell.type_name().and_then(|c| self.latency_data.get(&c))
301            {
302                go_done_edges.extend(
303                    ports
304                        .iter()
305                        .map(|(go, done, _)| (cell.get(go), cell.get(done))),
306                )
307            }
308        }
309        go_done_edges
310    }
311
312    /// Returns true if `port` is a "done" port, and we know the latency data
313    /// about `port`, or is a constant.
314    fn is_done_port_or_const(&self, port: &ir::Port) -> bool {
315        if let ir::PortParent::Cell(cwrf) = &port.parent {
316            let cr = cwrf.upgrade();
317            let cell = cr.borrow();
318            if let ir::CellType::Constant { val, .. } = &cell.prototype {
319                if *val > 0 {
320                    return true;
321                }
322            } else if let Some(ports) =
323                cell.type_name().and_then(|c| self.latency_data.get(&c))
324            {
325                return ports.is_done(&port.name);
326            }
327        }
328        false
329    }
330
331    /// Returns true if `graph` contains writes to "done" ports
332    /// that could have dynamic latencies, false otherwise.
333    fn contains_dyn_writes(&self, graph: &GraphAnalysis) -> bool {
334        for port in &graph.ports() {
335            match &port.borrow().parent {
336              ir::PortParent::Cell(cell_wrf) => {
337                  let cr = cell_wrf.upgrade();
338                  let cell = cr.borrow();
339                  if let Some(ports) =
340                      cell.type_name().and_then(|c| self.latency_data.get(&c))
341                  {
342                      let name = &port.borrow().name;
343                      if ports.is_go(name) {
344                          for write_port in graph.writes_to(&port.borrow()) {
345                              if !self
346                                  .is_done_port_or_const(&write_port.borrow())
347                              {
348                                  log::debug!(
349                                      "`{}` is not a done port",
350                                      write_port.borrow().canonical(),
351                                  );
352                                  return true;
353                              }
354                          }
355                      }
356                  }
357              }
358              ir::PortParent::Group(_) => {
359                  if port.borrow().name == "done" {
360                      for write_port in graph.writes_to(&port.borrow()) {
361                          if !self.is_done_port_or_const(&write_port.borrow())
362                          {
363                              log::debug!(
364                                  "`{}` is not a done port",
365                                  write_port.borrow().canonical(),
366                              );
367                              return true;
368                          }
369                      }
370                  }
371              }
372
373              ir::PortParent::StaticGroup(_) => // done ports of static groups should clearly NOT have static latencies
374              panic!("Have not decided how to handle static groups in infer-static-timing"),
375          }
376        }
377        false
378    }
379
380    /// Returns true if `graph` contains any nodes with degree > 1.
381    fn contains_node_deg_gt_one(graph: &GraphAnalysis) -> bool {
382        for port in graph.ports() {
383            if graph.writes_to(&port.borrow()).count() > 1 {
384                return true;
385            }
386        }
387        false
388    }
389
390    /// Attempts to infer the number of cycles starting when
391    /// `group[go]` is high, and port is high. If inference is
392    /// not possible, returns None.
393    fn infer_latency(&self, group: &ir::Group) -> Option<u64> {
394        // Creates a write dependency graph, which contains an edge (`a`, `b`) if:
395        //   - `a` is a "done" port, and writes to `b`, which is a "go" port
396        //   - `a` is a "done" port, and writes to `b`, which is the "done" port of this group
397        //   - `a` is an "out" port, and is a constant, and writes to `b`, a "go" port
398        //   - `a` is a "go" port, and `b` is a "done" port, and `a` and `b` share a parent cell
399        // Nodes that are not part of any edges that meet these criteria are excluded.
400        //
401        // For example, this group:
402        // ```
403        // group g1 {
404        //   a.in = 32'd1;
405        //   a.write_en = 1'd1;
406        //   g1[done] = a.done;
407        // }
408        // ```
409        // corresponds to this graph:
410        // ```
411        // constant(1) -> a.write_en
412        // a.write_en -> a.done
413        // a.done -> g1[done]
414        // ```
415        log::debug!("Checking group `{}`", group.name());
416        let graph_unprocessed = GraphAnalysis::from(group);
417        if self.contains_dyn_writes(&graph_unprocessed) {
418            log::debug!("FAIL: contains dynamic writes");
419            return None;
420        }
421
422        let go_done_edges = self.find_go_done_edges(group);
423        let graph = graph_unprocessed
424            .edge_induced_subgraph(|src, dst| self.mem_wrt_dep_graph(src, dst))
425            .add_edges(&go_done_edges)
426            .remove_isolated_vertices();
427
428        // Give up if a port has multiple writes to it.
429        if Self::contains_node_deg_gt_one(&graph) {
430            log::debug!("FAIL: Group contains multiple writes");
431            return None;
432        }
433
434        let mut tsort = graph.toposort();
435        let start = tsort.next()?;
436        let finish = tsort.last()?;
437
438        let paths = graph.paths(&start.borrow(), &finish.borrow());
439        // If there are no paths, give up.
440        if paths.is_empty() {
441            log::debug!("FAIL: No path between @go and @done port");
442            return None;
443        }
444        let first_path = paths.first().unwrap();
445
446        // Sum the latencies of each primitive along the path.
447        let mut latency_sum = 0;
448        for port in first_path {
449            if let ir::PortParent::Cell(cwrf) = &port.borrow().parent {
450                let cr = cwrf.upgrade();
451                let cell = cr.borrow();
452                if let Some(ports) =
453                    cell.type_name().and_then(|c| self.latency_data.get(&c))
454                {
455                    if let Some(latency) =
456                        ports.get_latency(&port.borrow().name)
457                    {
458                        latency_sum += latency;
459                    }
460                }
461            }
462        }
463
464        log::debug!("SUCCESS: Latency = {}", latency_sum);
465        Some(latency_sum)
466    }
467
468    /// Returns Some(latency) if a control statement has a latency, because
469    /// it is static or is has the @promotable attribute
470    pub fn get_possible_latency(c: &ir::Control) -> Option<u64> {
471        match c {
472            ir::Control::Static(sc) => Some(sc.get_latency()),
473            _ => c.get_attribute(ir::NumAttr::Promotable),
474        }
475    }
476
477    pub fn remove_promotable_from_seq(seq: &mut ir::Seq) {
478        for stmt in &mut seq.stmts {
479            Self::remove_promotable_attribute(stmt);
480        }
481        seq.get_mut_attributes().remove(ir::NumAttr::Promotable);
482    }
483
484    /// Removes the @promotable attribute from the control program.
485    /// Recursively visits the children of the control.
486    pub fn remove_promotable_attribute(c: &mut ir::Control) {
487        c.get_mut_attributes().remove(ir::NumAttr::Promotable);
488        match c {
489            ir::Control::Empty(_)
490            | ir::Control::Invoke(_)
491            | ir::Control::Enable(_)
492            | ir::Control::Static(_) => (),
493            ir::Control::While(ir::While { body, .. })
494            | ir::Control::Repeat(ir::Repeat { body, .. }) => {
495                Self::remove_promotable_attribute(body);
496            }
497            ir::Control::If(ir::If {
498                tbranch, fbranch, ..
499            }) => {
500                Self::remove_promotable_attribute(tbranch);
501                Self::remove_promotable_attribute(fbranch);
502            }
503            ir::Control::Seq(ir::Seq { stmts, .. })
504            | ir::Control::Par(ir::Par { stmts, .. }) => {
505                for stmt in stmts {
506                    Self::remove_promotable_attribute(stmt);
507                }
508            }
509        }
510    }
511
512    pub fn fixup_seq(&self, seq: &mut ir::Seq) {
513        seq.update_static(&self.static_component_latencies);
514    }
515
516    pub fn fixup_par(&self, par: &mut ir::Par) {
517        par.update_static(&self.static_component_latencies);
518    }
519
520    pub fn fixup_if(&self, _if: &mut ir::If) {
521        _if.update_static(&self.static_component_latencies);
522    }
523
524    pub fn fixup_while(&self, _while: &mut ir::While) {
525        _while.update_static(&self.static_component_latencies);
526    }
527
528    pub fn fixup_repeat(&self, repeat: &mut ir::Repeat) {
529        repeat.update_static(&self.static_component_latencies);
530    }
531
532    pub fn fixup_ctrl(&self, ctrl: &mut ir::Control) {
533        ctrl.update_static(&self.static_component_latencies);
534    }
535
536    /// "Fixes Up" the component. In particular:
537    /// 1. Removes @promotable annotations for any groups that write to any
538    /// `updated_components`.
539    /// 2. Try to re-infer groups' latencies.
540    /// 3. Removes all @promotable annotation from the control program.
541    /// 4. Re-infers the @promotable annotations for any groups or control.
542    /// Note that this only fixes up the component's ``internals''.
543    /// It does *not* fix the component's signature.
544    pub fn fixup_timing(&self, comp: &mut ir::Component) {
545        // Removing @promotable annotations for any groups that write to an updated_component,
546        // then try to re-infer the latency.
547        for group in comp.groups.iter() {
548            // This checks any group that writes to the component:
549            // We can probably switch this to any group that writes to the component's
550            // `go` port to be more precise analysis.
551            if group
552                .borrow_mut()
553                .assignments
554                .iter()
555                .analysis()
556                .cell_writes()
557                .any(|cell| match cell.borrow().prototype {
558                    CellType::Component { name } => {
559                        self.updated_components.contains(&name)
560                    }
561                    _ => false,
562                })
563            {
564                // Remove attribute from group.
565                group
566                    .borrow_mut()
567                    .attributes
568                    .remove(ir::NumAttr::Promotable);
569            }
570        }
571
572        for group in &mut comp.groups.iter() {
573            // Immediately try to re-infer the latency of the group.
574            let latency_result = self.infer_latency(&group.borrow());
575            if let Some(latency) = latency_result {
576                group
577                    .borrow_mut()
578                    .attributes
579                    .insert(ir::NumAttr::Promotable, latency);
580            }
581        }
582
583        // Removing @promotable annotations for the control flow, then trying
584        // to re-infer them.
585        Self::remove_promotable_attribute(&mut comp.control.borrow_mut());
586        comp.control
587            .borrow_mut()
588            .update_static(&self.static_component_latencies);
589    }
590}