1use bumpalo::{Bump, collections::Vec as BumpVec};
2use itertools::zip_eq;
3use rustc_hash::FxHashMap;
4use thiserror::Error;
5
6use super::{
7 LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
8};
9use crate::v0::{RegionKind, ScopeClosure, table};
10use crate::v0::{
11 scope::{LinkTable, SymbolTable, VarTable},
12 table::{LinkIndex, NodeId, RegionId, TermId, VarId},
13};
14
15pub struct Context<'a> {
16 module: table::Module<'a>,
17 bump: &'a Bump,
18 vars: VarTable<'a>,
19 links: LinkTable<&'a str>,
20 symbols: SymbolTable<'a>,
21 imports: FxHashMap<SymbolName, NodeId>,
22 terms: FxHashMap<table::Term<'a>, TermId>,
23}
24
25impl<'a> Context<'a> {
26 pub fn new(bump: &'a Bump) -> Self {
27 Self {
28 module: table::Module::default(),
29 bump,
30 vars: VarTable::new(),
31 links: LinkTable::new(),
32 symbols: SymbolTable::new(),
33 imports: FxHashMap::default(),
34 terms: FxHashMap::default(),
35 }
36 }
37
38 pub fn resolve_module(&mut self, module: &'a Module) -> BuildResult<()> {
39 self.module.root = self.module.insert_region(table::Region::default());
40 self.symbols.enter(self.module.root);
41 self.links.enter(self.module.root);
42
43 let children = self.resolve_nodes(&module.root.children)?;
44 let meta = self.resolve_terms(&module.root.meta)?;
45
46 let (links, ports) = self.links.exit();
47 self.symbols.exit();
48 let scope = Some(table::RegionScope { links, ports });
49
50 let all_children = {
53 let mut all_children =
54 BumpVec::with_capacity_in(children.len() + self.imports.len(), self.bump);
55 all_children.extend(children);
56 all_children.extend(self.imports.drain().map(|(_, node)| node));
57 all_children.into_bump_slice()
58 };
59
60 self.module.regions[self.module.root.index()] = table::Region {
61 kind: RegionKind::Module,
62 sources: &[],
63 targets: &[],
64 children: all_children,
65 meta,
66 signature: None,
67 scope,
68 };
69
70 Ok(())
71 }
72
73 fn resolve_terms(&mut self, terms: &'a [Term]) -> BuildResult<&'a [TermId]> {
74 try_alloc_slice(self.bump, terms.iter().map(|term| self.resolve_term(term)))
75 }
76
77 fn resolve_term(&mut self, term: &'a Term) -> BuildResult<TermId> {
78 let term = match term {
79 Term::Wildcard => table::Term::Wildcard,
80 Term::Var(var_name) => table::Term::Var(self.resolve_var(var_name)?),
81 Term::Apply(symbol_name, terms) => {
82 let symbol_id = self.resolve_symbol_name(symbol_name);
83 let terms = self.resolve_terms(terms)?;
84 table::Term::Apply(symbol_id, terms)
85 }
86 Term::List(parts) => table::Term::List(self.resolve_seq_parts(parts)?),
87 Term::Literal(literal) => table::Term::Literal(literal.clone()),
88 Term::Tuple(parts) => table::Term::Tuple(self.resolve_seq_parts(parts)?),
89 Term::Func(region) => {
90 let region = self.resolve_region(region, ScopeClosure::Closed)?;
91 table::Term::Func(region)
92 }
93 };
94
95 Ok(*self
96 .terms
97 .entry(term.clone())
98 .or_insert_with(|| self.module.insert_term(term)))
99 }
100
101 fn resolve_seq_parts(&mut self, parts: &'a [SeqPart]) -> BuildResult<&'a [table::SeqPart]> {
102 try_alloc_slice(
103 self.bump,
104 parts.iter().map(|part| self.resolve_seq_part(part)),
105 )
106 }
107
108 fn resolve_seq_part(&mut self, part: &'a SeqPart) -> BuildResult<table::SeqPart> {
109 Ok(match part {
110 SeqPart::Item(term) => table::SeqPart::Item(self.resolve_term(term)?),
111 SeqPart::Splice(term) => table::SeqPart::Splice(self.resolve_term(term)?),
112 })
113 }
114
115 fn resolve_nodes(&mut self, nodes: &'a [Node]) -> BuildResult<&'a [NodeId]> {
116 let ids: &[_] = self.bump.alloc_slice_fill_with(nodes.len(), |_| {
118 self.module.insert_node(table::Node::default())
119 });
120
121 for (id, node) in zip_eq(ids, nodes) {
126 if let Some(symbol_name) = node.operation.symbol_name() {
127 self.symbols
128 .insert(symbol_name.as_ref(), *id)
129 .map_err(|_| ResolveError::DuplicateSymbol(symbol_name.clone()))?;
130 }
131 }
132
133 for (id, node) in zip_eq(ids, nodes) {
135 self.resolve_node(*id, node)?;
136 }
137
138 Ok(ids)
139 }
140
141 fn resolve_node(&mut self, node_id: NodeId, node: &'a Node) -> BuildResult<()> {
142 let inputs = self.resolve_links(&node.inputs)?;
143 let outputs = self.resolve_links(&node.outputs)?;
144
145 if node.operation.symbol().is_some() {
147 self.vars.enter(node_id);
148 }
149
150 let mut scope_closure = ScopeClosure::Open;
151
152 let operation = match &node.operation {
153 Operation::Invalid => table::Operation::Invalid,
154 Operation::Dfg => table::Operation::Dfg,
155 Operation::Cfg => table::Operation::Cfg,
156 Operation::Block => table::Operation::Block,
157 Operation::TailLoop => table::Operation::TailLoop,
158 Operation::Conditional => table::Operation::Conditional,
159 Operation::DefineFunc(symbol) => {
160 let symbol = self.resolve_symbol(symbol)?;
161 scope_closure = ScopeClosure::Closed;
162 table::Operation::DefineFunc(symbol)
163 }
164 Operation::DeclareFunc(symbol) => {
165 let symbol = self.resolve_symbol(symbol)?;
166 table::Operation::DeclareFunc(symbol)
167 }
168 Operation::DefineAlias(symbol, term) => {
169 let symbol = self.resolve_symbol(symbol)?;
170 let term = self.resolve_term(term)?;
171 table::Operation::DefineAlias(symbol, term)
172 }
173 Operation::DeclareAlias(symbol) => {
174 let symbol = self.resolve_symbol(symbol)?;
175 table::Operation::DeclareAlias(symbol)
176 }
177 Operation::DeclareConstructor(symbol) => {
178 let symbol = self.resolve_symbol(symbol)?;
179 table::Operation::DeclareConstructor(symbol)
180 }
181 Operation::DeclareOperation(symbol) => {
182 let symbol = self.resolve_symbol(symbol)?;
183 table::Operation::DeclareOperation(symbol)
184 }
185 Operation::Import(symbol_name) => table::Operation::Import {
186 name: symbol_name.as_ref(),
187 },
188 Operation::Custom(term) => {
189 let term = self.resolve_term(term)?;
190 table::Operation::Custom(term)
191 }
192 };
193
194 let meta = self.resolve_terms(&node.meta)?;
195 let regions = self.resolve_regions(&node.regions, scope_closure)?;
196
197 let signature = match &node.signature {
198 Some(signature) => Some(self.resolve_term(signature)?),
199 None => None,
200 };
201
202 if node.operation.symbol().is_some() {
204 self.vars.exit();
205 }
206
207 self.module.nodes[node_id.index()] = table::Node {
208 operation,
209 inputs,
210 outputs,
211 regions,
212 meta,
213 signature,
214 };
215
216 Ok(())
217 }
218
219 fn resolve_links(&mut self, links: &'a [LinkName]) -> BuildResult<&'a [LinkIndex]> {
220 try_alloc_slice(self.bump, links.iter().map(|link| self.resolve_link(link)))
221 }
222
223 fn resolve_link(&mut self, link: &'a LinkName) -> BuildResult<LinkIndex> {
224 Ok(self.links.use_link(link.as_ref()))
225 }
226
227 fn resolve_regions(
228 &mut self,
229 regions: &'a [Region],
230 scope_closure: ScopeClosure,
231 ) -> BuildResult<&'a [RegionId]> {
232 try_alloc_slice(
233 self.bump,
234 regions
235 .iter()
236 .map(|region| self.resolve_region(region, scope_closure)),
237 )
238 }
239
240 fn resolve_region(
241 &mut self,
242 region: &'a Region,
243 scope_closure: ScopeClosure,
244 ) -> BuildResult<RegionId> {
245 let meta = self.resolve_terms(®ion.meta)?;
246 let signature = match ®ion.signature {
247 Some(signature) => Some(self.resolve_term(signature)?),
248 None => None,
249 };
250
251 let region_id = self.module.insert_region(table::Region::default());
254
255 self.symbols.enter(region_id);
257
258 if ScopeClosure::Closed == scope_closure {
260 self.links.enter(region_id);
261 }
262
263 let sources = self.resolve_links(®ion.sources)?;
264 let targets = self.resolve_links(®ion.targets)?;
265 let children = self.resolve_nodes(®ion.children)?;
266
267 let scope = match scope_closure {
269 ScopeClosure::Open => None,
270 ScopeClosure::Closed => {
271 let (links, ports) = self.links.exit();
272 Some(table::RegionScope { links, ports })
273 }
274 };
275 self.symbols.exit();
276
277 self.module.regions[region_id.index()] = table::Region {
278 kind: region.kind,
279 sources,
280 targets,
281 children,
282 meta,
283 signature,
284 scope,
285 };
286
287 Ok(region_id)
288 }
289
290 fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> {
291 let name = symbol.name.as_ref();
292 let visibility = &symbol.visibility;
293 let params = self.resolve_params(&symbol.params)?;
294 let constraints = self.resolve_terms(&symbol.constraints)?;
295 let signature = self.resolve_term(&symbol.signature)?;
296
297 Ok(self.bump.alloc(table::Symbol {
298 visibility,
299 name,
300 params,
301 constraints,
302 signature,
303 }))
304 }
305
306 fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param<'a>]> {
312 try_alloc_slice(
313 self.bump,
314 params.iter().map(|param| self.resolve_param(param)),
315 )
316 }
317
318 fn resolve_param(&mut self, param: &'a Param) -> BuildResult<table::Param<'a>> {
323 let name = param.name.as_ref();
324 let r#type = self.resolve_term(¶m.r#type)?;
325
326 self.vars
327 .insert(param.name.as_ref())
328 .map_err(|_| ResolveError::DuplicateVar(param.name.clone()))?;
329
330 Ok(table::Param { name, r#type })
331 }
332
333 fn resolve_var(&self, var_name: &'a VarName) -> BuildResult<VarId> {
334 self.vars
335 .resolve(var_name.as_ref())
336 .map_err(|_| ResolveError::UnknownVar(var_name.clone()))
337 }
338
339 fn resolve_symbol_name(&mut self, symbol_name: &'a SymbolName) -> NodeId {
346 if let Ok(node) = self.symbols.resolve(symbol_name.as_ref()) {
347 return node;
348 }
349
350 *self.imports.entry(symbol_name.clone()).or_insert_with(|| {
351 self.module.insert_node(table::Node {
352 operation: table::Operation::Import {
353 name: symbol_name.as_ref(),
354 },
355 ..Default::default()
356 })
357 })
358 }
359
360 pub fn finish(self) -> table::Module<'a> {
361 self.module
362 }
363}
364
365#[derive(Debug, Clone, Error)]
367#[non_exhaustive]
368#[error("Error resolving model module")]
369pub enum ResolveError {
370 #[error("unknown var: {0}")]
372 UnknownVar(VarName),
373 #[error("duplicate var: {0}")]
375 DuplicateVar(VarName),
376 #[error("duplicate symbol: {0}")]
378 DuplicateSymbol(SymbolName),
379}
380
381type BuildResult<T> = Result<T, ResolveError>;
382
383fn try_alloc_slice<T, E>(
384 bump: &Bump,
385 iter: impl IntoIterator<Item = Result<T, E>>,
386) -> Result<&[T], E> {
387 let iter = iter.into_iter();
388 let mut vec = BumpVec::with_capacity_in(iter.size_hint().0, bump);
389 for item in iter {
390 vec.push(item?);
391 }
392 Ok(vec.into_bump_slice())
393}
394
395#[cfg(test)]
396mod test {
397 use crate::v0::ast;
398 use bumpalo::Bump;
399 use std::str::FromStr as _;
400
401 #[test]
402 fn vars_in_root_scope() {
403 let text = "(hugr 0) (mod) (meta ?x)";
404 let ast = ast::Package::from_str(text).unwrap();
405 let bump = Bump::new();
406 assert!(ast.resolve(&bump).is_err());
407 }
408}