1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10pub const MODULES_PATH_ENV: &str = "ASSAY_MODULES_PATH";
12
13const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile"];
14
15fn lua_err(e: mlua::Error) -> anyhow::Error {
16 anyhow::anyhow!("{e}")
17}
18
19pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
20 create_vm_with_paths(client, None)
21}
22
23#[allow(dead_code)]
24pub fn create_vm_with_lib_path(client: reqwest::Client, lib_path: String) -> Result<Lua> {
25 create_vm_with_paths(client, Some(lib_path))
26}
27
28pub fn create_vm_with_paths(
29 client: reqwest::Client,
30 global_modules_path: Option<String>,
31) -> Result<Lua> {
32 let libs = StdLib::ALL_SAFE;
33 let lua = Lua::new_with(libs, LuaOptions::default()).map_err(lua_err)?;
34 lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
35 sandbox(&lua).map_err(lua_err)?;
36 register_fs_loader(&lua, global_modules_path).map_err(lua_err)?;
37 register_stdlib_loader(&lua).map_err(lua_err)?;
38 builtins::register_all(&lua, client).map_err(lua_err)?;
39 Ok(lua)
40}
41
42fn sandbox(lua: &Lua) -> mlua::Result<()> {
43 let globals = lua.globals();
44 for name in DANGEROUS_GLOBALS {
45 globals.set(*name, mlua::Value::Nil)?;
46 }
47
48 let string_lib: mlua::Table = globals.get("string")?;
49 string_lib.set("dump", mlua::Value::Nil)?;
50
51 Ok(())
52}
53
54fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
55 let package: mlua::Table = lua.globals().get("package")?;
56 let searchers: mlua::Table = package.get("searchers")?;
57
58 let stdlib_searcher = lua.create_function(|lua, module_name: String| {
64 let rest = match module_name.strip_prefix("assay.") {
65 Some(r) => r,
66 None => {
67 return Ok(mlua::Value::String(
68 lua.create_string(format!("not an assay.* module: {module_name}"))?,
69 ));
70 }
71 };
72
73 let base = rest.replace('.', "/");
74 let candidates = [format!("{base}.lua"), format!("{base}/init.lua")];
75
76 for path in &candidates {
77 if let Some(file) = STDLIB_DIR.get_file(path) {
78 let source = file.contents_utf8().ok_or_else(|| {
79 mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8"))
80 })?;
81 let loader = lua
82 .load(source)
83 .set_name(format!("@assay/{path}"))
84 .into_function()?;
85 return Ok(mlua::Value::Function(loader));
86 }
87 }
88
89 Ok(mlua::Value::String(
90 lua.create_string(format!("no embedded stdlib file: {}", candidates[0]))?,
91 ))
92 })?;
93
94 let len = searchers.len()?;
95 searchers.set(len + 1, stdlib_searcher)?;
96
97 Ok(())
98}
99
100fn register_fs_loader(lua: &Lua, global_modules_path: Option<String>) -> mlua::Result<()> {
101 let package: mlua::Table = lua.globals().get("package")?;
102 let searchers: mlua::Table = package.get("searchers")?;
103
104 let fs_searcher = lua.create_function(move |lua, module_name: String| {
107 let rest = match module_name.strip_prefix("assay.") {
108 Some(r) => r,
109 None => {
110 return Ok(mlua::Value::String(
111 lua.create_string(format!("not an assay.* module: {module_name}"))?,
112 ));
113 }
114 };
115 let base = rest.replace('.', "/");
116 let candidates = [format!("{base}.lua"), format!("{base}/init.lua")];
117
118 let try_load = |dir: &std::path::Path| -> Option<(std::path::PathBuf, String)> {
119 for rel in &candidates {
120 let full = dir.join(rel);
121 if let Ok(source) = std::fs::read_to_string(&full) {
122 return Some((full, source));
123 }
124 }
125 None
126 };
127
128 if let Some((full, source)) = try_load(std::path::Path::new("./modules")) {
130 let loader = lua
131 .load(source)
132 .set_name(format!("@{}", full.display()))
133 .into_function()?;
134 return Ok(mlua::Value::Function(loader));
135 }
136
137 let global_path = if let Some(ref custom_path) = global_modules_path {
139 std::path::PathBuf::from(custom_path)
140 } else if let Ok(modules_env) = std::env::var(MODULES_PATH_ENV) {
141 std::path::PathBuf::from(modules_env)
142 } else if let Ok(home) = std::env::var("HOME") {
143 std::path::Path::new(&home).join(".assay/modules")
144 } else {
145 std::path::PathBuf::new()
146 };
147
148 if !global_path.as_os_str().is_empty()
149 && let Some((full, source)) = try_load(&global_path)
150 {
151 let loader = lua
152 .load(source)
153 .set_name(format!("@{}", full.display()))
154 .into_function()?;
155 return Ok(mlua::Value::Function(loader));
156 }
157
158 Ok(mlua::Value::Nil)
161 })?;
162
163 let len = searchers.len()?;
164 searchers.set(len + 1, fs_searcher)?;
165
166 Ok(())
167}
168
169pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
170 if env.is_empty() {
171 return Ok(());
172 }
173 let globals = lua.globals();
174 let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
175 let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
176 for (k, v) in env {
177 check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
178 }
179 Ok(())
180}