1use super::AssignmentAnalysis;
2use crate::analysis::{compute_static::WithStatic, GraphAnalysis};
3use calyx_ir::{self as ir, GetAttributes, RRC};
4use ir::CellType;
5use itertools::Itertools;
6use std::collections::{HashMap, HashSet};
7
8#[derive(Debug)]
12pub struct GoDone {
13 ports: Vec<(ir::Id, ir::Id, u64)>,
14}
15
16impl GoDone {
17 pub fn new(ports: Vec<(ir::Id, ir::Id, u64)>) -> Self {
18 Self { ports }
19 }
20
21 pub fn is_go(&self, name: &ir::Id) -> bool {
23 self.ports.iter().any(|(go, _, _)| name == go)
24 }
25
26 pub fn is_done(&self, name: &ir::Id) -> bool {
28 self.ports.iter().any(|(_, done, _)| name == done)
29 }
30
31 pub fn get_latency(&self, go_port: &ir::Id) -> Option<u64> {
33 self.ports.iter().find_map(|(go, _, lat)| {
34 if go == go_port {
35 Some(*lat)
36 } else {
37 None
38 }
39 })
40 }
41
42 pub fn iter(&self) -> impl Iterator<Item = &(ir::Id, ir::Id, u64)> {
44 self.ports.iter()
45 }
46
47 pub fn is_empty(&self) -> bool {
49 self.ports.is_empty()
50 }
51
52 pub fn len(&self) -> usize {
54 self.ports.len()
55 }
56
57 pub fn get_ports(&self) -> &Vec<(ir::Id, ir::Id, u64)> {
59 &self.ports
60 }
61}
62
63impl From<&ir::Primitive> for GoDone {
64 fn from(prim: &ir::Primitive) -> Self {
65 let done_ports: HashMap<_, _> = prim
66 .find_all_with_attr(ir::NumAttr::Done)
67 .map(|pd| (pd.attributes.get(ir::NumAttr::Done), pd.name()))
68 .collect();
69
70 let go_ports = prim
71 .find_all_with_attr(ir::NumAttr::Go)
72 .filter_map(|pd| {
73 pd.attributes.get(ir::NumAttr::Interval).and_then(|st| {
75 done_ports
76 .get(&pd.attributes.get(ir::NumAttr::Go))
77 .map(|done_port| (pd.name(), *done_port, st))
78 })
79 })
80 .collect_vec();
81 GoDone::new(go_ports)
82 }
83}
84
85impl From<&ir::Cell> for GoDone {
86 fn from(cell: &ir::Cell) -> Self {
87 let done_ports: HashMap<_, _> = cell
88 .find_all_with_attr(ir::NumAttr::Done)
89 .map(|pr| {
90 let port = pr.borrow();
91 (port.attributes.get(ir::NumAttr::Done), port.name)
92 })
93 .collect();
94
95 let go_ports = cell
96 .find_all_with_attr(ir::NumAttr::Go)
97 .filter_map(|pr| {
98 let port = pr.borrow();
99 let st = match port.attributes.get(ir::NumAttr::Interval) {
101 Some(st) => Some(st),
102 None => port.attributes.get(ir::NumAttr::Promotable),
103 };
104 if let Some(static_latency) = st {
105 return done_ports
106 .get(&port.attributes.get(ir::NumAttr::Go))
107 .map(|done_port| {
108 (port.name, *done_port, static_latency)
109 });
110 }
111 None
112 })
113 .collect_vec();
114 GoDone::new(go_ports)
115 }
116}
117
118pub struct InferenceAnalysis {
121 pub latency_data: HashMap<ir::Id, GoDone>,
123 pub static_component_latencies: HashMap<ir::Id, u64>,
128
129 updated_components: HashSet<ir::Id>,
130}
131
132impl InferenceAnalysis {
133 pub fn from_ctx(ctx: &ir::Context) -> Self {
136 let mut latency_data = HashMap::new();
137 let mut static_component_latencies = HashMap::new();
138 for prim in ctx.lib.signatures() {
140 let prim_go_done = GoDone::from(prim);
141 if prim_go_done.len() == 1 {
142 static_component_latencies
143 .insert(prim.name, prim_go_done.get_ports()[0].2);
144 }
145 latency_data.insert(prim.name, GoDone::from(prim));
146 }
147 for comp in &ctx.components {
148 let comp_sig = comp.signature.borrow();
149
150 let done_ports: HashMap<_, _> = comp_sig
151 .find_all_with_attr(ir::NumAttr::Done)
152 .map(|pd| {
153 let pd_ref = pd.borrow();
154 (pd_ref.attributes.get(ir::NumAttr::Done), pd_ref.name)
155 })
156 .collect();
157
158 let go_ports = comp_sig
159 .find_all_with_attr(ir::NumAttr::Go)
160 .filter_map(|pd| {
161 let pd_ref = pd.borrow();
162 let st = match pd_ref.attributes.get(ir::NumAttr::Interval)
164 {
165 Some(st) => Some(st),
166 None => pd_ref.attributes.get(ir::NumAttr::Promotable),
167 };
168 if let Some(static_latency) = st {
169 return done_ports
170 .get(&pd_ref.attributes.get(ir::NumAttr::Go))
171 .map(|done_port| {
172 (pd_ref.name, *done_port, static_latency)
173 });
174 }
175 None
176 })
177 .collect_vec();
178
179 let go_done_comp = GoDone::new(go_ports);
180
181 if go_done_comp.len() == 1 {
182 static_component_latencies
183 .insert(comp.name, go_done_comp.get_ports()[0].2);
184 }
185 latency_data.insert(comp.name, go_done_comp);
186 }
187 InferenceAnalysis {
188 latency_data,
189 static_component_latencies,
190 updated_components: HashSet::new(),
191 }
192 }
193
194 pub fn add_component(
196 &mut self,
197 (comp_name, latency, go_done): (ir::Id, u64, GoDone),
198 ) {
199 self.latency_data.insert(comp_name, go_done);
200 self.static_component_latencies.insert(comp_name, latency);
201 }
202
203 pub fn remove_component(&mut self, comp_name: ir::Id) {
207 if self.latency_data.contains_key(&comp_name) {
208 self.updated_components.insert(comp_name);
211 }
212 self.latency_data.remove(&comp_name);
213 self.static_component_latencies.remove(&comp_name);
214 }
215
216 pub fn adjust_component(
220 &mut self,
221 (comp_name, adjusted_latency): (ir::Id, u64),
222 ) {
223 let mut updated = false;
225 self.latency_data.entry(comp_name).and_modify(|go_done| {
226 for (_, _, cur_latency) in &mut go_done.ports {
227 if *cur_latency != adjusted_latency {
229 *cur_latency = adjusted_latency;
230 updated = true;
231 }
232 }
233 });
234 self.static_component_latencies
235 .insert(comp_name, adjusted_latency);
236 if updated {
237 self.updated_components.insert(comp_name);
238 }
239 }
240
241 fn mem_wrt_dep_graph(&self, src: &ir::Port, dst: &ir::Port) -> bool {
246 match (&src.parent, &dst.parent) {
247 (
248 ir::PortParent::Cell(src_cell_wrf),
249 ir::PortParent::Cell(dst_cell_wrf),
250 ) => {
251 let src_rf = src_cell_wrf.upgrade();
252 let src_cell = src_rf.borrow();
253 let dst_rf = dst_cell_wrf.upgrade();
254 let dst_cell = dst_rf.borrow();
255 if let (Some(s_name), Some(d_name)) =
256 (src_cell.type_name(), dst_cell.type_name())
257 {
258 let data_src = self.latency_data.get(&s_name);
259 let data_dst = self.latency_data.get(&d_name);
260 if let (Some(dst_ports), Some(src_ports)) =
261 (data_dst, data_src)
262 {
263 return src_ports.is_done(&src.name)
264 && dst_ports.is_go(&dst.name);
265 }
266 }
267
268 if let (Some(d_name), ir::CellType::Constant { .. }) =
270 (dst_cell.type_name(), &src_cell.prototype)
271 {
272 if let Some(ports) = self.latency_data.get(&d_name) {
273 return ports.is_go(&dst.name);
274 }
275 }
276
277 false
278 }
279
280 (_, ir::PortParent::Group(_)) => dst.name == "done",
282
283 _ => false,
285 }
286 }
287
288 fn find_go_done_edges(
291 &self,
292 group: &ir::Group,
293 ) -> Vec<(RRC<ir::Port>, RRC<ir::Port>)> {
294 let rw_set = group.assignments.iter().analysis().cell_uses();
295 let mut go_done_edges: Vec<(RRC<ir::Port>, RRC<ir::Port>)> = Vec::new();
296
297 for cell_ref in rw_set {
298 let cell = cell_ref.borrow();
299 if let Some(ports) =
300 cell.type_name().and_then(|c| self.latency_data.get(&c))
301 {
302 go_done_edges.extend(
303 ports
304 .iter()
305 .map(|(go, done, _)| (cell.get(go), cell.get(done))),
306 )
307 }
308 }
309 go_done_edges
310 }
311
312 fn is_done_port_or_const(&self, port: &ir::Port) -> bool {
315 if let ir::PortParent::Cell(cwrf) = &port.parent {
316 let cr = cwrf.upgrade();
317 let cell = cr.borrow();
318 if let ir::CellType::Constant { val, .. } = &cell.prototype {
319 if *val > 0 {
320 return true;
321 }
322 } else if let Some(ports) =
323 cell.type_name().and_then(|c| self.latency_data.get(&c))
324 {
325 return ports.is_done(&port.name);
326 }
327 }
328 false
329 }
330
331 fn contains_dyn_writes(&self, graph: &GraphAnalysis) -> bool {
334 for port in &graph.ports() {
335 match &port.borrow().parent {
336 ir::PortParent::Cell(cell_wrf) => {
337 let cr = cell_wrf.upgrade();
338 let cell = cr.borrow();
339 if let Some(ports) =
340 cell.type_name().and_then(|c| self.latency_data.get(&c))
341 {
342 let name = &port.borrow().name;
343 if ports.is_go(name) {
344 for write_port in graph.writes_to(&port.borrow()) {
345 if !self
346 .is_done_port_or_const(&write_port.borrow())
347 {
348 log::debug!(
349 "`{}` is not a done port",
350 write_port.borrow().canonical(),
351 );
352 return true;
353 }
354 }
355 }
356 }
357 }
358 ir::PortParent::Group(_) => {
359 if port.borrow().name == "done" {
360 for write_port in graph.writes_to(&port.borrow()) {
361 if !self.is_done_port_or_const(&write_port.borrow())
362 {
363 log::debug!(
364 "`{}` is not a done port",
365 write_port.borrow().canonical(),
366 );
367 return true;
368 }
369 }
370 }
371 }
372
373 ir::PortParent::StaticGroup(_) => panic!("Have not decided how to handle static groups in infer-static-timing"),
375 }
376 }
377 false
378 }
379
380 fn contains_node_deg_gt_one(graph: &GraphAnalysis) -> bool {
382 for port in graph.ports() {
383 if graph.writes_to(&port.borrow()).count() > 1 {
384 return true;
385 }
386 }
387 false
388 }
389
390 fn infer_latency(&self, group: &ir::Group) -> Option<u64> {
394 log::debug!("Checking group `{}`", group.name());
416 let graph_unprocessed = GraphAnalysis::from(group);
417 if self.contains_dyn_writes(&graph_unprocessed) {
418 log::debug!("FAIL: contains dynamic writes");
419 return None;
420 }
421
422 let go_done_edges = self.find_go_done_edges(group);
423 let graph = graph_unprocessed
424 .edge_induced_subgraph(|src, dst| self.mem_wrt_dep_graph(src, dst))
425 .add_edges(&go_done_edges)
426 .remove_isolated_vertices();
427
428 if Self::contains_node_deg_gt_one(&graph) {
430 log::debug!("FAIL: Group contains multiple writes");
431 return None;
432 }
433
434 let mut tsort = graph.toposort();
435 let start = tsort.next()?;
436 let finish = tsort.last()?;
437
438 let paths = graph.paths(&start.borrow(), &finish.borrow());
439 if paths.is_empty() {
441 log::debug!("FAIL: No path between @go and @done port");
442 return None;
443 }
444 let first_path = paths.first().unwrap();
445
446 let mut latency_sum = 0;
448 for port in first_path {
449 if let ir::PortParent::Cell(cwrf) = &port.borrow().parent {
450 let cr = cwrf.upgrade();
451 let cell = cr.borrow();
452 if let Some(ports) =
453 cell.type_name().and_then(|c| self.latency_data.get(&c))
454 {
455 if let Some(latency) =
456 ports.get_latency(&port.borrow().name)
457 {
458 latency_sum += latency;
459 }
460 }
461 }
462 }
463
464 log::debug!("SUCCESS: Latency = {}", latency_sum);
465 Some(latency_sum)
466 }
467
468 pub fn get_possible_latency(c: &ir::Control) -> Option<u64> {
471 match c {
472 ir::Control::Static(sc) => Some(sc.get_latency()),
473 _ => c.get_attribute(ir::NumAttr::Promotable),
474 }
475 }
476
477 pub fn remove_promotable_from_seq(seq: &mut ir::Seq) {
478 for stmt in &mut seq.stmts {
479 Self::remove_promotable_attribute(stmt);
480 }
481 seq.get_mut_attributes().remove(ir::NumAttr::Promotable);
482 }
483
484 pub fn remove_promotable_attribute(c: &mut ir::Control) {
487 c.get_mut_attributes().remove(ir::NumAttr::Promotable);
488 match c {
489 ir::Control::Empty(_)
490 | ir::Control::Invoke(_)
491 | ir::Control::Enable(_)
492 | ir::Control::Static(_) => (),
493 ir::Control::While(ir::While { body, .. })
494 | ir::Control::Repeat(ir::Repeat { body, .. }) => {
495 Self::remove_promotable_attribute(body);
496 }
497 ir::Control::If(ir::If {
498 tbranch, fbranch, ..
499 }) => {
500 Self::remove_promotable_attribute(tbranch);
501 Self::remove_promotable_attribute(fbranch);
502 }
503 ir::Control::Seq(ir::Seq { stmts, .. })
504 | ir::Control::Par(ir::Par { stmts, .. }) => {
505 for stmt in stmts {
506 Self::remove_promotable_attribute(stmt);
507 }
508 }
509 }
510 }
511
512 pub fn fixup_seq(&self, seq: &mut ir::Seq) {
513 seq.update_static(&self.static_component_latencies);
514 }
515
516 pub fn fixup_par(&self, par: &mut ir::Par) {
517 par.update_static(&self.static_component_latencies);
518 }
519
520 pub fn fixup_if(&self, _if: &mut ir::If) {
521 _if.update_static(&self.static_component_latencies);
522 }
523
524 pub fn fixup_while(&self, _while: &mut ir::While) {
525 _while.update_static(&self.static_component_latencies);
526 }
527
528 pub fn fixup_repeat(&self, repeat: &mut ir::Repeat) {
529 repeat.update_static(&self.static_component_latencies);
530 }
531
532 pub fn fixup_ctrl(&self, ctrl: &mut ir::Control) {
533 ctrl.update_static(&self.static_component_latencies);
534 }
535
536 pub fn fixup_timing(&self, comp: &mut ir::Component) {
545 for group in comp.groups.iter() {
548 if group
552 .borrow_mut()
553 .assignments
554 .iter()
555 .analysis()
556 .cell_writes()
557 .any(|cell| match cell.borrow().prototype {
558 CellType::Component { name } => {
559 self.updated_components.contains(&name)
560 }
561 _ => false,
562 })
563 {
564 group
566 .borrow_mut()
567 .attributes
568 .remove(ir::NumAttr::Promotable);
569 }
570 }
571
572 for group in &mut comp.groups.iter() {
573 let latency_result = self.infer_latency(&group.borrow());
575 if let Some(latency) = latency_result {
576 group
577 .borrow_mut()
578 .attributes
579 .insert(ir::NumAttr::Promotable, latency);
580 }
581 }
582
583 Self::remove_promotable_attribute(&mut comp.control.borrow_mut());
586 comp.control
587 .borrow_mut()
588 .update_static(&self.static_component_latencies);
589 }
590}