calyx_opt/analysis/domination_analysis/
dominator_map.rs

1use crate::analysis::{
2    domination_analysis::{
3        node_analysis::{NodeReads, NodeSearch},
4        static_par_domination::StaticParDomination,
5    },
6    ControlId, ShareSet,
7};
8use calyx_ir as ir;
9use ir::GenericControl;
10use std::collections::{HashMap, HashSet};
11use std::fmt::Debug;
12
13const NODE_ID: ir::Attribute =
14    ir::Attribute::Internal(ir::InternalAttr::NODE_ID);
15const BEGIN_ID: ir::Attribute =
16    ir::Attribute::Internal(ir::InternalAttr::BEGIN_ID);
17const END_ID: ir::Attribute = ir::Attribute::Internal(ir::InternalAttr::END_ID);
18
19/// Builds a Domination Map for the control program. It maps nodes to sets of
20/// nodes. Here is what is included as a "node" in the domination map:
21/// - Invokes
22/// - Enables
23/// - While Guards
24/// - If Guards
25/// - "End" If nodes, representing the place we're at in the program after the if
26/// statement has just finished. This doesn't correspond to any actual Calyx code, but is
27/// just a conceptualization we use to reason about domination.
28/// Note that seqs and pars will *not* be included in the domination map.
29///
30/// Here is the algorithm we use to build the domination map.
31/// - Start with an emtpy map.
32/// - Visit each node n in the control program, and set:
33/// - dom(n) = {U dom(p) for each predecessor p of n} U {n}. In other words, take the
34/// dominators of each predecessor of n, and union them together. Then add n to
35/// this set, and set this set as the dominators of n.
36/// - (Another clarification): by "predecessors" of node n we mean the set of nodes
37/// that could be the most recent node executed when n begins to execute.
38/// - If we visit every node of the control program and the map has not changed,
39/// then we are done. If it has changed, then we visit each node again to repeat
40/// the process.
41///
42/// The reason why we can take the union (rather than intersection) of the
43/// dominators of each predecessor is because we know each predecessor of each
44/// node must (rather than may) be executed.
45/// There are two exceptions to this general rule, and we have special cases in
46/// our algorithm to deal with them.
47///
48/// 1) The While Guard
49/// The last node(s) in the while body are predecessor(s) of the while guard but
50/// are not guaranteed to be executed. So, we can think of the while guard's
51/// predecessors as being split in two groups: the "body predecessors" that are not guaranteed to
52/// be executed before the while guard and the "outside predecessors" that are
53/// outside the body of the while loop and are guaranteed to be executed before
54/// the while loop guard.
55/// Here we take:
56/// dom(while guard) = U(dom(outside preds)) U {while guard}
57///
58/// Justification:
59/// dom(while guard) is a subset of U(dom(outside preds)) U {while guard}
60/// Suppose n dominates the while guard. Every path to the while guard must end in
61/// 1) outside pred -> while guard OR 2) body pred -> while guard. But for choice 2)
62/// we know the path was really something like outside pred -> while guard -> body
63/// -> while guard... body -> while guard. Since n dominates the while guard
64/// we know that it *cannot* be in the while body. Therefore, since every path to the
65/// while guard is in the form outside pred -> [possibly while guard + some other
66/// while body statements] -> while guard, we know that n must either dominate
67/// outside pred or be the while guard itself.
68///
69/// dom(outside preds) U {while guard} is a subset of dom(while guard)
70/// Suppose n dominates outside preds. Since we already established that every
71/// path to the while guard involves going through otuside preds, we know that
72/// n dominates the while guard.
73///
74/// 2) "End Node" of If Statements
75/// In this case, *neither* of the predecessor sets (the set in the tbranch or
76/// the set in the fbranch) are guaranteed to be executed.
77/// Here we take:
78/// dom(end node) = dom(if guard) U {end node}.
79///
80/// Justification:
81/// dom(end node) is a subset of dom(if guard) U {end node}.
82/// If n dominates the end node, then it either a) is the end node itself, or b) must
83/// dominate the if guard. Justification for b)
84/// Every possible path to the if guard must be followed by
85/// if guard -> tbranch/fbranch -> end node. We also know that n must exist
86/// outside the tbranch/fbranch (if it was inside either branch, it wouldn't
87/// dominate the end node). Therefore, since we know that n must have appeared somewhere
88/// before if_guard on the path to end node, we know n dominates the if guard.
89///
90/// dom(if guard) U {end node} is a subset of dom(end node)
91/// If n dominates the if guard or is itself the end node, then it is very easy to
92/// see how it will dominate the end node.
93#[derive(Default)]
94pub struct DominatorMap {
95    /// Map from node (either invokes, enables, or if/while ports) ids to the ids of nodes that dominate it
96    pub map: HashMap<u64, HashSet<u64>>,
97    /// Maps ids of control stmts, to the "last" nodes in them. By "last" is meant
98    /// the final node that will be executed in them. For invokes and enables, it
99    /// will be themselves, for while statements it will be the while guard,
100    /// and for if statements it will be the "if" nods. For pars in seqs, you
101    /// have to look inside the children to see what their "last" nodes are.
102    pub exits_map: HashMap<u64, HashSet<u64>>,
103    /// an analysis to help domination across static pars
104    /// static pars give us more precise timing guarantees and therefore allow
105    /// us to more aggresively assign dominators
106    pub static_par_domination: StaticParDomination,
107    pub component_name: ir::Id,
108}
109
110impl Debug for DominatorMap {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        //must sort the hashmap and hashsets in order to get consistent ordering
113        writeln!(
114            f,
115            "The numbers in the domination map refer to the BEGIN_ID, END_ID, and NODE_ID attributes \nthat are attached to each non-empty control statement when the domination map is built. \nTo see which ID's refer to which control statement, look at the Calyx Program, which should \nbe printed along with the map when it is printed."
116        )?;
117        writeln!(
118            f,
119            "Domination Map for component \"{}\"  {{",
120            self.component_name
121        )?;
122        let map = self.map.clone();
123        let mut vec1: Vec<(u64, HashSet<u64>)> = map.into_iter().collect();
124        vec1.sort_by(|(k1, _), (k2, _)| k1.cmp(k2));
125        for (k, hs) in vec1.into_iter() {
126            write!(f, "Node: {:?} --", k)?;
127            let mut vec = hs.into_iter().collect::<Vec<_>>();
128            vec.sort_unstable();
129            writeln!(f, " Dominators: {:?}", vec)?;
130        }
131        write!(f, "}}")
132    }
133}
134
135#[inline]
136fn get_id_static<const BEGIN: bool>(c: &ir::StaticControl) -> u64 {
137    let v = match c {
138        ir::StaticControl::If(_) => {
139            if BEGIN {
140                c.get_attribute(BEGIN_ID)
141            } else {
142                c.get_attribute(END_ID)
143            }
144        }
145        _ => c.get_attribute(NODE_ID),
146    };
147    v.unwrap_or_else(|| unreachable!(
148            "get_id() shouldn't be called on control stmts that don't have id numbering"
149    ))
150}
151
152// Given a control, gets its associated id. For if statments, gets the
153// beginning id if begin_id is true and end_id if begin_id is false.
154// Should not be called on empty control
155// statements or any other statements that don't have an id numbering.
156#[inline]
157fn get_id<const BEGIN: bool>(c: &ir::Control) -> u64 {
158    let v = match c {
159        ir::Control::If(_) | ir::Control::Static(ir::StaticControl::If(_)) => {
160            if BEGIN {
161                c.get_attribute(BEGIN_ID)
162            } else {
163                c.get_attribute(END_ID)
164            }
165        }
166        _ => c.get_attribute(NODE_ID),
167    };
168    v.unwrap_or_else(|| unreachable!(
169            "get_id() shouldn't be called on control stmts that don't have id numbering"
170    ))
171}
172
173fn matches_key_static(sc: &ir::StaticControl, key: u64) -> bool {
174    if get_id_static::<true>(sc) == key {
175        return true;
176    }
177    //could match the end id of an if statement as well
178    if let Some(end) = sc.get_attribute(END_ID) {
179        key == end
180    } else {
181        false
182    }
183}
184
185// Given a control stmt c and a key, returns true if c matches key, false
186// otherwise. For if stmts return true if key matches either begin or end id.
187fn matches_key(c: &ir::Control, key: u64) -> bool {
188    if get_id::<true>(c) == key {
189        return true;
190    }
191    //could match the end id of an if statement as well
192    if let Some(end) = c.get_attribute(END_ID) {
193        key == end
194    } else {
195        false
196    }
197}
198
199fn get_final_static(sc: &ir::StaticControl) -> HashSet<u64> {
200    let mut hs = HashSet::new();
201    match sc {
202        ir::StaticControl::Empty(_) => (),
203        ir::StaticControl::Enable(_) | ir::StaticControl::Invoke(_) => {
204            hs.insert(ControlId::get_guaranteed_attribute_static(sc, NODE_ID));
205        }
206        ir::StaticControl::Repeat(ir::StaticRepeat {
207            body,
208            num_repeats,
209            ..
210        }) => {
211            // `Repeat 0` statements are essentially just Control::empty() stmts
212            // and therefore do not have "final" nodes
213            if *num_repeats != 0 {
214                return get_final_static(body);
215            }
216        }
217        ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => {
218            return get_final_static(stmts[..].last().unwrap_or_else(|| {
219                panic!(
220                    "error: empty Static Seq block. TODO: Make Static Seq work on collapse-control pass."
221                )
222            }));
223        }
224        ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
225            for stmt in stmts {
226                let stmt_final = get_final_static(stmt);
227                hs = hs.union(&stmt_final).copied().collect()
228            }
229        }
230        ir::StaticControl::If(_) => {
231            hs.insert(ControlId::get_guaranteed_attribute_static(sc, END_ID));
232        }
233    }
234    hs
235}
236
237// Gets the "final" nodes in control c. Used to build exits_map.
238fn get_final(c: &ir::Control) -> HashSet<u64> {
239    let mut hs = HashSet::new();
240    match c {
241        ir::Control::Empty(_) => (),
242        ir::Control::Invoke(_)
243        | ir::Control::Enable(_)
244        | ir::Control::While(_) => {
245            hs.insert(ControlId::get_guaranteed_attribute(c, NODE_ID));
246        }
247        ir::Control::Repeat(ir::Repeat {
248            body, num_repeats, ..
249        }) => {
250            // `Repeat 0` statements are essentially just Control::empty() stmts
251            // and therefore do not have "final" nodes
252            if *num_repeats != 0 {
253                return get_final(body);
254            }
255        }
256        ir::Control::If(_) => {
257            hs.insert(ControlId::get_guaranteed_attribute(c, END_ID));
258        }
259        ir::Control::Seq(ir::Seq { stmts, .. }) => {
260            return get_final(stmts[..].last().unwrap_or_else(|| {
261                panic!("error: empty Seq block. Run collapse-control pass.")
262            }));
263        }
264        ir::Control::Par(ir::Par { stmts, .. }) => {
265            for stmt in stmts {
266                let stmt_final = get_final(stmt);
267                hs = hs.union(&stmt_final).copied().collect()
268            }
269        }
270        ir::Control::Static(s) => return get_final_static(s),
271    }
272    hs
273}
274
275impl DominatorMap {
276    /// Construct a domination map.
277    pub fn new(control: &mut ir::Control, component_name: ir::Id) -> Self {
278        ControlId::compute_unique_ids(control, 0, true);
279        let mut map = DominatorMap {
280            map: HashMap::new(),
281            exits_map: HashMap::new(),
282            static_par_domination: StaticParDomination::new(
283                control,
284                component_name,
285            ),
286            component_name,
287        };
288        map.build_exit_map(control);
289        map.build_map(control);
290        map
291    }
292
293    fn build_exit_map_static(&mut self, sc: &ir::StaticControl) {
294        match sc {
295            ir::StaticControl::Enable(_) | ir::StaticControl::Invoke(_) => {
296                let id =
297                    ControlId::get_guaranteed_attribute_static(sc, NODE_ID);
298                self.exits_map.insert(id, HashSet::from([id]));
299            }
300            ir::StaticControl::Repeat(ir::StaticRepeat { body, .. }) => {
301                let id =
302                    ControlId::get_guaranteed_attribute_static(sc, NODE_ID);
303                self.exits_map.insert(id, get_final_static(sc));
304                self.build_exit_map_static(body);
305            }
306            ir::StaticControl::Seq(ir::StaticSeq { stmts, .. })
307            | ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
308                for stmt in stmts {
309                    self.build_exit_map_static(stmt);
310                }
311                let id =
312                    ControlId::get_guaranteed_attribute_static(sc, NODE_ID);
313                self.exits_map.insert(id, get_final_static(sc));
314            }
315            ir::StaticControl::If(ir::StaticIf {
316                tbranch, fbranch, ..
317            }) => {
318                let begin_id =
319                    ControlId::get_guaranteed_attribute_static(sc, BEGIN_ID);
320                let end_id =
321                    ControlId::get_guaranteed_attribute_static(sc, END_ID);
322                self.exits_map.insert(begin_id, HashSet::from([end_id]));
323                self.exits_map.insert(end_id, HashSet::from([end_id]));
324                self.build_exit_map_static(tbranch);
325                self.build_exit_map_static(fbranch);
326            }
327            ir::StaticControl::Empty(_) => (),
328        }
329    }
330
331    // Builds the "exit map" of c. This is getting what will be the final "node"
332    // executed in c.
333    fn build_exit_map(&mut self, c: &ir::Control) {
334        match c {
335            ir::Control::Empty(_) => (),
336            ir::Control::Invoke(_) | ir::Control::Enable(_) => {
337                let id = ControlId::get_guaranteed_attribute(c, NODE_ID);
338                self.exits_map.insert(id, HashSet::from([id]));
339            }
340            ir::Control::While(ir::While { body, .. }) => {
341                let id = ControlId::get_guaranteed_attribute(c, NODE_ID);
342                self.exits_map.insert(id, HashSet::from([id]));
343                self.build_exit_map(body);
344            }
345            ir::Control::Repeat(ir::Repeat { body, .. }) => {
346                let id = ControlId::get_guaranteed_attribute(c, NODE_ID);
347                self.exits_map.insert(id, get_final(body));
348                self.build_exit_map(body);
349            }
350            ir::Control::If(ir::If {
351                tbranch, fbranch, ..
352            }) => {
353                let begin_id = ControlId::get_guaranteed_attribute(c, BEGIN_ID);
354                let end_id = ControlId::get_guaranteed_attribute(c, END_ID);
355                self.exits_map.insert(begin_id, HashSet::from([end_id]));
356                self.exits_map.insert(end_id, HashSet::from([end_id]));
357                self.build_exit_map(tbranch);
358                self.build_exit_map(fbranch);
359            }
360            ir::Control::Seq(ir::Seq { stmts, .. })
361            | ir::Control::Par(ir::Par { stmts, .. }) => {
362                for stmt in stmts {
363                    self.build_exit_map(stmt);
364                }
365                let id = ControlId::get_guaranteed_attribute(c, NODE_ID);
366                self.exits_map.insert(id, get_final(c));
367            }
368            ir::Control::Static(sc) => self.build_exit_map_static(sc),
369        }
370    }
371
372    // Builds the domination map by running update_map() until the map
373    // stops changing.
374    fn build_map(&mut self, main_c: &mut ir::Control) {
375        let mut og_map = self.map.clone();
376        self.update_map(main_c, 0, &HashSet::new());
377        while og_map != self.map {
378            og_map = self.map.clone();
379            self.update_map(main_c, 0, &HashSet::new());
380        }
381        self.update_static_dominators();
382    }
383
384    // updates static dominators based on self.static_par_domination
385    // this can more aggresively add dominators to the map by
386    // using the timing guarantees of static par
387    fn update_static_dominators(&mut self) {
388        let new_static_domminators =
389            self.static_par_domination.get_static_dominators();
390        for (node_id, node_dominators) in new_static_domminators {
391            let cur_dominators = self.map.entry(node_id).or_default();
392            cur_dominators.extend(node_dominators);
393        }
394    }
395
396    // Given an id and its predecessors pred, and a domination map d_map, updates
397    // d_map accordingly (i.e. the union of all dominators of the predecessors
398    // plus itself).
399    fn update_node(&mut self, pred: &HashSet<u64>, id: u64) {
400        let mut union: HashSet<u64> = HashSet::new();
401        for id in pred.iter() {
402            if let Some(dominators) = self.map.get(id) {
403                union = union.union(dominators).copied().collect();
404            }
405        }
406        union.insert(id);
407        self.map.insert(id, union);
408    }
409
410    fn update_map_static(
411        &mut self,
412        main_sc: &ir::StaticControl,
413        cur_id: u64,
414        pred: &HashSet<u64>,
415    ) {
416        match Self::get_static_control(cur_id, main_sc) {
417            Some(GenericControl::Dynamic(_)) => {
418                unreachable!("should never get dynamic from get_static_control")
419            }
420            None => (),
421            Some(GenericControl::Static(sc)) => match sc {
422                ir::StaticControl::Empty(_) => (),
423                ir::StaticControl::Enable(_) | ir::StaticControl::Invoke(_) => {
424                    self.update_node(pred, cur_id);
425                }
426                ir::StaticControl::Repeat(ir::StaticRepeat {
427                    body,
428                    num_repeats,
429                    ..
430                }) => {
431                    if *num_repeats != 0 {
432                        let body_id = get_id_static::<true>(body);
433                        self.update_map_static(main_sc, body_id, pred);
434                    }
435                }
436                ir::StaticControl::Seq(ir::StaticSeq { stmts, .. }) => {
437                    let mut p = pred;
438                    let mut nxt: HashSet<u64>;
439                    for stmt in stmts {
440                        let id = get_id_static::<true>(stmt);
441                        self.update_map_static(main_sc, id, p);
442                        // updating the predecessors for the next stmt we iterate
443                        nxt = self
444                            .exits_map
445                            .get(&id)
446                            .unwrap_or(
447                                // If the exits map is empty, then it means the
448                                // current stmt is `Repeat 0`/Empty.
449                                // So the predecessors for the nxt stmt are the
450                                // same as the predecessors for the current stmt.
451                                pred,
452                            )
453                            .clone();
454                        p = &nxt;
455                    }
456                }
457                ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
458                    for stmt in stmts {
459                        let id = get_id_static::<true>(stmt);
460                        self.update_map_static(main_sc, id, pred);
461                    }
462                }
463                ir::StaticControl::If(ir::StaticIf {
464                    tbranch,
465                    fbranch,
466                    ..
467                }) => {
468                    //updating the if guard
469                    self.update_node(pred, cur_id);
470
471                    //building a set w/ just the if_guard id in it
472                    let if_guard_set = HashSet::from([cur_id]);
473
474                    //updating the tbranch
475                    let t_id = get_id_static::<true>(tbranch);
476                    self.update_map_static(main_sc, t_id, &if_guard_set);
477
478                    // If the false branch is present, update the map
479                    if !matches!(**fbranch, ir::StaticControl::Empty(_)) {
480                        let f_id = get_id_static::<true>(fbranch);
481                        self.update_map_static(main_sc, f_id, &if_guard_set);
482                    }
483
484                    let end_id =
485                        ControlId::get_guaranteed_attribute_static(sc, END_ID);
486                    self.update_node(&if_guard_set, end_id)
487                }
488            },
489        }
490    }
491
492    // Looks through each "node" in the "graph" and updates the dominators accordingly
493    fn update_map(
494        &mut self,
495        main_c: &ir::Control,
496        cur_id: u64,
497        pred: &HashSet<u64>,
498    ) {
499        match Self::get_control(cur_id, main_c) {
500            None => (),
501            Some(GenericControl::Dynamic(c)) => {
502                match c {
503                    ir::Control::Empty(_) => {
504                        unreachable!(
505                            "should not pattern match agaisnt empty in update_map()"
506                        )
507                    }
508                    ir::Control::Invoke(_)
509                    | ir::Control::Enable(_) => {
510                        self.update_node(pred, cur_id);
511                    }
512                    ir::Control::Seq(ir::Seq { stmts, .. }) => {
513                        let mut p = pred;
514                        let mut nxt: HashSet<u64>;
515                        for stmt in stmts {
516                            let id = get_id::<true>(stmt);
517                            self.update_map(main_c, id, p);
518                            nxt = self
519                                .exits_map
520                                .get(&id)
521                                .unwrap_or(pred
522                                    // If the exits map is empty, then it means the
523                                    // current stmt is `Repeat 0`/Empty.
524                                    // So the predecessors for the nxt stmt are the
525                                    // same as the predecessors for the current stmt
526                                ).clone();
527                            p = &nxt;
528                        }
529                    }
530                    ir::Control::Par(ir::Par { stmts, .. }) => {
531                        for stmt in stmts {
532                            let id = get_id::<true>(stmt);
533                            self.update_map(main_c, id, pred);
534                        }
535                    }
536                    ir::Control::Repeat(ir::Repeat { body, num_repeats, .. }) => {
537                        if *num_repeats != 0 {
538                            let body_id = get_id::<true>(body);
539                            self.update_map(main_c, body_id, pred);
540                        }
541                    }
542                    // Keep in mind that NODE_IDs attached to while loops/if statements
543                    // refer to the while/if guard, and as we pattern match against a while
544                    // or if statement, the control statement refers to the "guard",
545                    // which includes their combinational group and the conditional port
546                    // So (for example) if a while loop has NODE_ID = 10, then "node 10"
547                    // refers to the while guard-- comb group and conditional port-- but not the body.
548                    ir::Control::While(ir::While { body, .. }) => {
549                        self.update_node(pred, cur_id);
550                        // updating the while body
551                        let body_id = get_id::<true>(body);
552                        self.update_map(
553                            main_c,
554                            body_id,
555                            &HashSet::from([cur_id]),
556                        );
557                    }
558                    ir::Control::If(ir::If {
559                        tbranch, fbranch, ..
560                    }) => {
561                        //updating the if guard
562                        self.update_node(pred, cur_id);
563
564                        //building a set w/ just the if_guard id in it
565                        let if_guard_set = HashSet::from([cur_id]);
566
567                        //updating the tbranch
568                        let t_id = get_id::<true>(tbranch);
569                        self.update_map(main_c, t_id, &if_guard_set);
570
571                        // If the false branch is present, update the map
572                        if !matches!(**fbranch, ir::Control::Empty(_)) {
573                            let f_id = get_id::<true>(fbranch);
574                            self.update_map(main_c, f_id, &if_guard_set);
575                        }
576
577                        let end_id =
578                            ControlId::get_guaranteed_attribute(c, END_ID);
579                        self.update_node(&if_guard_set, end_id)
580                    }
581                    ir::Control::Static(_) => panic!("when matching c in GenericControl::Dynamic(c), c shouldn't be Static Control")
582                };
583            }
584            Some(GenericControl::Static(sc)) => {
585                let static_id = get_id_static::<true>(sc);
586                self.update_map_static(sc, static_id, pred);
587            }
588        }
589    }
590
591    pub fn get_static_control(
592        id: u64,
593        sc: &ir::StaticControl,
594    ) -> Option<GenericControl> {
595        if matches!(sc, ir::StaticControl::Empty(_)) {
596            return None;
597        }
598        if matches_key_static(sc, id) {
599            return Some(GenericControl::from(sc));
600        };
601        match sc {
602            ir::StaticControl::Empty(_)
603            | ir::StaticControl::Enable(_)
604            | ir::StaticControl::Invoke(_) => None,
605            ir::StaticControl::Repeat(ir::StaticRepeat { body, .. }) => {
606                Self::get_static_control(id, body)
607            }
608            ir::StaticControl::Seq(ir::StaticSeq { stmts, .. })
609            | ir::StaticControl::Par(ir::StaticPar { stmts, .. }) => {
610                for stmt in stmts {
611                    match Self::get_static_control(id, stmt) {
612                        None => (),
613                        Some(GenericControl::Dynamic(_)) => {
614                            unreachable!("Got a GenericControl::Dynamic when we called get_static_control")
615                        }
616                        Some(GenericControl::Static(sc)) => {
617                            return Some(GenericControl::from(sc))
618                        }
619                    }
620                }
621                None
622            }
623            ir::StaticControl::If(ir::StaticIf {
624                tbranch, fbranch, ..
625            }) => {
626                match Self::get_static_control(id, tbranch) {
627                    Some(GenericControl::Dynamic(_)) => {
628                        unreachable!("Got a GenericControl::Dynamic when we called get_static_control")
629                    }
630                    Some(GenericControl::Static(sc)) => {
631                        return Some(GenericControl::from(sc))
632                    }
633                    None => (),
634                }
635                match Self::get_static_control(id, fbranch) {
636                    Some(GenericControl::Dynamic(_)) => {
637                        unreachable!("Got a GenericControl::Dynamic when we called get_static_control")
638                    }
639                    Some(GenericControl::Static(sc)) => {
640                        return Some(GenericControl::from(sc))
641                    }
642                    None => (),
643                };
644                None
645            }
646        }
647    }
648
649    /// Given a control c and an id, finds the control statement within c that
650    /// has id, if it exists. If it doesn't, return None.
651    pub fn get_control(id: u64, c: &ir::Control) -> Option<GenericControl> {
652        if matches!(c, ir::Control::Empty(_)) {
653            return None;
654        }
655        if matches_key(c, id) {
656            return Some(GenericControl::from(c));
657        }
658        match c {
659            ir::Control::Empty(_)
660            | ir::Control::Invoke(_)
661            | ir::Control::Enable(_) => None,
662            ir::Control::Seq(ir::Seq { stmts, .. })
663            | ir::Control::Par(ir::Par { stmts, .. }) => {
664                for stmt in stmts {
665                    match Self::get_control(id, stmt) {
666                        None => (),
667                        Some(GenericControl::Dynamic(c)) => {
668                            return Some(GenericControl::from(c))
669                        }
670                        Some(GenericControl::Static(sc)) => {
671                            return Some(GenericControl::from(sc))
672                        }
673                    }
674                }
675                None
676            }
677            ir::Control::Repeat(ir::Repeat { body, .. }) => {
678                Self::get_control(id, body)
679            }
680            ir::Control::If(ir::If {
681                tbranch, fbranch, ..
682            }) => {
683                match Self::get_control(id, tbranch) {
684                    Some(GenericControl::Dynamic(c)) => {
685                        return Some(GenericControl::from(c))
686                    }
687                    Some(GenericControl::Static(sc)) => {
688                        return Some(GenericControl::from(sc))
689                    }
690                    None => (),
691                }
692                match Self::get_control(id, fbranch) {
693                    Some(GenericControl::Dynamic(c)) => {
694                        return Some(GenericControl::from(c))
695                    }
696                    Some(GenericControl::Static(sc)) => {
697                        return Some(GenericControl::from(sc))
698                    }
699                    None => (),
700                };
701                None
702            }
703            ir::Control::While(ir::While { body, .. }) => {
704                Self::get_control(id, body)
705            }
706            ir::Control::Static(sc) => Self::get_static_control(id, sc),
707        }
708    }
709
710    // Given a set of nodes, gets the control in main_control that corresponds
711    // to the node. If there is a node in the set not corresponding to a control
712    // statement in main_control, then it gives an unreachable! error.
713    // Returns two vectors: controls, static_controls
714    // (the dynamic and static nodes)
715    pub fn get_control_nodes<'a>(
716        nodes: &HashSet<u64>,
717        main_control: &'a ir::Control,
718    ) -> (Vec<&'a ir::Control>, Vec<&'a ir::StaticControl>) {
719        let mut controls: Vec<&ir::Control> = Vec::new();
720        let mut static_controls: Vec<&ir::StaticControl> = Vec::new();
721        for node in nodes {
722            match Self::get_control(*node, main_control) {
723                Some(GenericControl::Static(sc)) => static_controls.push(sc),
724                Some(GenericControl::Dynamic(c)) => controls.push(c),
725                None => {
726                    unreachable!("No control statement for ID {}", node)
727                }
728            }
729        }
730        (controls, static_controls)
731    }
732
733    // Gets the reads of shareable cells in node
734    // Assumes the control statements in comp have been given NODE_IDs in the same
735    // style of the domination map NODE_ID stuff.
736    pub fn get_node_reads(
737        node: &u64,
738        comp: &mut ir::Component,
739        shareset: &ShareSet,
740    ) -> HashSet<ir::Id> {
741        NodeReads::get_reads_of_node(node, comp, shareset)
742    }
743
744    // Returns whether key is guaranteed to be written in at least one of nodes
745    // Assumes the control statements in comp have been given NODE_IDs in the same
746    // style of the domination map NODE_ID stuff.
747    pub fn key_written_guaranteed(
748        key: ir::Id,
749        nodes: &HashSet<u64>,
750        comp: &mut ir::Component,
751    ) -> bool {
752        let search_struct = NodeSearch::new(key);
753        search_struct.is_written_guaranteed(nodes, comp)
754    }
755}