1use bumpalo::{Bump, collections::Vec as BumpVec};
2use fxhash::FxHashMap;
3use itertools::zip_eq;
4use thiserror::Error;
5
6use super::{
7 LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
8};
9use crate::v0::{RegionKind, ScopeClosure, table};
10use crate::v0::{
11 scope::{LinkTable, SymbolTable, VarTable},
12 table::{LinkIndex, NodeId, RegionId, TermId, VarId},
13};
14
15pub struct Context<'a> {
16 module: table::Module<'a>,
17 bump: &'a Bump,
18 vars: VarTable<'a>,
19 links: LinkTable<&'a str>,
20 symbols: SymbolTable<'a>,
21 imports: FxHashMap<SymbolName, NodeId>,
22 terms: FxHashMap<table::Term<'a>, TermId>,
23}
24
25impl<'a> Context<'a> {
26 pub fn new(bump: &'a Bump) -> Self {
27 Self {
28 module: table::Module::default(),
29 bump,
30 vars: VarTable::new(),
31 links: LinkTable::new(),
32 symbols: SymbolTable::new(),
33 imports: FxHashMap::default(),
34 terms: FxHashMap::default(),
35 }
36 }
37
38 pub fn resolve_module(&mut self, module: &'a Module) -> BuildResult<()> {
39 self.module.root = self.module.insert_region(table::Region::default());
40 self.symbols.enter(self.module.root);
41 self.links.enter(self.module.root);
42
43 let children = self.resolve_nodes(&module.root.children)?;
44 let meta = self.resolve_terms(&module.root.meta)?;
45
46 let (links, ports) = self.links.exit();
47 self.symbols.exit();
48 let scope = Some(table::RegionScope { links, ports });
49
50 let all_children = {
53 let mut all_children =
54 BumpVec::with_capacity_in(children.len() + self.imports.len(), self.bump);
55 all_children.extend(children);
56 all_children.extend(self.imports.drain().map(|(_, node)| node));
57 all_children.into_bump_slice()
58 };
59
60 self.module.regions[self.module.root.index()] = table::Region {
61 kind: RegionKind::Module,
62 sources: &[],
63 targets: &[],
64 children: all_children,
65 meta,
66 signature: None,
67 scope,
68 };
69
70 Ok(())
71 }
72
73 fn resolve_terms(&mut self, terms: &'a [Term]) -> BuildResult<&'a [TermId]> {
74 try_alloc_slice(self.bump, terms.iter().map(|term| self.resolve_term(term)))
75 }
76
77 fn resolve_term(&mut self, term: &'a Term) -> BuildResult<TermId> {
78 let term = match term {
79 Term::Wildcard => table::Term::Wildcard,
80 Term::Var(var_name) => table::Term::Var(self.resolve_var(var_name)?),
81 Term::Apply(symbol_name, terms) => {
82 let symbol_id = self.resolve_symbol_name(symbol_name);
83 let terms = self.resolve_terms(terms)?;
84 table::Term::Apply(symbol_id, terms)
85 }
86 Term::List(parts) => table::Term::List(self.resolve_seq_parts(parts)?),
87 Term::Literal(literal) => table::Term::Literal(literal.clone()),
88 Term::Tuple(parts) => table::Term::Tuple(self.resolve_seq_parts(parts)?),
89 Term::Func(region) => {
90 let region = self.resolve_region(region, ScopeClosure::Closed)?;
91 table::Term::Func(region)
92 }
93 };
94
95 Ok(*self
96 .terms
97 .entry(term.clone())
98 .or_insert_with(|| self.module.insert_term(term)))
99 }
100
101 fn resolve_seq_parts(&mut self, parts: &'a [SeqPart]) -> BuildResult<&'a [table::SeqPart]> {
102 try_alloc_slice(
103 self.bump,
104 parts.iter().map(|part| self.resolve_seq_part(part)),
105 )
106 }
107
108 fn resolve_seq_part(&mut self, part: &'a SeqPart) -> BuildResult<table::SeqPart> {
109 Ok(match part {
110 SeqPart::Item(term) => table::SeqPart::Item(self.resolve_term(term)?),
111 SeqPart::Splice(term) => table::SeqPart::Splice(self.resolve_term(term)?),
112 })
113 }
114
115 fn resolve_nodes(&mut self, nodes: &'a [Node]) -> BuildResult<&'a [NodeId]> {
116 let ids: &[_] = self.bump.alloc_slice_fill_with(nodes.len(), |_| {
118 self.module.insert_node(table::Node::default())
119 });
120
121 for (id, node) in zip_eq(ids, nodes) {
126 if let Some(symbol_name) = node.operation.symbol_name() {
127 self.symbols
128 .insert(symbol_name.as_ref(), *id)
129 .map_err(|_| ResolveError::DuplicateSymbol(symbol_name.clone()))?;
130 }
131 }
132
133 for (id, node) in zip_eq(ids, nodes) {
135 self.resolve_node(*id, node)?;
136 }
137
138 Ok(ids)
139 }
140
141 fn resolve_node(&mut self, node_id: NodeId, node: &'a Node) -> BuildResult<()> {
142 let inputs = self.resolve_links(&node.inputs)?;
143 let outputs = self.resolve_links(&node.outputs)?;
144
145 if node.operation.symbol().is_some() {
147 self.vars.enter(node_id);
148 }
149
150 let mut scope_closure = ScopeClosure::Open;
151
152 let operation = match &node.operation {
153 Operation::Invalid => table::Operation::Invalid,
154 Operation::Dfg => table::Operation::Dfg,
155 Operation::Cfg => table::Operation::Cfg,
156 Operation::Block => table::Operation::Block,
157 Operation::TailLoop => table::Operation::TailLoop,
158 Operation::Conditional => table::Operation::Conditional,
159 Operation::DefineFunc(symbol) => {
160 let symbol = self.resolve_symbol(symbol)?;
161 scope_closure = ScopeClosure::Closed;
162 table::Operation::DefineFunc(symbol)
163 }
164 Operation::DeclareFunc(symbol) => {
165 let symbol = self.resolve_symbol(symbol)?;
166 table::Operation::DeclareFunc(symbol)
167 }
168 Operation::DefineAlias(symbol, term) => {
169 let symbol = self.resolve_symbol(symbol)?;
170 let term = self.resolve_term(term)?;
171 table::Operation::DefineAlias(symbol, term)
172 }
173 Operation::DeclareAlias(symbol) => {
174 let symbol = self.resolve_symbol(symbol)?;
175 table::Operation::DeclareAlias(symbol)
176 }
177 Operation::DeclareConstructor(symbol) => {
178 let symbol = self.resolve_symbol(symbol)?;
179 table::Operation::DeclareConstructor(symbol)
180 }
181 Operation::DeclareOperation(symbol) => {
182 let symbol = self.resolve_symbol(symbol)?;
183 table::Operation::DeclareOperation(symbol)
184 }
185 Operation::Import(symbol_name) => table::Operation::Import {
186 name: symbol_name.as_ref(),
187 },
188 Operation::Custom(term) => {
189 let term = self.resolve_term(term)?;
190 table::Operation::Custom(term)
191 }
192 };
193
194 let meta = self.resolve_terms(&node.meta)?;
195 let regions = self.resolve_regions(&node.regions, scope_closure)?;
196
197 let signature = match &node.signature {
198 Some(signature) => Some(self.resolve_term(signature)?),
199 None => None,
200 };
201
202 if node.operation.symbol().is_some() {
204 self.vars.exit();
205 }
206
207 self.module.nodes[node_id.index()] = table::Node {
208 operation,
209 inputs,
210 outputs,
211 regions,
212 meta,
213 signature,
214 };
215
216 Ok(())
217 }
218
219 fn resolve_links(&mut self, links: &'a [LinkName]) -> BuildResult<&'a [LinkIndex]> {
220 try_alloc_slice(self.bump, links.iter().map(|link| self.resolve_link(link)))
221 }
222
223 fn resolve_link(&mut self, link: &'a LinkName) -> BuildResult<LinkIndex> {
224 Ok(self.links.use_link(link.as_ref()))
225 }
226
227 fn resolve_regions(
228 &mut self,
229 regions: &'a [Region],
230 scope_closure: ScopeClosure,
231 ) -> BuildResult<&'a [RegionId]> {
232 try_alloc_slice(
233 self.bump,
234 regions
235 .iter()
236 .map(|region| self.resolve_region(region, scope_closure)),
237 )
238 }
239
240 fn resolve_region(
241 &mut self,
242 region: &'a Region,
243 scope_closure: ScopeClosure,
244 ) -> BuildResult<RegionId> {
245 let meta = self.resolve_terms(®ion.meta)?;
246 let signature = match ®ion.signature {
247 Some(signature) => Some(self.resolve_term(signature)?),
248 None => None,
249 };
250
251 let region_id = self.module.insert_region(table::Region::default());
254
255 self.symbols.enter(region_id);
257
258 if ScopeClosure::Closed == scope_closure {
260 self.links.enter(region_id);
261 }
262
263 let sources = self.resolve_links(®ion.sources)?;
264 let targets = self.resolve_links(®ion.targets)?;
265 let children = self.resolve_nodes(®ion.children)?;
266
267 let scope = match scope_closure {
269 ScopeClosure::Open => None,
270 ScopeClosure::Closed => {
271 let (links, ports) = self.links.exit();
272 Some(table::RegionScope { links, ports })
273 }
274 };
275 self.symbols.exit();
276
277 self.module.regions[region_id.index()] = table::Region {
278 kind: region.kind,
279 sources,
280 targets,
281 children,
282 meta,
283 signature,
284 scope,
285 };
286
287 Ok(region_id)
288 }
289
290 fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> {
291 let name = symbol.name.as_ref();
292 let params = self.resolve_params(&symbol.params)?;
293 let constraints = self.resolve_terms(&symbol.constraints)?;
294 let signature = self.resolve_term(&symbol.signature)?;
295
296 Ok(self.bump.alloc(table::Symbol {
297 name,
298 params,
299 constraints,
300 signature,
301 }))
302 }
303
304 fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param<'a>]> {
310 try_alloc_slice(
311 self.bump,
312 params.iter().map(|param| self.resolve_param(param)),
313 )
314 }
315
316 fn resolve_param(&mut self, param: &'a Param) -> BuildResult<table::Param<'a>> {
321 let name = param.name.as_ref();
322 let r#type = self.resolve_term(¶m.r#type)?;
323
324 self.vars
325 .insert(param.name.as_ref())
326 .map_err(|_| ResolveError::DuplicateVar(param.name.clone()))?;
327
328 Ok(table::Param { name, r#type })
329 }
330
331 fn resolve_var(&self, var_name: &'a VarName) -> BuildResult<VarId> {
332 self.vars
333 .resolve(var_name.as_ref())
334 .map_err(|_| ResolveError::UnknownVar(var_name.clone()))
335 }
336
337 fn resolve_symbol_name(&mut self, symbol_name: &'a SymbolName) -> NodeId {
344 if let Ok(node) = self.symbols.resolve(symbol_name.as_ref()) {
345 return node;
346 }
347
348 *self.imports.entry(symbol_name.clone()).or_insert_with(|| {
349 self.module.insert_node(table::Node {
350 operation: table::Operation::Import {
351 name: symbol_name.as_ref(),
352 },
353 ..Default::default()
354 })
355 })
356 }
357
358 pub fn finish(self) -> table::Module<'a> {
359 self.module
360 }
361}
362
363#[derive(Debug, Clone, Error)]
365#[non_exhaustive]
366pub enum ResolveError {
367 #[error("unknown var: {0}")]
369 UnknownVar(VarName),
370 #[error("duplicate var: {0}")]
372 DuplicateVar(VarName),
373 #[error("duplicate symbol: {0}")]
375 DuplicateSymbol(SymbolName),
376}
377
378type BuildResult<T> = Result<T, ResolveError>;
379
380fn try_alloc_slice<T, E>(
381 bump: &Bump,
382 iter: impl IntoIterator<Item = Result<T, E>>,
383) -> Result<&[T], E> {
384 let iter = iter.into_iter();
385 let mut vec = BumpVec::with_capacity_in(iter.size_hint().0, bump);
386 for item in iter {
387 vec.push(item?);
388 }
389 Ok(vec.into_bump_slice())
390}