calyx_opt/analysis/
compute_static.rs

1use calyx_ir::{self as ir, GetAttributes};
2use std::collections::HashMap;
3use std::rc::Rc;
4
5/// Trait to propagate and extra "static" attributes through [ir::Control].
6/// Calling the update function ensures that the current program, as well as all
7/// sub-programs have a "static" attribute on them.
8/// Usage:
9/// ```
10/// use calyx::analysis::compute_static::WithStatic;
11/// let con: ir::Control = todo!(); // A complex control program
12/// con.update(&HashMap::new());    // Compute the static information for the program
13/// ```
14pub trait WithStatic
15where
16    Self: GetAttributes,
17{
18    /// Extra information needed to compute static information for this type.
19    type Info;
20
21    /// Compute the static information for the type if possible and add it to its attribute.
22    /// Implementors should instead implement [WithStatic::compute_static] and call this function
23    /// on sub-programs.
24    /// **Ensures**: All sub-programs of the type will also be updated.
25    fn update_static(&mut self, extra: &Self::Info) -> Option<u64> {
26        if let Some(time) = self.compute_static(extra) {
27            self.get_mut_attributes()
28                .insert(ir::NumAttr::Promotable, time);
29            Some(time)
30        } else {
31            None
32        }
33    }
34
35    /// Compute the static information for the type if possible and update all sub-programs.
36    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64>;
37}
38
39type CompTime = HashMap<ir::Id, u64>;
40
41impl WithStatic for ir::Control {
42    // Mapping from name of components to their latency information
43    type Info = CompTime;
44
45    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
46        match self {
47            ir::Control::Seq(seq) => seq.update_static(extra),
48            ir::Control::Par(par) => par.update_static(extra),
49            ir::Control::If(if_) => if_.update_static(extra),
50            ir::Control::While(wh) => wh.update_static(extra),
51            ir::Control::Repeat(rep) => rep.update_static(extra),
52            ir::Control::Invoke(inv) => inv.update_static(extra),
53            ir::Control::Enable(en) => en.update_static(&()),
54            ir::Control::Empty(_) => Some(0),
55            ir::Control::Static(sc) => Some(sc.get_latency()),
56        }
57    }
58}
59
60impl WithStatic for ir::Enable {
61    type Info = ();
62    fn compute_static(&mut self, _: &Self::Info) -> Option<u64> {
63        // Attempt to get the latency from the attribute on the enable first, or
64        // failing that, from the group.
65        self.attributes.get(ir::NumAttr::Promotable).or_else(|| {
66            self.group.borrow().attributes.get(ir::NumAttr::Promotable)
67        })
68    }
69}
70
71impl WithStatic for ir::Invoke {
72    type Info = CompTime;
73    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
74        self.attributes.get(ir::NumAttr::Promotable).or_else(|| {
75            let comp = self.comp.borrow().type_name()?;
76            extra.get(&comp).cloned()
77        })
78    }
79}
80
81/// Walk over a set of control statements and call `update_static` on each of them.
82/// Use a merge function to merge the results of the `update_static` calls.
83fn walk_static<T, F>(stmts: &mut [T], extra: &T::Info, merge: F) -> Option<u64>
84where
85    T: WithStatic,
86    F: Fn(u64, u64) -> u64,
87{
88    let mut latency = Some(0);
89    // This is implemented as a loop because we want to call `update_static` on
90    // each statement even if we cannot compute a total latency anymore.
91    for stmt in stmts.iter_mut() {
92        let stmt_latency = stmt.update_static(extra);
93        latency = match (latency, stmt_latency) {
94            (Some(l), Some(s)) => Some(merge(l, s)),
95            (_, _) => None,
96        }
97    }
98    latency
99}
100
101impl WithStatic for ir::Seq {
102    type Info = CompTime;
103    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
104        walk_static(&mut self.stmts, extra, |x, y| x + y)
105    }
106}
107
108impl WithStatic for ir::Par {
109    type Info = CompTime;
110    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
111        walk_static(&mut self.stmts, extra, std::cmp::max)
112    }
113}
114
115impl WithStatic for ir::If {
116    type Info = CompTime;
117    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
118        // Cannot compute latency information for `if`-`with`
119        let t_latency = self.tbranch.update_static(extra);
120        let f_latency = self.fbranch.update_static(extra);
121        if self.cond.is_some() {
122            log::debug!("Cannot compute latency for while-with");
123            return None;
124        }
125        match (t_latency, f_latency) {
126            (Some(t), Some(f)) => Some(std::cmp::max(t, f)),
127            (_, _) => None,
128        }
129    }
130}
131
132impl WithStatic for ir::While {
133    type Info = CompTime;
134    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
135        let b_time = self.body.update_static(extra)?;
136        // Cannot compute latency information for `while`-`with`
137        if self.cond.is_some() {
138            log::debug!("Cannot compute latency for while-with");
139            return None;
140        }
141        let bound = self.attributes.get(ir::NumAttr::Bound)?;
142        Some(bound * b_time)
143    }
144}
145
146impl WithStatic for ir::Repeat {
147    type Info = CompTime;
148    fn compute_static(&mut self, extra: &Self::Info) -> Option<u64> {
149        let b_time = self.body.update_static(extra)?;
150        let num_repeats = self.num_repeats;
151        Some(num_repeats * b_time)
152    }
153}
154
155pub trait IntoStatic {
156    type StaticCon;
157    fn make_static(&mut self) -> Option<Self::StaticCon>;
158}
159
160impl IntoStatic for ir::Seq {
161    type StaticCon = ir::StaticSeq;
162    fn make_static(&mut self) -> Option<Self::StaticCon> {
163        let mut static_stmts: Vec<ir::StaticControl> = Vec::new();
164        let mut latency = 0;
165        for stmt in self.stmts.iter() {
166            if !matches!(stmt, ir::Control::Static(_)) {
167                log::debug!("Cannot build `static seq`. Control statement inside `seq` is not static");
168                return None;
169            }
170        }
171
172        for stmt in self.stmts.drain(..) {
173            let ir::Control::Static(sc) = stmt else {
174                unreachable!("We have already checked that all control statements are static")
175            };
176            latency += sc.get_latency();
177            static_stmts.push(sc);
178        }
179        Some(ir::StaticSeq {
180            stmts: static_stmts,
181            attributes: self.attributes.clone(),
182            latency,
183        })
184    }
185}
186
187impl IntoStatic for ir::Par {
188    type StaticCon = ir::StaticPar;
189    fn make_static(&mut self) -> Option<Self::StaticCon> {
190        let mut static_stmts: Vec<ir::StaticControl> = Vec::new();
191        let mut latency = 0;
192        for stmt in self.stmts.iter() {
193            if !matches!(stmt, ir::Control::Static(_)) {
194                log::debug!("Cannot build `static seq`. Control statement inside `seq` is not static");
195                return None;
196            }
197        }
198
199        for stmt in self.stmts.drain(..) {
200            let ir::Control::Static(sc) = stmt else {
201                unreachable!("We have already checked that all control statements are static")
202            };
203            latency = std::cmp::max(latency, sc.get_latency());
204            static_stmts.push(sc);
205        }
206        Some(ir::StaticPar {
207            stmts: static_stmts,
208            attributes: self.attributes.clone(),
209            latency,
210        })
211    }
212}
213
214impl IntoStatic for ir::If {
215    type StaticCon = ir::StaticIf;
216    fn make_static(&mut self) -> Option<Self::StaticCon> {
217        if !(self.tbranch.is_static() && self.fbranch.is_static()) {
218            return None;
219        };
220        let tb = std::mem::replace(&mut *self.tbranch, ir::Control::empty());
221        let fb = std::mem::replace(&mut *self.fbranch, ir::Control::empty());
222        let ir::Control::Static(sc_t) = tb else {
223            unreachable!("we have already checked tbranch to be static")
224        };
225        let ir::Control::Static(sc_f) = fb else {
226            unreachable!("we have already checker fbranch to be static")
227        };
228        let latency = std::cmp::max(sc_t.get_latency(), sc_f.get_latency());
229        Some(ir::StaticIf {
230            tbranch: Box::new(sc_t),
231            fbranch: Box::new(sc_f),
232            attributes: ir::Attributes::default(),
233            port: Rc::clone(&self.port),
234            latency,
235        })
236    }
237}