1use crate::traversal::{Action, Named, VisResult, Visitor};
2use calyx_ir::RRC;
3use calyx_ir::{self as ir, GetAttributes};
4use calyx_ir::{build_assignments, guard, structure};
5use calyx_utils::{CalyxResult, Error};
6use linked_hash_map::LinkedHashMap;
7use std::collections::{HashMap, HashSet};
8use std::rc::Rc;
9
10#[derive(Default)]
11pub struct CompileSync {
33 barriers: BarrierMap,
34}
35
36type BarrierMap = LinkedHashMap<u64, ([RRC<ir::Cell>; 2], [RRC<ir::Group>; 3])>;
38
39impl Named for CompileSync {
40 fn name() -> &'static str {
41 "compile-sync"
42 }
43
44 fn description() -> &'static str {
45 "Implement barriers for statements marked with @sync attribute"
46 }
47}
48
49fn count_barriers(
51 s: &ir::Control,
52 count: &mut HashSet<u64>,
53) -> CalyxResult<()> {
54 match s {
55 ir::Control::Empty(_) => {
56 if let Some(n) = s.get_attributes().get(ir::NumAttr::Sync) {
57 count.insert(n);
58 }
59 Ok(())
60 }
61 ir::Control::Seq(seq) => {
62 for stmt in seq.stmts.iter() {
63 count_barriers(stmt, count)?;
64 }
65 Ok(())
66 }
67 ir::Control::While(ir::While { body, .. })
68 | ir::Control::Repeat(ir::Repeat { body, .. }) => {
69 count_barriers(body, count)?;
70 Ok(())
71 }
72 ir::Control::If(i) => {
73 count_barriers(&i.tbranch, count)?;
74 count_barriers(&i.fbranch, count)?;
75 Ok(())
76 }
77 ir::Control::Enable(e) => {
78 if s.get_attributes().get(ir::NumAttr::Sync).is_some() {
79 return Err(Error::malformed_control(
80 "Enable or Invoke controls cannot be marked with @sync"
81 .to_string(),
82 )
83 .with_pos(&e.attributes));
84 }
85 Ok(())
86 }
87 ir::Control::Invoke(i) => {
88 if s.get_attributes().get(ir::NumAttr::Sync).is_some() {
89 return Err(Error::malformed_control(
90 "Enable or Invoke controls cannot be marked with @sync"
91 .to_string(),
92 )
93 .with_pos(&i.attributes));
94 }
95 Ok(())
96 }
97 ir::Control::Par(_) => Ok(()),
98 ir::Control::Static(_) => Ok(()),
99 }
100}
101
102impl CompileSync {
103 fn build_barriers(
104 &mut self,
105 builder: &mut ir::Builder,
106 s: &mut ir::Control,
107 count: &mut HashMap<u64, u64>,
108 ) {
109 match s {
110 ir::Control::Empty(_) => {
111 if let Some(ref n) = s.get_attributes().get(ir::NumAttr::Sync) {
112 if self.barriers.get(n).is_none() {
113 self.add_shared_structure(builder, n);
114 }
115 let (cells, groups) = &self.barriers[n];
116 let member_idx = count[n];
117
118 let mut new_s =
119 build_member(builder, cells, groups, &member_idx);
120 std::mem::swap(s, &mut new_s);
121 }
122 }
123 ir::Control::Seq(seq) => {
124 for stmt in seq.stmts.iter_mut() {
134 self.build_barriers(builder, stmt, count);
135 }
136 }
137 ir::Control::While(w) => {
138 self.build_barriers(builder, &mut w.body, count);
139 }
140 ir::Control::If(i) => {
141 self.build_barriers(builder, &mut i.tbranch, count);
142 self.build_barriers(builder, &mut i.fbranch, count);
143 }
144 _ => {}
145 }
146 }
147
148 fn add_shared_structure(
149 &mut self,
150 builder: &mut ir::Builder,
151 barrier_idx: &u64,
152 ) {
153 structure!(builder;
154 let barrier = prim std_sync_reg(32);
155 let eq = prim std_eq(32);
156 );
157 let restore = build_restore(builder, &barrier);
158 let wait_restore = build_wait_restore(builder, &eq);
159 let clear_barrier = build_clear_barrier(builder, &barrier);
160 let shared_cells: [RRC<ir::Cell>; 2] = [barrier, eq];
161 let shared_groups: [RRC<ir::Group>; 3] =
162 [wait_restore, restore, clear_barrier];
163 let info = (shared_cells, shared_groups);
164 self.barriers.insert(*barrier_idx, info);
165 }
166}
167
168fn build_incr_barrier(
170 builder: &mut ir::Builder,
171 barrier: &RRC<ir::Cell>,
172 save: &RRC<ir::Cell>,
173 member_idx: &u64,
174) -> RRC<ir::Group> {
175 let group = builder.add_group("incr_barrier");
176 structure!(builder;
177 let incr = prim std_add(32);
178 let cst_1 = constant(1, 1);
179 let cst_2 = constant(1, 32););
180 let read_done_guard = guard!(barrier[format!("read_done_{member_idx}")]);
181 let assigns = build_assignments!(builder;
182 barrier[format!("read_en_{member_idx}")] = ?cst_1["out"];
184 incr["left"] = ? barrier[format!("out_{member_idx}")];
186 incr["right"] = ? cst_2["out"];
188 save["in"] = read_done_guard? incr["out"];
190 save["write_en"] = ? barrier[format!("read_done_{member_idx}")];
192 group["done"] = ?save["done"];
194 );
195
196 group.borrow_mut().assignments.extend(assigns);
197 group
198}
199
200fn build_write_barrier(
202 builder: &mut ir::Builder,
203 barrier: &RRC<ir::Cell>,
204 save: &RRC<ir::Cell>,
205 member_idx: &u64,
206) -> RRC<ir::Group> {
207 let group = builder.add_group("write_barrier");
208 structure!(builder;
209 let cst_1 = constant(1, 1););
210 let assigns = build_assignments!(builder;
211 barrier[format!("write_en_{member_idx}")] = ?cst_1["out"];
213 barrier[format!("in_{member_idx}")] = ?save["out"];
215 group["done"] = ?barrier[format!("write_done_{member_idx}")];
217 );
218 group.borrow_mut().assignments.extend(assigns);
219 group
220}
221
222fn build_wait(builder: &mut ir::Builder, eq: &RRC<ir::Cell>) -> RRC<ir::Group> {
227 let group = builder.add_group("wt");
228 structure!(builder;
229 let wait_reg = prim std_reg(1);
230 let cst_1 = constant(1, 1););
231 let eq_guard = guard!(eq["out"]);
232 let assigns = build_assignments!(builder;
233 wait_reg["in"] = ?eq["out"];
236 wait_reg["write_en"] = eq_guard? cst_1["out"];
238 group["done"] = ?wait_reg["done"];);
240 group.borrow_mut().assignments.extend(assigns);
241 group
242}
243
244fn build_clear_barrier(
246 builder: &mut ir::Builder,
247 barrier: &RRC<ir::Cell>,
248) -> RRC<ir::Group> {
249 let group = builder.add_group("clear_barrier");
250 structure!(builder;
251 let cst_1 = constant(1, 1););
252 let assigns = build_assignments!(builder;
253 barrier["read_en_0"] = ?cst_1["out"];
255 group["done"] = ?barrier["read_done_0"];
257 );
258 group.borrow_mut().assignments.extend(assigns);
259 group
260}
261
262fn build_restore(
264 builder: &mut ir::Builder,
265 barrier: &RRC<ir::Cell>,
266) -> RRC<ir::Group> {
267 let group = builder.add_group("restore");
268 structure!(builder;
269 let cst_1 = constant(1,1);
270 let cst_2 = constant(0, 32););
271 let assigns = build_assignments!(builder;
272 barrier["write_en_0"] = ?cst_1["out"];
274 barrier["in_0"] = ?cst_2["out"];
276 group["done"] = ?barrier["write_done_0"];
278 );
279 group.borrow_mut().assignments.extend(assigns);
280 group
281}
282
283fn build_wait_restore(
287 builder: &mut ir::Builder,
288 eq: &RRC<ir::Cell>,
289) -> RRC<ir::Group> {
290 let group = builder.add_group("wait_restore");
291 structure!(builder;
292 let wait_restore_reg = prim std_reg(1);
293 let cst_1 = constant(1, 1););
294 let eq_guard = !guard!(eq["out"]);
295 let assigns = build_assignments!(builder;
296 wait_restore_reg["in"] = eq_guard? cst_1["out"];
298 wait_restore_reg["write_en"] = eq_guard? cst_1["out"];
300 group["done"] = ?wait_restore_reg["done"];
302 );
303 group.borrow_mut().assignments.extend(assigns);
304 group
305}
306
307fn build_member(
309 builder: &mut ir::Builder,
310 cells: &[RRC<ir::Cell>; 2],
311 groups: &[RRC<ir::Group>; 3],
312 member_idx: &u64,
313) -> ir::Control {
314 let mut stmts: Vec<ir::Control> = Vec::new();
315
316 let barrier = Rc::clone(&cells[0]);
317 let eq = Rc::clone(&cells[1]);
318 let wait_restore = Rc::clone(&groups[0]);
319 let restore = Rc::clone(&groups[1]);
320 let clear_barrier = Rc::clone(&groups[2]);
321
322 structure!(builder;
323 let save = prim std_reg(32););
324 let incr_barrier =
325 build_incr_barrier(builder, &barrier, &save, &(member_idx - 1));
326 let write_barrier =
327 build_write_barrier(builder, &barrier, &save, &(member_idx - 1));
328 let wait = build_wait(builder, &eq);
329
330 stmts.push(ir::Control::enable(incr_barrier));
331 stmts.push(ir::Control::enable(write_barrier));
332 stmts.push(ir::Control::enable(wait));
333 if member_idx == &1 {
334 stmts.push(ir::Control::enable(clear_barrier));
335 stmts.push(ir::Control::enable(restore));
336 } else {
337 stmts.push(ir::Control::enable(wait_restore));
338 }
339 ir::Control::seq(stmts)
340}
341
342impl Visitor for CompileSync {
343 fn finish_par(
344 &mut self,
345 s: &mut ir::Par,
346 comp: &mut ir::Component,
347 sigs: &ir::LibrarySignatures,
348 _comps: &[ir::Component],
349 ) -> VisResult {
350 let mut builder = ir::Builder::new(comp, sigs);
351 let mut barrier_count: HashMap<u64, u64> = HashMap::new();
352 for stmt in s.stmts.iter_mut() {
353 let mut cnt: HashSet<u64> = HashSet::new();
354 count_barriers(stmt, &mut cnt)?;
355 for barrier in cnt {
356 barrier_count
357 .entry(barrier)
358 .and_modify(|count| *count += 1)
359 .or_insert(1);
360 }
361 self.build_barriers(&mut builder, stmt, &mut barrier_count);
362 }
363
364 if self.barriers.is_empty() {
365 return Ok(Action::Continue);
366 }
367
368 let mut init_barriers: Vec<ir::Control> = Vec::new();
369 for (n, (cells, groups)) in self.barriers.iter() {
370 let barrier = Rc::clone(&cells[0]);
371 let eq = Rc::clone(&cells[1]);
372 let restore = Rc::clone(&groups[1]);
373 let n_members = barrier_count.get(n).unwrap();
374 structure!(builder;
375 let num_members = constant(*n_members, 32);
376 );
377 let assigns = build_assignments!(builder;
379 eq["left"] = ?barrier["peek"];
381 eq["right"] = ?num_members["out"];
383 );
384 builder.component.continuous_assignments.extend(assigns);
385
386 init_barriers.push(ir::Control::enable(restore));
387 }
388
389 let mut changed_sequence: Vec<ir::Control> =
391 vec![ir::Control::par(init_barriers)];
392 let mut copied_par_stmts: Vec<ir::Control> = Vec::new();
393 for con in s.stmts.drain(..) {
394 copied_par_stmts.push(con);
395 }
396 changed_sequence.push(ir::Control::par(copied_par_stmts));
397
398 Ok(Action::change(ir::Control::seq(changed_sequence)))
399 }
400}