1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10pub const MODULES_PATH_ENV: &str = "ASSAY_MODULES_PATH";
12
13const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile"];
14
15fn lua_err(e: mlua::Error) -> anyhow::Error {
16 anyhow::anyhow!("{e}")
17}
18
19pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
20 create_vm_with_paths(client, None)
21}
22
23#[allow(dead_code)]
24pub fn create_vm_with_lib_path(client: reqwest::Client, lib_path: String) -> Result<Lua> {
25 create_vm_with_paths(client, Some(lib_path))
26}
27
28pub fn create_vm_with_paths(client: reqwest::Client, global_modules_path: Option<String>) -> Result<Lua> {
29 let libs = StdLib::ALL_SAFE;
30 let lua = Lua::new_with(libs, LuaOptions::default()).map_err(lua_err)?;
31 lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
32 sandbox(&lua).map_err(lua_err)?;
33 register_fs_loader(&lua, global_modules_path).map_err(lua_err)?;
34 register_stdlib_loader(&lua).map_err(lua_err)?;
35 builtins::register_all(&lua, client).map_err(lua_err)?;
36 Ok(lua)
37}
38
39fn sandbox(lua: &Lua) -> mlua::Result<()> {
40 let globals = lua.globals();
41 for name in DANGEROUS_GLOBALS {
42 globals.set(*name, mlua::Value::Nil)?;
43 }
44
45 let string_lib: mlua::Table = globals.get("string")?;
46 string_lib.set("dump", mlua::Value::Nil)?;
47
48 Ok(())
49}
50
51fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
52 let package: mlua::Table = lua.globals().get("package")?;
53 let searchers: mlua::Table = package.get("searchers")?;
54
55 let stdlib_searcher = lua.create_function(|lua, module_name: String| {
56 let path = if let Some(rest) = module_name.strip_prefix("assay.") {
57 format!("{rest}.lua")
58 } else {
59 return Ok(mlua::Value::String(
60 lua.create_string(format!("not an assay.* module: {module_name}"))?,
61 ));
62 };
63
64 match STDLIB_DIR.get_file(&path) {
65 Some(file) => {
66 let source = file
67 .contents_utf8()
68 .ok_or_else(|| mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8")))?;
69 let loader = lua
70 .load(source)
71 .set_name(format!("@assay/{path}"))
72 .into_function()?;
73 Ok(mlua::Value::Function(loader))
74 }
75 None => Ok(mlua::Value::String(
76 lua.create_string(format!("no embedded stdlib file: {path}"))?,
77 )),
78 }
79 })?;
80
81 let len = searchers.len()?;
82 searchers.set(len + 1, stdlib_searcher)?;
83
84 Ok(())
85}
86
87fn register_fs_loader(lua: &Lua, global_modules_path: Option<String>) -> mlua::Result<()> {
88 let package: mlua::Table = lua.globals().get("package")?;
89 let searchers: mlua::Table = package.get("searchers")?;
90
91 let fs_searcher = lua.create_function(move |lua, module_name: String| {
92 let filename = if let Some(rest) = module_name.strip_prefix("assay.") {
93 format!("{rest}.lua")
94 } else {
95 return Ok(mlua::Value::String(
96 lua.create_string(format!("not an assay.* module: {module_name}"))?,
97 ));
98 };
99
100 let project_path = std::path::Path::new("./modules").join(&filename);
102 if let Ok(source) = std::fs::read_to_string(&project_path) {
103 let loader = lua
104 .load(source)
105 .set_name(format!("@{}", project_path.display()))
106 .into_function()?;
107 return Ok(mlua::Value::Function(loader));
108 }
109
110 let global_path = if let Some(ref custom_path) = global_modules_path {
112 std::path::PathBuf::from(custom_path)
113 } else if let Ok(modules_env) = std::env::var(MODULES_PATH_ENV) {
114 std::path::PathBuf::from(modules_env)
115 } else if let Ok(home) = std::env::var("HOME") {
116 std::path::Path::new(&home).join(".assay/modules")
117 } else {
118 std::path::PathBuf::new()
120 };
121
122 if !global_path.as_os_str().is_empty() {
123 let global_file_path = global_path.join(&filename);
124 if let Ok(source) = std::fs::read_to_string(&global_file_path) {
125 let loader = lua
126 .load(source)
127 .set_name(format!("@{}", global_file_path.display()))
128 .into_function()?;
129 return Ok(mlua::Value::Function(loader));
130 }
131 }
132
133 Ok(mlua::Value::Nil)
136 })?;
137
138 let len = searchers.len()?;
139 searchers.set(len + 1, fs_searcher)?;
140
141 Ok(())
142}
143
144pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
145 if env.is_empty() {
146 return Ok(());
147 }
148 let globals = lua.globals();
149 let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
150 let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
151 for (k, v) in env {
152 check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
153 }
154 Ok(())
155}