1use bumpalo::{collections::Vec as BumpVec, Bump};
2use fxhash::FxHashMap;
3use itertools::zip_eq;
4use thiserror::Error;
5
6use super::{
7 LinkName, Module, Node, Operation, Param, Region, SeqPart, Symbol, SymbolName, Term, VarName,
8};
9use crate::v0::{
10 scope::{LinkTable, SymbolTable, VarTable},
11 table::{LinkIndex, NodeId, RegionId, TermId, VarId},
12};
13use crate::v0::{table, RegionKind, ScopeClosure};
14
15pub struct Context<'a> {
16 module: table::Module<'a>,
17 bump: &'a Bump,
18 vars: VarTable<'a>,
19 links: LinkTable<&'a str>,
20 symbols: SymbolTable<'a>,
21 imports: FxHashMap<SymbolName, NodeId>,
22}
23
24impl<'a> Context<'a> {
25 pub fn new(bump: &'a Bump) -> Self {
26 Self {
27 module: table::Module::default(),
28 bump,
29 vars: VarTable::new(),
30 links: LinkTable::new(),
31 symbols: SymbolTable::new(),
32 imports: FxHashMap::default(),
33 }
34 }
35
36 pub fn resolve_module(&mut self, module: &'a Module) -> BuildResult<()> {
37 self.module.root = self.module.insert_region(table::Region::default());
38 self.symbols.enter(self.module.root);
39 self.links.enter(self.module.root);
40
41 let children = self.resolve_nodes(&module.root.children)?;
42 let meta = self.resolve_terms(&module.root.meta)?;
43
44 let (links, ports) = self.links.exit();
45 self.symbols.exit();
46 let scope = Some(table::RegionScope { links, ports });
47
48 let all_children = {
51 let mut all_children =
52 BumpVec::with_capacity_in(children.len() + self.imports.len(), self.bump);
53 all_children.extend(children);
54 all_children.extend(self.imports.drain().map(|(_, node)| node));
55 all_children.into_bump_slice()
56 };
57
58 self.module.regions[self.module.root.index()] = table::Region {
59 kind: RegionKind::Module,
60 sources: &[],
61 targets: &[],
62 children: all_children,
63 meta,
64 signature: None,
65 scope,
66 };
67
68 Ok(())
69 }
70
71 fn resolve_terms(&mut self, terms: &'a [Term]) -> BuildResult<&'a [TermId]> {
72 try_alloc_slice(self.bump, terms.iter().map(|term| self.resolve_term(term)))
73 }
74
75 fn resolve_term(&mut self, term: &'a Term) -> BuildResult<TermId> {
76 let term = match term {
77 Term::Wildcard => table::Term::Wildcard,
78 Term::Var(var_name) => table::Term::Var(self.resolve_var(var_name)?),
79 Term::Apply(symbol_name, terms) => {
80 let symbol_id = self.resolve_symbol_name(symbol_name);
81 let terms = self.resolve_terms(terms)?;
82 table::Term::Apply(symbol_id, terms)
83 }
84 Term::List(parts) => table::Term::List(self.resolve_seq_parts(parts)?),
85 Term::Literal(literal) => table::Term::Literal(literal.clone()),
86 Term::Tuple(parts) => table::Term::Tuple(self.resolve_seq_parts(parts)?),
87 Term::Func(region) => {
88 let region = self.resolve_region(region, ScopeClosure::Closed)?;
89 table::Term::ConstFunc(region)
90 }
91 Term::ExtSet => table::Term::ExtSet(&[]),
92 };
93
94 Ok(self.module.insert_term(term))
95 }
96
97 fn resolve_seq_parts(&mut self, parts: &'a [SeqPart]) -> BuildResult<&'a [table::SeqPart]> {
98 try_alloc_slice(
99 self.bump,
100 parts.iter().map(|part| self.resolve_seq_part(part)),
101 )
102 }
103
104 fn resolve_seq_part(&mut self, part: &'a SeqPart) -> BuildResult<table::SeqPart> {
105 Ok(match part {
106 SeqPart::Item(term) => table::SeqPart::Item(self.resolve_term(term)?),
107 SeqPart::Splice(term) => table::SeqPart::Splice(self.resolve_term(term)?),
108 })
109 }
110
111 fn resolve_nodes(&mut self, nodes: &'a [Node]) -> BuildResult<&'a [NodeId]> {
112 let ids: &[_] = self.bump.alloc_slice_fill_with(nodes.len(), |_| {
114 self.module.insert_node(table::Node::default())
115 });
116
117 for (id, node) in zip_eq(ids, nodes) {
122 if let Some(symbol_name) = node.operation.symbol_name() {
123 self.symbols
124 .insert(symbol_name.as_ref(), *id)
125 .map_err(|_| ResolveError::DuplicateSymbol(symbol_name.clone()))?;
126 }
127 }
128
129 for (id, node) in zip_eq(ids, nodes) {
131 self.resolve_node(*id, node)?;
132 }
133
134 Ok(ids)
135 }
136
137 fn resolve_node(&mut self, node_id: NodeId, node: &'a Node) -> BuildResult<()> {
138 let inputs = self.resolve_links(&node.inputs)?;
139 let outputs = self.resolve_links(&node.outputs)?;
140
141 if node.operation.symbol().is_some() {
143 self.vars.enter(node_id);
144 }
145
146 let mut scope_closure = ScopeClosure::Open;
147
148 let operation = match &node.operation {
149 Operation::Invalid => table::Operation::Invalid,
150 Operation::Dfg => table::Operation::Dfg,
151 Operation::Cfg => table::Operation::Cfg,
152 Operation::Block => table::Operation::Block,
153 Operation::TailLoop => table::Operation::TailLoop,
154 Operation::Conditional => table::Operation::Conditional,
155 Operation::DefineFunc(symbol) => {
156 let symbol = self.resolve_symbol(symbol)?;
157 scope_closure = ScopeClosure::Closed;
158 table::Operation::DefineFunc(symbol)
159 }
160 Operation::DeclareFunc(symbol) => {
161 let symbol = self.resolve_symbol(symbol)?;
162 table::Operation::DeclareFunc(symbol)
163 }
164 Operation::DefineAlias(symbol, term) => {
165 let symbol = self.resolve_symbol(symbol)?;
166 let term = self.resolve_term(term)?;
167 table::Operation::DefineAlias(symbol, term)
168 }
169 Operation::DeclareAlias(symbol) => {
170 let symbol = self.resolve_symbol(symbol)?;
171 table::Operation::DeclareAlias(symbol)
172 }
173 Operation::DeclareConstructor(symbol) => {
174 let symbol = self.resolve_symbol(symbol)?;
175 table::Operation::DeclareConstructor(symbol)
176 }
177 Operation::DeclareOperation(symbol) => {
178 let symbol = self.resolve_symbol(symbol)?;
179 table::Operation::DeclareOperation(symbol)
180 }
181 Operation::Import(symbol_name) => table::Operation::Import {
182 name: symbol_name.as_ref(),
183 },
184 Operation::Custom(term) => {
185 let term = self.resolve_term(term)?;
186 table::Operation::Custom(term)
187 }
188 };
189
190 let meta = self.resolve_terms(&node.meta)?;
191 let regions = self.resolve_regions(&node.regions, scope_closure)?;
192
193 let signature = match &node.signature {
194 Some(signature) => Some(self.resolve_term(signature)?),
195 None => None,
196 };
197
198 if node.operation.symbol().is_some() {
200 self.vars.exit();
201 }
202
203 self.module.nodes[node_id.index()] = table::Node {
204 operation,
205 inputs,
206 outputs,
207 regions,
208 meta,
209 signature,
210 };
211
212 Ok(())
213 }
214
215 fn resolve_links(&mut self, links: &'a [LinkName]) -> BuildResult<&'a [LinkIndex]> {
216 try_alloc_slice(self.bump, links.iter().map(|link| self.resolve_link(link)))
217 }
218
219 fn resolve_link(&mut self, link: &'a LinkName) -> BuildResult<LinkIndex> {
220 Ok(self.links.use_link(link.as_ref()))
221 }
222
223 fn resolve_regions(
224 &mut self,
225 regions: &'a [Region],
226 scope_closure: ScopeClosure,
227 ) -> BuildResult<&'a [RegionId]> {
228 try_alloc_slice(
229 self.bump,
230 regions
231 .iter()
232 .map(|region| self.resolve_region(region, scope_closure)),
233 )
234 }
235
236 fn resolve_region(
237 &mut self,
238 region: &'a Region,
239 scope_closure: ScopeClosure,
240 ) -> BuildResult<RegionId> {
241 let meta = self.resolve_terms(®ion.meta)?;
242 let signature = match ®ion.signature {
243 Some(signature) => Some(self.resolve_term(signature)?),
244 None => None,
245 };
246
247 let region_id = self.module.insert_region(table::Region::default());
250
251 self.symbols.enter(region_id);
253
254 if ScopeClosure::Closed == scope_closure {
256 self.links.enter(region_id);
257 }
258
259 let sources = self.resolve_links(®ion.sources)?;
260 let targets = self.resolve_links(®ion.targets)?;
261 let children = self.resolve_nodes(®ion.children)?;
262
263 let scope = match scope_closure {
265 ScopeClosure::Open => None,
266 ScopeClosure::Closed => {
267 let (links, ports) = self.links.exit();
268 Some(table::RegionScope { links, ports })
269 }
270 };
271 self.symbols.exit();
272
273 self.module.regions[region_id.index()] = table::Region {
274 kind: region.kind,
275 sources,
276 targets,
277 children,
278 meta,
279 signature,
280 scope,
281 };
282
283 Ok(region_id)
284 }
285
286 fn resolve_symbol(&mut self, symbol: &'a Symbol) -> BuildResult<&'a table::Symbol<'a>> {
287 let name = symbol.name.as_ref();
288 let params = self.resolve_params(&symbol.params)?;
289 let constraints = self.resolve_terms(&symbol.constraints)?;
290 let signature = self.resolve_term(&symbol.signature)?;
291
292 Ok(self.bump.alloc(table::Symbol {
293 name,
294 params,
295 constraints,
296 signature,
297 }))
298 }
299
300 fn resolve_params(&mut self, params: &'a [Param]) -> BuildResult<&'a [table::Param<'a>]> {
306 try_alloc_slice(
307 self.bump,
308 params.iter().map(|param| self.resolve_param(param)),
309 )
310 }
311
312 fn resolve_param(&mut self, param: &'a Param) -> BuildResult<table::Param<'a>> {
317 let name = param.name.as_ref();
318 let r#type = self.resolve_term(¶m.r#type)?;
319
320 self.vars
321 .insert(param.name.as_ref())
322 .map_err(|_| ResolveError::DuplicateVar(param.name.clone()))?;
323
324 Ok(table::Param { name, r#type })
325 }
326
327 fn resolve_var(&self, var_name: &'a VarName) -> BuildResult<VarId> {
328 self.vars
329 .resolve(var_name.as_ref())
330 .map_err(|_| ResolveError::UnknownVar(var_name.clone()))
331 }
332
333 fn resolve_symbol_name(&mut self, symbol_name: &'a SymbolName) -> NodeId {
340 if let Ok(node) = self.symbols.resolve(symbol_name.as_ref()) {
341 return node;
342 }
343
344 *self.imports.entry(symbol_name.clone()).or_insert_with(|| {
345 self.module.insert_node(table::Node {
346 operation: table::Operation::Import {
347 name: symbol_name.as_ref(),
348 },
349 ..Default::default()
350 })
351 })
352 }
353
354 pub fn finish(self) -> table::Module<'a> {
355 self.module
356 }
357}
358
359#[derive(Debug, Clone, Error)]
361pub enum ResolveError {
362 #[error("unknown var: {0}")]
364 UnknownVar(VarName),
365 #[error("duplicate var: {0}")]
367 DuplicateVar(VarName),
368 #[error("duplicate symbol: {0}")]
370 DuplicateSymbol(SymbolName),
371}
372
373type BuildResult<T> = Result<T, ResolveError>;
374
375fn try_alloc_slice<T, E>(
376 bump: &Bump,
377 iter: impl IntoIterator<Item = Result<T, E>>,
378) -> Result<&[T], E> {
379 let iter = iter.into_iter();
380 let mut vec = BumpVec::with_capacity_in(iter.size_hint().0, bump);
381 for item in iter {
382 vec.push(item?);
383 }
384 Ok(vec.into_bump_slice())
385}