calyx_opt/passes/
discover_external.rs

1use crate::traversal::{Action, ConstructVisitor, Named, Visitor};
2use calyx_ir as ir;
3use calyx_utils::CalyxResult;
4use ir::RRC;
5use itertools::Itertools;
6use linked_hash_map::LinkedHashMap;
7use std::collections::{HashMap, HashSet};
8
9/// A pass to detect cells that have been inlined into the top-level component
10/// and turn them into real cells marked with [ir::BoolAttr::External].
11pub struct DiscoverExternal {
12    /// The default value used for parameters that cannot be inferred.
13    default: u64,
14    /// The suffix to be remove from the inferred names
15    suffix: Option<String>,
16}
17
18impl Named for DiscoverExternal {
19    fn name() -> &'static str {
20        "discover-external"
21    }
22
23    fn description() -> &'static str {
24        "Detect cells that have been inlined into a component's interface and turn them into @external cells"
25    }
26}
27
28impl ConstructVisitor for DiscoverExternal {
29    fn from(ctx: &ir::Context) -> CalyxResult<Self>
30    where
31        Self: Sized,
32    {
33        // Manual parsing because our options are not flags
34        let n = Self::name();
35        let given_opts: HashSet<_> = ctx
36            .extra_opts
37            .iter()
38            .filter_map(|opt| {
39                let mut splits = opt.split(':');
40                if splits.next() == Some(n) {
41                    splits.next()
42                } else {
43                    None
44                }
45            })
46            .collect();
47
48        let mut default = None;
49        let mut suffix = None;
50        for opt in given_opts {
51            let mut splits = opt.split('=');
52            let spl = splits.next();
53            // Search for the "default=<n>" option
54            if spl == Some("default") {
55                let Some(val) = splits.next().and_then(|v| v.parse().ok())
56                else {
57                    log::warn!("Failed to parse default value. Please specify using -x {}:default=<n>", n);
58                    continue;
59                };
60                log::info!("Setting default value to {}", val);
61
62                default = Some(val);
63            }
64            // Search for "strip-suffix=<str>" option
65            else if spl == Some("strip-suffix") {
66                let Some(suff) = splits.next() else {
67                    log::warn!("Failed to parse suffix. Please specify using -x {}:strip-suffix=<str>", n);
68                    continue;
69                };
70                log::info!("Setting suffix to {}", suff);
71
72                suffix = Some(suff.to_string());
73            }
74        }
75
76        Ok(Self {
77            default: default.unwrap_or(32),
78            suffix,
79        })
80    }
81
82    fn clear_data(&mut self) {
83        /* All data is shared */
84    }
85}
86
87impl Visitor for DiscoverExternal {
88    fn start(
89        &mut self,
90        comp: &mut ir::Component,
91        sigs: &ir::LibrarySignatures,
92        _comps: &[ir::Component],
93    ) -> crate::traversal::VisResult {
94        // Ignore non-toplevel components
95        if !comp.attributes.has(ir::BoolAttr::TopLevel) {
96            return Ok(Action::Stop);
97        }
98
99        // Group ports by longest common prefix
100        // NOTE(rachit): This is an awfully inefficient representation. We really
101        // want a TrieMap here.
102        let mut prefix_map: LinkedHashMap<String, HashSet<ir::Id>> =
103            LinkedHashMap::new();
104        for port in comp.signature.borrow().ports() {
105            let name = port.borrow().name;
106            let mut prefix = String::new();
107            // Walk over the port name and add it to the prefix map
108            for c in name.as_ref().chars() {
109                prefix.push(c);
110                if prefix == name.as_ref() {
111                    // We have reached the end of the name
112                    break;
113                }
114                // Remove prefix from name
115                let name = name.as_ref().strip_prefix(&prefix).unwrap();
116                prefix_map
117                    .entry(prefix.clone())
118                    .or_default()
119                    .insert(name.into());
120            }
121        }
122
123        // For all cells in the library, build a set of port names.
124        let mut prim_ports: LinkedHashMap<ir::Id, HashSet<ir::Id>> =
125            LinkedHashMap::new();
126        for prim in sigs.signatures() {
127            let hs = prim
128                .signature
129                .iter()
130                .filter(|p| {
131                    // Ignore clk and reset cells
132                    !p.attributes.has(ir::BoolAttr::Clk)
133                        && !p.attributes.has(ir::BoolAttr::Reset)
134                })
135                .map(|p| p.name())
136                .collect::<HashSet<_>>();
137            prim_ports.insert(prim.name, hs);
138        }
139
140        // For all prefixes, check if there is a primitive that matches the
141        // prefix. If there is, then we have an external cell.
142        let mut pre_to_prim: LinkedHashMap<String, ir::Id> =
143            LinkedHashMap::new();
144        for (prefix, ports) in prefix_map.iter() {
145            for (&prim, prim_ports) in prim_ports.iter() {
146                if prim_ports == ports {
147                    pre_to_prim.insert(prefix.clone(), prim);
148                }
149            }
150        }
151
152        // Collect all ports associated with a specific prefix
153        let mut port_map: LinkedHashMap<String, Vec<RRC<ir::Port>>> =
154            LinkedHashMap::new();
155        'outer: for port in &comp.signature.borrow().ports {
156            // If this matches a prefix, add it to the corresponding port map
157            for pre in pre_to_prim.keys() {
158                if port.borrow().name.as_ref().starts_with(pre) {
159                    port_map.entry(pre.clone()).or_default().push(port.clone());
160                    continue 'outer;
161                }
162            }
163        }
164
165        // Add external cells for all matching prefixes
166        let mut pre_to_cells = LinkedHashMap::new();
167        for (pre, &prim) in &pre_to_prim {
168            log::info!("Prefix {} matches primitive {}", pre, prim);
169            // Attempt to infer the parameters for the external cell
170            let prim_sig = sigs.get_primitive(prim);
171            let ports = &port_map[pre];
172            let mut params: LinkedHashMap<_, Option<u64>> = prim_sig
173                .params
174                .clone()
175                .into_iter()
176                .map(|p| (p, None))
177                .collect();
178
179            // Walk over the abstract port definition and attempt to match the bitwidths
180            for abs in &prim_sig.signature {
181                if let ir::Width::Param { value } = abs.width {
182                    // Find the corresponding port
183                    let port = ports
184                        .iter()
185                        .find(|p| {
186                            p.borrow()
187                                .name
188                                .as_ref()
189                                .ends_with(abs.name().as_ref())
190                        })
191                        .unwrap_or_else(|| {
192                            panic!("No port found for {}", abs.name())
193                        });
194                    // Update the value of the parameter
195                    let v = params.get_mut(&value).unwrap();
196                    if let Some(v) = v {
197                        if *v != port.borrow().width {
198                            log::warn!(
199                                "Mismatched bitwidths for {} in {}, defaulting to {}",
200                                pre,
201                                prim,
202                                self.default
203                            );
204                            *v = self.default;
205                        }
206                    } else {
207                        *v = Some(port.borrow().width);
208                    }
209                }
210            }
211
212            let param_values = params
213                .into_iter()
214                .map(|(_, v)| {
215                    if let Some(v) = v {
216                        v
217                    } else {
218                        log::warn!(
219                            "Unable to infer parameter value for {} in {}, defaulting to {}",
220                            pre,
221                            prim,
222                            self.default
223                        );
224                        self.default
225                    }
226                })
227                .collect_vec();
228
229            let mut builder = ir::Builder::new(comp, sigs);
230            // Remove the suffix from the cell name
231            let name = if let Some(suf) = &self.suffix {
232                pre.strip_suffix(suf).unwrap_or(pre)
233            } else {
234                pre
235            };
236            let cell = builder.add_primitive(name, prim, &param_values);
237            cell.borrow_mut()
238                .attributes
239                .insert(ir::BoolAttr::External, 1);
240            pre_to_cells.insert(pre.clone(), cell);
241        }
242
243        // Rewrite the ports mentioned in the component signature and remove them
244        let mut rewrites: ir::rewriter::PortRewriteMap = HashMap::new();
245        for (pre, ports) in port_map {
246            // let prim = sigs.get_primitive(pre_to_prim[&pre]);
247            let cr = pre_to_cells[&pre].clone();
248            let cell = cr.borrow();
249            let cell_ports = cell.ports();
250            // Iterate over ports with the same names.
251            for pr in ports {
252                let port = pr.borrow();
253                let cp = cell_ports
254                    .iter()
255                    .find(|p| {
256                        port.name.as_ref().ends_with(p.borrow().name.as_ref())
257                    })
258                    .unwrap_or_else(|| {
259                        panic!("No port found for {}", port.name)
260                    });
261                rewrites.insert(port.canonical(), cp.clone());
262            }
263        }
264
265        comp.for_each_assignment(|assign| {
266            assign.for_each_port(|port| {
267                rewrites.get(&port.borrow().canonical()).cloned()
268            })
269        });
270        comp.for_each_static_assignment(|assign| {
271            assign.for_each_port(|port| {
272                rewrites.get(&port.borrow().canonical()).cloned()
273            })
274        });
275
276        // Remove all ports from the signature that match a prefix
277        comp.signature.borrow_mut().ports.retain(|p| {
278            !pre_to_prim
279                .keys()
280                .any(|pre| p.borrow().name.as_ref().starts_with(pre))
281        });
282
283        // Purely structural pass
284        Ok(Action::Stop)
285    }
286}