use std::{
env::{self, current_dir},
path::PathBuf,
sync::Arc,
};
use mlua::prelude::*;
use os_str_bytes::{OsStrBytes, RawOsStr};
pub fn create(lua: &Lua) -> LuaResult<()> {
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
return Ok(());
}
let pwd = lua.create_string(¤t_dir()?.to_raw_bytes())?;
lua.set_named_registry_value("require_pwd", pwd)?;
let debug: LuaTable = lua.globals().raw_get("debug")?;
let info: LuaFunction = debug.raw_get("info")?;
lua.set_named_registry_value("require_getinfo", info)?;
let require: LuaFunction = lua.globals().raw_get("require")?;
lua.set_named_registry_value("require_original", require)?;
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_pwd: LuaString = lua.named_registry_value("require_pwd")?;
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
let require_source: LuaString = require_getinfo.call((2, "s"))?;
let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes());
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let mut path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
.parent()
.unwrap()
.join(raw_path.to_os_str());
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() {
path_relative_to_pwd = canonicalized.with_extension("");
}
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() {
path_relative_to_pwd = canonicalized.with_extension("");
}
if let Ok(stripped) = path_relative_to_pwd.strip_prefix(&raw_pwd_str.to_os_str()) {
path_relative_to_pwd = stripped.to_path_buf();
}
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str());
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
match result {
Err(LuaError::CallbackError { traceback, cause }) => {
let before = format!(
"runtime error: cannot find '{}'",
path_relative_to_pwd.to_str().unwrap()
);
let after = format!(
"Invalid require path '{}' ({})",
require_path.to_str().unwrap(),
path_relative_to_pwd.to_str().unwrap()
);
let cause = Arc::new(LuaError::RuntimeError(
cause.to_string().replace(&before, &after),
));
Err(LuaError::CallbackError { traceback, cause })
}
Err(e) => Err(e),
Ok(result) => Ok(result),
}
})?;
lua.globals().raw_set("require", new_require)?;
Ok(())
}