use std::mem;
use std::path::Path;
use erg_common::config::ErgConfig;
use erg_common::pathutil::squash;
use erg_common::python_util::BUILTIN_PYTHON_MODS;
use erg_common::traits::Locational;
use erg_common::Str;
use erg_common::{enum_unwrap, log};
use erg_parser::ast::{DefId, OperationKind};
use erg_parser::token::{Token, TokenKind, DOT, EQUAL};
use crate::context::Context;
use crate::ty::typaram::TyParam;
use crate::ty::value::ValueObj;
use crate::ty::HasType;
use erg_common::fresh::fresh_varname;
use crate::hir::*;
use crate::module::SharedModuleCache;
pub struct HIRLinker<'a> {
cfg: &'a ErgConfig,
mod_cache: &'a SharedModuleCache,
}
impl<'a> HIRLinker<'a> {
pub fn new(cfg: &'a ErgConfig, mod_cache: &'a SharedModuleCache) -> Self {
Self { cfg, mod_cache }
}
pub fn link(&self, mut main: HIR) -> HIR {
log!(info "the linking process has started.");
for chunk in main.module.iter_mut() {
self.replace_import(chunk);
}
for chunk in main.module.iter_mut() {
Self::resolve_pymod_path(chunk);
}
log!(info "linked: {main}");
main
}
fn resolve_pymod_path(expr: &mut Expr) {
match expr {
Expr::Lit(_) => {}
Expr::Accessor(acc) => {
if let Accessor::Attr(attr) = acc {
Self::resolve_pymod_path(&mut attr.obj);
if acc.ref_t().is_py_module() {
let import = Expr::Import(acc.clone());
*expr = Expr::Compound(Block::new(vec![import, mem::take(expr)]));
}
}
}
Expr::Array(array) => match array {
Array::Normal(arr) => {
for elem in arr.elems.pos_args.iter_mut() {
Self::resolve_pymod_path(&mut elem.expr);
}
}
Array::WithLength(arr) => {
Self::resolve_pymod_path(&mut arr.elem);
Self::resolve_pymod_path(&mut arr.len);
}
_ => todo!(),
},
Expr::Tuple(tuple) => match tuple {
Tuple::Normal(tup) => {
for elem in tup.elems.pos_args.iter_mut() {
Self::resolve_pymod_path(&mut elem.expr);
}
}
},
Expr::Set(set) => match set {
Set::Normal(st) => {
for elem in st.elems.pos_args.iter_mut() {
Self::resolve_pymod_path(&mut elem.expr);
}
}
Set::WithLength(st) => {
Self::resolve_pymod_path(&mut st.elem);
Self::resolve_pymod_path(&mut st.len);
}
},
Expr::Dict(dict) => match dict {
Dict::Normal(dic) => {
for elem in dic.kvs.iter_mut() {
Self::resolve_pymod_path(&mut elem.key);
Self::resolve_pymod_path(&mut elem.value);
}
}
other => todo!("{other}"),
},
Expr::Record(record) => {
for attr in record.attrs.iter_mut() {
for chunk in attr.body.block.iter_mut() {
Self::resolve_pymod_path(chunk);
}
}
}
Expr::BinOp(binop) => {
Self::resolve_pymod_path(&mut binop.lhs);
Self::resolve_pymod_path(&mut binop.rhs);
}
Expr::UnaryOp(unaryop) => {
Self::resolve_pymod_path(&mut unaryop.expr);
}
Expr::Call(call) => {
Self::resolve_pymod_path(&mut call.obj);
for arg in call.args.pos_args.iter_mut() {
Self::resolve_pymod_path(&mut arg.expr);
}
for arg in call.args.kw_args.iter_mut() {
Self::resolve_pymod_path(&mut arg.expr);
}
}
Expr::Def(def) => {
for chunk in def.body.block.iter_mut() {
Self::resolve_pymod_path(chunk);
}
}
Expr::Lambda(lambda) => {
for chunk in lambda.body.iter_mut() {
Self::resolve_pymod_path(chunk);
}
}
Expr::ClassDef(class_def) => {
for def in class_def.methods.iter_mut() {
Self::resolve_pymod_path(def);
}
}
Expr::PatchDef(patch_def) => {
for def in patch_def.methods.iter_mut() {
Self::resolve_pymod_path(def);
}
}
Expr::ReDef(redef) => {
for chunk in redef.block.iter_mut() {
Self::resolve_pymod_path(chunk);
}
}
Expr::TypeAsc(tasc) => Self::resolve_pymod_path(&mut tasc.expr),
Expr::Code(chunks) | Expr::Compound(chunks) => {
for chunk in chunks.iter_mut() {
Self::resolve_pymod_path(chunk);
}
}
Expr::Import(_) => unreachable!(),
Expr::Dummy(_) => {}
}
}
fn replace_import(&self, expr: &mut Expr) {
match expr {
Expr::Lit(_) => {}
Expr::Accessor(acc) => {
match acc {
Accessor::Attr(attr) => {
self.replace_import(&mut attr.obj);
}
Accessor::Ident(_) => {}
}
}
Expr::Array(array) => match array {
Array::Normal(arr) => {
for elem in arr.elems.pos_args.iter_mut() {
self.replace_import(&mut elem.expr);
}
}
Array::WithLength(arr) => {
self.replace_import(&mut arr.elem);
self.replace_import(&mut arr.len);
}
_ => todo!(),
},
Expr::Tuple(tuple) => match tuple {
Tuple::Normal(tup) => {
for elem in tup.elems.pos_args.iter_mut() {
self.replace_import(&mut elem.expr);
}
}
},
Expr::Set(set) => match set {
Set::Normal(st) => {
for elem in st.elems.pos_args.iter_mut() {
self.replace_import(&mut elem.expr);
}
}
Set::WithLength(st) => {
self.replace_import(&mut st.elem);
self.replace_import(&mut st.len);
}
},
Expr::Dict(dict) => match dict {
Dict::Normal(dic) => {
for elem in dic.kvs.iter_mut() {
self.replace_import(&mut elem.key);
self.replace_import(&mut elem.value);
}
}
other => todo!("{other}"),
},
Expr::Record(record) => {
for attr in record.attrs.iter_mut() {
for chunk in attr.body.block.iter_mut() {
self.replace_import(chunk);
}
}
}
Expr::BinOp(binop) => {
self.replace_import(&mut binop.lhs);
self.replace_import(&mut binop.rhs);
}
Expr::UnaryOp(unaryop) => {
self.replace_import(&mut unaryop.expr);
}
Expr::Call(call) => match call.additional_operation() {
Some(OperationKind::Import) => {
self.replace_erg_import(expr);
}
Some(OperationKind::PyImport) => {
self.replace_py_import(expr);
}
_ => {
self.replace_import(&mut call.obj);
for arg in call.args.pos_args.iter_mut() {
self.replace_import(&mut arg.expr);
}
for arg in call.args.kw_args.iter_mut() {
self.replace_import(&mut arg.expr);
}
}
},
Expr::Def(def) => {
for chunk in def.body.block.iter_mut() {
self.replace_import(chunk);
}
}
Expr::Lambda(lambda) => {
for chunk in lambda.body.iter_mut() {
self.replace_import(chunk);
}
}
Expr::ClassDef(class_def) => {
for def in class_def.methods.iter_mut() {
self.replace_import(def);
}
}
Expr::PatchDef(patch_def) => {
for def in patch_def.methods.iter_mut() {
self.replace_import(def);
}
}
Expr::ReDef(redef) => {
for chunk in redef.block.iter_mut() {
self.replace_import(chunk);
}
}
Expr::TypeAsc(tasc) => self.replace_import(&mut tasc.expr),
Expr::Code(chunks) | Expr::Compound(chunks) => {
for chunk in chunks.iter_mut() {
self.replace_import(chunk);
}
}
Expr::Import(_) => unreachable!(),
Expr::Dummy(_) => {}
}
}
fn replace_erg_import(&self, expr: &mut Expr) {
let line = expr.ln_begin().unwrap_or(0);
let path =
enum_unwrap!(expr.ref_t().typarams().remove(0), TyParam::Value:(ValueObj::Str:(_)));
let path = Path::new(&path[..]);
let path = Context::resolve_real_path(self.cfg, path).unwrap();
let hir_cfg = if self.cfg.input.is_repl() {
self.mod_cache
.get(path.as_path())
.and_then(|entry| entry.hir.clone().map(|hir| (hir, entry.cfg().clone())))
} else {
self.mod_cache
.remove(path.as_path())
.and_then(|entry| entry.hir.map(|hir| (hir, entry.module.context.cfg.clone())))
};
let mod_name = enum_unwrap!(expr, Expr::Call)
.args
.get_left_or_key("Path")
.unwrap();
if let Some((hir, cfg)) = hir_cfg {
let linker = HIRLinker::new(&cfg, self.mod_cache);
let hir = linker.link(hir);
let code = Expr::Code(Block::new(Vec::from(hir.module)));
let module_type =
Expr::Accessor(Accessor::private_with_line(Str::ever("#ModuleType"), line));
let args = Args::single(PosArg::new(mod_name.clone()));
let block = Block::new(vec![module_type.call_expr(args)]);
let tmp = Identifier::private_with_line(Str::from(fresh_varname()), line);
let mod_def = Expr::Def(Def::new(
Signature::Var(VarSignature::new(tmp.clone(), None)),
DefBody::new(EQUAL, block, DefId(0)),
));
let module = Expr::Accessor(Accessor::Ident(tmp));
let __dict__ = Identifier::public("__dict__");
let m_dict = module.clone().attr_expr(__dict__);
let locals = Expr::Accessor(Accessor::public_with_line(Str::ever("locals"), line));
let locals_call = locals.call_expr(Args::empty());
let args = Args::single(PosArg::new(locals_call));
let mod_update = Expr::Call(Call::new(
m_dict.clone(),
Some(Identifier::public("update")),
args,
));
let exec = Expr::Accessor(Accessor::public_with_line(Str::ever("exec"), line));
let args = Args::pos_only(vec![PosArg::new(code), PosArg::new(m_dict)], None);
let exec_code = exec.call_expr(args);
let compound = Block::new(vec![mod_def, mod_update, exec_code, module]);
*expr = Expr::Compound(compound);
}
}
fn replace_py_import(&self, expr: &mut Expr) {
let mut dir = self.cfg.input.dir();
let args = &mut enum_unwrap!(expr, Expr::Call).args;
let mod_name_lit = enum_unwrap!(args.remove_left_or_key("Path").unwrap(), Expr::Lit);
let mod_name_str = enum_unwrap!(mod_name_lit.value.clone(), ValueObj::Str);
if BUILTIN_PYTHON_MODS.contains(&&mod_name_str[..]) {
args.push_pos(PosArg::new(Expr::Lit(mod_name_lit)));
return;
}
let mod_name_str = if let Some(stripped) = mod_name_str.strip_prefix("./") {
stripped
} else {
&mod_name_str
};
dir.push(mod_name_str);
let dir = squash(dir);
let mut comps = dir.components();
let _first = comps.next().unwrap();
let path = dir.to_string_lossy().replace(['/', '\\'], ".");
let token = Token::new(
TokenKind::StrLit,
path,
mod_name_lit.ln_begin().unwrap(),
mod_name_lit.col_begin().unwrap(),
);
let mod_name = Expr::Lit(Literal::try_from(token).unwrap());
args.insert_pos(0, PosArg::new(mod_name));
let line = expr.ln_begin().unwrap_or(0);
for attr in comps {
*expr = mem::replace(expr, Expr::Code(Block::empty())).attr_expr(
Identifier::public_with_line(
DOT,
Str::rc(attr.as_os_str().to_str().unwrap()),
line,
),
);
}
}
}