use super::{BindMap, ImportsMap, PackageLoader};
use crate::{
diagnostics::{Diagnostics, DiagnosticsConfig},
fun::{
parser::ParseBook, Adt, AdtCtr, Book, Definition, HvmDefinition, Name, Pattern, Source, SourceKind, Term,
},
imp::{self, Expr, MatchArm, Stmt},
imports::packages::Packages,
maybe_grow,
};
use indexmap::{map::Entry, IndexMap};
use itertools::Itertools;
impl ParseBook {
pub fn load_imports(
self,
mut loader: impl PackageLoader,
diag_config: DiagnosticsConfig,
) -> Result<Book, Diagnostics> {
let diag = &mut Diagnostics::new(diag_config);
let pkgs = &mut Packages::new(self);
let mut book = pkgs.load_imports(&mut loader, diag)?;
book.apply_imports(None, diag, pkgs)?;
diag.fatal(())?;
eprint!("{}", diag);
let mut book = book.to_fun()?;
book.desugar_ctr_use();
Ok(book)
}
fn apply_imports(
&mut self,
main_imports: Option<&ImportsMap>,
diag: &mut Diagnostics,
pkgs: &mut Packages,
) -> Result<(), Diagnostics> {
self.load_packages(main_imports, diag, pkgs)?;
self.apply_import_binds(main_imports, pkgs);
Ok(())
}
fn load_packages(
&mut self,
main_imports: Option<&ImportsMap>,
diag: &mut Diagnostics,
pkgs: &mut Packages,
) -> Result<(), Diagnostics> {
let sources = self.import_ctx.sources().into_iter().cloned().collect_vec();
for src in sources {
let Some(package) = pkgs.books.swap_remove(&src) else { continue };
let mut package = package.into_inner();
let main_imports = main_imports.unwrap_or(&self.import_ctx.map);
package.apply_imports(Some(main_imports), diag, pkgs)?;
package.apply_adts(&src, main_imports);
package.apply_defs(&src, main_imports);
let Book { defs, hvm_defs, adts, .. } = package.to_fun()?;
for (name, adt) in adts {
let adts = pkgs.loaded_adts.entry(src.clone()).or_default();
adts.insert(name.clone(), adt.ctrs.keys().cloned().collect_vec());
self.add_imported_adt(name, adt, diag);
}
for def in defs.into_values() {
self.add_imported_def(def, diag);
}
for def in hvm_defs.into_values() {
self.add_imported_hvm_def(def, diag);
}
}
Ok(())
}
fn apply_import_binds(&mut self, main_imports: Option<&ImportsMap>, pkgs: &Packages) {
let main_imports = main_imports.unwrap_or(&self.import_ctx.map);
let mut local_imports = BindMap::new();
let mut adt_imports = BindMap::new();
'outer: for (bind, src) in self.import_ctx.map.binds.iter().rev() {
if self.contains_def(bind) | self.ctrs.contains_key(bind) | self.adts.contains_key(bind) {
continue;
}
let nam = if main_imports.contains_source(src) { src.clone() } else { Name::new(format!("__{}", src)) };
for pkg in self.import_ctx.sources() {
if let Some(book) = pkgs.loaded_adts.get(pkg) {
if let Some(ctrs) = book.get(&nam) {
for ctr in ctrs.iter().rev() {
let full_ctr_name = ctr.split("__").nth(1).unwrap_or(ctr.as_ref());
let ctr_name = full_ctr_name.strip_prefix(src.as_ref()).unwrap();
let bind = Name::new(format!("{}{}", bind, ctr_name));
local_imports.insert(bind, ctr.clone());
}
adt_imports.insert(bind.clone(), nam.clone());
continue 'outer;
}
}
}
local_imports.insert(bind.clone(), nam);
}
for (_, def) in self.local_defs_mut() {
def.apply_binds(true, &local_imports);
def.apply_type_binds(&adt_imports);
}
}
fn apply_adts(&mut self, src: &Name, main_imports: &ImportsMap) {
let adts = std::mem::take(&mut self.adts);
let mut new_adts = IndexMap::new();
let mut adts_map = vec![];
let mut ctrs_map = IndexMap::new();
let mut new_ctrs = IndexMap::new();
for (mut name, mut adt) in adts {
if adt.source.is_local() {
adt.source.kind = SourceKind::Imported;
let old_name = name.clone();
name = Name::new(format!("{}/{}", src, name));
let mangle_name = !main_imports.contains_source(&name);
let mut mangle_adt_name = mangle_name;
for (old_nam, ctr) in std::mem::take(&mut adt.ctrs) {
let mut ctr_name = Name::new(format!("{}/{}", src, old_nam));
let mangle_ctr = mangle_name && !main_imports.contains_source(&ctr_name);
if mangle_ctr {
mangle_adt_name = true;
ctr_name = Name::new(format!("__{}", ctr_name));
}
let ctr = AdtCtr { name: ctr_name.clone(), ..ctr };
new_ctrs.insert(ctr_name.clone(), name.clone());
ctrs_map.insert(old_nam, ctr_name.clone());
adt.ctrs.insert(ctr_name, ctr);
}
if mangle_adt_name {
name = Name::new(format!("__{}", name));
}
adt.name = name.clone();
adts_map.push((old_name, name.clone()));
}
new_adts.insert(name.clone(), adt);
}
for (_, adt) in &mut new_adts {
for (_, ctr) in &mut adt.ctrs {
for (from, to) in &adts_map {
ctr.typ.subst_ctr(from, to);
}
}
}
let adts_map = adts_map.into_iter().collect::<IndexMap<_, _>>();
for (_, def) in self.local_defs_mut() {
def.apply_binds(true, &ctrs_map);
def.apply_type_binds(&adts_map);
}
self.adts = new_adts;
self.ctrs = new_ctrs;
}
fn apply_defs(&mut self, src: &Name, main_imports: &ImportsMap) {
let mut canonical_map: IndexMap<_, _> = IndexMap::new();
for (_, def) in self.local_defs_mut() {
def.canonicalize_name(src, main_imports, &mut canonical_map);
}
for (_, def) in self.local_defs_mut() {
def.apply_binds(false, &canonical_map);
def.source_mut().kind = SourceKind::Imported;
}
}
}
impl ParseBook {
pub fn top_level_names(&self) -> impl Iterator<Item = &Name> {
let imp_defs = self.imp_defs.keys();
let fun_defs = self.fun_defs.keys();
let hvm_defs = self.hvm_defs.keys();
let adts = self.adts.keys();
let ctrs = self.ctrs.keys();
imp_defs.chain(fun_defs).chain(hvm_defs).chain(adts).chain(ctrs)
}
fn add_imported_adt(&mut self, nam: Name, adt: Adt, diag: &mut Diagnostics) {
if self.adts.get(&nam).is_some() {
let err = format!("The imported datatype '{nam}' conflicts with the datatype '{nam}'.");
diag.add_book_error(err);
} else {
for ctr in adt.ctrs.keys() {
if self.contains_def(ctr) {
let err = format!("The imported constructor '{ctr}' conflicts with the definition '{ctr}'.");
diag.add_book_error(err);
}
match self.ctrs.entry(ctr.clone()) {
Entry::Vacant(e) => _ = e.insert(nam.clone()),
Entry::Occupied(e) => {
let ctr = e.key();
let err = format!("The imported constructor '{ctr}' conflicts with the constructor '{ctr}'.");
diag.add_book_error(err);
}
}
}
self.adts.insert(nam, adt);
}
}
fn add_imported_def(&mut self, def: Definition, diag: &mut Diagnostics) {
if !self.has_def_conflict(&def.name, diag) {
self.fun_defs.insert(def.name.clone(), def);
}
}
fn add_imported_hvm_def(&mut self, def: HvmDefinition, diag: &mut Diagnostics) {
if !self.has_def_conflict(&def.name, diag) {
self.hvm_defs.insert(def.name.clone(), def);
}
}
fn has_def_conflict(&mut self, name: &Name, diag: &mut Diagnostics) -> bool {
if self.contains_def(name) {
let err = format!("The imported definition '{name}' conflicts with the definition '{name}'.");
diag.add_book_error(err);
true
} else if self.ctrs.contains_key(name) {
let err = format!("The imported definition '{name}' conflicts with the constructor '{name}'.");
diag.add_book_error(err);
true
} else {
false
}
}
fn local_defs_mut(&mut self) -> impl Iterator<Item = (&Name, &mut dyn Def)> {
let fun = self.fun_defs.iter_mut().map(|(nam, def)| (nam, def as &mut dyn Def));
let imp = self.imp_defs.iter_mut().map(|(nam, def)| (nam, def as &mut dyn Def));
let hvm = self.hvm_defs.iter_mut().map(|(nam, def)| (nam, def as &mut dyn Def));
fun.chain(imp).chain(hvm).filter(|(_, def)| def.source().is_local())
}
}
trait Def {
fn canonicalize_name(&mut self, src: &Name, main_imports: &ImportsMap, binds: &mut BindMap) {
let def_name = self.name_mut();
let mut new_name = Name::new(format!("{}/{}", src, def_name));
if !main_imports.contains_source(&new_name) {
new_name = Name::new(format!("__{}", new_name));
}
binds.insert(def_name.clone(), new_name.clone());
*def_name = new_name;
}
fn apply_binds(&mut self, maybe_constructor: bool, binds: &BindMap);
fn apply_type_binds(&mut self, binds: &BindMap);
fn source(&self) -> &Source;
fn source_mut(&mut self) -> &mut Source;
fn name_mut(&mut self) -> &mut Name;
}
impl Def for Definition {
fn apply_binds(&mut self, maybe_constructor: bool, binds: &BindMap) {
fn rename_ctr_pattern(pat: &mut Pattern, binds: &BindMap) {
for pat in pat.children_mut() {
rename_ctr_pattern(pat, binds);
}
match pat {
Pattern::Ctr(nam, _) => {
if let Some(alias) = binds.get(nam) {
*nam = alias.clone();
}
}
Pattern::Var(Some(nam)) => {
if let Some(alias) = binds.get(nam) {
*nam = alias.clone();
}
}
_ => {}
}
}
for rule in &mut self.rules {
if maybe_constructor {
for pat in &mut rule.pats {
rename_ctr_pattern(pat, binds);
}
}
let bod = std::mem::take(&mut rule.body);
rule.body = bod.fold_uses(binds.iter().rev());
}
}
fn apply_type_binds(&mut self, binds: &BindMap) {
for (from, to) in binds.iter().rev() {
self.typ.subst_ctr(from, to);
for rule in &mut self.rules {
rule.body.subst_type_ctrs(from, to);
}
}
}
fn source(&self) -> &Source {
&self.source
}
fn source_mut(&mut self) -> &mut Source {
&mut self.source
}
fn name_mut(&mut self) -> &mut Name {
&mut self.name
}
}
impl Def for imp::Definition {
fn apply_binds(&mut self, _maybe_constructor: bool, binds: &BindMap) {
let bod = std::mem::take(&mut self.body);
self.body = bod.fold_uses(binds.iter().rev());
}
fn apply_type_binds(&mut self, binds: &BindMap) {
fn subst_type_ctrs_stmt(stmt: &mut Stmt, from: &Name, to: &Name) {
maybe_grow(|| match stmt {
Stmt::Assign { nxt, .. } => {
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::InPlace { nxt, .. } => {
subst_type_ctrs_stmt(nxt, from, to);
}
Stmt::If { then, otherwise, nxt, .. } => {
subst_type_ctrs_stmt(then, from, to);
subst_type_ctrs_stmt(otherwise, from, to);
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::Match { arms, nxt, .. } => {
for MatchArm { lft: _, rgt } in arms {
subst_type_ctrs_stmt(rgt, from, to);
}
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::Switch { arms, nxt, .. } => {
for arm in arms {
subst_type_ctrs_stmt(arm, from, to);
}
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::Bend { step, base, nxt, .. } => {
subst_type_ctrs_stmt(step, from, to);
subst_type_ctrs_stmt(base, from, to);
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::Fold { arms, nxt, .. } => {
for MatchArm { lft: _, rgt } in arms {
subst_type_ctrs_stmt(rgt, from, to);
}
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::With { typ, bod, nxt } => {
if typ == from {
*typ = to.clone();
}
subst_type_ctrs_stmt(bod, from, to);
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::Ask { nxt, .. } => {
if let Some(nxt) = nxt {
subst_type_ctrs_stmt(nxt, from, to);
}
}
Stmt::Return { .. } => {}
Stmt::Open { typ, nxt, .. } => {
if typ == from {
*typ = to.clone();
}
subst_type_ctrs_stmt(nxt, from, to);
}
Stmt::Use { nxt, .. } => {
subst_type_ctrs_stmt(nxt, from, to);
}
Stmt::LocalDef { def, nxt } => {
def.apply_type_binds(&[(from.clone(), to.clone())].into_iter().collect());
subst_type_ctrs_stmt(nxt, from, to);
}
Stmt::Err => {}
})
}
for (from, to) in binds.iter().rev() {
self.typ.subst_ctr(from, to);
subst_type_ctrs_stmt(&mut self.body, from, to);
}
}
fn source(&self) -> &Source {
&self.source
}
fn source_mut(&mut self) -> &mut Source {
&mut self.source
}
fn name_mut(&mut self) -> &mut Name {
&mut self.name
}
}
impl Def for HvmDefinition {
fn apply_binds(&mut self, _maybe_constructor: bool, _binds: &BindMap) {}
fn apply_type_binds(&mut self, binds: &BindMap) {
for (from, to) in binds.iter().rev() {
self.typ.subst_ctr(from, to);
}
}
fn source(&self) -> &Source {
&self.source
}
fn source_mut(&mut self) -> &mut Source {
&mut self.source
}
fn name_mut(&mut self) -> &mut Name {
&mut self.name
}
fn canonicalize_name(&mut self, src: &Name, main_imports: &ImportsMap, binds: &mut BindMap) {
let def_name = self.name_mut();
let mut new_name = Name::new(std::format!("{}/{}", src, def_name));
if !main_imports.contains_source(&new_name) {
new_name = Name::new(std::format!("__{}", new_name));
}
binds.insert(def_name.clone(), new_name.clone());
*def_name = new_name;
}
}
impl Term {
fn fold_uses<'a>(self, map: impl Iterator<Item = (&'a Name, &'a Name)>) -> Self {
map.fold(self, |acc, (bind, nam)| Self::Use {
nam: Some(bind.clone()),
val: Box::new(Self::Var { nam: nam.clone() }),
nxt: Box::new(acc),
})
}
}
impl Stmt {
fn fold_uses<'a>(self, map: impl Iterator<Item = (&'a Name, &'a Name)>) -> Self {
map.fold(self, |acc, (bind, nam)| Self::Use {
nam: bind.clone(),
val: Box::new(Expr::Var { nam: nam.clone() }),
nxt: Box::new(acc),
})
}
}