calyx_opt/passes/sync/
compile_sync.rs

1use crate::traversal::{Action, Named, VisResult, Visitor};
2use calyx_ir::RRC;
3use calyx_ir::{self as ir, GetAttributes};
4use calyx_ir::{build_assignments, guard, structure};
5use calyx_utils::{CalyxResult, Error};
6use linked_hash_map::LinkedHashMap;
7use std::collections::{HashMap, HashSet};
8use std::rc::Rc;
9
10#[derive(Default)]
11/// 1. loop through all control statements under "par" block to find # barriers
12///    needed and # members of each barrier
13/// 2. add all cells and groups needed
14/// 3. loop through all control statements, find the statements with @sync
15///    attribute and replace them with
16///     seq {
17///       <stmt>;
18///       incr_barrier_0_*;
19///       write_barrier_0_*;
20///       wait_*;
21///       restore_*;
22///     }
23///    or
24///     seq {
25///       <stmt>;
26///       incr_barrier_*_*;
27///       write_barrier_*_*;
28///       wait_*;
29///       wait_restore_*;
30///     }
31
32pub struct CompileSync {
33    barriers: BarrierMap,
34}
35
36/// the structure used to store cells and groups shared by one barrier
37type BarrierMap = LinkedHashMap<u64, ([RRC<ir::Cell>; 2], [RRC<ir::Group>; 3])>;
38
39impl Named for CompileSync {
40    fn name() -> &'static str {
41        "compile-sync"
42    }
43
44    fn description() -> &'static str {
45        "Implement barriers for statements marked with @sync attribute"
46    }
47}
48
49/// put into the count set the barrier indices appearing in the thread
50fn count_barriers(
51    s: &ir::Control,
52    count: &mut HashSet<u64>,
53) -> CalyxResult<()> {
54    match s {
55        ir::Control::Empty(_) => {
56            if let Some(n) = s.get_attributes().get(ir::NumAttr::Sync) {
57                count.insert(n);
58            }
59            Ok(())
60        }
61        ir::Control::Seq(seq) => {
62            for stmt in seq.stmts.iter() {
63                count_barriers(stmt, count)?;
64            }
65            Ok(())
66        }
67        ir::Control::While(ir::While { body, .. })
68        | ir::Control::Repeat(ir::Repeat { body, .. }) => {
69            count_barriers(body, count)?;
70            Ok(())
71        }
72        ir::Control::If(i) => {
73            count_barriers(&i.tbranch, count)?;
74            count_barriers(&i.fbranch, count)?;
75            Ok(())
76        }
77        ir::Control::Enable(e) => {
78            if s.get_attributes().get(ir::NumAttr::Sync).is_some() {
79                return Err(Error::malformed_control(
80                    "Enable or Invoke controls cannot be marked with @sync"
81                        .to_string(),
82                )
83                .with_pos(&e.attributes));
84            }
85            Ok(())
86        }
87        ir::Control::Invoke(i) => {
88            if s.get_attributes().get(ir::NumAttr::Sync).is_some() {
89                return Err(Error::malformed_control(
90                    "Enable or Invoke controls cannot be marked with @sync"
91                        .to_string(),
92                )
93                .with_pos(&i.attributes));
94            }
95            Ok(())
96        }
97        ir::Control::Par(_) => Ok(()),
98        ir::Control::Static(_) => Ok(()),
99    }
100}
101
102impl CompileSync {
103    fn build_barriers(
104        &mut self,
105        builder: &mut ir::Builder,
106        s: &mut ir::Control,
107        count: &mut HashMap<u64, u64>,
108    ) {
109        match s {
110            ir::Control::Empty(_) => {
111                if let Some(ref n) = s.get_attributes().get(ir::NumAttr::Sync) {
112                    if self.barriers.get(n).is_none() {
113                        self.add_shared_structure(builder, n);
114                    }
115                    let (cells, groups) = &self.barriers[n];
116                    let member_idx = count[n];
117
118                    let mut new_s =
119                        build_member(builder, cells, groups, &member_idx);
120                    std::mem::swap(s, &mut new_s);
121                }
122            }
123            ir::Control::Seq(seq) => {
124                // go through each control statement
125                // if @sync
126                // see if we already have the set of shared primitives and groups
127                // initialized
128                // True -> generate the inidividual groups and buikld the seq stmt
129                // False -> generate the shared groups, cells
130                //          put the shared groups in the barriermap
131                //          generate the individual groups
132                //          build the seq stmt
133                for stmt in seq.stmts.iter_mut() {
134                    self.build_barriers(builder, stmt, count);
135                }
136            }
137            ir::Control::While(w) => {
138                self.build_barriers(builder, &mut w.body, count);
139            }
140            ir::Control::If(i) => {
141                self.build_barriers(builder, &mut i.tbranch, count);
142                self.build_barriers(builder, &mut i.fbranch, count);
143            }
144            _ => {}
145        }
146    }
147
148    fn add_shared_structure(
149        &mut self,
150        builder: &mut ir::Builder,
151        barrier_idx: &u64,
152    ) {
153        structure!(builder;
154                let barrier = prim std_sync_reg(32);
155                let eq = prim std_eq(32);
156        );
157        let restore = build_restore(builder, &barrier);
158        let wait_restore = build_wait_restore(builder, &eq);
159        let clear_barrier = build_clear_barrier(builder, &barrier);
160        let shared_cells: [RRC<ir::Cell>; 2] = [barrier, eq];
161        let shared_groups: [RRC<ir::Group>; 3] =
162            [wait_restore, restore, clear_barrier];
163        let info = (shared_cells, shared_groups);
164        self.barriers.insert(*barrier_idx, info);
165    }
166}
167
168//put together the group to read and increment the barrier
169fn build_incr_barrier(
170    builder: &mut ir::Builder,
171    barrier: &RRC<ir::Cell>,
172    save: &RRC<ir::Cell>,
173    member_idx: &u64,
174) -> RRC<ir::Group> {
175    let group = builder.add_group("incr_barrier");
176    structure!(builder;
177        let incr = prim std_add(32);
178        let cst_1 = constant(1, 1);
179        let cst_2 = constant(1, 32););
180    let read_done_guard = guard!(barrier[format!("read_done_{member_idx}")]);
181    let assigns = build_assignments!(builder;
182        // barrier_*.read_en_0 = 1'd1;
183        barrier[format!("read_en_{member_idx}")] = ?cst_1["out"];
184        //incr_*_*.left = barrier_*.out_*;
185        incr["left"] = ? barrier[format!("out_{member_idx}")];
186        // incr_*_*.right = 32'd1;
187        incr["right"] = ? cst_2["out"];
188        // save_*_*.in = barrier_*.read_done_*? incr_1.out;
189        save["in"] = read_done_guard? incr["out"];
190        // save_*_*.write_en = barrier_*.read_done_*;
191        save["write_en"] = ? barrier[format!("read_done_{member_idx}")];
192        // incr_barrier_*_*[done] = save_*_*.done;
193        group["done"] = ?save["done"];
194    );
195
196    group.borrow_mut().assignments.extend(assigns);
197    group
198}
199
200// put together the group to write to the barrier after incrementing
201fn build_write_barrier(
202    builder: &mut ir::Builder,
203    barrier: &RRC<ir::Cell>,
204    save: &RRC<ir::Cell>,
205    member_idx: &u64,
206) -> RRC<ir::Group> {
207    let group = builder.add_group("write_barrier");
208    structure!(builder;
209    let cst_1 = constant(1, 1););
210    let assigns = build_assignments!(builder;
211        // barrier_*.write_en_* = 1'd1;
212        barrier[format!("write_en_{member_idx}")] = ?cst_1["out"];
213        // barrier_*.in_* = save_*_*.out;
214        barrier[format!("in_{member_idx}")] = ?save["out"];
215        // write_barrier_*_*[done] = barrier_*.write_done_*;
216        group["done"] = ?barrier[format!("write_done_{member_idx}")];
217    );
218    group.borrow_mut().assignments.extend(assigns);
219    group
220}
221
222// Put together the group to wait until the peek value reaches capacity.
223// We don't actually care about the value being written to the register.
224// We're only using it to make sure that the barrier has reached the expected
225// value.
226fn build_wait(builder: &mut ir::Builder, eq: &RRC<ir::Cell>) -> RRC<ir::Group> {
227    let group = builder.add_group("wt");
228    structure!(builder;
229    let wait_reg = prim std_reg(1);
230    let cst_1 = constant(1, 1););
231    let eq_guard = guard!(eq["out"]);
232    let assigns = build_assignments!(builder;
233        // wait_reg_*.in = eq_*.out;
234        // XXX(rachit): Since the value doesn't matter, can this just be zero?
235        wait_reg["in"] = ?eq["out"];
236        // wait_reg_*.write_en = eq_*.out? 1'd1;
237        wait_reg["write_en"] = eq_guard? cst_1["out"];
238        // wait_*[done] = wait_reg_*.done;
239        group["done"] = ?wait_reg["done"];);
240    group.borrow_mut().assignments.extend(assigns);
241    group
242}
243
244// put together the group to empty out the sync reg before resetting it to 0
245fn build_clear_barrier(
246    builder: &mut ir::Builder,
247    barrier: &RRC<ir::Cell>,
248) -> RRC<ir::Group> {
249    let group = builder.add_group("clear_barrier");
250    structure!(builder;
251    let cst_1 = constant(1, 1););
252    let assigns = build_assignments!(builder;
253    // barrier_*.read_en_0 = 1'd1;
254    barrier["read_en_0"] = ?cst_1["out"];
255    //clear_barrier_*[done] = barrier_1.read_done_0;
256    group["done"] = ?barrier["read_done_0"];
257    );
258    group.borrow_mut().assignments.extend(assigns);
259    group
260}
261
262// put together the group to restore the barrier to 0
263fn build_restore(
264    builder: &mut ir::Builder,
265    barrier: &RRC<ir::Cell>,
266) -> RRC<ir::Group> {
267    let group = builder.add_group("restore");
268    structure!(builder;
269    let cst_1 = constant(1,1);
270    let cst_2 = constant(0, 32););
271    let assigns = build_assignments!(builder;
272        // barrier_*.write_en_0 = 1'd1;
273        barrier["write_en_0"] = ?cst_1["out"];
274        // barrier_*.in_0 = 32'd0;
275        barrier["in_0"] = ?cst_2["out"];
276        // restore_*[done] = barrier_*.write_done_0;
277        group["done"] = ?barrier["write_done_0"];
278    );
279    group.borrow_mut().assignments.extend(assigns);
280    group
281}
282
283// Put together the group to wait for restorer to finish.
284// Like the wait group, we don't care about the value in the register
285// We just want to wait till the value in the barrier is set to 0.
286fn build_wait_restore(
287    builder: &mut ir::Builder,
288    eq: &RRC<ir::Cell>,
289) -> RRC<ir::Group> {
290    let group = builder.add_group("wait_restore");
291    structure!(builder;
292    let wait_restore_reg = prim std_reg(1);
293    let cst_1 = constant(1, 1););
294    let eq_guard = !guard!(eq["out"]);
295    let assigns = build_assignments!(builder;
296    // wait_restore_reg_*.in = !eq_*.out? 1'd1;
297    wait_restore_reg["in"] = eq_guard? cst_1["out"];
298    // wait_restore_reg_*.write_en = !eq_*.out? 1'd1;
299    wait_restore_reg["write_en"] = eq_guard? cst_1["out"];
300    //wait_restore_*[done] = wait_restore_reg_*.done;
301    group["done"] = ?wait_restore_reg["done"];
302    );
303    group.borrow_mut().assignments.extend(assigns);
304    group
305}
306
307// put together the sequence of groups that a barrier member requires
308fn build_member(
309    builder: &mut ir::Builder,
310    cells: &[RRC<ir::Cell>; 2],
311    groups: &[RRC<ir::Group>; 3],
312    member_idx: &u64,
313) -> ir::Control {
314    let mut stmts: Vec<ir::Control> = Vec::new();
315
316    let barrier = Rc::clone(&cells[0]);
317    let eq = Rc::clone(&cells[1]);
318    let wait_restore = Rc::clone(&groups[0]);
319    let restore = Rc::clone(&groups[1]);
320    let clear_barrier = Rc::clone(&groups[2]);
321
322    structure!(builder;
323        let save = prim std_reg(32););
324    let incr_barrier =
325        build_incr_barrier(builder, &barrier, &save, &(member_idx - 1));
326    let write_barrier =
327        build_write_barrier(builder, &barrier, &save, &(member_idx - 1));
328    let wait = build_wait(builder, &eq);
329
330    stmts.push(ir::Control::enable(incr_barrier));
331    stmts.push(ir::Control::enable(write_barrier));
332    stmts.push(ir::Control::enable(wait));
333    if member_idx == &1 {
334        stmts.push(ir::Control::enable(clear_barrier));
335        stmts.push(ir::Control::enable(restore));
336    } else {
337        stmts.push(ir::Control::enable(wait_restore));
338    }
339    ir::Control::seq(stmts)
340}
341
342impl Visitor for CompileSync {
343    fn finish_par(
344        &mut self,
345        s: &mut ir::Par,
346        comp: &mut ir::Component,
347        sigs: &ir::LibrarySignatures,
348        _comps: &[ir::Component],
349    ) -> VisResult {
350        let mut builder = ir::Builder::new(comp, sigs);
351        let mut barrier_count: HashMap<u64, u64> = HashMap::new();
352        for stmt in s.stmts.iter_mut() {
353            let mut cnt: HashSet<u64> = HashSet::new();
354            count_barriers(stmt, &mut cnt)?;
355            for barrier in cnt {
356                barrier_count
357                    .entry(barrier)
358                    .and_modify(|count| *count += 1)
359                    .or_insert(1);
360            }
361            self.build_barriers(&mut builder, stmt, &mut barrier_count);
362        }
363
364        if self.barriers.is_empty() {
365            return Ok(Action::Continue);
366        }
367
368        let mut init_barriers: Vec<ir::Control> = Vec::new();
369        for (n, (cells, groups)) in self.barriers.iter() {
370            let barrier = Rc::clone(&cells[0]);
371            let eq = Rc::clone(&cells[1]);
372            let restore = Rc::clone(&groups[1]);
373            let n_members = barrier_count.get(n).unwrap();
374            structure!(builder;
375                let num_members = constant(*n_members, 32);
376            );
377            // add continuous assignments
378            let assigns = build_assignments!(builder;
379            // eq_*.left = barrier_*.peek;
380            eq["left"] = ?barrier["peek"];
381            // eq_*.right = 32'd* (number of members);
382            eq["right"] = ?num_members["out"];
383            );
384            builder.component.continuous_assignments.extend(assigns);
385
386            init_barriers.push(ir::Control::enable(restore));
387        }
388
389        // wrap the par stmt in a seq stmt so that barriers are initialized
390        let mut changed_sequence: Vec<ir::Control> =
391            vec![ir::Control::par(init_barriers)];
392        let mut copied_par_stmts: Vec<ir::Control> = Vec::new();
393        for con in s.stmts.drain(..) {
394            copied_par_stmts.push(con);
395        }
396        changed_sequence.push(ir::Control::par(copied_par_stmts));
397
398        Ok(Action::change(ir::Control::seq(changed_sequence)))
399    }
400}