1use std::cell::RefCell;
2use std::mem::{replace, take};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5
6use erg_common::config::ErgConfig;
7use erg_common::dict::Dict as Dic;
8use erg_common::fresh::SharedFreshNameGenerator;
9use erg_common::log;
10use erg_common::pathutil::squash;
11use erg_common::traits::{Locational, Stream};
12use erg_common::Str;
13
14use erg_parser::ast::{DefId, OperationKind};
15use erg_parser::token::{Token, TokenKind, DOT, EQUAL};
16
17use crate::ty::value::ValueObj;
18use crate::ty::HasType;
19
20use crate::hir::*;
21use crate::module::SharedModuleCache;
22
23pub struct Mod {
24 variable: Expr,
25 definition: Expr,
26}
27
28impl Mod {
29 const fn new(variable: Expr, definition: Expr) -> Self {
30 Self {
31 variable,
32 definition,
33 }
34 }
35}
36
37pub struct HIRLinker<'a> {
40 cfg: &'a ErgConfig,
41 mod_cache: &'a SharedModuleCache,
42 removed_mods: Rc<RefCell<Dic<PathBuf, Mod>>>,
43 fresh_gen: SharedFreshNameGenerator,
44}
45
46impl<'a> HIRLinker<'a> {
47 pub fn new(cfg: &'a ErgConfig, mod_cache: &'a SharedModuleCache) -> Self {
48 Self {
49 cfg,
50 mod_cache,
51 removed_mods: Rc::new(RefCell::new(Dic::new())),
52 fresh_gen: SharedFreshNameGenerator::new("hir_linker"),
53 }
54 }
55
56 fn inherit(&self, cfg: &'a ErgConfig) -> Self {
57 Self {
58 cfg,
59 mod_cache: self.mod_cache,
60 removed_mods: self.removed_mods.clone(),
61 fresh_gen: self.fresh_gen.clone(),
62 }
63 }
64
65 pub fn link(&self, mut main: HIR) -> HIR {
66 log!(info "the linking process has started.");
67 for chunk in main.module.iter_mut() {
68 self.replace_import(chunk);
69 }
70 for (i, module) in self.removed_mods.borrow_mut().values_mut().enumerate() {
72 main.module.insert(i, take(&mut module.definition));
73 }
74 for chunk in main.module.iter_mut() {
75 Self::resolve_pymod_path(chunk);
76 }
77 log!(info "linked:\n{main}");
78 main
79 }
80
81 fn link_child(&self, mut hir: HIR) -> HIR {
82 for chunk in hir.module.iter_mut() {
83 self.replace_import(chunk);
84 }
85 for chunk in hir.module.iter_mut() {
86 Self::resolve_pymod_path(chunk);
87 }
88 hir
89 }
90
91 fn resolve_pymod_path(expr: &mut Expr) {
113 match expr {
114 Expr::Literal(_) => {}
115 Expr::Accessor(acc) => {
116 if let Accessor::Attr(attr) = acc {
117 Self::resolve_pymod_path(&mut attr.obj);
118 if acc.ref_t().is_py_module() {
119 let import = Expr::Import(acc.clone());
120 *expr = Expr::Compound(Block::new(vec![import, take(expr)]));
121 }
122 }
123 }
124 Expr::List(list) => match list {
125 List::Normal(lis) => {
126 for elem in lis.elems.pos_args.iter_mut() {
127 Self::resolve_pymod_path(&mut elem.expr);
128 }
129 }
130 List::WithLength(lis) => {
131 Self::resolve_pymod_path(&mut lis.elem);
132 if let Some(len) = lis.len.as_deref_mut() {
133 Self::resolve_pymod_path(len);
134 }
135 }
136 _ => todo!(),
137 },
138 Expr::Tuple(tuple) => match tuple {
139 Tuple::Normal(tup) => {
140 for elem in tup.elems.pos_args.iter_mut() {
141 Self::resolve_pymod_path(&mut elem.expr);
142 }
143 }
144 },
145 Expr::Set(set) => match set {
146 Set::Normal(st) => {
147 for elem in st.elems.pos_args.iter_mut() {
148 Self::resolve_pymod_path(&mut elem.expr);
149 }
150 }
151 Set::WithLength(st) => {
152 Self::resolve_pymod_path(&mut st.elem);
153 Self::resolve_pymod_path(&mut st.len);
154 }
155 },
156 Expr::Dict(dict) => match dict {
157 Dict::Normal(dic) => {
158 for elem in dic.kvs.iter_mut() {
159 Self::resolve_pymod_path(&mut elem.key);
160 Self::resolve_pymod_path(&mut elem.value);
161 }
162 }
163 other => todo!("{other}"),
164 },
165 Expr::Record(record) => {
166 for attr in record.attrs.iter_mut() {
167 for chunk in attr.body.block.iter_mut() {
168 Self::resolve_pymod_path(chunk);
169 }
170 }
171 }
172 Expr::BinOp(binop) => {
173 Self::resolve_pymod_path(&mut binop.lhs);
174 Self::resolve_pymod_path(&mut binop.rhs);
175 }
176 Expr::UnaryOp(unaryop) => {
177 Self::resolve_pymod_path(&mut unaryop.expr);
178 }
179 Expr::Call(call) => {
180 Self::resolve_pymod_path(&mut call.obj);
181 for arg in call.args.pos_args.iter_mut() {
182 Self::resolve_pymod_path(&mut arg.expr);
183 }
184 for arg in call.args.kw_args.iter_mut() {
185 Self::resolve_pymod_path(&mut arg.expr);
186 }
187 }
188 Expr::Def(def) => {
189 for chunk in def.body.block.iter_mut() {
190 Self::resolve_pymod_path(chunk);
191 }
192 }
193 Expr::Lambda(lambda) => {
194 for chunk in lambda.body.iter_mut() {
195 Self::resolve_pymod_path(chunk);
196 }
197 }
198 Expr::ClassDef(class_def) => {
199 for def in class_def.all_methods_mut() {
200 Self::resolve_pymod_path(def);
201 }
202 }
203 Expr::PatchDef(patch_def) => {
204 for def in patch_def.methods.iter_mut() {
205 Self::resolve_pymod_path(def);
206 }
207 }
208 Expr::ReDef(redef) => {
209 for chunk in redef.block.iter_mut() {
211 Self::resolve_pymod_path(chunk);
212 }
213 }
214 Expr::TypeAsc(tasc) => Self::resolve_pymod_path(&mut tasc.expr),
215 Expr::Code(chunks) | Expr::Compound(chunks) => {
216 for chunk in chunks.iter_mut() {
217 Self::resolve_pymod_path(chunk);
218 }
219 }
220 Expr::Import(_) => {}
221 Expr::Dummy(_) => {}
222 }
223 }
224
225 fn replace_import(&self, expr: &mut Expr) {
226 match expr {
227 Expr::Literal(_) => {}
228 Expr::Accessor(acc) => {
229 match acc {
234 Accessor::Attr(attr) => {
235 self.replace_import(&mut attr.obj);
236 if attr.ident.inspect() == "__file__"
237 && attr.ident.vi.def_loc.module.is_none()
238 {
239 *expr = self.__file__();
240 }
241 }
242 Accessor::Ident(ident) => match &ident.inspect()[..] {
243 "module" => {
244 *expr = Self::self_module();
245 }
246 "global" => {
247 *expr = Expr::from(Identifier::static_public("__builtins__"));
248 }
249 "__file__" if ident.vi.def_loc.module.is_none() => {
250 *expr = self.__file__();
251 }
252 _ => {}
253 },
254 }
255 }
256 Expr::List(list) => match list {
257 List::Normal(lis) => {
258 for elem in lis.elems.pos_args.iter_mut() {
259 self.replace_import(&mut elem.expr);
260 }
261 }
262 List::WithLength(lis) => {
263 self.replace_import(&mut lis.elem);
264 if let Some(len) = lis.len.as_deref_mut() {
265 self.replace_import(len);
266 }
267 }
268 _ => todo!(),
269 },
270 Expr::Tuple(tuple) => match tuple {
271 Tuple::Normal(tup) => {
272 for elem in tup.elems.pos_args.iter_mut() {
273 self.replace_import(&mut elem.expr);
274 }
275 }
276 },
277 Expr::Set(set) => match set {
278 Set::Normal(st) => {
279 for elem in st.elems.pos_args.iter_mut() {
280 self.replace_import(&mut elem.expr);
281 }
282 }
283 Set::WithLength(st) => {
284 self.replace_import(&mut st.elem);
285 self.replace_import(&mut st.len);
286 }
287 },
288 Expr::Dict(dict) => match dict {
289 Dict::Normal(dic) => {
290 for elem in dic.kvs.iter_mut() {
291 self.replace_import(&mut elem.key);
292 self.replace_import(&mut elem.value);
293 }
294 }
295 other => todo!("{other}"),
296 },
297 Expr::Record(record) => {
298 for attr in record.attrs.iter_mut() {
299 for chunk in attr.body.block.iter_mut() {
300 self.replace_import(chunk);
301 }
302 }
303 }
304 Expr::BinOp(binop) => {
305 self.replace_import(&mut binop.lhs);
306 self.replace_import(&mut binop.rhs);
307 }
308 Expr::UnaryOp(unaryop) => {
309 self.replace_import(&mut unaryop.expr);
310 }
311 Expr::Call(call) => match call.additional_operation() {
312 Some(OperationKind::Import) => {
313 self.replace_erg_import(expr);
314 }
315 Some(OperationKind::PyImport) => {
316 self.replace_py_import(expr);
317 }
318 _ => {
319 self.replace_import(&mut call.obj);
320 for arg in call.args.pos_args.iter_mut() {
321 self.replace_import(&mut arg.expr);
322 }
323 if let Some(arg) = call.args.var_args.as_deref_mut() {
324 self.replace_py_import(&mut arg.expr);
325 }
326 for arg in call.args.kw_args.iter_mut() {
327 self.replace_import(&mut arg.expr);
328 }
329 }
330 },
331 Expr::Def(def) => {
332 for chunk in def.body.block.iter_mut() {
333 self.replace_import(chunk);
334 }
335 }
336 Expr::Lambda(lambda) => {
337 for chunk in lambda.body.iter_mut() {
338 self.replace_import(chunk);
339 }
340 }
341 Expr::ClassDef(class_def) => {
342 for def in class_def.all_methods_mut() {
343 self.replace_import(def);
344 }
345 }
346 Expr::PatchDef(patch_def) => {
347 for def in patch_def.methods.iter_mut() {
348 self.replace_import(def);
349 }
350 }
351 Expr::ReDef(redef) => {
352 for chunk in redef.block.iter_mut() {
354 self.replace_import(chunk);
355 }
356 }
357 Expr::TypeAsc(tasc) => self.replace_import(&mut tasc.expr),
358 Expr::Code(chunks) | Expr::Compound(chunks) => {
359 for chunk in chunks.iter_mut() {
360 self.replace_import(chunk);
361 }
362 }
363 Expr::Import(_) => unreachable!(),
364 Expr::Dummy(_) => {}
365 }
366 }
367
368 fn self_module() -> Expr {
369 let __import__ = Identifier::static_public("__import__");
370 let __name__ = Identifier::static_public("__name__");
371 Expr::from(__import__).call1(Expr::from(__name__))
372 }
373
374 fn __file__(&self) -> Expr {
375 let path = self.cfg.input.path().to_path_buf();
376 let token = Token::new_fake(
377 TokenKind::StrLit,
378 format!(
379 "\"{}\"",
380 path.canonicalize().unwrap_or(path).to_string_lossy()
381 ),
382 0,
383 0,
384 0,
385 );
386 let lit = Literal::try_from(token).unwrap();
387 Expr::from(lit)
388 }
389
390 fn replace_erg_import(&self, expr: &mut Expr) {
402 let line = expr.ln_begin().unwrap_or(0);
403 let Some(path) = expr.ref_t().module_path() else {
404 unreachable!()
405 };
406 if matches!((path.canonicalize(), self.cfg.input.path().canonicalize()), (Ok(l), Ok(r)) if l == r)
412 {
413 *expr = Self::self_module();
414 return;
415 }
416 let hir_cfg = if self.cfg.input.is_repl() {
418 self.mod_cache
419 .get(path.as_path())
420 .and_then(|entry| entry.hir.clone().map(|hir| (hir, entry.cfg().clone())))
421 } else {
422 self.mod_cache
423 .remove(path.as_path())
424 .and_then(|entry| entry.hir.map(|hir| (hir, entry.module.context.cfg.clone())))
425 };
426 let Expr::Call(call) = expr else {
427 log!(err "{expr}");
428 return;
429 };
430 let Some(mod_name) = call.args.get_left_or_key("Path") else {
431 log!(err "{call}");
432 return;
433 };
434 if let Some((hir, cfg)) = hir_cfg {
437 *expr = self.modularize(mod_name.clone(), hir, cfg, line, path);
438 } else if let Some(module) = self.removed_mods.borrow().get(&path) {
439 *expr = module.variable.clone();
440 }
441 }
442
443 fn modularize(
444 &self,
445 mod_name: Expr,
446 hir: HIR,
447 cfg: ErgConfig,
448 line: u32,
449 path: PathBuf,
450 ) -> Expr {
451 let tmp = Identifier::private_with_line(self.fresh_gen.fresh_varname(), line);
452 let mod_var = Expr::Accessor(Accessor::Ident(tmp.clone()));
453 let module_type =
454 Expr::Accessor(Accessor::private_with_line(Str::ever("#ModuleType"), line));
455 let args = Args::single(PosArg::new(mod_name));
456 let block = Block::new(vec![module_type.call_expr(args)]);
457 let mod_def = Expr::Def(Def::new(
458 Signature::Var(VarSignature::global(tmp, None)),
459 DefBody::new(EQUAL, block, DefId(0)),
460 ));
461 self.removed_mods
462 .borrow_mut()
463 .insert(path, Mod::new(mod_var.clone(), mod_def));
464 let linker = self.inherit(&cfg);
465 let hir = linker.link_child(hir);
466 let code = Expr::Code(Block::new(Vec::from(hir.module)));
467 let __dict__ = Identifier::static_public("__dict__");
468 let m_dict = mod_var.clone().attr_expr(__dict__);
469 let locals = Expr::Accessor(Accessor::public_with_line(Str::ever("locals"), line));
470 let locals_call = locals.call_expr(Args::empty());
471 let args = Args::single(PosArg::new(locals_call));
472 let mod_update = Expr::Call(Call::new(
473 m_dict.clone(),
474 Some(Identifier::static_public("update")),
475 args,
476 ));
477 let exec = Expr::Accessor(Accessor::public_with_line(Str::ever("exec"), line));
478 let args = Args::pos_only(vec![PosArg::new(code), PosArg::new(m_dict)], None);
479 let exec_code = exec.call_expr(args);
480 let compound = Block::new(vec![mod_update, exec_code, mod_var]);
481 Expr::Compound(compound)
482 }
483
484 fn replace_py_import(&self, expr: &mut Expr) {
492 let args = if let Expr::Call(call) = expr {
493 &mut call.args
494 } else {
495 log!(err "{expr}");
496 return;
497 };
498 let Some(Expr::Literal(mod_name_lit)) = args.remove_left_or_key("Path") else {
499 log!(err "{args}");
500 return;
501 };
502 let ValueObj::Str(mod_name_str) = mod_name_lit.value.clone() else {
503 log!(err "{mod_name_lit}");
504 return;
505 };
506 let mut dir = self.cfg.input.dir();
507 let mod_path = self
508 .cfg
509 .input
510 .resolve_decl_path(Path::new(&mod_name_str[..]), self.cfg)
511 .unwrap();
512 if !mod_path
513 .canonicalize()
514 .unwrap()
515 .starts_with(dir.canonicalize().unwrap())
516 {
517 dir = PathBuf::new();
518 }
519 let mod_name_str = if let Some(stripped) = mod_name_str.strip_prefix("./") {
520 stripped
521 } else {
522 &mod_name_str
523 };
524 dir.push(mod_name_str);
525 let dir = squash(dir);
526 let mut comps = dir.components();
527 let _first = comps.next().unwrap();
528 let path = dir.to_string_lossy().replace(['/', '\\'], ".");
529 let token = Token::new_fake(
530 TokenKind::StrLit,
531 format!("\"{path}\""),
532 mod_name_lit.ln_begin().unwrap(),
533 mod_name_lit.col_begin().unwrap(),
534 mod_name_lit.col_end().unwrap(),
535 );
536 let mod_name = Expr::Literal(Literal::try_from(token).unwrap());
537 args.insert_pos(0, PosArg::new(mod_name));
538 let line = expr.ln_begin().unwrap_or(0);
539 for attr in comps {
540 *expr =
541 replace(expr, Expr::Code(Block::empty())).attr_expr(Identifier::public_with_line(
542 DOT,
543 Str::rc(attr.as_os_str().to_str().unwrap()),
544 line,
545 ));
546 }
547 }
548}