1use std::{
2 borrow::Cow,
3 collections::HashSet,
4 path::{Path, PathBuf},
5};
6
7use anyhow::{anyhow, Result};
8use full_moon::{
9 tokenizer::{Token, TokenType},
10 visitors::Visitor,
11 LuaVersion,
12};
13use path_slash::PathBufExt;
14use pathdiff::diff_paths;
15use tokio::fs;
16
17#[inline]
18fn make_relative(path: &PathBuf) -> Cow<Path> {
19 if path.starts_with(".") | path.starts_with("..") {
20 Cow::Borrowed(path.as_path())
21 } else {
22 Cow::Owned(Path::new(".").join(path))
23 }
24}
25
26#[derive(Debug)]
27struct CollectUsedLibraries {
28 libraries: HashSet<String>,
29 used_libraries: HashSet<String>,
30}
31
32impl CollectUsedLibraries {
33 fn new(libraries: HashSet<String>) -> Self {
34 Self {
35 libraries,
36 used_libraries: HashSet::new(),
37 }
38 }
39}
40
41impl Visitor for CollectUsedLibraries {
42 fn visit_identifier(&mut self, identifier: &Token) {
43 if let TokenType::Identifier { identifier } = identifier.token_type() {
44 let identifier = identifier.to_string();
45 if self.libraries.contains(&identifier) {
46 self.used_libraries.insert(identifier);
47 }
48 }
49 }
50}
51
52pub struct Injector {
54 module_path: PathBuf,
55 exports: HashSet<String>,
56 removes: Option<Vec<String>>,
57 lua_version: LuaVersion,
58}
59
60impl Injector {
61 pub fn new(
62 module_path: PathBuf,
63 exports: HashSet<String>,
64 lua_version: LuaVersion,
65 removes: Option<Vec<String>>,
66 ) -> Self {
67 Self {
68 module_path,
69 exports,
70 removes,
71 lua_version,
72 }
73 }
74
75 pub fn module_path(&self) -> &PathBuf {
76 &self.module_path
77 }
78
79 pub fn removes(&self) -> &Option<Vec<String>> {
80 &self.removes
81 }
82
83 pub async fn inject(&self, source_path: &PathBuf) -> Result<()> {
84 let parent = source_path
85 .parent()
86 .ok_or(anyhow!("File path must have parent path"))?;
87 let require_path = diff_paths(self.module_path(), parent)
88 .ok_or(anyhow!("Couldn't resolve the require path"))?
89 .with_extension("");
90 let require_path = make_relative(&require_path).to_path_buf();
91
92 let code = fs::read_to_string(source_path).await?;
93
94 let mut lines: Vec<String> = code.lines().map(String::from).collect();
95 let mut libraries_texts: Vec<String> = Vec::new();
96
97 let ast = full_moon::parse_fallible(code.as_str(), self.lua_version)
98 .into_result()
99 .map_err(|errors| anyhow!("{:?}", errors))?;
100
101 let mut collect_used_libs = CollectUsedLibraries::new(self.exports.clone());
102 collect_used_libs.visit_ast(&ast);
103
104 for lib in collect_used_libs.used_libraries {
105 log::debug!("used library: {}", lib);
106 libraries_texts.push(format!(
107 "local {}=require'{}'.{} ",
108 lib,
109 require_path.to_slash_lossy(),
110 lib
111 ));
112 }
113
114 if let Some(removes) = self.removes() {
115 for lib in removes {
116 libraries_texts.push(format!("local {}=nil ", lib));
117 }
118 }
119
120 let libraries_text = libraries_texts.join("");
121 if let Some(first_line) = lines.get_mut(0) {
122 first_line.insert_str(0, &libraries_text);
123 } else {
124 lines.push(libraries_text);
125 }
126
127 let new_content = lines.join("\n");
128
129 log::debug!("injected source path: {:?}", source_path);
130
131 fs::write(source_path, new_content).await?;
132
133 Ok(())
134 }
135}