Skip to main content

assay/lua/
mod.rs

1pub mod async_bridge;
2pub mod builtins;
3
4use anyhow::Result;
5use include_dir::{Dir, include_dir};
6use mlua::{Lua, LuaOptions, StdLib};
7
8static STDLIB_DIR: Dir = include_dir!("$CARGO_MANIFEST_DIR/stdlib");
9
10/// Environment variable to override the global module search path.
11pub const MODULES_PATH_ENV: &str = "ASSAY_MODULES_PATH";
12
13const DANGEROUS_GLOBALS: &[&str] = &["load", "loadfile", "dofile"];
14
15fn lua_err(e: mlua::Error) -> anyhow::Error {
16    anyhow::anyhow!("{e}")
17}
18
19pub fn create_vm(client: reqwest::Client) -> Result<Lua> {
20    create_vm_with_paths(client, None)
21}
22
23#[allow(dead_code)]
24pub fn create_vm_with_lib_path(client: reqwest::Client, lib_path: String) -> Result<Lua> {
25    create_vm_with_paths(client, Some(lib_path))
26}
27
28pub fn create_vm_with_paths(
29    client: reqwest::Client,
30    global_modules_path: Option<String>,
31) -> Result<Lua> {
32    let libs = StdLib::ALL_SAFE;
33    let lua = Lua::new_with(libs, LuaOptions::default()).map_err(lua_err)?;
34    lua.set_memory_limit(64 * 1024 * 1024).map_err(lua_err)?;
35    sandbox(&lua).map_err(lua_err)?;
36    register_fs_loader(&lua, global_modules_path).map_err(lua_err)?;
37    register_stdlib_loader(&lua).map_err(lua_err)?;
38    builtins::register_all(&lua, client).map_err(lua_err)?;
39    Ok(lua)
40}
41
42fn sandbox(lua: &Lua) -> mlua::Result<()> {
43    let globals = lua.globals();
44    for name in DANGEROUS_GLOBALS {
45        globals.set(*name, mlua::Value::Nil)?;
46    }
47
48    let string_lib: mlua::Table = globals.get("string")?;
49    string_lib.set("dump", mlua::Value::Nil)?;
50
51    Ok(())
52}
53
54fn register_stdlib_loader(lua: &Lua) -> mlua::Result<()> {
55    let package: mlua::Table = lua.globals().get("package")?;
56    let searchers: mlua::Table = package.get("searchers")?;
57
58    // Resolves `require("assay.ory.kratos")` -> "ory/kratos.lua" by replacing
59    // dots with slashes, matching standard Lua package loading convention.
60    // Tries "<path>.lua" first, then falls back to "<path>/init.lua" so
61    // both `stdlib/ory.lua` (flat convenience wrapper) and
62    // `stdlib/ory/kratos.lua` (nested submodule) resolve correctly.
63    let stdlib_searcher = lua.create_function(|lua, module_name: String| {
64        let rest = match module_name.strip_prefix("assay.") {
65            Some(r) => r,
66            None => {
67                return Ok(mlua::Value::String(
68                    lua.create_string(format!("not an assay.* module: {module_name}"))?,
69                ));
70            }
71        };
72
73        let base = rest.replace('.', "/");
74        let candidates = [format!("{base}.lua"), format!("{base}/init.lua")];
75
76        for path in &candidates {
77            if let Some(file) = STDLIB_DIR.get_file(path) {
78                let source = file.contents_utf8().ok_or_else(|| {
79                    mlua::Error::runtime(format!("stdlib {path}: invalid UTF-8"))
80                })?;
81                let loader = lua
82                    .load(source)
83                    .set_name(format!("@assay/{path}"))
84                    .into_function()?;
85                return Ok(mlua::Value::Function(loader));
86            }
87        }
88
89        Ok(mlua::Value::String(
90            lua.create_string(format!("no embedded stdlib file: {}", candidates[0]))?,
91        ))
92    })?;
93
94    let len = searchers.len()?;
95    searchers.set(len + 1, stdlib_searcher)?;
96
97    Ok(())
98}
99
100fn register_fs_loader(lua: &Lua, global_modules_path: Option<String>) -> mlua::Result<()> {
101    let package: mlua::Table = lua.globals().get("package")?;
102    let searchers: mlua::Table = package.get("searchers")?;
103
104    // Same dotted-path resolution as the stdlib loader: `assay.ory.kratos`
105    // -> "ory/kratos.lua", falling back to "ory/kratos/init.lua".
106    let fs_searcher = lua.create_function(move |lua, module_name: String| {
107        let rest = match module_name.strip_prefix("assay.") {
108            Some(r) => r,
109            None => {
110                return Ok(mlua::Value::String(
111                    lua.create_string(format!("not an assay.* module: {module_name}"))?,
112                ));
113            }
114        };
115        let base = rest.replace('.', "/");
116        let candidates = [format!("{base}.lua"), format!("{base}/init.lua")];
117
118        let try_load = |dir: &std::path::Path| -> Option<(std::path::PathBuf, String)> {
119            for rel in &candidates {
120                let full = dir.join(rel);
121                if let Ok(source) = std::fs::read_to_string(&full) {
122                    return Some((full, source));
123                }
124            }
125            None
126        };
127
128        // Priority 1: ./modules/<path>.lua (per-project)
129        if let Some((full, source)) = try_load(std::path::Path::new("./modules")) {
130            let loader = lua
131                .load(source)
132                .set_name(format!("@{}", full.display()))
133                .into_function()?;
134            return Ok(mlua::Value::Function(loader));
135        }
136
137        // Priority 2: $ASSAY_MODULES_PATH or ~/.assay/modules/<path>.lua
138        let global_path = if let Some(ref custom_path) = global_modules_path {
139            std::path::PathBuf::from(custom_path)
140        } else if let Ok(modules_env) = std::env::var(MODULES_PATH_ENV) {
141            std::path::PathBuf::from(modules_env)
142        } else if let Ok(home) = std::env::var("HOME") {
143            std::path::Path::new(&home).join(".assay/modules")
144        } else {
145            std::path::PathBuf::new()
146        };
147
148        if !global_path.as_os_str().is_empty()
149            && let Some((full, source)) = try_load(&global_path)
150        {
151            let loader = lua
152                .load(source)
153                .set_name(format!("@{}", full.display()))
154                .into_function()?;
155            return Ok(mlua::Value::Function(loader));
156        }
157
158        // Priority 3: Built-in modules are handled by register_stdlib_loader
159        // Return nil to fall through to the next searcher
160        Ok(mlua::Value::Nil)
161    })?;
162
163    let len = searchers.len()?;
164    searchers.set(len + 1, fs_searcher)?;
165
166    Ok(())
167}
168
169pub fn inject_env(lua: &Lua, env: &std::collections::HashMap<String, String>) -> Result<()> {
170    if env.is_empty() {
171        return Ok(());
172    }
173    let globals = lua.globals();
174    let env_table: mlua::Table = globals.get("env").map_err(lua_err)?;
175    let check_env: mlua::Table = env_table.get("_check_env").map_err(lua_err)?;
176    for (k, v) in env {
177        check_env.set(k.as_str(), v.as_str()).map_err(lua_err)?;
178    }
179    Ok(())
180}