Skip to main content

assay/lua/
mod.rs

1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile", "collectgarbage", "print"];
11
12fn lua_err(e: mlua::Error) -> anyhow::Error {
13    anyhow::anyhow!("{e}")
14}
15
16pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
17    let safe_libs = StdLib::ALL_SAFE ^ StdLib::IO ^ StdLib::OS;
18    let lua = Lua::new_with(safe_libs, LuaOptions::default()).map_err(lua_err)?;
19    lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
20    sandbox(&lua).map_err(lua_err)?;
21    register_stdlib_loader(&lua).map_err(lua_err)?;
22    builtins::register_all(&lua, client).map_err(lua_err)?;
23    Ok(lua)
24}
25
26fn sandbox(lua: &Lua) -> mlua::Result<()> {
27    let globals = lua.globals();
28    for name in DANGEROUS_GLOBALS {
29        globals.set(*name, mlua::Value::Nil)?;
30    }
31
32    let string_lib: mlua::Table = globals.get("string")?;
33    string_lib.set("dump", mlua::Value::Nil)?;
34
35    Ok(())
36}
37
38fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
39    let package: mlua::Table = lua.globals().get("package")?;
40    let searchers: mlua::Table = package.get("searchers")?;
41
42    let stdlib_searcher = lua.create_function(|lua, module_name: String| {
43        let path = if let Some(rest) = module_name.strip_prefix("assay.") {
44            format!("{rest}.lua")
45        } else {
46            return Ok(mlua::Value::String(
47                lua.create_string(format!("not an assay.* module: {module_name}"))?,
48            ));
49        };
50
51        match STDLIB_DIR.get_file(&path) {
52            Some(file) => {
53                let source = file
54                    .contents_utf8()
55                    .ok_or_else(|| mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8")))?;
56                let loader = lua
57                    .load(source)
58                    .set_name(format!("@assay/{path}"))
59                    .into_function()?;
60                Ok(mlua::Value::Function(loader))
61            }
62            None => Ok(mlua::Value::String(
63                lua.create_string(format!("no embedded stdlib file: {path}"))?,
64            )),
65        }
66    })?;
67
68    let len = searchers.len()?;
69    searchers.set(len + 1, stdlib_searcher)?;
70
71    Ok(())
72}
73
74pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
75    if env.is_empty() {
76        return Ok(());
77    }
78    let globals = lua.globals();
79    let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
80    let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
81    for (k, v) in env {
82        check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
83    }
84    Ok(())
85}