Skip to main content

assay/lua/
mod.rs

1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10/// Environment variable to override the global module search path.
11pub const MODULES_PATH_ENV: &str = "ASSAY_MODULES_PATH";
12
13const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile"];
14
15fn lua_err(e: mlua::Error) -> anyhow::Error {
16    anyhow::anyhow!("{e}")
17}
18
19pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
20    create_vm_with_paths(client, None)
21}
22
23#[allow(dead_code)]
24pub fn create_vm_with_lib_path(client: reqwest::Client, lib_path: String) -> Result<Lua> {
25    create_vm_with_paths(client, Some(lib_path))
26}
27
28pub fn create_vm_with_paths(
29    client: reqwest::Client,
30    global_modules_path: Option<String>,
31) -> Result<Lua> {
32    let libs = StdLib::ALL_SAFE;
33    let lua = Lua::new_with(libs, LuaOptions::default()).map_err(lua_err)?;
34    lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
35    sandbox(&lua).map_err(lua_err)?;
36    register_fs_loader(&lua, global_modules_path).map_err(lua_err)?;
37    register_stdlib_loader(&lua).map_err(lua_err)?;
38    builtins::register_all(&lua, client).map_err(lua_err)?;
39    Ok(lua)
40}
41
42fn sandbox(lua: &Lua) -> mlua::Result<()> {
43    let globals = lua.globals();
44    for name in DANGEROUS_GLOBALS {
45        globals.set(*name, mlua::Value::Nil)?;
46    }
47
48    let string_lib: mlua::Table = globals.get("string")?;
49    string_lib.set("dump", mlua::Value::Nil)?;
50
51    Ok(())
52}
53
54fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
55    let package: mlua::Table = lua.globals().get("package")?;
56    let searchers: mlua::Table = package.get("searchers")?;
57
58    let stdlib_searcher = lua.create_function(|lua, module_name: String| {
59        let path = if let Some(rest) = module_name.strip_prefix("assay.") {
60            format!("{rest}.lua")
61        } else {
62            return Ok(mlua::Value::String(
63                lua.create_string(format!("not an assay.* module: {module_name}"))?,
64            ));
65        };
66
67        match STDLIB_DIR.get_file(&path) {
68            Some(file) => {
69                let source = file
70                    .contents_utf8()
71                    .ok_or_else(|| mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8")))?;
72                let loader = lua
73                    .load(source)
74                    .set_name(format!("@assay/{path}"))
75                    .into_function()?;
76                Ok(mlua::Value::Function(loader))
77            }
78            None => Ok(mlua::Value::String(
79                lua.create_string(format!("no embedded stdlib file: {path}"))?,
80            )),
81        }
82    })?;
83
84    let len = searchers.len()?;
85    searchers.set(len + 1, stdlib_searcher)?;
86
87    Ok(())
88}
89
90fn register_fs_loader(lua: &Lua, global_modules_path: Option<String>) -> mlua::Result<()> {
91    let package: mlua::Table = lua.globals().get("package")?;
92    let searchers: mlua::Table = package.get("searchers")?;
93
94    let fs_searcher = lua.create_function(move |lua, module_name: String| {
95        let filename = if let Some(rest) = module_name.strip_prefix("assay.") {
96            format!("{rest}.lua")
97        } else {
98            return Ok(mlua::Value::String(
99                lua.create_string(format!("not an assay.* module: {module_name}"))?,
100            ));
101        };
102
103        // Priority 1: ./modules/<name>.lua (per-project)
104        let project_path = std::path::Path::new("./modules").join(&filename);
105        if let Ok(source) = std::fs::read_to_string(&project_path) {
106            let loader = lua
107                .load(source)
108                .set_name(format!("@{}", project_path.display()))
109                .into_function()?;
110            return Ok(mlua::Value::Function(loader));
111        }
112
113        // Priority 2: ~/.assay/modules/<name>.lua (global user) or $ASSAY_MODULES_PATH
114        let global_path = if let Some(ref custom_path) = global_modules_path {
115            std::path::PathBuf::from(custom_path)
116        } else if let Ok(modules_env) = std::env::var(MODULES_PATH_ENV) {
117            std::path::PathBuf::from(modules_env)
118        } else if let Ok(home) = std::env::var("HOME") {
119            std::path::Path::new(&home).join(".assay/modules")
120        } else {
121            // No home directory available, skip global path
122            std::path::PathBuf::new()
123        };
124
125        if !global_path.as_os_str().is_empty() {
126            let global_file_path = global_path.join(&filename);
127            if let Ok(source) = std::fs::read_to_string(&global_file_path) {
128                let loader = lua
129                    .load(source)
130                    .set_name(format!("@{}", global_file_path.display()))
131                    .into_function()?;
132                return Ok(mlua::Value::Function(loader));
133            }
134        }
135
136        // Priority 3: Built-in modules are handled by register_stdlib_loader
137        // Return nil to fall through to the next searcher
138        Ok(mlua::Value::Nil)
139    })?;
140
141    let len = searchers.len()?;
142    searchers.set(len + 1, fs_searcher)?;
143
144    Ok(())
145}
146
147pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
148    if env.is_empty() {
149        return Ok(());
150    }
151    let globals = lua.globals();
152    let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
153    let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
154    for (k, v) in env {
155        check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
156    }
157    Ok(())
158}