hugr_model/v0/ast/
resolve.rs

1use bumpalo::{collections::Vec as BumpVec, Bump};
2use fxhash::FxHashMap;
3use itertools::zip_eq;
4use thiserror::Error;
5
6use super::{
7    LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
8};
9use crate::v0::{
10    scope::{LinkTable, SymbolTable, VarTable},
11    table::{LinkIndex, NodeId, RegionId, TermId, VarId},
12};
13use crate::v0::{table, RegionKind, ScopeClosure};
14
15pub struct Context<'a> {
16    module: table::Module<'a>,
17    bump: &'a Bump,
18    vars: VarTable<'a>,
19    links: LinkTable<&'a str>,
20    symbols: SymbolTable<'a>,
21    imports: FxHashMap<SymbolName, NodeId>,
22}
23
24impl<'a> Context<'a> {
25    pub fn new(bump: &'a Bump) -> Self {
26        Self {
27            module: table::Module::default(),
28            bump,
29            vars: VarTable::new(),
30            links: LinkTable::new(),
31            symbols: SymbolTable::new(),
32            imports: FxHashMap::default(),
33        }
34    }
35
36    pub fn resolve_module(&mut self, module: &'a Module) -> BuildResult<()> {
37        self.module.root = self.module.insert_region(table::Region::default());
38        self.symbols.enter(self.module.root);
39        self.links.enter(self.module.root);
40
41        let children = self.resolve_nodes(&module.root.children)?;
42        let meta = self.resolve_terms(&module.root.meta)?;
43
44        let (links, ports) = self.links.exit();
45        self.symbols.exit();
46        let scope = Some(table::RegionScope { links, ports });
47
48        // Symbols that could not be resolved within the module still need to
49        // be represented by a node. This is why we add import nodes.
50        let all_children = {
51            let mut all_children =
52                BumpVec::with_capacity_in(children.len() + self.imports.len(), self.bump);
53            all_children.extend(children);
54            all_children.extend(self.imports.drain().map(|(_, node)| node));
55            all_children.into_bump_slice()
56        };
57
58        self.module.regions[self.module.root.index()] = table::Region {
59            kind: RegionKind::Module,
60            sources: &[],
61            targets: &[],
62            children: all_children,
63            meta,
64            signature: None,
65            scope,
66        };
67
68        Ok(())
69    }
70
71    fn resolve_terms(&mut self, terms: &'a [Term]) -> BuildResult<&'a [TermId]> {
72        try_alloc_slice(self.bump, terms.iter().map(|term| self.resolve_term(term)))
73    }
74
75    fn resolve_term(&mut self, term: &'a Term) -> BuildResult<TermId> {
76        let term = match term {
77            Term::Wildcard => table::Term::Wildcard,
78            Term::Var(var_name) => table::Term::Var(self.resolve_var(var_name)?),
79            Term::Apply(symbol_name, terms) => {
80                let symbol_id = self.resolve_symbol_name(symbol_name);
81                let terms = self.resolve_terms(terms)?;
82                table::Term::Apply(symbol_id, terms)
83            }
84            Term::List(parts) => table::Term::List(self.resolve_seq_parts(parts)?),
85            Term::Literal(literal) => table::Term::Literal(literal.clone()),
86            Term::Tuple(parts) => table::Term::Tuple(self.resolve_seq_parts(parts)?),
87            Term::Func(region) => {
88                let region = self.resolve_region(region, ScopeClosure::Closed)?;
89                table::Term::ConstFunc(region)
90            }
91            Term::ExtSet => table::Term::ExtSet(&[]),
92        };
93
94        Ok(self.module.insert_term(term))
95    }
96
97    fn resolve_seq_parts(&mut self, parts: &'a [SeqPart]) -> BuildResult<&'a [table::SeqPart]> {
98        try_alloc_slice(
99            self.bump,
100            parts.iter().map(|part| self.resolve_seq_part(part)),
101        )
102    }
103
104    fn resolve_seq_part(&mut self, part: &'a SeqPart) -> BuildResult<table::SeqPart> {
105        Ok(match part {
106            SeqPart::Item(term) => table::SeqPart::Item(self.resolve_term(term)?),
107            SeqPart::Splice(term) => table::SeqPart::Splice(self.resolve_term(term)?),
108        })
109    }
110
111    fn resolve_nodes(&mut self, nodes: &'a [Node]) -> BuildResult<&'a [NodeId]> {
112        // Allocate ids for all nodes by introducing placeholders into the module.
113        let ids: &[_] = self.bump.alloc_slice_fill_with(nodes.len(), |_| {
114            self.module.insert_node(table::Node::default())
115        });
116
117        // For those nodes that introduce symbols, we then associate the symbol
118        // with the id of the node. This serves as a form of forward declaration
119        // so that the symbol is visible in the current region regardless of the
120        // order of the nodes.
121        for (id, node) in zip_eq(ids, nodes) {
122            if let Some(symbol_name) = node.operation.symbol_name() {
123                self.symbols
124                    .insert(symbol_name.as_ref(), *id)
125                    .map_err(|_| ResolveError::DuplicateSymbol(symbol_name.clone()))?;
126            }
127        }
128
129        // Finally we can build the actual nodes.
130        for (id, node) in zip_eq(ids, nodes) {
131            self.resolve_node(*id, node)?;
132        }
133
134        Ok(ids)
135    }
136
137    fn resolve_node(&mut self, node_id: NodeId, node: &'a Node) -> BuildResult<()> {
138        let inputs = self.resolve_links(&node.inputs)?;
139        let outputs = self.resolve_links(&node.outputs)?;
140
141        // When the node introduces a symbol it also introduces a new variable scope.
142        if node.operation.symbol().is_some() {
143            self.vars.enter(node_id);
144        }
145
146        let mut scope_closure = ScopeClosure::Open;
147
148        let operation = match &node.operation {
149            Operation::Invalid => table::Operation::Invalid,
150            Operation::Dfg => table::Operation::Dfg,
151            Operation::Cfg => table::Operation::Cfg,
152            Operation::Block => table::Operation::Block,
153            Operation::TailLoop => table::Operation::TailLoop,
154            Operation::Conditional => table::Operation::Conditional,
155            Operation::DefineFunc(symbol) => {
156                let symbol = self.resolve_symbol(symbol)?;
157                scope_closure = ScopeClosure::Closed;
158                table::Operation::DefineFunc(symbol)
159            }
160            Operation::DeclareFunc(symbol) => {
161                let symbol = self.resolve_symbol(symbol)?;
162                table::Operation::DeclareFunc(symbol)
163            }
164            Operation::DefineAlias(symbol, term) => {
165                let symbol = self.resolve_symbol(symbol)?;
166                let term = self.resolve_term(term)?;
167                table::Operation::DefineAlias(symbol, term)
168            }
169            Operation::DeclareAlias(symbol) => {
170                let symbol = self.resolve_symbol(symbol)?;
171                table::Operation::DeclareAlias(symbol)
172            }
173            Operation::DeclareConstructor(symbol) => {
174                let symbol = self.resolve_symbol(symbol)?;
175                table::Operation::DeclareConstructor(symbol)
176            }
177            Operation::DeclareOperation(symbol) => {
178                let symbol = self.resolve_symbol(symbol)?;
179                table::Operation::DeclareOperation(symbol)
180            }
181            Operation::Import(symbol_name) => table::Operation::Import {
182                name: symbol_name.as_ref(),
183            },
184            Operation::Custom(term) => {
185                let term = self.resolve_term(term)?;
186                table::Operation::Custom(term)
187            }
188        };
189
190        let meta = self.resolve_terms(&node.meta)?;
191        let regions = self.resolve_regions(&node.regions, scope_closure)?;
192
193        let signature = match &node.signature {
194            Some(signature) => Some(self.resolve_term(signature)?),
195            None => None,
196        };
197
198        // We need to close the variable scope if we have opened one before.
199        if node.operation.symbol().is_some() {
200            self.vars.exit();
201        }
202
203        self.module.nodes[node_id.index()] = table::Node {
204            operation,
205            inputs,
206            outputs,
207            regions,
208            meta,
209            signature,
210        };
211
212        Ok(())
213    }
214
215    fn resolve_links(&mut self, links: &'a [LinkName]) -> BuildResult<&'a [LinkIndex]> {
216        try_alloc_slice(self.bump, links.iter().map(|link| self.resolve_link(link)))
217    }
218
219    fn resolve_link(&mut self, link: &'a LinkName) -> BuildResult<LinkIndex> {
220        Ok(self.links.use_link(link.as_ref()))
221    }
222
223    fn resolve_regions(
224        &mut self,
225        regions: &'a [Region],
226        scope_closure: ScopeClosure,
227    ) -> BuildResult<&'a [RegionId]> {
228        try_alloc_slice(
229            self.bump,
230            regions
231                .iter()
232                .map(|region| self.resolve_region(region, scope_closure)),
233        )
234    }
235
236    fn resolve_region(
237        &mut self,
238        region: &'a Region,
239        scope_closure: ScopeClosure,
240    ) -> BuildResult<RegionId> {
241        let meta = self.resolve_terms(&region.meta)?;
242        let signature = match &region.signature {
243            Some(signature) => Some(self.resolve_term(signature)?),
244            None => None,
245        };
246
247        // We insert a placeholder for the region in order to allocate a region
248        // id, which we need to track the region's scopes.
249        let region_id = self.module.insert_region(table::Region::default());
250
251        // Each region defines a new scope for symbols.
252        self.symbols.enter(region_id);
253
254        // If the region is closed, it also defines a new scope for links.
255        if ScopeClosure::Closed == scope_closure {
256            self.links.enter(region_id);
257        }
258
259        let sources = self.resolve_links(&region.sources)?;
260        let targets = self.resolve_links(&region.targets)?;
261        let children = self.resolve_nodes(&region.children)?;
262
263        // Close the region's scopes.
264        let scope = match scope_closure {
265            ScopeClosure::Open => None,
266            ScopeClosure::Closed => {
267                let (links, ports) = self.links.exit();
268                Some(table::RegionScope { links, ports })
269            }
270        };
271        self.symbols.exit();
272
273        self.module.regions[region_id.index()] = table::Region {
274            kind: region.kind,
275            sources,
276            targets,
277            children,
278            meta,
279            signature,
280            scope,
281        };
282
283        Ok(region_id)
284    }
285
286    fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> {
287        let name = symbol.name.as_ref();
288        let params = self.resolve_params(&symbol.params)?;
289        let constraints = self.resolve_terms(&symbol.constraints)?;
290        let signature = self.resolve_term(&symbol.signature)?;
291
292        Ok(self.bump.alloc(table::Symbol {
293            name,
294            params,
295            constraints,
296            signature,
297        }))
298    }
299
300    /// Builds symbol parameters.
301    ///
302    /// This incrementally inserts the names of the parameters into the current
303    /// variable scope, so that any parameter is in scope for each of its
304    /// succeeding parameters.
305    fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param<'a>]> {
306        try_alloc_slice(
307            self.bump,
308            params.iter().map(|param| self.resolve_param(param)),
309        )
310    }
311
312    /// Builds a symbol parameter.
313    ///
314    /// This inserts the name of the parameter into the current variable scope,
315    /// making the parameter accessible as a variable.
316    fn resolve_param(&mut self, param: &'a Param) -> BuildResult<table::Param<'a>> {
317        let name = param.name.as_ref();
318        let r#type = self.resolve_term(&param.r#type)?;
319
320        self.vars
321            .insert(param.name.as_ref())
322            .map_err(|_| ResolveError::DuplicateVar(param.name.clone()))?;
323
324        Ok(table::Param { name, r#type })
325    }
326
327    fn resolve_var(&self, var_name: &'a VarName) -> BuildResult<VarId> {
328        self.vars
329            .resolve(var_name.as_ref())
330            .map_err(|_| ResolveError::UnknownVar(var_name.clone()))
331    }
332
333    /// Resolves a symbol name and returns the node that introduces the symbol.
334    ///
335    /// When there is no symbol with this name in scope, we create a new import
336    /// node in the module and record that the symbol has been implicitly
337    /// imported. At the end of the building process, these import nodes are
338    /// inserted into the module's scope.
339    fn resolve_symbol_name(&mut self, symbol_name: &'a SymbolName) -> NodeId {
340        if let Ok(node) = self.symbols.resolve(symbol_name.as_ref()) {
341            return node;
342        }
343
344        *self.imports.entry(symbol_name.clone()).or_insert_with(|| {
345            self.module.insert_node(table::Node {
346                operation: table::Operation::Import {
347                    name: symbol_name.as_ref(),
348                },
349                ..Default::default()
350            })
351        })
352    }
353
354    pub fn finish(self) -> table::Module<'a> {
355        self.module
356    }
357}
358
359/// Error that may occur in [`Module::resolve`].
360#[derive(Debug, Clone, Error)]
361pub enum ResolveError {
362    /// Unknown variable.
363    #[error("unknown var: {0}")]
364    UnknownVar(VarName),
365    /// Duplicate variable definition in the same symbol.
366    #[error("duplicate var: {0}")]
367    DuplicateVar(VarName),
368    /// Duplicate symbol definition in the same region.
369    #[error("duplicate symbol: {0}")]
370    DuplicateSymbol(SymbolName),
371}
372
373type BuildResult<T> = Result<T, ResolveError>;
374
375fn try_alloc_slice<T, E>(
376    bump: &Bump,
377    iter: impl IntoIterator<Item = Result<T, E>>,
378) -> Result<&[T], E> {
379    let iter = iter.into_iter();
380    let mut vec = BumpVec::with_capacity_in(iter.size_hint().0, bump);
381    for item in iter {
382        vec.push(item?);
383    }
384    Ok(vec.into_bump_slice())
385}