dalbit_core/
injector.rs

1use std::{
2    borrow::Cow,
3    collections::HashSet,
4    path::{Path, PathBuf},
5};
6
7use anyhow::{anyhow, Result};
8use full_moon::{
9    tokenizer::{Token, TokenType},
10    visitors::Visitor,
11    LuaVersion,
12};
13use path_slash::PathBufExt;
14use pathdiff::diff_paths;
15use tokio::fs;
16
17#[inline]
18fn make_relative(path: &PathBuf) -> Cow<Path> {
19    if path.starts_with(".") | path.starts_with("..") {
20        Cow::Borrowed(path.as_path())
21    } else {
22        Cow::Owned(Path::new(".").join(path))
23    }
24}
25
26#[derive(Debug)]
27struct CollectUsedLibraries {
28    libraries: HashSet<String>,
29    used_libraries: HashSet<String>,
30}
31
32impl CollectUsedLibraries {
33    fn new(libraries: HashSet<String>) -> Self {
34        Self {
35            libraries,
36            used_libraries: HashSet::new(),
37        }
38    }
39}
40
41impl Visitor for CollectUsedLibraries {
42    fn visit_identifier(&mut self, identifier: &Token) {
43        if let TokenType::Identifier { identifier } = identifier.token_type() {
44            let identifier = identifier.to_string();
45            if self.libraries.contains(&identifier) {
46                self.used_libraries.insert(identifier);
47            }
48        }
49    }
50}
51
52/// Injector that injects module's export which is a table constructor.
53pub struct Injector {
54    module_path: PathBuf,
55    exports: HashSet<String>,
56    removes: Option<Vec<String>>,
57    lua_version: LuaVersion,
58}
59
60impl Injector {
61    pub fn new(
62        module_path: PathBuf,
63        exports: HashSet<String>,
64        lua_version: LuaVersion,
65        removes: Option<Vec<String>>,
66    ) -> Self {
67        Self {
68            module_path,
69            exports,
70            removes,
71            lua_version,
72        }
73    }
74
75    pub fn module_path(&self) -> &PathBuf {
76        &self.module_path
77    }
78
79    pub fn removes(&self) -> &Option<Vec<String>> {
80        &self.removes
81    }
82
83    pub async fn inject(&self, source_path: &PathBuf) -> Result<()> {
84        let parent = source_path
85            .parent()
86            .ok_or(anyhow!("File path must have parent path"))?;
87        let require_path = diff_paths(self.module_path(), parent)
88            .ok_or(anyhow!("Couldn't resolve the require path"))?
89            .with_extension("");
90        let require_path = make_relative(&require_path).to_path_buf();
91
92        let code = fs::read_to_string(source_path).await?;
93
94        let mut lines: Vec<String> = code.lines().map(String::from).collect();
95        let mut libraries_texts: Vec<String> = Vec::new();
96
97        let ast = full_moon::parse_fallible(code.as_str(), self.lua_version)
98            .into_result()
99            .map_err(|errors| anyhow!("{:?}", errors))?;
100
101        let mut collect_used_libs = CollectUsedLibraries::new(self.exports.clone());
102        collect_used_libs.visit_ast(&ast);
103
104        for lib in collect_used_libs.used_libraries {
105            log::debug!("used library: {}", lib);
106            libraries_texts.push(format!(
107                "local {}=require'{}'.{} ",
108                lib,
109                require_path.to_slash_lossy(),
110                lib
111            ));
112        }
113
114        if let Some(removes) = self.removes() {
115            for lib in removes {
116                libraries_texts.push(format!("local {}=nil ", lib));
117            }
118        }
119
120        let libraries_text = libraries_texts.join("");
121        if let Some(first_line) = lines.get_mut(0) {
122            first_line.insert_str(0, &libraries_text);
123        } else {
124            lines.push(libraries_text);
125        }
126
127        let new_content = lines.join("\n");
128
129        log::debug!("injected source path: {:?}", source_path);
130
131        fs::write(source_path, new_content).await?;
132
133        Ok(())
134    }
135}