1use super::{BindMap, ImportsMap, PackageLoader};
2use crate::{
3 diagnostics::{Diagnostics, DiagnosticsConfig},
4 fun::{
5 parser::ParseBook, Adt, AdtCtr, Book, Definition, HvmDefinition, Name, Pattern, Source, SourceKind, Term,
6 },
7 imp::{self, Expr, MatchArm, Stmt},
8 imports::packages::Packages,
9 maybe_grow,
10};
11use indexmap::{map::Entry, IndexMap};
12use itertools::Itertools;
13
14impl ParseBook {
15 pub fn load_imports(
29 self,
30 mut loader: impl PackageLoader,
31 diag_config: DiagnosticsConfig,
32 ) -> Result<Book, Diagnostics> {
33 let diag = &mut Diagnostics::new(diag_config);
34 let pkgs = &mut Packages::new(self);
35 let mut book = pkgs.load_imports(&mut loader, diag)?;
38
39 book.apply_imports(None, diag, pkgs)?;
41 diag.fatal(())?;
42 eprint!("{}", diag);
43
44 let mut book = book.to_fun()?;
46
47 book.desugar_ctr_use();
49
50 Ok(book)
51 }
52
53 fn apply_imports(
56 &mut self,
57 main_imports: Option<&ImportsMap>,
58 diag: &mut Diagnostics,
59 pkgs: &mut Packages,
60 ) -> Result<(), Diagnostics> {
61 self.load_packages(main_imports, diag, pkgs)?;
62 self.apply_import_binds(main_imports, pkgs);
63 Ok(())
64 }
65
66 fn load_packages(
69 &mut self,
70 main_imports: Option<&ImportsMap>,
71 diag: &mut Diagnostics,
72 pkgs: &mut Packages,
73 ) -> Result<(), Diagnostics> {
74 let sources = self.import_ctx.sources().into_iter().cloned().collect_vec();
75
76 for src in sources {
77 let Some(package) = pkgs.books.swap_remove(&src) else { continue };
78 let mut package = package.into_inner();
79
80 let main_imports = main_imports.unwrap_or(&self.import_ctx.map);
83
84 package.apply_imports(Some(main_imports), diag, pkgs)?;
85
86 package.apply_adts(&src, main_imports);
88 package.apply_defs(&src, main_imports);
89
90 let Book { defs, hvm_defs, adts, .. } = package.to_fun()?;
91
92 for (name, adt) in adts {
95 let adts = pkgs.loaded_adts.entry(src.clone()).or_default();
96 adts.insert(name.clone(), adt.ctrs.keys().cloned().collect_vec());
97 self.add_imported_adt(name, adt, diag);
98 }
99
100 for def in defs.into_values() {
102 self.add_imported_def(def, diag);
103 }
104
105 for def in hvm_defs.into_values() {
107 self.add_imported_hvm_def(def, diag);
108 }
109 }
110 Ok(())
111 }
112
113 fn apply_import_binds(&mut self, main_imports: Option<&ImportsMap>, pkgs: &Packages) {
117 let main_imports = main_imports.unwrap_or(&self.import_ctx.map);
120 let mut local_imports = BindMap::new();
121 let mut adt_imports = BindMap::new();
122
123 'outer: for (bind, src) in self.import_ctx.map.binds.iter().rev() {
125 if self.contains_def(bind) | self.ctrs.contains_key(bind) | self.adts.contains_key(bind) {
126 continue;
130 }
131
132 let nam = if main_imports.contains_source(src) { src.clone() } else { Name::new(format!("__{}", src)) };
133
134 for pkg in self.import_ctx.sources() {
137 if let Some(book) = pkgs.loaded_adts.get(pkg) {
138 if let Some(ctrs) = book.get(&nam) {
139 for ctr in ctrs.iter().rev() {
140 let full_ctr_name = ctr.split("__").nth(1).unwrap_or(ctr.as_ref());
141 let ctr_name = full_ctr_name.strip_prefix(src.as_ref()).unwrap();
142 let bind = Name::new(format!("{}{}", bind, ctr_name));
143 local_imports.insert(bind, ctr.clone());
144 }
145 adt_imports.insert(bind.clone(), nam.clone());
147 continue 'outer;
148 }
149 }
150 }
151
152 local_imports.insert(bind.clone(), nam);
154 }
155
156 for (_, def) in self.local_defs_mut() {
157 def.apply_binds(true, &local_imports);
158 def.apply_type_binds(&adt_imports);
159 }
160 }
161
162 fn apply_adts(&mut self, src: &Name, main_imports: &ImportsMap) {
166 let adts = std::mem::take(&mut self.adts);
167 let mut new_adts = IndexMap::new();
168 let mut adts_map = vec![];
169 let mut ctrs_map = IndexMap::new();
170 let mut new_ctrs = IndexMap::new();
171
172 for (mut name, mut adt) in adts {
175 if adt.source.is_local() {
176 adt.source.kind = SourceKind::Imported;
177 let old_name = name.clone();
178 name = Name::new(format!("{}/{}", src, name));
179
180 let mangle_name = !main_imports.contains_source(&name);
181 let mut mangle_adt_name = mangle_name;
182
183 for (old_nam, ctr) in std::mem::take(&mut adt.ctrs) {
184 let mut ctr_name = Name::new(format!("{}/{}", src, old_nam));
185
186 let mangle_ctr = mangle_name && !main_imports.contains_source(&ctr_name);
187
188 if mangle_ctr {
189 mangle_adt_name = true;
190 ctr_name = Name::new(format!("__{}", ctr_name));
191 }
192
193 let ctr = AdtCtr { name: ctr_name.clone(), ..ctr };
194 new_ctrs.insert(ctr_name.clone(), name.clone());
195 ctrs_map.insert(old_nam, ctr_name.clone());
196 adt.ctrs.insert(ctr_name, ctr);
197 }
198
199 if mangle_adt_name {
200 name = Name::new(format!("__{}", name));
201 }
202
203 adt.name = name.clone();
204 adts_map.push((old_name, name.clone()));
205 }
206
207 new_adts.insert(name.clone(), adt);
208 }
209
210 for (_, adt) in &mut new_adts {
212 for (_, ctr) in &mut adt.ctrs {
213 for (from, to) in &adts_map {
214 ctr.typ.subst_ctr(from, to);
215 }
216 }
217 }
218
219 let adts_map = adts_map.into_iter().collect::<IndexMap<_, _>>();
220 for (_, def) in self.local_defs_mut() {
221 def.apply_binds(true, &ctrs_map);
223
224 def.apply_type_binds(&adts_map);
226 }
227
228 self.adts = new_adts;
229 self.ctrs = new_ctrs;
230 }
231
232 fn apply_defs(&mut self, src: &Name, main_imports: &ImportsMap) {
235 let mut canonical_map: IndexMap<_, _> = IndexMap::new();
236
237 for (_, def) in self.local_defs_mut() {
240 def.canonicalize_name(src, main_imports, &mut canonical_map);
241 }
242
243 for (_, def) in self.local_defs_mut() {
245 def.apply_binds(false, &canonical_map);
246 def.source_mut().kind = SourceKind::Imported;
247 }
248 }
249}
250
251impl ParseBook {
253 pub fn top_level_names(&self) -> impl Iterator<Item = &Name> {
254 let imp_defs = self.imp_defs.keys();
255 let fun_defs = self.fun_defs.keys();
256 let hvm_defs = self.hvm_defs.keys();
257 let adts = self.adts.keys();
258 let ctrs = self.ctrs.keys();
259
260 imp_defs.chain(fun_defs).chain(hvm_defs).chain(adts).chain(ctrs)
261 }
262
263 fn add_imported_adt(&mut self, nam: Name, adt: Adt, diag: &mut Diagnostics) {
264 if self.adts.get(&nam).is_some() {
265 let err = format!("The imported datatype '{nam}' conflicts with the datatype '{nam}'.");
266 diag.add_book_error(err);
267 } else {
268 for ctr in adt.ctrs.keys() {
269 if self.contains_def(ctr) {
270 let err = format!("The imported constructor '{ctr}' conflicts with the definition '{ctr}'.");
271 diag.add_book_error(err);
272 }
273 match self.ctrs.entry(ctr.clone()) {
274 Entry::Vacant(e) => _ = e.insert(nam.clone()),
275 Entry::Occupied(e) => {
276 let ctr = e.key();
277 let err = format!("The imported constructor '{ctr}' conflicts with the constructor '{ctr}'.");
278 diag.add_book_error(err);
279 }
280 }
281 }
282 self.adts.insert(nam, adt);
283 }
284 }
285
286 fn add_imported_def(&mut self, def: Definition, diag: &mut Diagnostics) {
287 if !self.has_def_conflict(&def.name, diag) {
288 self.fun_defs.insert(def.name.clone(), def);
289 }
290 }
291
292 fn add_imported_hvm_def(&mut self, def: HvmDefinition, diag: &mut Diagnostics) {
293 if !self.has_def_conflict(&def.name, diag) {
294 self.hvm_defs.insert(def.name.clone(), def);
295 }
296 }
297
298 fn has_def_conflict(&mut self, name: &Name, diag: &mut Diagnostics) -> bool {
299 if self.contains_def(name) {
300 let err = format!("The imported definition '{name}' conflicts with the definition '{name}'.");
301 diag.add_book_error(err);
302 true
303 } else if self.ctrs.contains_key(name) {
304 let err = format!("The imported definition '{name}' conflicts with the constructor '{name}'.");
305 diag.add_book_error(err);
306 true
307 } else {
308 false
309 }
310 }
311
312 fn local_defs_mut(&mut self) -> impl Iterator<Item = (&Name, &mut dyn Def)> {
313 let fun = self.fun_defs.iter_mut().map(|(nam, def)| (nam, def as &mut dyn Def));
314 let imp = self.imp_defs.iter_mut().map(|(nam, def)| (nam, def as &mut dyn Def));
315 let hvm = self.hvm_defs.iter_mut().map(|(nam, def)| (nam, def as &mut dyn Def));
316 fun.chain(imp).chain(hvm).filter(|(_, def)| def.source().is_local())
317 }
318}
319
320trait Def {
322 fn canonicalize_name(&mut self, src: &Name, main_imports: &ImportsMap, binds: &mut BindMap) {
323 let def_name = self.name_mut();
324 let mut new_name = Name::new(format!("{}/{}", src, def_name));
325
326 if !main_imports.contains_source(&new_name) {
327 new_name = Name::new(format!("__{}", new_name));
328 }
329
330 binds.insert(def_name.clone(), new_name.clone());
331 *def_name = new_name;
332 }
333
334 fn apply_binds(&mut self, maybe_constructor: bool, binds: &BindMap);
339
340 fn apply_type_binds(&mut self, binds: &BindMap);
341
342 fn source(&self) -> &Source;
343 fn source_mut(&mut self) -> &mut Source;
344 fn name_mut(&mut self) -> &mut Name;
345}
346
347impl Def for Definition {
348 fn apply_binds(&mut self, maybe_constructor: bool, binds: &BindMap) {
349 fn rename_ctr_pattern(pat: &mut Pattern, binds: &BindMap) {
350 for pat in pat.children_mut() {
351 rename_ctr_pattern(pat, binds);
352 }
353 match pat {
354 Pattern::Ctr(nam, _) => {
355 if let Some(alias) = binds.get(nam) {
356 *nam = alias.clone();
357 }
358 }
359 Pattern::Var(Some(nam)) => {
360 if let Some(alias) = binds.get(nam) {
361 *nam = alias.clone();
362 }
363 }
364 _ => {}
365 }
366 }
367
368 for rule in &mut self.rules {
369 if maybe_constructor {
370 for pat in &mut rule.pats {
371 rename_ctr_pattern(pat, binds);
372 }
373 }
374 let bod = std::mem::take(&mut rule.body);
375 rule.body = bod.fold_uses(binds.iter().rev());
376 }
377 }
378
379 fn apply_type_binds(&mut self, binds: &BindMap) {
380 for (from, to) in binds.iter().rev() {
381 self.typ.subst_ctr(from, to);
382 for rule in &mut self.rules {
383 rule.body.subst_type_ctrs(from, to);
384 }
385 }
386 }
387
388 fn source(&self) -> &Source {
389 &self.source
390 }
391
392 fn source_mut(&mut self) -> &mut Source {
393 &mut self.source
394 }
395
396 fn name_mut(&mut self) -> &mut Name {
397 &mut self.name
398 }
399}
400
401impl Def for imp::Definition {
402 fn apply_binds(&mut self, _maybe_constructor: bool, binds: &BindMap) {
403 let bod = std::mem::take(&mut self.body);
404 self.body = bod.fold_uses(binds.iter().rev());
405 }
406
407 fn apply_type_binds(&mut self, binds: &BindMap) {
408 fn subst_type_ctrs_stmt(stmt: &mut Stmt, from: &Name, to: &Name) {
409 maybe_grow(|| match stmt {
410 Stmt::Assign { nxt, .. } => {
411 if let Some(nxt) = nxt {
412 subst_type_ctrs_stmt(nxt, from, to);
413 }
414 }
415 Stmt::InPlace { nxt, .. } => {
416 subst_type_ctrs_stmt(nxt, from, to);
417 }
418 Stmt::If { then, otherwise, nxt, .. } => {
419 subst_type_ctrs_stmt(then, from, to);
420 subst_type_ctrs_stmt(otherwise, from, to);
421 if let Some(nxt) = nxt {
422 subst_type_ctrs_stmt(nxt, from, to);
423 }
424 }
425 Stmt::Match { arms, nxt, .. } => {
426 for MatchArm { lft: _, rgt } in arms {
427 subst_type_ctrs_stmt(rgt, from, to);
428 }
429 if let Some(nxt) = nxt {
430 subst_type_ctrs_stmt(nxt, from, to);
431 }
432 }
433 Stmt::Switch { arms, nxt, .. } => {
434 for arm in arms {
435 subst_type_ctrs_stmt(arm, from, to);
436 }
437 if let Some(nxt) = nxt {
438 subst_type_ctrs_stmt(nxt, from, to);
439 }
440 }
441 Stmt::Bend { step, base, nxt, .. } => {
442 subst_type_ctrs_stmt(step, from, to);
443 subst_type_ctrs_stmt(base, from, to);
444 if let Some(nxt) = nxt {
445 subst_type_ctrs_stmt(nxt, from, to);
446 }
447 }
448 Stmt::Fold { arms, nxt, .. } => {
449 for MatchArm { lft: _, rgt } in arms {
450 subst_type_ctrs_stmt(rgt, from, to);
451 }
452 if let Some(nxt) = nxt {
453 subst_type_ctrs_stmt(nxt, from, to);
454 }
455 }
456 Stmt::With { typ, bod, nxt } => {
457 if typ == from {
458 *typ = to.clone();
459 }
460 subst_type_ctrs_stmt(bod, from, to);
461 if let Some(nxt) = nxt {
462 subst_type_ctrs_stmt(nxt, from, to);
463 }
464 }
465 Stmt::Ask { nxt, .. } => {
466 if let Some(nxt) = nxt {
467 subst_type_ctrs_stmt(nxt, from, to);
468 }
469 }
470 Stmt::Return { .. } => {}
471 Stmt::Open { typ, nxt, .. } => {
472 if typ == from {
473 *typ = to.clone();
474 }
475 subst_type_ctrs_stmt(nxt, from, to);
476 }
477 Stmt::Use { nxt, .. } => {
478 subst_type_ctrs_stmt(nxt, from, to);
479 }
480 Stmt::LocalDef { def, nxt } => {
481 def.apply_type_binds(&[(from.clone(), to.clone())].into_iter().collect());
482 subst_type_ctrs_stmt(nxt, from, to);
483 }
484 Stmt::Err => {}
485 })
486 }
487 for (from, to) in binds.iter().rev() {
488 self.typ.subst_ctr(from, to);
489 subst_type_ctrs_stmt(&mut self.body, from, to);
490 }
491 }
492
493 fn source(&self) -> &Source {
494 &self.source
495 }
496
497 fn source_mut(&mut self) -> &mut Source {
498 &mut self.source
499 }
500
501 fn name_mut(&mut self) -> &mut Name {
502 &mut self.name
503 }
504}
505
506impl Def for HvmDefinition {
507 fn apply_binds(&mut self, _maybe_constructor: bool, _binds: &BindMap) {}
509
510 fn apply_type_binds(&mut self, binds: &BindMap) {
511 for (from, to) in binds.iter().rev() {
512 self.typ.subst_ctr(from, to);
513 }
514 }
515
516 fn source(&self) -> &Source {
517 &self.source
518 }
519
520 fn source_mut(&mut self) -> &mut Source {
521 &mut self.source
522 }
523
524 fn name_mut(&mut self) -> &mut Name {
525 &mut self.name
526 }
527
528 fn canonicalize_name(&mut self, src: &Name, main_imports: &ImportsMap, binds: &mut BindMap) {
529 let def_name = self.name_mut();
530 let mut new_name = Name::new(std::format!("{}/{}", src, def_name));
531
532 if !main_imports.contains_source(&new_name) {
533 new_name = Name::new(std::format!("__{}", new_name));
534 }
535
536 binds.insert(def_name.clone(), new_name.clone());
537 *def_name = new_name;
538 }
539}
540
541impl Term {
542 fn fold_uses<'a>(self, map: impl Iterator<Item = (&'a Name, &'a Name)>) -> Self {
543 map.fold(self, |acc, (bind, nam)| Self::Use {
544 nam: Some(bind.clone()),
545 val: Box::new(Self::Var { nam: nam.clone() }),
546 nxt: Box::new(acc),
547 })
548 }
549}
550
551impl Stmt {
552 fn fold_uses<'a>(self, map: impl Iterator<Item = (&'a Name, &'a Name)>) -> Self {
553 map.fold(self, |acc, (bind, nam)| Self::Use {
554 nam: bind.clone(),
555 val: Box::new(Expr::Var { nam: nam.clone() }),
556 nxt: Box::new(acc),
557 })
558 }
559}