erg_compiler/
link_hir.rs

1use std::cell::RefCell;
2use std::mem::{replace, take};
3use std::path::{Path, PathBuf};
4use std::rc::Rc;
5
6use erg_common::config::ErgConfig;
7use erg_common::dict::Dict as Dic;
8use erg_common::fresh::SharedFreshNameGenerator;
9use erg_common::log;
10use erg_common::pathutil::squash;
11use erg_common::traits::{Locational, Stream};
12use erg_common::Str;
13
14use erg_parser::ast::{DefId, OperationKind};
15use erg_parser::token::{Token, TokenKind, DOT, EQUAL};
16
17use crate::ty::value::ValueObj;
18use crate::ty::HasType;
19
20use crate::hir::*;
21use crate::module::SharedModuleCache;
22
23pub struct Mod {
24    variable: Expr,
25    definition: Expr,
26}
27
28impl Mod {
29    const fn new(variable: Expr, definition: Expr) -> Self {
30        Self {
31            variable,
32            definition,
33        }
34    }
35}
36
37/// Link code using the module cache.
38/// Erg links all non-Python modules into a single pyc file.
39pub struct HIRLinker<'a> {
40    cfg: &'a ErgConfig,
41    mod_cache: &'a SharedModuleCache,
42    removed_mods: Rc<RefCell<Dic<PathBuf, Mod>>>,
43    fresh_gen: SharedFreshNameGenerator,
44}
45
46impl<'a> HIRLinker<'a> {
47    pub fn new(cfg: &'a ErgConfig, mod_cache: &'a SharedModuleCache) -> Self {
48        Self {
49            cfg,
50            mod_cache,
51            removed_mods: Rc::new(RefCell::new(Dic::new())),
52            fresh_gen: SharedFreshNameGenerator::new("hir_linker"),
53        }
54    }
55
56    fn inherit(&self, cfg: &'a ErgConfig) -> Self {
57        Self {
58            cfg,
59            mod_cache: self.mod_cache,
60            removed_mods: self.removed_mods.clone(),
61            fresh_gen: self.fresh_gen.clone(),
62        }
63    }
64
65    pub fn link(&self, mut main: HIR) -> HIR {
66        log!(info "the linking process has started.");
67        for chunk in main.module.iter_mut() {
68            self.replace_import(chunk);
69        }
70        // declare all modules first (due to cyclic modules)
71        for (i, module) in self.removed_mods.borrow_mut().values_mut().enumerate() {
72            main.module.insert(i, take(&mut module.definition));
73        }
74        for chunk in main.module.iter_mut() {
75            Self::resolve_pymod_path(chunk);
76        }
77        log!(info "linked:\n{main}");
78        main
79    }
80
81    fn link_child(&self, mut hir: HIR) -> HIR {
82        for chunk in hir.module.iter_mut() {
83            self.replace_import(chunk);
84        }
85        for chunk in hir.module.iter_mut() {
86            Self::resolve_pymod_path(chunk);
87        }
88        hir
89    }
90
91    /// ```erg
92    /// urllib = pyimport "urllib"
93    /// urllib.request.urlopen! "https://example.com"
94    /// ```
95    /// ↓
96    /// ```python
97    /// urllib = __import__("urllib")
98    /// import urllib.request
99    /// urllib.request.urlopen("https://example.com")
100    /// ```
101    /// other example:
102    /// ```erg
103    /// mpl = pyimport "matplotlib"
104    /// mpl.pyplot.plot! [1, 2, 3]
105    /// ```
106    /// ↓
107    /// ```python
108    /// mpl = __import__("matplotlib")
109    /// import matplotlib.pyplot # mpl.pyplot.foo is now allowed
110    /// mpl.pyplot.plot([1, 2, 3])
111    /// ```
112    fn resolve_pymod_path(expr: &mut Expr) {
113        match expr {
114            Expr::Literal(_) => {}
115            Expr::Accessor(acc) => {
116                if let Accessor::Attr(attr) = acc {
117                    Self::resolve_pymod_path(&mut attr.obj);
118                    if acc.ref_t().is_py_module() {
119                        let import = Expr::Import(acc.clone());
120                        *expr = Expr::Compound(Block::new(vec![import, take(expr)]));
121                    }
122                }
123            }
124            Expr::List(list) => match list {
125                List::Normal(lis) => {
126                    for elem in lis.elems.pos_args.iter_mut() {
127                        Self::resolve_pymod_path(&mut elem.expr);
128                    }
129                }
130                List::WithLength(lis) => {
131                    Self::resolve_pymod_path(&mut lis.elem);
132                    if let Some(len) = lis.len.as_deref_mut() {
133                        Self::resolve_pymod_path(len);
134                    }
135                }
136                _ => todo!(),
137            },
138            Expr::Tuple(tuple) => match tuple {
139                Tuple::Normal(tup) => {
140                    for elem in tup.elems.pos_args.iter_mut() {
141                        Self::resolve_pymod_path(&mut elem.expr);
142                    }
143                }
144            },
145            Expr::Set(set) => match set {
146                Set::Normal(st) => {
147                    for elem in st.elems.pos_args.iter_mut() {
148                        Self::resolve_pymod_path(&mut elem.expr);
149                    }
150                }
151                Set::WithLength(st) => {
152                    Self::resolve_pymod_path(&mut st.elem);
153                    Self::resolve_pymod_path(&mut st.len);
154                }
155            },
156            Expr::Dict(dict) => match dict {
157                Dict::Normal(dic) => {
158                    for elem in dic.kvs.iter_mut() {
159                        Self::resolve_pymod_path(&mut elem.key);
160                        Self::resolve_pymod_path(&mut elem.value);
161                    }
162                }
163                other => todo!("{other}"),
164            },
165            Expr::Record(record) => {
166                for attr in record.attrs.iter_mut() {
167                    for chunk in attr.body.block.iter_mut() {
168                        Self::resolve_pymod_path(chunk);
169                    }
170                }
171            }
172            Expr::BinOp(binop) => {
173                Self::resolve_pymod_path(&mut binop.lhs);
174                Self::resolve_pymod_path(&mut binop.rhs);
175            }
176            Expr::UnaryOp(unaryop) => {
177                Self::resolve_pymod_path(&mut unaryop.expr);
178            }
179            Expr::Call(call) => {
180                Self::resolve_pymod_path(&mut call.obj);
181                for arg in call.args.pos_args.iter_mut() {
182                    Self::resolve_pymod_path(&mut arg.expr);
183                }
184                for arg in call.args.kw_args.iter_mut() {
185                    Self::resolve_pymod_path(&mut arg.expr);
186                }
187            }
188            Expr::Def(def) => {
189                for chunk in def.body.block.iter_mut() {
190                    Self::resolve_pymod_path(chunk);
191                }
192            }
193            Expr::Lambda(lambda) => {
194                for chunk in lambda.body.iter_mut() {
195                    Self::resolve_pymod_path(chunk);
196                }
197            }
198            Expr::ClassDef(class_def) => {
199                for def in class_def.all_methods_mut() {
200                    Self::resolve_pymod_path(def);
201                }
202            }
203            Expr::PatchDef(patch_def) => {
204                for def in patch_def.methods.iter_mut() {
205                    Self::resolve_pymod_path(def);
206                }
207            }
208            Expr::ReDef(redef) => {
209                // REVIEW:
210                for chunk in redef.block.iter_mut() {
211                    Self::resolve_pymod_path(chunk);
212                }
213            }
214            Expr::TypeAsc(tasc) => Self::resolve_pymod_path(&mut tasc.expr),
215            Expr::Code(chunks) | Expr::Compound(chunks) => {
216                for chunk in chunks.iter_mut() {
217                    Self::resolve_pymod_path(chunk);
218                }
219            }
220            Expr::Import(_) => {}
221            Expr::Dummy(_) => {}
222        }
223    }
224
225    fn replace_import(&self, expr: &mut Expr) {
226        match expr {
227            Expr::Literal(_) => {}
228            Expr::Accessor(acc) => {
229                /*if acc.ref_t().is_py_module() {
230                    let import = Expr::Import(acc.clone());
231                    *expr = Expr::Compound(Block::new(vec![import, mem::take(expr)]));
232                }*/
233                match acc {
234                    Accessor::Attr(attr) => {
235                        self.replace_import(&mut attr.obj);
236                        if attr.ident.inspect() == "__file__"
237                            && attr.ident.vi.def_loc.module.is_none()
238                        {
239                            *expr = self.__file__();
240                        }
241                    }
242                    Accessor::Ident(ident) => match &ident.inspect()[..] {
243                        "module" => {
244                            *expr = Self::self_module();
245                        }
246                        "global" => {
247                            *expr = Expr::from(Identifier::static_public("__builtins__"));
248                        }
249                        "__file__" if ident.vi.def_loc.module.is_none() => {
250                            *expr = self.__file__();
251                        }
252                        _ => {}
253                    },
254                }
255            }
256            Expr::List(list) => match list {
257                List::Normal(lis) => {
258                    for elem in lis.elems.pos_args.iter_mut() {
259                        self.replace_import(&mut elem.expr);
260                    }
261                }
262                List::WithLength(lis) => {
263                    self.replace_import(&mut lis.elem);
264                    if let Some(len) = lis.len.as_deref_mut() {
265                        self.replace_import(len);
266                    }
267                }
268                _ => todo!(),
269            },
270            Expr::Tuple(tuple) => match tuple {
271                Tuple::Normal(tup) => {
272                    for elem in tup.elems.pos_args.iter_mut() {
273                        self.replace_import(&mut elem.expr);
274                    }
275                }
276            },
277            Expr::Set(set) => match set {
278                Set::Normal(st) => {
279                    for elem in st.elems.pos_args.iter_mut() {
280                        self.replace_import(&mut elem.expr);
281                    }
282                }
283                Set::WithLength(st) => {
284                    self.replace_import(&mut st.elem);
285                    self.replace_import(&mut st.len);
286                }
287            },
288            Expr::Dict(dict) => match dict {
289                Dict::Normal(dic) => {
290                    for elem in dic.kvs.iter_mut() {
291                        self.replace_import(&mut elem.key);
292                        self.replace_import(&mut elem.value);
293                    }
294                }
295                other => todo!("{other}"),
296            },
297            Expr::Record(record) => {
298                for attr in record.attrs.iter_mut() {
299                    for chunk in attr.body.block.iter_mut() {
300                        self.replace_import(chunk);
301                    }
302                }
303            }
304            Expr::BinOp(binop) => {
305                self.replace_import(&mut binop.lhs);
306                self.replace_import(&mut binop.rhs);
307            }
308            Expr::UnaryOp(unaryop) => {
309                self.replace_import(&mut unaryop.expr);
310            }
311            Expr::Call(call) => match call.additional_operation() {
312                Some(OperationKind::Import) => {
313                    self.replace_erg_import(expr);
314                }
315                Some(OperationKind::PyImport) => {
316                    self.replace_py_import(expr);
317                }
318                _ => {
319                    self.replace_import(&mut call.obj);
320                    for arg in call.args.pos_args.iter_mut() {
321                        self.replace_import(&mut arg.expr);
322                    }
323                    if let Some(arg) = call.args.var_args.as_deref_mut() {
324                        self.replace_py_import(&mut arg.expr);
325                    }
326                    for arg in call.args.kw_args.iter_mut() {
327                        self.replace_import(&mut arg.expr);
328                    }
329                }
330            },
331            Expr::Def(def) => {
332                for chunk in def.body.block.iter_mut() {
333                    self.replace_import(chunk);
334                }
335            }
336            Expr::Lambda(lambda) => {
337                for chunk in lambda.body.iter_mut() {
338                    self.replace_import(chunk);
339                }
340            }
341            Expr::ClassDef(class_def) => {
342                for def in class_def.all_methods_mut() {
343                    self.replace_import(def);
344                }
345            }
346            Expr::PatchDef(patch_def) => {
347                for def in patch_def.methods.iter_mut() {
348                    self.replace_import(def);
349                }
350            }
351            Expr::ReDef(redef) => {
352                // REVIEW:
353                for chunk in redef.block.iter_mut() {
354                    self.replace_import(chunk);
355                }
356            }
357            Expr::TypeAsc(tasc) => self.replace_import(&mut tasc.expr),
358            Expr::Code(chunks) | Expr::Compound(chunks) => {
359                for chunk in chunks.iter_mut() {
360                    self.replace_import(chunk);
361                }
362            }
363            Expr::Import(_) => unreachable!(),
364            Expr::Dummy(_) => {}
365        }
366    }
367
368    fn self_module() -> Expr {
369        let __import__ = Identifier::static_public("__import__");
370        let __name__ = Identifier::static_public("__name__");
371        Expr::from(__import__).call1(Expr::from(__name__))
372    }
373
374    fn __file__(&self) -> Expr {
375        let path = self.cfg.input.path().to_path_buf();
376        let token = Token::new_fake(
377            TokenKind::StrLit,
378            format!(
379                "\"{}\"",
380                path.canonicalize().unwrap_or(path).to_string_lossy()
381            ),
382            0,
383            0,
384            0,
385        );
386        let lit = Literal::try_from(token).unwrap();
387        Expr::from(lit)
388    }
389
390    /// ```erg
391    /// x = import "mod"
392    /// ```
393    /// ↓
394    /// ```python
395    /// x =
396    ///     _x = ModuleType("mod")
397    ///     _x.__dict__.update(locals()) # `Nat`, etc. are in locals but treated as globals, so they cannot be put in the third argument of exec.
398    ///     exec(code, _x.__dict__)  # `code` is the mod's content
399    ///     _x
400    /// ```
401    fn replace_erg_import(&self, expr: &mut Expr) {
402        let line = expr.ln_begin().unwrap_or(0);
403        let Some(path) = expr.ref_t().module_path() else {
404            unreachable!()
405        };
406        // # module.er
407        // self = import "module"
408        // ↓
409        // # module.er
410        // self = __import__(__name__)
411        if matches!((path.canonicalize(), self.cfg.input.path().canonicalize()), (Ok(l), Ok(r)) if l == r)
412        {
413            *expr = Self::self_module();
414            return;
415        }
416        // In the case of REPL, entries cannot be used up
417        let hir_cfg = if self.cfg.input.is_repl() {
418            self.mod_cache
419                .get(path.as_path())
420                .and_then(|entry| entry.hir.clone().map(|hir| (hir, entry.cfg().clone())))
421        } else {
422            self.mod_cache
423                .remove(path.as_path())
424                .and_then(|entry| entry.hir.map(|hir| (hir, entry.module.context.cfg.clone())))
425        };
426        let Expr::Call(call) = expr else {
427            log!(err "{expr}");
428            return;
429        };
430        let Some(mod_name) = call.args.get_left_or_key("Path") else {
431            log!(err "{call}");
432            return;
433        };
434        // let sig = option_enum_unwrap!(&def.sig, Signature::Var)
435        //    .unwrap_or_else(|| todo!("module subroutines are not allowed"));
436        if let Some((hir, cfg)) = hir_cfg {
437            *expr = self.modularize(mod_name.clone(), hir, cfg, line, path);
438        } else if let Some(module) = self.removed_mods.borrow().get(&path) {
439            *expr = module.variable.clone();
440        }
441    }
442
443    fn modularize(
444        &self,
445        mod_name: Expr,
446        hir: HIR,
447        cfg: ErgConfig,
448        line: u32,
449        path: PathBuf,
450    ) -> Expr {
451        let tmp = Identifier::private_with_line(self.fresh_gen.fresh_varname(), line);
452        let mod_var = Expr::Accessor(Accessor::Ident(tmp.clone()));
453        let module_type =
454            Expr::Accessor(Accessor::private_with_line(Str::ever("#ModuleType"), line));
455        let args = Args::single(PosArg::new(mod_name));
456        let block = Block::new(vec![module_type.call_expr(args)]);
457        let mod_def = Expr::Def(Def::new(
458            Signature::Var(VarSignature::global(tmp, None)),
459            DefBody::new(EQUAL, block, DefId(0)),
460        ));
461        self.removed_mods
462            .borrow_mut()
463            .insert(path, Mod::new(mod_var.clone(), mod_def));
464        let linker = self.inherit(&cfg);
465        let hir = linker.link_child(hir);
466        let code = Expr::Code(Block::new(Vec::from(hir.module)));
467        let __dict__ = Identifier::static_public("__dict__");
468        let m_dict = mod_var.clone().attr_expr(__dict__);
469        let locals = Expr::Accessor(Accessor::public_with_line(Str::ever("locals"), line));
470        let locals_call = locals.call_expr(Args::empty());
471        let args = Args::single(PosArg::new(locals_call));
472        let mod_update = Expr::Call(Call::new(
473            m_dict.clone(),
474            Some(Identifier::static_public("update")),
475            args,
476        ));
477        let exec = Expr::Accessor(Accessor::public_with_line(Str::ever("exec"), line));
478        let args = Args::pos_only(vec![PosArg::new(code), PosArg::new(m_dict)], None);
479        let exec_code = exec.call_expr(args);
480        let compound = Block::new(vec![mod_update, exec_code, mod_var]);
481        Expr::Compound(compound)
482    }
483
484    /// ```erg
485    /// x = pyimport "x" # called from dir "a"
486    /// ```
487    /// ↓
488    /// ```python
489    /// x = __import__("a.x").x
490    /// ```
491    fn replace_py_import(&self, expr: &mut Expr) {
492        let args = if let Expr::Call(call) = expr {
493            &mut call.args
494        } else {
495            log!(err "{expr}");
496            return;
497        };
498        let Some(Expr::Literal(mod_name_lit)) = args.remove_left_or_key("Path") else {
499            log!(err "{args}");
500            return;
501        };
502        let ValueObj::Str(mod_name_str) = mod_name_lit.value.clone() else {
503            log!(err "{mod_name_lit}");
504            return;
505        };
506        let mut dir = self.cfg.input.dir();
507        let mod_path = self
508            .cfg
509            .input
510            .resolve_decl_path(Path::new(&mod_name_str[..]), self.cfg)
511            .unwrap();
512        if !mod_path
513            .canonicalize()
514            .unwrap()
515            .starts_with(dir.canonicalize().unwrap())
516        {
517            dir = PathBuf::new();
518        }
519        let mod_name_str = if let Some(stripped) = mod_name_str.strip_prefix("./") {
520            stripped
521        } else {
522            &mod_name_str
523        };
524        dir.push(mod_name_str);
525        let dir = squash(dir);
526        let mut comps = dir.components();
527        let _first = comps.next().unwrap();
528        let path = dir.to_string_lossy().replace(['/', '\\'], ".");
529        let token = Token::new_fake(
530            TokenKind::StrLit,
531            format!("\"{path}\""),
532            mod_name_lit.ln_begin().unwrap(),
533            mod_name_lit.col_begin().unwrap(),
534            mod_name_lit.col_end().unwrap(),
535        );
536        let mod_name = Expr::Literal(Literal::try_from(token).unwrap());
537        args.insert_pos(0, PosArg::new(mod_name));
538        let line = expr.ln_begin().unwrap_or(0);
539        for attr in comps {
540            *expr =
541                replace(expr, Expr::Code(Block::empty())).attr_expr(Identifier::public_with_line(
542                    DOT,
543                    Str::rc(attr.as_os_str().to_str().unwrap()),
544                    line,
545                ));
546        }
547    }
548}