Skip to main content

assay/lua/
mod.rs

1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10/// Default filesystem path for external Lua libraries.
11const DEFAULT_LIB_PATH: &str = "/libs";
12
13/// Environment variable to override the external library search path.
14pub const LIB_PATH_ENV: &str = "ASSAY_LIB_PATH";
15
16const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile", "collectgarbage", "print"];
17
18fn lua_err(e: mlua::Error) -> anyhow::Error {
19    anyhow::anyhow!("{e}")
20}
21
22pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
23    let lib_path = std::env::var(LIB_PATH_ENV).unwrap_or_else(|_| DEFAULT_LIB_PATH.to_string());
24    create_vm_with_lib_path(client, lib_path)
25}
26
27pub fn create_vm_with_lib_path(client: reqwest::Client, lib_path: String) -> Result<Lua> {
28    let safe_libs = StdLib::ALL_SAFE ^ StdLib::IO ^ StdLib::OS;
29    let lua = Lua::new_with(safe_libs, LuaOptions::default()).map_err(lua_err)?;
30    lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
31    sandbox(&lua).map_err(lua_err)?;
32    register_stdlib_loader(&lua).map_err(lua_err)?;
33    register_fs_loader(&lua, lib_path).map_err(lua_err)?;
34    builtins::register_all(&lua, client).map_err(lua_err)?;
35    Ok(lua)
36}
37
38fn sandbox(lua: &Lua) -> mlua::Result<()> {
39    let globals = lua.globals();
40    for name in DANGEROUS_GLOBALS {
41        globals.set(*name, mlua::Value::Nil)?;
42    }
43
44    let string_lib: mlua::Table = globals.get("string")?;
45    string_lib.set("dump", mlua::Value::Nil)?;
46
47    Ok(())
48}
49
50fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
51    let package: mlua::Table = lua.globals().get("package")?;
52    let searchers: mlua::Table = package.get("searchers")?;
53
54    let stdlib_searcher = lua.create_function(|lua, module_name: String| {
55        let path = if let Some(rest) = module_name.strip_prefix("assay.") {
56            format!("{rest}.lua")
57        } else {
58            return Ok(mlua::Value::String(
59                lua.create_string(format!("not an assay.* module: {module_name}"))?,
60            ));
61        };
62
63        match STDLIB_DIR.get_file(&path) {
64            Some(file) => {
65                let source = file
66                    .contents_utf8()
67                    .ok_or_else(|| mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8")))?;
68                let loader = lua
69                    .load(source)
70                    .set_name(format!("@assay/{path}"))
71                    .into_function()?;
72                Ok(mlua::Value::Function(loader))
73            }
74            None => Ok(mlua::Value::String(
75                lua.create_string(format!("no embedded stdlib file: {path}"))?,
76            )),
77        }
78    })?;
79
80    let len = searchers.len()?;
81    searchers.set(len + 1, stdlib_searcher)?;
82
83    Ok(())
84}
85
86fn register_fs_loader(lua: &Lua, lib_path: String) -> mlua::Result<()> {
87    let package: mlua::Table = lua.globals().get("package")?;
88    let searchers: mlua::Table = package.get("searchers")?;
89
90    let fs_searcher = lua.create_function(move |lua, module_name: String| {
91        let filename = if let Some(rest) = module_name.strip_prefix("assay.") {
92            format!("{rest}.lua")
93        } else {
94            return Ok(mlua::Value::String(
95                lua.create_string(format!("not an assay.* module: {module_name}"))?,
96            ));
97        };
98
99        let full_path = std::path::Path::new(&lib_path).join(&filename);
100
101        match std::fs::read_to_string(&full_path) {
102            Ok(source) => {
103                let loader = lua
104                    .load(source)
105                    .set_name(format!("@{}", full_path.display()))
106                    .into_function()?;
107                Ok(mlua::Value::Function(loader))
108            }
109            Err(_) => Ok(mlua::Value::String(
110                lua.create_string(format!("no file at {}", full_path.display()))?,
111            )),
112        }
113    })?;
114
115    let len = searchers.len()?;
116    searchers.set(len + 1, fs_searcher)?;
117
118    Ok(())
119}
120
121pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
122    if env.is_empty() {
123        return Ok(());
124    }
125    let globals = lua.globals();
126    let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
127    let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
128    for (k, v) in env {
129        check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
130    }
131    Ok(())
132}