pub mod policy;
use crate::runtime::docs::{self, LuaDoc, LuaDocTyp};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SafetyLevel {
Safe,
Unsafe,
}
#[derive(Debug, Clone)]
pub enum ApiEntry {
Function {
name: &'static str,
safety: SafetyLevel,
},
Module {
name: &'static str,
entries: &'static [ApiEntry],
},
SafeModule { name: &'static str },
}
impl ApiEntry {
pub const fn unsafe_function(name: &'static str) -> Self {
ApiEntry::Function {
name,
safety: SafetyLevel::Unsafe,
}
}
pub const fn safe_function(name: &'static str) -> Self {
ApiEntry::Function {
name,
safety: SafetyLevel::Safe,
}
}
pub const fn safe_module(name: &'static str) -> Self {
ApiEntry::SafeModule { name }
}
}
pub type ApiSpec = &'static [ApiEntry];
pub const DEFAULT_API_SPEC: ApiSpec = &[
ApiEntry::Module {
name: "os",
entries: &[
ApiEntry::safe_function("time"),
ApiEntry::safe_function("date"),
ApiEntry::unsafe_function("execute"),
ApiEntry::unsafe_function("remove"),
ApiEntry::unsafe_function("rename"),
ApiEntry::unsafe_function("exit"),
ApiEntry::unsafe_function("getenv"),
ApiEntry::unsafe_function("setlocale"),
ApiEntry::unsafe_function("tmpname"),
],
},
ApiEntry::Module {
name: "io",
entries: &[
ApiEntry::unsafe_function("open"),
ApiEntry::unsafe_function("close"),
ApiEntry::unsafe_function("read"),
ApiEntry::unsafe_function("write"),
ApiEntry::unsafe_function("flush"),
ApiEntry::unsafe_function("lines"),
ApiEntry::unsafe_function("input"),
ApiEntry::unsafe_function("output"),
ApiEntry::unsafe_function("popen"),
ApiEntry::unsafe_function("tmpfile"),
ApiEntry::unsafe_function("type"),
],
},
ApiEntry::safe_module("string"),
ApiEntry::safe_module("table"),
ApiEntry::safe_module("math"),
ApiEntry::safe_module("utf8"),
ApiEntry::unsafe_function("load"),
ApiEntry::unsafe_function("loadstring"),
ApiEntry::unsafe_function("loadfile"),
ApiEntry::unsafe_function("dofile"),
ApiEntry::unsafe_function("require"),
ApiEntry::unsafe_function("getmetatable"),
ApiEntry::unsafe_function("setmetatable"),
ApiEntry::unsafe_function("rawget"),
ApiEntry::unsafe_function("rawset"),
ApiEntry::unsafe_function("rawequal"),
ApiEntry::unsafe_function("rawlen"),
ApiEntry::unsafe_function("collectgarbage"),
ApiEntry::safe_function("type"),
ApiEntry::safe_function("tonumber"),
ApiEntry::safe_function("tostring"),
ApiEntry::safe_function("print"),
ApiEntry::safe_function("ipairs"),
ApiEntry::safe_function("pairs"),
ApiEntry::safe_function("next"),
ApiEntry::safe_function("select"),
ApiEntry::safe_function("assert"),
ApiEntry::safe_function("error"),
ApiEntry::safe_function("pcall"),
ApiEntry::safe_function("xpcall"),
];
pub fn apply(lua: &mlua::Lua) -> mlua::Result<()> {
let policy = Arc::new(policy::DenyAllPolicy);
apply_with_policy(lua, policy, None)?;
register_docs(lua)?;
Ok(())
}
pub fn apply_with_policy<P: policy::Policy + 'static>(
lua: &mlua::Lua,
policy: Arc<P>,
api_spec: Option<ApiSpec>,
) -> mlua::Result<()> {
let spec = api_spec.unwrap_or(DEFAULT_API_SPEC);
let globals = lua.globals();
let mut modules_in_spec = std::collections::HashSet::new();
for entry in spec {
if let ApiEntry::Module { name, .. } = entry {
modules_in_spec.insert(*name);
}
}
let processed = process_entries(
lua, spec, policy, "", &globals,
)?;
globals.clear()?;
processed.for_each(|k: mlua::Value, v: mlua::Value| globals.set(k, v))?;
Ok(())
}
fn register_docs(lua: &mlua::Lua) -> mlua::Result<()> {
docs::register(
lua,
&LuaDoc {
name: "os".to_string(),
typ: LuaDocTyp::Scope,
description: "Operating system functions (sandboxed)".to_string(),
},
)?;
docs::register(
lua,
&LuaDoc {
name: "os.time".to_string(),
typ: LuaDocTyp::Function,
description: "Returns current Unix timestamp".to_string(),
},
)?;
docs::register(
lua,
&LuaDoc {
name: "os.date".to_string(),
typ: LuaDocTyp::Function,
description: "Formats date/time. Usage: os.date(format, time?)".to_string(),
},
)?;
docs::register(
lua,
&LuaDoc {
name: "string".to_string(),
typ: LuaDocTyp::Scope,
description: "String manipulation functions".to_string(),
},
)?;
docs::register(
lua,
&LuaDoc {
name: "table".to_string(),
typ: LuaDocTyp::Scope,
description: "Table manipulation functions".to_string(),
},
)?;
docs::register(
lua,
&LuaDoc {
name: "math".to_string(),
typ: LuaDocTyp::Scope,
description: "Mathematical functions".to_string(),
},
)?;
docs::register(
lua,
&LuaDoc {
name: "utf8".to_string(),
typ: LuaDocTyp::Scope,
description: "UTF-8 support functions".to_string(),
},
)?;
Ok(())
}
fn process_entries<P: policy::Policy + 'static>(
lua: &mlua::Lua,
entries: &'static [ApiEntry],
policy: Arc<P>,
name_prefix: &str,
source_table: &mlua::Table,
) -> mlua::Result<mlua::Table> {
let target_table = lua.create_table()?;
for entry in entries {
match entry {
ApiEntry::Function { name, safety } => {
if let Ok(value) = source_table.get::<mlua::Value>(*name) {
match safety {
SafetyLevel::Safe => {
target_table.set(*name, value)?;
}
SafetyLevel::Unsafe => {
if let mlua::Value::Function(func) = value {
let qualified_name = if name_prefix.is_empty() {
name.to_string()
} else {
format!("{}.{}", name_prefix, name)
};
let wrapped = wrap_unsafe_function(
lua,
&qualified_name,
func,
Arc::clone(&policy),
)?;
target_table.set(*name, wrapped)?;
} else {
target_table.set(*name, value)?;
}
}
}
}
}
ApiEntry::Module { name, entries } => {
if let Ok(original_module) = source_table.get::<mlua::Table>(*name) {
let qualified_name = if name_prefix.is_empty() {
name.to_string()
} else {
format!("{}.{}", name_prefix, name)
};
let processed_module = process_entries(
lua,
entries,
Arc::clone(&policy),
&qualified_name,
&original_module,
)?;
target_table.set(*name, processed_module)?;
}
}
ApiEntry::SafeModule { name } => {
if let Ok(original_module) = source_table.get::<mlua::Table>(*name) {
target_table.set(*name, original_module)?;
}
}
}
}
Ok(target_table)
}
pub fn wrap_unsafe_function<P: policy::Policy + 'static>(
lua: &mlua::Lua,
function_name: &str,
original_fn: mlua::Function,
policy: Arc<P>,
) -> mlua::Result<mlua::Function> {
let function_name_owned = function_name.to_string();
lua.create_function(move |_lua, args: mlua::MultiValue| {
let action = policy::Action::CallFunction {
name: function_name_owned.clone(),
args: args.clone(),
};
let decision = policy.check_access(&action);
if let policy::Decision::Deny(reason) = decision {
eprintln!("Access denied: {}", reason);
return Ok(mlua::MultiValue::from_vec(vec![mlua::Value::Nil]));
}
let result = original_fn.call::<mlua::MultiValue>(args)?;
Ok(result)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::output::with_output_capture;
#[test]
fn test_safe_functions_work() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let time: i64 = lua.load("return os.time()").eval().unwrap();
assert!(time > 0);
let date: String = lua.load("return os.date('%Y')").eval().unwrap();
assert_eq!(date.len(), 4);
let result: i32 = lua.load("return tonumber('42')").eval().unwrap();
assert_eq!(result, 42);
}
#[test]
fn test_safe_modules_work() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let upper: String = lua.load("return string.upper('hello')").eval().unwrap();
assert_eq!(upper, "HELLO");
let sqrt: f64 = lua.load("return math.sqrt(16)").eval().unwrap();
assert_eq!(sqrt, 4.0);
let concat: String = lua
.load("return table.concat({'a', 'b', 'c'}, ',')")
.eval()
.unwrap();
assert_eq!(concat, "a,b,c");
}
#[test]
fn test_unsafe_functions_return_nil() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let result: mlua::Value = lua.load("return os.execute('echo test')").eval().unwrap();
assert!(matches!(result, mlua::Value::Nil));
let result: mlua::Value = lua.load("return io.open('test.txt', 'r')").eval().unwrap();
assert!(matches!(result, mlua::Value::Nil));
let result: mlua::Value = lua.load("return load('return 1')").eval().unwrap();
assert!(matches!(result, mlua::Value::Nil));
}
#[test]
fn test_forbidden_globals_are_nil() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let debug: mlua::Value = lua.load("return debug").eval().unwrap();
assert!(matches!(debug, mlua::Value::Nil));
let coroutine: mlua::Value = lua.load("return coroutine").eval().unwrap();
assert!(matches!(coroutine, mlua::Value::Nil));
let package: mlua::Value = lua.load("return package").eval().unwrap();
assert!(matches!(package, mlua::Value::Nil));
}
#[test]
fn test_basic_lua_functions_work() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let (result, output) = with_output_capture(&lua, |lua| {
lua.load(
r#"
local t = {10, 20, 30}
for i, v in ipairs(t) do
print(v)
end
"#,
)
.exec()
})
.unwrap();
assert!(result.is_ok());
assert_eq!(output.len(), 3);
}
#[test]
fn test_custom_globals_after_sandboxing_persist() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
lua.globals()
.set("custom", lua.create_function(|_, ()| Ok(42)).unwrap())
.unwrap();
let result: i32 = lua.load("return custom()").eval().unwrap();
assert_eq!(result, 42);
}
#[test]
fn test_docs_registered() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let docs_type: String = lua.load("return type(docs)").eval().unwrap();
assert_eq!(docs_type, "table");
}
#[test]
fn test_public_api_surface() {
let lua = mlua::Lua::new();
apply(&lua).unwrap();
let lua2 = mlua::Lua::new();
let policy = Arc::new(policy::DenyAllPolicy);
apply_with_policy(&lua2, policy, None).unwrap();
let _entry = ApiEntry::safe_function("test");
let _safety = SafetyLevel::Safe;
let _spec: ApiSpec = &[];
let _default = DEFAULT_API_SPEC;
lua2.load("function test() return 42 end").exec().unwrap();
let func: mlua::Function = lua2.globals().get("test").unwrap();
let policy2 = Arc::new(policy::DenyAllPolicy);
let _wrapped = wrap_unsafe_function(&lua2, "test", func, policy2).unwrap();
}
}