use std::collections::HashMap;
use std::path::{Path, PathBuf};
use mlua::{Lua, LuaSerdeExt, Result, Value};
use crate::sandbox::{FsSandbox, InitError, ReadError, SandboxedFs};
use crate::{ResolveError, Resolver};
type NativeFactory = Box<dyn Fn(&Lua) -> Result<Value> + Send + Sync>;
fn read_to_resolve_error(err: ReadError, name: &str, sanitized_path: &Path) -> ResolveError {
match err {
ReadError::Traversal { .. } => ResolveError::PathTraversal {
name: name.to_owned(),
},
ReadError::Io { source, .. } => ResolveError::Io {
path: sanitized_path.to_path_buf(),
source,
},
}
}
pub struct MemoryResolver {
modules: HashMap<String, String>,
}
impl Default for MemoryResolver {
fn default() -> Self {
Self::new()
}
}
impl MemoryResolver {
pub fn new() -> Self {
Self {
modules: HashMap::new(),
}
}
pub fn add(mut self, name: impl Into<String>, source: impl Into<String>) -> Self {
self.modules.insert(name.into(), source.into());
self
}
}
impl Resolver for MemoryResolver {
fn resolve(&self, lua: &Lua, name: &str) -> Option<Result<Value>> {
let source = self.modules.get(name)?;
Some(lua.load(source.as_str()).set_name(name).eval())
}
}
pub struct NativeResolver {
modules: HashMap<String, NativeFactory>,
}
impl Default for NativeResolver {
fn default() -> Self {
Self::new()
}
}
impl NativeResolver {
pub fn new() -> Self {
Self {
modules: HashMap::new(),
}
}
pub fn add(
mut self,
name: impl Into<String>,
factory: impl Fn(&Lua) -> Result<Value> + Send + Sync + 'static,
) -> Self {
self.modules.insert(name.into(), Box::new(factory));
self
}
}
impl Resolver for NativeResolver {
fn resolve(&self, lua: &Lua, name: &str) -> Option<Result<Value>> {
let factory = self.modules.get(name)?;
Some(factory(lua))
}
}
pub struct FsResolver {
sandbox: Box<dyn SandboxedFs>,
extension: String,
init_name: String,
module_separator: char,
}
impl FsResolver {
pub fn new(root: impl Into<PathBuf>) -> std::result::Result<Self, InitError> {
let fs = FsSandbox::new(root)?;
Ok(Self::with_sandbox(fs))
}
pub fn with_sandbox(sandbox: impl SandboxedFs + 'static) -> Self {
let conv = crate::LuaConvention::default();
Self {
sandbox: Box::new(sandbox),
extension: conv.extension.to_owned(),
init_name: conv.init_name.to_owned(),
module_separator: conv.module_separator,
}
}
pub fn with_convention(self, conv: crate::LuaConvention) -> Self {
Self {
extension: conv.extension.to_owned(),
init_name: conv.init_name.to_owned(),
module_separator: conv.module_separator,
..self
}
}
pub fn with_extension(mut self, ext: impl Into<String>) -> Self {
self.extension = ext.into();
self
}
pub fn with_init_name(mut self, name: impl Into<String>) -> Self {
self.init_name = name.into();
self
}
pub fn with_module_separator(mut self, sep: char) -> Self {
self.module_separator = sep;
self
}
}
impl Resolver for FsResolver {
fn resolve(&self, lua: &Lua, name: &str) -> Option<Result<Value>> {
let relative = name.replace(self.module_separator, "/");
let candidates = [
PathBuf::from(format!("{relative}.{}", self.extension)),
PathBuf::from(format!("{relative}/{}.{}", self.init_name, self.extension)),
];
for candidate in &candidates {
match self.sandbox.read(candidate) {
Ok(Some(file)) => {
let source_name = candidate.display().to_string();
return Some(lua.load(file.content).set_name(source_name).eval());
}
Ok(None) => continue,
Err(e) => {
return Some(Err(mlua::Error::external(read_to_resolve_error(
e, name, candidate,
))));
}
}
}
None
}
}
type AssetParserFn = Box<dyn Fn(&Lua, &str) -> Result<Value> + Send + Sync>;
pub struct AssetResolver {
sandbox: Box<dyn SandboxedFs>,
parsers: HashMap<String, AssetParserFn>,
}
impl AssetResolver {
pub fn new(root: impl Into<PathBuf>) -> std::result::Result<Self, InitError> {
let fs = FsSandbox::new(root)?;
Ok(Self::with_sandbox(fs))
}
pub fn with_sandbox(sandbox: impl SandboxedFs + 'static) -> Self {
Self {
sandbox: Box::new(sandbox),
parsers: HashMap::new(),
}
}
pub fn parser(
mut self,
ext: impl Into<String>,
f: impl Fn(&Lua, &str) -> Result<Value> + Send + Sync + 'static,
) -> Self {
self.parsers.insert(ext.into(), Box::new(f));
self
}
}
pub fn json_parser() -> impl Fn(&Lua, &str) -> Result<Value> + Send + Sync {
|lua, content| {
let json: serde_json::Value = serde_json::from_str(content).map_err(|e| {
mlua::Error::external(ResolveError::AssetParse {
source: Box::new(e),
})
})?;
lua.to_value(&json)
}
}
pub fn text_parser() -> impl Fn(&Lua, &str) -> Result<Value> + Send + Sync {
|lua, content| lua.create_string(content).map(Value::String)
}
impl Resolver for AssetResolver {
fn resolve(&self, lua: &Lua, name: &str) -> Option<Result<Value>> {
let ext = Path::new(name).extension()?.to_str()?;
let parser = self.parsers.get(ext)?;
let asset_path = Path::new(name);
let file = match self.sandbox.read(asset_path) {
Ok(Some(file)) => file,
Ok(None) => return None,
Err(e) => {
return Some(Err(mlua::Error::external(read_to_resolve_error(
e, name, asset_path,
))));
}
};
Some(parser(lua, &file.content))
}
}
pub struct PrefixResolver {
prefix: String,
separator: char,
inner: Box<dyn Resolver>,
}
impl PrefixResolver {
pub fn new(prefix: impl Into<String>, inner: impl Resolver + 'static) -> Self {
Self {
prefix: prefix.into(),
separator: crate::LuaConvention::default().module_separator,
inner: Box::new(inner),
}
}
pub fn with_convention(mut self, conv: crate::LuaConvention) -> Self {
self.separator = conv.module_separator;
self
}
pub fn with_separator(mut self, separator: char) -> Self {
self.separator = separator;
self
}
}
impl Resolver for PrefixResolver {
fn resolve(&self, lua: &Lua, name: &str) -> Option<Result<Value>> {
let mut prefix_with_sep = String::with_capacity(self.prefix.len() + 1);
prefix_with_sep.push_str(&self.prefix);
prefix_with_sep.push(self.separator);
let rest = name.strip_prefix(&prefix_with_sep)?;
self.inner.resolve(lua, rest)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sandbox::{FileContent, ReadError};
fn must_resolve(resolver: &dyn Resolver, lua: &Lua, name: &str) -> Value {
match resolver.resolve(lua, name) {
Some(Ok(v)) => v,
Some(Err(e)) => panic!("resolve('{name}') returned Err: {e}"),
None => panic!("resolve('{name}') returned None"),
}
}
fn must_resolve_err(resolver: &dyn Resolver, lua: &Lua, name: &str) -> String {
match resolver.resolve(lua, name) {
Some(Err(e)) => e.to_string(),
Some(Ok(_)) => panic!("resolve('{name}') returned Ok, expected Err"),
None => panic!("resolve('{name}') returned None, expected Some(Err)"),
}
}
fn get_field<V: mlua::FromLua>(value: &Value, key: impl mlua::IntoLua) -> V {
value
.as_table()
.expect("expected Table value")
.get::<V>(key)
.expect("table field access failed")
}
struct MockSandbox {
files: HashMap<PathBuf, String>,
}
impl MockSandbox {
fn new() -> Self {
Self {
files: HashMap::new(),
}
}
fn with_file(mut self, path: impl Into<PathBuf>, content: &str) -> Self {
self.files.insert(path.into(), content.to_owned());
self
}
}
impl SandboxedFs for MockSandbox {
fn read(&self, relative: &Path) -> std::result::Result<Option<FileContent>, ReadError> {
match self.files.get(relative) {
Some(content) => Ok(Some(FileContent {
content: content.clone(),
resolved_path: relative.to_path_buf(),
})),
None => Ok(None),
}
}
}
#[test]
fn fs_resolver_dot_to_path_conversion() {
let mock = MockSandbox::new().with_file("lib/helper.lua", "return { name = 'mocked' }");
let resolver = FsResolver::with_sandbox(mock);
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "lib.helper");
assert_eq!(get_field::<String>(&value, "name"), "mocked");
}
#[test]
fn fs_resolver_init_lua_fallback() {
let mock = MockSandbox::new().with_file("mypkg/init.lua", "return { from_init = true }");
let resolver = FsResolver::with_sandbox(mock);
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "mypkg");
assert!(get_field::<bool>(&value, "from_init"));
}
#[test]
fn fs_resolver_miss_returns_none() {
let mock = MockSandbox::new();
let resolver = FsResolver::with_sandbox(mock);
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "nonexistent").is_none());
}
#[test]
fn fs_resolver_custom_extension() {
let mock = MockSandbox::new().with_file("lib/helper.luau", "return { name = 'luau_mod' }");
let resolver = FsResolver::with_sandbox(mock).with_extension("luau");
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "lib.helper");
assert_eq!(get_field::<String>(&value, "name"), "luau_mod");
}
#[test]
fn fs_resolver_custom_init_name() {
let mock = MockSandbox::new().with_file("mypkg/mod.lua", "return { from_mod = true }");
let resolver = FsResolver::with_sandbox(mock).with_init_name("mod");
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "mypkg");
assert!(get_field::<bool>(&value, "from_mod"));
}
#[test]
fn fs_resolver_custom_extension_ignores_default() {
let mock = MockSandbox::new().with_file("helper.lua", "return 'wrong'");
let resolver = FsResolver::with_sandbox(mock).with_extension("luau");
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "helper").is_none());
}
#[test]
fn fs_resolver_with_convention_luau() {
let mock = MockSandbox::new()
.with_file("lib/helper.luau", "return { name = 'luau' }")
.with_file("pkg/init.luau", "return { pkg = true }");
let resolver = FsResolver::with_sandbox(mock).with_convention(crate::LuaConvention::LUAU);
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "lib.helper");
assert_eq!(get_field::<String>(&value, "name"), "luau");
let value = must_resolve(&resolver, &lua, "pkg");
assert!(get_field::<bool>(&value, "pkg"));
}
#[test]
fn convention_then_override() {
let mock = MockSandbox::new().with_file("pkg/mod.luau", "return { ok = true }");
let resolver = FsResolver::with_sandbox(mock)
.with_convention(crate::LuaConvention::LUAU)
.with_init_name("mod");
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "pkg");
assert!(get_field::<bool>(&value, "ok"));
}
#[test]
fn lua_convention_default_is_lua54() {
assert_eq!(crate::LuaConvention::default(), crate::LuaConvention::LUA54);
}
#[test]
fn asset_resolver_json_to_table() {
let mock = MockSandbox::new().with_file("config.json", r#"{"port": 8080}"#);
let resolver = AssetResolver::with_sandbox(mock).parser("json", json_parser());
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "config.json");
assert_eq!(get_field::<i32>(&value, "port"), 8080);
}
#[test]
fn asset_resolver_text_to_string() {
let mock = MockSandbox::new().with_file("query.sql", "SELECT 1");
let resolver = AssetResolver::with_sandbox(mock).parser("sql", text_parser());
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "query.sql");
let s: String = lua.unpack(value).expect("unpack String failed");
assert_eq!(s, "SELECT 1");
}
#[test]
fn asset_resolver_unregistered_ext_returns_none() {
let mock = MockSandbox::new().with_file("data.xyz", "stuff");
let resolver = AssetResolver::with_sandbox(mock).parser("json", json_parser());
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "data.xyz").is_none());
}
#[test]
fn asset_resolver_no_ext_returns_none() {
let mock = MockSandbox::new();
let resolver = AssetResolver::with_sandbox(mock);
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "noext").is_none());
}
#[test]
fn asset_resolver_custom_parser() {
let mock = MockSandbox::new().with_file("data.csv", "a,b,c");
let resolver = AssetResolver::with_sandbox(mock).parser("csv", |lua, content| {
let t = lua.create_table()?;
for (i, field) in content.split(',').enumerate() {
t.set(i + 1, lua.create_string(field)?)?;
}
Ok(Value::Table(t))
});
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "data.csv");
assert_eq!(get_field::<String>(&value, 1), "a");
}
struct IoErrorSandbox {
kind: std::io::ErrorKind,
}
impl SandboxedFs for IoErrorSandbox {
fn read(&self, relative: &Path) -> std::result::Result<Option<FileContent>, ReadError> {
Err(ReadError::Io {
path: relative.to_path_buf(),
source: std::io::Error::new(self.kind, "mock I/O error"),
})
}
}
#[test]
fn fs_resolver_propagates_io_error() {
let resolver = FsResolver::with_sandbox(IoErrorSandbox {
kind: std::io::ErrorKind::PermissionDenied,
});
let lua = mlua::Lua::new();
let msg = must_resolve_err(&resolver, &lua, "anything");
assert!(
msg.contains("I/O error"),
"expected ResolveError::Io message: {msg}"
);
}
#[test]
fn asset_resolver_propagates_io_error() {
let resolver = AssetResolver::with_sandbox(IoErrorSandbox {
kind: std::io::ErrorKind::PermissionDenied,
})
.parser("json", json_parser());
let lua = mlua::Lua::new();
let msg = must_resolve_err(&resolver, &lua, "data.json");
assert!(
msg.contains("I/O error"),
"expected ResolveError::Io message: {msg}"
);
}
#[test]
fn prefix_strips_and_delegates() {
let inner = MemoryResolver::new().add("helper", "return 'from helper'");
let resolver = PrefixResolver::new("sm", inner);
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "sm.helper");
let s: String = lua.unpack(value).expect("unpack String failed");
assert_eq!(s, "from helper");
}
#[test]
fn prefix_non_matching_returns_none() {
let inner = MemoryResolver::new().add("helper", "return 'x'");
let resolver = PrefixResolver::new("sm", inner);
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "other.helper").is_none());
}
#[test]
fn prefix_exact_match_without_separator_returns_none() {
let inner = MemoryResolver::new().add("helper", "return 'x'");
let resolver = PrefixResolver::new("sm", inner);
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "sm").is_none());
}
#[test]
fn prefix_no_substring_match() {
let inner = MemoryResolver::new().add("tp", "return 'x'");
let resolver = PrefixResolver::new("sm", inner);
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "smtp").is_none());
}
#[test]
fn prefix_custom_separator() {
let inner = MemoryResolver::new().add("http", "return 'http mod'");
let resolver = PrefixResolver::new("@std", inner).with_separator('/');
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "@std/http");
let s: String = lua.unpack(value).expect("unpack String failed");
assert_eq!(s, "http mod");
}
#[test]
fn prefix_nested_name() {
let mock = MockSandbox::new().with_file("ui/button.lua", "return { name = 'button' }");
let resolver = PrefixResolver::new("game", FsResolver::with_sandbox(mock));
let lua = mlua::Lua::new();
let value = must_resolve(&resolver, &lua, "game.ui.button");
assert_eq!(get_field::<String>(&value, "name"), "button");
}
#[test]
fn prefix_inner_miss_returns_none() {
let inner = MemoryResolver::new().add("helper", "return 'x'");
let resolver = PrefixResolver::new("sm", inner);
let lua = mlua::Lua::new();
assert!(resolver.resolve(&lua, "sm.nonexistent").is_none());
}
}