1use calyx_ir::{self as ir};
2use std::collections::HashMap;
3use std::rc::Rc;
4
5#[derive(Debug, Default)]
6pub struct PromotionAnalysis {
7 static_group_name: HashMap<ir::Id, ir::Id>,
9}
10
11impl PromotionAnalysis {
12 fn check_latencies_match(actual: u64, inferred: u64) {
13 assert_eq!(actual, inferred, "Inferred and Annotated Latencies do not match. Latency: {}. Inferred: {}", actual, inferred);
14 }
15
16 pub fn get_inferred_latency(c: &ir::Control) -> u64 {
17 let ir::Control::Static(sc) = c else {
18 let Some(latency) = c.get_attribute(ir::NumAttr::Promotable) else {
19 unreachable!("Called get_latency on control that is neither static nor promotable")
20 };
21 return latency;
22 };
23 sc.get_latency()
24 }
25
26 pub fn can_be_promoted(c: &ir::Control) -> bool {
29 c.is_static() || c.has_attribute(ir::NumAttr::Promotable)
30 }
31
32 fn construct_static_group(
35 &mut self,
36 builder: &mut ir::Builder,
37 group: ir::RRC<ir::Group>,
38 latency: u64,
39 ) -> ir::RRC<ir::StaticGroup> {
40 if let Some(s_name) = self.static_group_name.get(&group.borrow().name())
41 {
42 builder.component.find_static_group(*s_name).unwrap()
43 } else {
44 let sg = builder.add_static_group(group.borrow().name(), latency);
45 self.static_group_name
46 .insert(group.borrow().name(), sg.borrow().name());
47 for assignment in group.borrow().assignments.iter() {
48 if !(assignment.dst.borrow().is_hole()
50 && assignment.dst.borrow().name == "done")
51 {
52 sg.borrow_mut()
53 .assignments
54 .push(ir::Assignment::from(assignment.clone()));
55 }
56 }
57 Rc::clone(&sg)
58 }
59 }
60
61 pub fn convert_enable_to_static(
63 &mut self,
64 s: &mut ir::Enable,
65 builder: &mut ir::Builder,
66 ) -> ir::StaticControl {
67 s.attributes.remove(ir::NumAttr::Promotable);
68 ir::StaticControl::Enable(ir::StaticEnable {
69 group: self.construct_static_group(
71 builder,
72 Rc::clone(&s.group),
73 s.group
74 .borrow()
75 .get_attributes()
76 .unwrap()
77 .get(ir::NumAttr::Promotable)
78 .unwrap(),
79 ),
80 attributes: std::mem::take(&mut s.attributes),
81 })
82 }
83
84 pub fn convert_invoke_to_static(
86 &mut self,
87 s: &mut ir::Invoke,
88 ) -> ir::StaticControl {
89 assert!(
90 s.comb_group.is_none(),
91 "Shouldn't Promote to Static if there is a Comb Group",
92 );
93 let latency = s.attributes.get(ir::NumAttr::Promotable).unwrap();
94 s.attributes.remove(ir::NumAttr::Promotable);
95 let s_inv = ir::StaticInvoke {
96 comp: Rc::clone(&s.comp),
97 inputs: std::mem::take(&mut s.inputs),
98 outputs: std::mem::take(&mut s.outputs),
99 latency,
100 attributes: std::mem::take(&mut s.attributes),
101 ref_cells: std::mem::take(&mut s.ref_cells),
102 comb_group: std::mem::take(&mut s.comb_group),
103 };
104 ir::StaticControl::Invoke(s_inv)
105 }
106
107 pub fn convert_to_static(
110 &mut self,
111 c: &mut ir::Control,
112 builder: &mut ir::Builder,
113 ) -> ir::StaticControl {
114 assert!(
115 c.has_attribute(ir::NumAttr::Promotable) || c.is_static(),
116 "Called convert_to_static control that is neither static nor promotable"
117 );
118 let bound_attribute = c.get_attribute(ir::NumAttr::Bound);
121 let inferred_latency = Self::get_inferred_latency(c);
124 match c {
125 ir::Control::Empty(_) => ir::StaticControl::empty(),
126 ir::Control::Enable(s) => self.convert_enable_to_static(s, builder),
127 ir::Control::Seq(ir::Seq { stmts, attributes }) => {
128 attributes.remove(ir::NumAttr::Promotable);
130 attributes.insert(ir::NumAttr::Compactable, 1);
132 let static_stmts =
133 self.convert_vec_to_static(builder, std::mem::take(stmts));
134 let latency =
135 static_stmts.iter().map(|s| s.get_latency()).sum();
136 Self::check_latencies_match(latency, inferred_latency);
137 ir::StaticControl::Seq(ir::StaticSeq {
138 stmts: static_stmts,
139 attributes: std::mem::take(attributes),
140 latency,
141 })
142 }
143 ir::Control::Par(ir::Par { stmts, attributes }) => {
144 attributes.remove(ir::NumAttr::Promotable);
146 let static_stmts =
148 self.convert_vec_to_static(builder, std::mem::take(stmts));
149 let latency = static_stmts
151 .iter()
152 .map(|s| s.get_latency())
153 .max()
154 .unwrap_or_else(|| unreachable!("Empty Par Block"));
155 Self::check_latencies_match(latency, inferred_latency);
156 ir::StaticControl::Par(ir::StaticPar {
157 stmts: static_stmts,
158 attributes: ir::Attributes::default(),
159 latency,
160 })
161 }
162 ir::Control::Repeat(ir::Repeat {
163 body,
164 num_repeats,
165 attributes,
166 }) => {
167 attributes.remove(ir::NumAttr::Promotable);
169 let sc = self.convert_to_static(body, builder);
170 let latency = (*num_repeats) * sc.get_latency();
171 Self::check_latencies_match(latency, inferred_latency);
172 ir::StaticControl::Repeat(ir::StaticRepeat {
173 attributes: std::mem::take(attributes),
174 body: Box::new(sc),
175 num_repeats: *num_repeats,
176 latency,
177 })
178 }
179 ir::Control::While(ir::While {
180 body, attributes, ..
181 }) => {
182 attributes.remove(ir::NumAttr::Promotable);
184 attributes.remove(ir::NumAttr::Bound);
186 let sc = self.convert_to_static(body, builder);
187 let num_repeats = bound_attribute.unwrap_or_else(|| unreachable!("Called convert_to_static on a while loop without a bound"));
188 let latency = num_repeats * sc.get_latency();
189 Self::check_latencies_match(latency, inferred_latency);
190 ir::StaticControl::Repeat(ir::StaticRepeat {
191 attributes: std::mem::take(attributes),
192 body: Box::new(sc),
193 num_repeats,
194 latency,
195 })
196 }
197 ir::Control::If(ir::If {
198 port,
199 tbranch,
200 fbranch,
201 attributes,
202 ..
203 }) => {
204 attributes.remove(ir::NumAttr::Promotable);
206 let static_tbranch = self.convert_to_static(tbranch, builder);
207 let static_fbranch = self.convert_to_static(fbranch, builder);
208 let latency = std::cmp::max(
209 static_tbranch.get_latency(),
210 static_fbranch.get_latency(),
211 );
212 Self::check_latencies_match(latency, inferred_latency);
213 ir::StaticControl::static_if(
214 Rc::clone(port),
215 Box::new(static_tbranch),
216 Box::new(static_fbranch),
217 latency,
218 )
219 }
220 ir::Control::Static(_) => c.take_static_control(),
221 ir::Control::Invoke(s) => self.convert_invoke_to_static(s),
222 }
223 }
224
225 pub fn convert_vec_to_static(
228 &mut self,
229 builder: &mut ir::Builder,
230 control_vec: Vec<ir::Control>,
231 ) -> Vec<ir::StaticControl> {
232 control_vec
233 .into_iter()
234 .map(|mut c| self.convert_to_static(&mut c, builder))
235 .collect()
236 }
237}