Skip to main content

assay/lua/
mod.rs

1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10/// Environment variable to override the global module search path.
11pub const MODULES_PATH_ENV: &str = "ASSAY_MODULES_PATH";
12
13const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile"];
14
15fn lua_err(e: mlua::Error) -> anyhow::Error {
16    anyhow::anyhow!("{e}")
17}
18
19pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
20    create_vm_with_paths(client, None)
21}
22
23#[allow(dead_code)]
24pub fn create_vm_with_lib_path(client: reqwest::Client, lib_path: String) -> Result<Lua> {
25    create_vm_with_paths(client, Some(lib_path))
26}
27
28pub fn create_vm_with_paths(client: reqwest::Client, global_modules_path: Option<String>) -> Result<Lua> {
29    let libs = StdLib::ALL_SAFE;
30    let lua = Lua::new_with(libs, LuaOptions::default()).map_err(lua_err)?;
31    lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
32    sandbox(&lua).map_err(lua_err)?;
33    register_fs_loader(&lua, global_modules_path).map_err(lua_err)?;
34    register_stdlib_loader(&lua).map_err(lua_err)?;
35    builtins::register_all(&lua, client).map_err(lua_err)?;
36    Ok(lua)
37}
38
39fn sandbox(lua: &Lua) -> mlua::Result<()> {
40    let globals = lua.globals();
41    for name in DANGEROUS_GLOBALS {
42        globals.set(*name, mlua::Value::Nil)?;
43    }
44
45    let string_lib: mlua::Table = globals.get("string")?;
46    string_lib.set("dump", mlua::Value::Nil)?;
47
48    Ok(())
49}
50
51fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
52    let package: mlua::Table = lua.globals().get("package")?;
53    let searchers: mlua::Table = package.get("searchers")?;
54
55    let stdlib_searcher = lua.create_function(|lua, module_name: String| {
56        let path = if let Some(rest) = module_name.strip_prefix("assay.") {
57            format!("{rest}.lua")
58        } else {
59            return Ok(mlua::Value::String(
60                lua.create_string(format!("not an assay.* module: {module_name}"))?,
61            ));
62        };
63
64        match STDLIB_DIR.get_file(&path) {
65            Some(file) => {
66                let source = file
67                    .contents_utf8()
68                    .ok_or_else(|| mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8")))?;
69                let loader = lua
70                    .load(source)
71                    .set_name(format!("@assay/{path}"))
72                    .into_function()?;
73                Ok(mlua::Value::Function(loader))
74            }
75            None => Ok(mlua::Value::String(
76                lua.create_string(format!("no embedded stdlib file: {path}"))?,
77            )),
78        }
79    })?;
80
81    let len = searchers.len()?;
82    searchers.set(len + 1, stdlib_searcher)?;
83
84    Ok(())
85}
86
87fn register_fs_loader(lua: &Lua, global_modules_path: Option<String>) -> mlua::Result<()> {
88    let package: mlua::Table = lua.globals().get("package")?;
89    let searchers: mlua::Table = package.get("searchers")?;
90
91    let fs_searcher = lua.create_function(move |lua, module_name: String| {
92        let filename = if let Some(rest) = module_name.strip_prefix("assay.") {
93            format!("{rest}.lua")
94        } else {
95            return Ok(mlua::Value::String(
96                lua.create_string(format!("not an assay.* module: {module_name}"))?,
97            ));
98        };
99
100        // Priority 1: ./modules/<name>.lua (per-project)
101        let project_path = std::path::Path::new("./modules").join(&filename);
102        if let Ok(source) = std::fs::read_to_string(&project_path) {
103            let loader = lua
104                .load(source)
105                .set_name(format!("@{}", project_path.display()))
106                .into_function()?;
107            return Ok(mlua::Value::Function(loader));
108        }
109
110        // Priority 2: ~/.assay/modules/<name>.lua (global user) or $ASSAY_MODULES_PATH
111        let global_path = if let Some(ref custom_path) = global_modules_path {
112            std::path::PathBuf::from(custom_path)
113        } else if let Ok(modules_env) = std::env::var(MODULES_PATH_ENV) {
114            std::path::PathBuf::from(modules_env)
115        } else if let Ok(home) = std::env::var("HOME") {
116            std::path::Path::new(&home).join(".assay/modules")
117        } else {
118            // No home directory available, skip global path
119            std::path::PathBuf::new()
120        };
121
122        if !global_path.as_os_str().is_empty() {
123            let global_file_path = global_path.join(&filename);
124            if let Ok(source) = std::fs::read_to_string(&global_file_path) {
125                let loader = lua
126                    .load(source)
127                    .set_name(format!("@{}", global_file_path.display()))
128                    .into_function()?;
129                return Ok(mlua::Value::Function(loader));
130            }
131        }
132
133        // Priority 3: Built-in modules are handled by register_stdlib_loader
134        // Return nil to fall through to the next searcher
135        Ok(mlua::Value::Nil)
136    })?;
137
138    let len = searchers.len()?;
139    searchers.set(len + 1, fs_searcher)?;
140
141    Ok(())
142}
143
144pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
145    if env.is_empty() {
146        return Ok(());
147    }
148    let globals = lua.globals();
149    let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
150    let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
151    for (k, v) in env {
152        check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
153    }
154    Ok(())
155}