hugr_model/v0/ast/
resolve.rs

1use bumpalo::{Bump, collections::Vec as BumpVec};
2use itertools::zip_eq;
3use rustc_hash::FxHashMap;
4use thiserror::Error;
5
6use super::{
7    LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
8};
9use crate::v0::{RegionKind, ScopeClosure, table};
10use crate::v0::{
11    scope::{LinkTable, SymbolTable, VarTable},
12    table::{LinkIndex, NodeId, RegionId, TermId, VarId},
13};
14
15pub struct Context<'a> {
16    module: table::Module<'a>,
17    bump: &'a Bump,
18    vars: VarTable<'a>,
19    links: LinkTable<&'a str>,
20    symbols: SymbolTable<'a>,
21    imports: FxHashMap<SymbolName, NodeId>,
22    terms: FxHashMap<table::Term<'a>, TermId>,
23}
24
25impl<'a> Context<'a> {
26    pub fn new(bump: &'a Bump) -> Self {
27        Self {
28            module: table::Module::default(),
29            bump,
30            vars: VarTable::new(),
31            links: LinkTable::new(),
32            symbols: SymbolTable::new(),
33            imports: FxHashMap::default(),
34            terms: FxHashMap::default(),
35        }
36    }
37
38    pub fn resolve_module(&mut self, module: &'a Module) -> BuildResult<()> {
39        self.module.root = self.module.insert_region(table::Region::default());
40        self.symbols.enter(self.module.root);
41        self.links.enter(self.module.root);
42
43        let children = self.resolve_nodes(&module.root.children)?;
44        let meta = self.resolve_terms(&module.root.meta)?;
45
46        let (links, ports) = self.links.exit();
47        self.symbols.exit();
48        let scope = Some(table::RegionScope { links, ports });
49
50        // Symbols that could not be resolved within the module still need to
51        // be represented by a node. This is why we add import nodes.
52        let all_children = {
53            let mut all_children =
54                BumpVec::with_capacity_in(children.len() + self.imports.len(), self.bump);
55            all_children.extend(children);
56            all_children.extend(self.imports.drain().map(|(_, node)| node));
57            all_children.into_bump_slice()
58        };
59
60        self.module.regions[self.module.root.index()] = table::Region {
61            kind: RegionKind::Module,
62            sources: &[],
63            targets: &[],
64            children: all_children,
65            meta,
66            signature: None,
67            scope,
68        };
69
70        Ok(())
71    }
72
73    fn resolve_terms(&mut self, terms: &'a [Term]) -> BuildResult<&'a [TermId]> {
74        try_alloc_slice(self.bump, terms.iter().map(|term| self.resolve_term(term)))
75    }
76
77    fn resolve_term(&mut self, term: &'a Term) -> BuildResult<TermId> {
78        let term = match term {
79            Term::Wildcard => table::Term::Wildcard,
80            Term::Var(var_name) => table::Term::Var(self.resolve_var(var_name)?),
81            Term::Apply(symbol_name, terms) => {
82                let symbol_id = self.resolve_symbol_name(symbol_name);
83                let terms = self.resolve_terms(terms)?;
84                table::Term::Apply(symbol_id, terms)
85            }
86            Term::List(parts) => table::Term::List(self.resolve_seq_parts(parts)?),
87            Term::Literal(literal) => table::Term::Literal(literal.clone()),
88            Term::Tuple(parts) => table::Term::Tuple(self.resolve_seq_parts(parts)?),
89            Term::Func(region) => {
90                let region = self.resolve_region(region, ScopeClosure::Closed)?;
91                table::Term::Func(region)
92            }
93        };
94
95        Ok(*self
96            .terms
97            .entry(term.clone())
98            .or_insert_with(|| self.module.insert_term(term)))
99    }
100
101    fn resolve_seq_parts(&mut self, parts: &'a [SeqPart]) -> BuildResult<&'a [table::SeqPart]> {
102        try_alloc_slice(
103            self.bump,
104            parts.iter().map(|part| self.resolve_seq_part(part)),
105        )
106    }
107
108    fn resolve_seq_part(&mut self, part: &'a SeqPart) -> BuildResult<table::SeqPart> {
109        Ok(match part {
110            SeqPart::Item(term) => table::SeqPart::Item(self.resolve_term(term)?),
111            SeqPart::Splice(term) => table::SeqPart::Splice(self.resolve_term(term)?),
112        })
113    }
114
115    fn resolve_nodes(&mut self, nodes: &'a [Node]) -> BuildResult<&'a [NodeId]> {
116        // Allocate ids for all nodes by introducing placeholders into the module.
117        let ids: &[_] = self.bump.alloc_slice_fill_with(nodes.len(), |_| {
118            self.module.insert_node(table::Node::default())
119        });
120
121        // For those nodes that introduce symbols, we then associate the symbol
122        // with the id of the node. This serves as a form of forward declaration
123        // so that the symbol is visible in the current region regardless of the
124        // order of the nodes.
125        for (id, node) in zip_eq(ids, nodes) {
126            if let Some(symbol_name) = node.operation.symbol_name() {
127                self.symbols
128                    .insert(symbol_name.as_ref(), *id)
129                    .map_err(|_| ResolveError::DuplicateSymbol(symbol_name.clone()))?;
130            }
131        }
132
133        // Finally we can build the actual nodes.
134        for (id, node) in zip_eq(ids, nodes) {
135            self.resolve_node(*id, node)?;
136        }
137
138        Ok(ids)
139    }
140
141    fn resolve_node(&mut self, node_id: NodeId, node: &'a Node) -> BuildResult<()> {
142        let inputs = self.resolve_links(&node.inputs)?;
143        let outputs = self.resolve_links(&node.outputs)?;
144
145        // When the node introduces a symbol it also introduces a new variable scope.
146        if node.operation.symbol().is_some() {
147            self.vars.enter(node_id);
148        }
149
150        let mut scope_closure = ScopeClosure::Open;
151
152        let operation = match &node.operation {
153            Operation::Invalid => table::Operation::Invalid,
154            Operation::Dfg => table::Operation::Dfg,
155            Operation::Cfg => table::Operation::Cfg,
156            Operation::Block => table::Operation::Block,
157            Operation::TailLoop => table::Operation::TailLoop,
158            Operation::Conditional => table::Operation::Conditional,
159            Operation::DefineFunc(symbol) => {
160                let symbol = self.resolve_symbol(symbol)?;
161                scope_closure = ScopeClosure::Closed;
162                table::Operation::DefineFunc(symbol)
163            }
164            Operation::DeclareFunc(symbol) => {
165                let symbol = self.resolve_symbol(symbol)?;
166                table::Operation::DeclareFunc(symbol)
167            }
168            Operation::DefineAlias(symbol, term) => {
169                let symbol = self.resolve_symbol(symbol)?;
170                let term = self.resolve_term(term)?;
171                table::Operation::DefineAlias(symbol, term)
172            }
173            Operation::DeclareAlias(symbol) => {
174                let symbol = self.resolve_symbol(symbol)?;
175                table::Operation::DeclareAlias(symbol)
176            }
177            Operation::DeclareConstructor(symbol) => {
178                let symbol = self.resolve_symbol(symbol)?;
179                table::Operation::DeclareConstructor(symbol)
180            }
181            Operation::DeclareOperation(symbol) => {
182                let symbol = self.resolve_symbol(symbol)?;
183                table::Operation::DeclareOperation(symbol)
184            }
185            Operation::Import(symbol_name) => table::Operation::Import {
186                name: symbol_name.as_ref(),
187            },
188            Operation::Custom(term) => {
189                let term = self.resolve_term(term)?;
190                table::Operation::Custom(term)
191            }
192        };
193
194        let meta = self.resolve_terms(&node.meta)?;
195        let regions = self.resolve_regions(&node.regions, scope_closure)?;
196
197        let signature = match &node.signature {
198            Some(signature) => Some(self.resolve_term(signature)?),
199            None => None,
200        };
201
202        // We need to close the variable scope if we have opened one before.
203        if node.operation.symbol().is_some() {
204            self.vars.exit();
205        }
206
207        self.module.nodes[node_id.index()] = table::Node {
208            operation,
209            inputs,
210            outputs,
211            regions,
212            meta,
213            signature,
214        };
215
216        Ok(())
217    }
218
219    fn resolve_links(&mut self, links: &'a [LinkName]) -> BuildResult<&'a [LinkIndex]> {
220        try_alloc_slice(self.bump, links.iter().map(|link| self.resolve_link(link)))
221    }
222
223    fn resolve_link(&mut self, link: &'a LinkName) -> BuildResult<LinkIndex> {
224        Ok(self.links.use_link(link.as_ref()))
225    }
226
227    fn resolve_regions(
228        &mut self,
229        regions: &'a [Region],
230        scope_closure: ScopeClosure,
231    ) -> BuildResult<&'a [RegionId]> {
232        try_alloc_slice(
233            self.bump,
234            regions
235                .iter()
236                .map(|region| self.resolve_region(region, scope_closure)),
237        )
238    }
239
240    fn resolve_region(
241        &mut self,
242        region: &'a Region,
243        scope_closure: ScopeClosure,
244    ) -> BuildResult<RegionId> {
245        let meta = self.resolve_terms(&region.meta)?;
246        let signature = match &region.signature {
247            Some(signature) => Some(self.resolve_term(signature)?),
248            None => None,
249        };
250
251        // We insert a placeholder for the region in order to allocate a region
252        // id, which we need to track the region's scopes.
253        let region_id = self.module.insert_region(table::Region::default());
254
255        // Each region defines a new scope for symbols.
256        self.symbols.enter(region_id);
257
258        // If the region is closed, it also defines a new scope for links.
259        if ScopeClosure::Closed == scope_closure {
260            self.links.enter(region_id);
261        }
262
263        let sources = self.resolve_links(&region.sources)?;
264        let targets = self.resolve_links(&region.targets)?;
265        let children = self.resolve_nodes(&region.children)?;
266
267        // Close the region's scopes.
268        let scope = match scope_closure {
269            ScopeClosure::Open => None,
270            ScopeClosure::Closed => {
271                let (links, ports) = self.links.exit();
272                Some(table::RegionScope { links, ports })
273            }
274        };
275        self.symbols.exit();
276
277        self.module.regions[region_id.index()] = table::Region {
278            kind: region.kind,
279            sources,
280            targets,
281            children,
282            meta,
283            signature,
284            scope,
285        };
286
287        Ok(region_id)
288    }
289
290    fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> {
291        let name = symbol.name.as_ref();
292        let visibility = &symbol.visibility;
293        let params = self.resolve_params(&symbol.params)?;
294        let constraints = self.resolve_terms(&symbol.constraints)?;
295        let signature = self.resolve_term(&symbol.signature)?;
296
297        Ok(self.bump.alloc(table::Symbol {
298            visibility,
299            name,
300            params,
301            constraints,
302            signature,
303        }))
304    }
305
306    /// Builds symbol parameters.
307    ///
308    /// This incrementally inserts the names of the parameters into the current
309    /// variable scope, so that any parameter is in scope for each of its
310    /// succeeding parameters.
311    fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param<'a>]> {
312        try_alloc_slice(
313            self.bump,
314            params.iter().map(|param| self.resolve_param(param)),
315        )
316    }
317
318    /// Builds a symbol parameter.
319    ///
320    /// This inserts the name of the parameter into the current variable scope,
321    /// making the parameter accessible as a variable.
322    fn resolve_param(&mut self, param: &'a Param) -> BuildResult<table::Param<'a>> {
323        let name = param.name.as_ref();
324        let r#type = self.resolve_term(&param.r#type)?;
325
326        self.vars
327            .insert(param.name.as_ref())
328            .map_err(|_| ResolveError::DuplicateVar(param.name.clone()))?;
329
330        Ok(table::Param { name, r#type })
331    }
332
333    fn resolve_var(&self, var_name: &'a VarName) -> BuildResult<VarId> {
334        self.vars
335            .resolve(var_name.as_ref())
336            .map_err(|_| ResolveError::UnknownVar(var_name.clone()))
337    }
338
339    /// Resolves a symbol name and returns the node that introduces the symbol.
340    ///
341    /// When there is no symbol with this name in scope, we create a new import
342    /// node in the module and record that the symbol has been implicitly
343    /// imported. At the end of the building process, these import nodes are
344    /// inserted into the module's scope.
345    fn resolve_symbol_name(&mut self, symbol_name: &'a SymbolName) -> NodeId {
346        if let Ok(node) = self.symbols.resolve(symbol_name.as_ref()) {
347            return node;
348        }
349
350        *self.imports.entry(symbol_name.clone()).or_insert_with(|| {
351            self.module.insert_node(table::Node {
352                operation: table::Operation::Import {
353                    name: symbol_name.as_ref(),
354                },
355                ..Default::default()
356            })
357        })
358    }
359
360    pub fn finish(self) -> table::Module<'a> {
361        self.module
362    }
363}
364
365/// Error that may occur in [`Module::resolve`].
366#[derive(Debug, Clone, Error)]
367#[non_exhaustive]
368#[error("Error resolving model module")]
369pub enum ResolveError {
370    /// Unknown variable.
371    #[error("unknown var: {0}")]
372    UnknownVar(VarName),
373    /// Duplicate variable definition in the same symbol.
374    #[error("duplicate var: {0}")]
375    DuplicateVar(VarName),
376    /// Duplicate symbol definition in the same region.
377    #[error("duplicate symbol: {0}")]
378    DuplicateSymbol(SymbolName),
379}
380
381type BuildResult<T> = Result<T, ResolveError>;
382
383fn try_alloc_slice<T, E>(
384    bump: &Bump,
385    iter: impl IntoIterator<Item = Result<T, E>>,
386) -> Result<&[T], E> {
387    let iter = iter.into_iter();
388    let mut vec = BumpVec::with_capacity_in(iter.size_hint().0, bump);
389    for item in iter {
390        vec.push(item?);
391    }
392    Ok(vec.into_bump_slice())
393}
394
395#[cfg(test)]
396mod test {
397    use crate::v0::ast;
398    use bumpalo::Bump;
399    use std::str::FromStr as _;
400
401    #[test]
402    fn vars_in_root_scope() {
403        let text = "(hugr 0) (mod) (meta ?x)";
404        let ast = ast::Package::from_str(text).unwrap();
405        let bump = Bump::new();
406        assert!(ast.resolve(&bump).is_err());
407    }
408}