1use std::cell::RefCell;
2use std::collections::VecDeque;
3use std::ffi::CStr;
4use std::io::Result as IoResult;
5use std::ops::{Deref, DerefMut};
6use std::os::raw::{c_char, c_int, c_void};
7use std::path::{Component, Path, PathBuf};
8use std::result::Result as StdResult;
9use std::{env, fmt, fs, mem, ptr};
10
11use crate::error::{Error, Result};
12use crate::function::Function;
13use crate::state::{callback_error_ext, Lua};
14use crate::table::Table;
15use crate::types::MaybeSend;
16
17#[derive(Debug, Clone)]
19pub enum NavigateError {
20 Ambiguous,
21 NotFound,
22 Other(Error),
23}
24
25trait IntoNavigateResult {
26 fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult>;
27}
28
29impl IntoNavigateResult for StdResult<(), NavigateError> {
30 fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult> {
31 match self {
32 Ok(()) => Ok(ffi::luarequire_NavigateResult::Success),
33 Err(NavigateError::Ambiguous) => Ok(ffi::luarequire_NavigateResult::Ambiguous),
34 Err(NavigateError::NotFound) => Ok(ffi::luarequire_NavigateResult::NotFound),
35 Err(NavigateError::Other(err)) => Err(err),
36 }
37 }
38}
39
40impl From<Error> for NavigateError {
41 fn from(err: Error) -> Self {
42 NavigateError::Other(err)
43 }
44}
45
46type WriteResult = ffi::luarequire_WriteResult;
47
48pub trait Require {
50 fn is_require_allowed(&self, chunk_name: &str) -> bool;
52
53 fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError>;
55
56 fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError>;
62
63 fn to_parent(&mut self) -> StdResult<(), NavigateError>;
65
66 fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError>;
68
69 fn has_module(&self) -> bool;
71
72 fn cache_key(&self) -> String;
76
77 fn has_config(&self) -> bool;
79
80 fn config(&self) -> IoResult<Vec<u8>>;
84
85 fn loader(&self, lua: &Lua) -> Result<Function>;
91}
92
93impl fmt::Debug for dyn Require {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 write!(f, "<dyn Require>")
96 }
97}
98
99#[derive(Default, Debug)]
101pub struct TextRequirer {
102 abs_path: PathBuf,
104 rel_path: PathBuf,
106 resolved_path: Option<PathBuf>,
109}
110
111impl TextRequirer {
112 const CHUNK_PREFIX: &str = "@";
115
116 const FILE_EXTENSIONS: &[&str] = &["luau", "lua"];
118
119 pub fn new() -> Self {
121 Self::default()
122 }
123
124 fn normalize_chunk_name(chunk_name: &str) -> &str {
125 if let Some((path, line)) = chunk_name.rsplit_once(':') {
126 if line.parse::<u32>().is_ok() {
127 return path;
128 }
129 }
130 chunk_name
131 }
132
133 fn normalize_path(path: &Path) -> PathBuf {
135 let mut components = VecDeque::new();
136
137 for comp in path.components() {
138 match comp {
139 Component::Prefix(..) | Component::RootDir => {
140 components.push_back(comp);
141 }
142 Component::CurDir => {}
143 Component::ParentDir => {
144 if matches!(components.back(), None | Some(Component::ParentDir)) {
145 components.push_back(Component::ParentDir);
146 } else if matches!(components.back(), Some(Component::Normal(..))) {
147 components.pop_back();
148 }
149 }
150 Component::Normal(..) => components.push_back(comp),
151 }
152 }
153
154 if matches!(components.front(), None | Some(Component::Normal(..))) {
155 components.push_front(Component::CurDir);
156 }
157
158 components.into_iter().collect()
160 }
161
162 fn resolve_module(path: &Path) -> StdResult<Option<PathBuf>, NavigateError> {
166 let mut found_path = None;
167
168 if path.components().next_back() != Some(Component::Normal("init".as_ref())) {
169 let current_ext = (path.extension().and_then(|s| s.to_str()))
170 .map(|s| format!("{s}."))
171 .unwrap_or_default();
172 for ext in Self::FILE_EXTENSIONS {
173 let candidate = path.with_extension(format!("{current_ext}{ext}"));
174 if candidate.is_file() && found_path.replace(candidate).is_some() {
175 return Err(NavigateError::Ambiguous);
176 }
177 }
178 }
179 if path.is_dir() {
180 for component in Self::FILE_EXTENSIONS.iter().map(|ext| format!("init.{ext}")) {
181 let candidate = path.join(component);
182 if candidate.is_file() && found_path.replace(candidate).is_some() {
183 return Err(NavigateError::Ambiguous);
184 }
185 }
186
187 if found_path.is_none() {
188 return Ok(None);
190 }
191 }
192
193 Ok(Some(found_path.ok_or(NavigateError::NotFound)?))
194 }
195}
196
197impl Require for TextRequirer {
198 fn is_require_allowed(&self, chunk_name: &str) -> bool {
199 chunk_name.starts_with(Self::CHUNK_PREFIX)
200 }
201
202 fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError> {
203 if !chunk_name.starts_with(Self::CHUNK_PREFIX) {
204 return Err(NavigateError::NotFound);
205 }
206 let chunk_name = Self::normalize_chunk_name(&chunk_name[1..]);
207 let chunk_path = Self::normalize_path(chunk_name.as_ref());
208
209 if chunk_path.extension() == Some("rs".as_ref()) {
210 let chunk_filename = chunk_path.file_name().unwrap();
212 let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
213 self.abs_path = Self::normalize_path(&cwd.join(chunk_filename));
214 self.rel_path = ([Component::CurDir, Component::Normal(chunk_filename)].into_iter()).collect();
215 self.resolved_path = None;
216
217 return Ok(());
218 }
219
220 if chunk_path.is_absolute() {
221 let resolved_path = Self::resolve_module(&chunk_path)?;
222 self.abs_path = chunk_path.clone();
223 self.rel_path = chunk_path;
224 self.resolved_path = resolved_path;
225 } else {
226 let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
228 let abs_path = Self::normalize_path(&cwd.join(&chunk_path));
229 let resolved_path = Self::resolve_module(&abs_path)?;
230 self.abs_path = abs_path;
231 self.rel_path = chunk_path;
232 self.resolved_path = resolved_path;
233 }
234
235 Ok(())
236 }
237
238 fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError> {
239 let path = Self::normalize_path(path.as_ref());
240 let resolved_path = Self::resolve_module(&path)?;
241
242 self.abs_path = path.clone();
243 self.rel_path = path;
244 self.resolved_path = resolved_path;
245
246 Ok(())
247 }
248
249 fn to_parent(&mut self) -> StdResult<(), NavigateError> {
250 let mut abs_path = self.abs_path.clone();
251 if !abs_path.pop() {
252 return Err(NavigateError::NotFound);
256 }
257 let mut rel_parent = self.rel_path.clone();
258 rel_parent.pop();
259 let resolved_path = Self::resolve_module(&abs_path)?;
260
261 self.abs_path = abs_path;
262 self.rel_path = Self::normalize_path(&rel_parent);
263 self.resolved_path = resolved_path;
264
265 Ok(())
266 }
267
268 fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError> {
269 let abs_path = self.abs_path.join(name);
270 let rel_path = self.rel_path.join(name);
271 let resolved_path = Self::resolve_module(&abs_path)?;
272
273 self.abs_path = abs_path;
274 self.rel_path = rel_path;
275 self.resolved_path = resolved_path;
276
277 Ok(())
278 }
279
280 fn has_module(&self) -> bool {
281 (self.resolved_path.as_deref())
282 .map(Path::is_file)
283 .unwrap_or(false)
284 }
285
286 fn cache_key(&self) -> String {
287 self.resolved_path.as_deref().unwrap().display().to_string()
288 }
289
290 fn has_config(&self) -> bool {
291 self.abs_path.is_dir() && self.abs_path.join(".luaurc").is_file()
292 }
293
294 fn config(&self) -> IoResult<Vec<u8>> {
295 fs::read(self.abs_path.join(".luaurc"))
296 }
297
298 fn loader(&self, lua: &Lua) -> Result<Function> {
299 let name = format!("@{}", self.rel_path.display());
300 lua.load(self.resolved_path.as_deref().unwrap())
301 .set_name(name)
302 .into_function()
303 }
304}
305
306struct Context(Box<dyn Require>);
307
308impl Deref for Context {
309 type Target = dyn Require;
310
311 fn deref(&self) -> &Self::Target {
312 &*self.0
313 }
314}
315
316impl DerefMut for Context {
317 fn deref_mut(&mut self) -> &mut Self::Target {
318 &mut *self.0
319 }
320}
321
322macro_rules! try_borrow {
323 ($state:expr, $ctx:expr) => {
324 match (*($ctx as *const RefCell<Context>)).try_borrow() {
325 Ok(ctx) => ctx,
326 Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
327 }
328 };
329}
330
331macro_rules! try_borrow_mut {
332 ($state:expr, $ctx:expr) => {
333 match (*($ctx as *const RefCell<Context>)).try_borrow_mut() {
334 Ok(ctx) => ctx,
335 Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
336 }
337 };
338}
339
340pub(super) unsafe extern "C-unwind" fn init_config(config: *mut ffi::luarequire_Configuration) {
341 if config.is_null() {
342 return;
343 }
344
345 unsafe extern "C-unwind" fn is_require_allowed(
346 state: *mut ffi::lua_State,
347 ctx: *mut c_void,
348 requirer_chunkname: *const c_char,
349 ) -> bool {
350 if requirer_chunkname.is_null() {
351 return false;
352 }
353
354 let this = try_borrow!(state, ctx);
355 let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
356 this.is_require_allowed(&chunk_name)
357 }
358
359 unsafe extern "C-unwind" fn reset(
360 state: *mut ffi::lua_State,
361 ctx: *mut c_void,
362 requirer_chunkname: *const c_char,
363 ) -> ffi::luarequire_NavigateResult {
364 let mut this = try_borrow_mut!(state, ctx);
365 let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
366 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
367 this.reset(&chunk_name).into_nav_result()
368 })
369 }
370
371 unsafe extern "C-unwind" fn jump_to_alias(
372 state: *mut ffi::lua_State,
373 ctx: *mut c_void,
374 path: *const c_char,
375 ) -> ffi::luarequire_NavigateResult {
376 let mut this = try_borrow_mut!(state, ctx);
377 let path = CStr::from_ptr(path).to_string_lossy();
378 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
379 this.jump_to_alias(&path).into_nav_result()
380 })
381 }
382
383 unsafe extern "C-unwind" fn to_parent(
384 state: *mut ffi::lua_State,
385 ctx: *mut c_void,
386 ) -> ffi::luarequire_NavigateResult {
387 let mut this = try_borrow_mut!(state, ctx);
388 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
389 this.to_parent().into_nav_result()
390 })
391 }
392
393 unsafe extern "C-unwind" fn to_child(
394 state: *mut ffi::lua_State,
395 ctx: *mut c_void,
396 name: *const c_char,
397 ) -> ffi::luarequire_NavigateResult {
398 let mut this = try_borrow_mut!(state, ctx);
399 let name = CStr::from_ptr(name).to_string_lossy();
400 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
401 this.to_child(&name).into_nav_result()
402 })
403 }
404
405 unsafe extern "C-unwind" fn is_module_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
406 let this = try_borrow!(state, ctx);
407 this.has_module()
408 }
409
410 unsafe extern "C-unwind" fn get_chunkname(
411 _state: *mut ffi::lua_State,
412 _ctx: *mut c_void,
413 buffer: *mut c_char,
414 buffer_size: usize,
415 size_out: *mut usize,
416 ) -> WriteResult {
417 write_to_buffer(buffer, buffer_size, size_out, &[])
418 }
419
420 unsafe extern "C-unwind" fn get_loadname(
421 _state: *mut ffi::lua_State,
422 _ctx: *mut c_void,
423 buffer: *mut c_char,
424 buffer_size: usize,
425 size_out: *mut usize,
426 ) -> WriteResult {
427 write_to_buffer(buffer, buffer_size, size_out, &[])
428 }
429
430 unsafe extern "C-unwind" fn get_cache_key(
431 state: *mut ffi::lua_State,
432 ctx: *mut c_void,
433 buffer: *mut c_char,
434 buffer_size: usize,
435 size_out: *mut usize,
436 ) -> WriteResult {
437 let this = try_borrow!(state, ctx);
438 let cache_key = this.cache_key();
439 write_to_buffer(buffer, buffer_size, size_out, cache_key.as_bytes())
440 }
441
442 unsafe extern "C-unwind" fn is_config_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
443 let this = try_borrow!(state, ctx);
444 this.has_config()
445 }
446
447 unsafe extern "C-unwind" fn get_config(
448 state: *mut ffi::lua_State,
449 ctx: *mut c_void,
450 buffer: *mut c_char,
451 buffer_size: usize,
452 size_out: *mut usize,
453 ) -> WriteResult {
454 let this = try_borrow!(state, ctx);
455 let config = callback_error_ext(state, ptr::null_mut(), true, move |_, _| Ok(this.config()?));
456 write_to_buffer(buffer, buffer_size, size_out, &config)
457 }
458
459 unsafe extern "C-unwind" fn load(
460 state: *mut ffi::lua_State,
461 ctx: *mut c_void,
462 _path: *const c_char,
463 _chunkname: *const c_char,
464 _loadname: *const c_char,
465 ) -> c_int {
466 let this = try_borrow!(state, ctx);
467 callback_error_ext(state, ptr::null_mut(), true, move |extra, _| {
468 let rawlua = (*extra).raw_lua();
469 let loader = this.loader(rawlua.lua())?;
470 rawlua.push(state, loader)?;
471 Ok(1)
472 })
473 }
474
475 (*config).is_require_allowed = is_require_allowed;
476 (*config).reset = reset;
477 (*config).jump_to_alias = jump_to_alias;
478 (*config).to_parent = to_parent;
479 (*config).to_child = to_child;
480 (*config).is_module_present = is_module_present;
481 (*config).get_chunkname = get_chunkname;
482 (*config).get_loadname = get_loadname;
483 (*config).get_cache_key = get_cache_key;
484 (*config).is_config_present = is_config_present;
485 (*config).get_alias = None;
486 (*config).get_config = Some(get_config);
487 (*config).load = load;
488}
489
490unsafe fn write_to_buffer(
492 buffer: *mut c_char,
493 buffer_size: usize,
494 size_out: *mut usize,
495 data: &[u8],
496) -> WriteResult {
497 let is_null_terminated = data.last() == Some(&0);
499 *size_out = data.len() + if is_null_terminated { 0 } else { 1 };
500 if *size_out > buffer_size {
501 return WriteResult::BufferTooSmall;
502 }
503 ptr::copy_nonoverlapping(data.as_ptr(), buffer as *mut _, data.len());
504 if !is_null_terminated {
505 *buffer.add(data.len()) = 0;
506 }
507 WriteResult::Success
508}
509
510pub(super) fn create_require_function<R: Require + MaybeSend + 'static>(
511 lua: &Lua,
512 require: R,
513) -> Result<Function> {
514 unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int {
515 let mut ar: ffi::lua_Debug = mem::zeroed();
516 for level in 2.. {
517 if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 {
518 ffi::luaL_error(state, cstr!("require is not supported in this context"));
519 }
520 if CStr::from_ptr(ar.what) != c"C" {
521 break;
522 }
523 }
524 ffi::lua_pushstring(state, ar.source);
525 1
526 }
527
528 unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int {
529 let ctx = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1));
530 let ctx = try_borrow!(state, ctx);
531 let cache_key = ctx.cache_key();
532 ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len());
533 1
534 }
535
536 let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe {
537 lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| {
538 let context = Context(Box::new(require));
539 let context_ptr = ffi::lua_newuserdata_t(state, RefCell::new(context));
540 ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1);
541 ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file"));
542 ffi::luarequire_pushproxyrequire(state, init_config, context_ptr as *mut _);
543 ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE);
544 ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__ULUA_LOADER_CACHE"));
545 })
546 }?;
547
548 unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int {
549 ffi::luaL_where(state, 1);
550 ffi::lua_pushvalue(state, 1);
551 ffi::lua_concat(state, 2);
552 ffi::lua_error(state);
553 }
554
555 unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int {
556 ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1)));
557 1
558 }
559
560 unsafe extern "C-unwind" fn to_lowercase(state: *mut ffi::lua_State) -> c_int {
561 let s = ffi::luaL_checkstring(state, 1);
562 let s = CStr::from_ptr(s);
563 if !s.to_bytes().iter().any(|&c| c.is_ascii_uppercase()) {
564 return 1;
566 }
567 callback_error_ext(state, ptr::null_mut(), true, |extra, _| {
568 let s = (s.to_bytes().iter())
569 .map(|&c| c.to_ascii_lowercase())
570 .collect::<bstr::BString>();
571 (*extra).raw_lua().push(state, s).map(|_| 1)
572 })
573 }
574
575 let (error, r#type, to_lowercase) = unsafe {
576 lua.exec_raw::<(Function, Function, Function)>((), move |state| {
577 ffi::lua_pushcfunctiond(state, error, cstr!("error"));
578 ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
579 ffi::lua_pushcfunctiond(state, to_lowercase, cstr!("to_lowercase"));
580 })
581 }?;
582
583 let env = lua.create_table_with_capacity(0, 7)?;
585 env.raw_set("get_cache_key", get_cache_key)?;
586 env.raw_set("find_current_file", find_current_file)?;
587 env.raw_set("proxyrequire", proxyrequire)?;
588 env.raw_set("REGISTERED_MODULES", registered_modules)?;
589 env.raw_set("LOADER_CACHE", loader_cache)?;
590 env.raw_set("error", error)?;
591 env.raw_set("type", r#type)?;
592 env.raw_set("to_lowercase", to_lowercase)?;
593
594 lua.load(
595 r#"
596 local path = ...
597 if type(path) ~= "string" then
598 error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")")
599 end
600
601 -- Check if the module (path) is explicitly registered
602 local maybe_result = REGISTERED_MODULES[to_lowercase(path)]
603 if maybe_result ~= nil then
604 return maybe_result
605 end
606
607 local loader = proxyrequire(path, find_current_file())
608 local cache_key = get_cache_key()
609 -- Check if the loader result is already cached
610 local result = LOADER_CACHE[cache_key]
611 if result ~= nil then
612 return result
613 end
614
615 -- Call the loader function and cache the result
616 result = loader()
617 if result == nil then
618 result = true
619 end
620 LOADER_CACHE[cache_key] = result
621 return result
622 "#,
623 )
624 .try_cache()
625 .set_name("=__ulua_require")
626 .set_environment(env)
627 .into_function()
628}
629
630#[cfg(test)]
631mod tests {
632 use std::path::Path;
633
634 use super::TextRequirer;
635
636 #[test]
637 fn test_path_normalize() {
638 for (input, expected) in [
639 ("", "./"),
641 (".", "./"),
642 ("a/relative/path", "./a/relative/path"),
643 ("./remove/extraneous/symbols/", "./remove/extraneous/symbols"),
645 ("./remove/extraneous//symbols", "./remove/extraneous/symbols"),
646 ("./remove/extraneous/symbols/.", "./remove/extraneous/symbols"),
647 ("./remove/extraneous/./symbols", "./remove/extraneous/symbols"),
648 ("../remove/extraneous/symbols/", "../remove/extraneous/symbols"),
649 ("../remove/extraneous//symbols", "../remove/extraneous/symbols"),
650 ("../remove/extraneous/symbols/.", "../remove/extraneous/symbols"),
651 ("../remove/extraneous/./symbols", "../remove/extraneous/symbols"),
652 ("/remove/extraneous/symbols/", "/remove/extraneous/symbols"),
653 ("/remove/extraneous//symbols", "/remove/extraneous/symbols"),
654 ("/remove/extraneous/symbols/.", "/remove/extraneous/symbols"),
655 ("/remove/extraneous/./symbols", "/remove/extraneous/symbols"),
656 ("./remove/me/..", "./remove"),
658 ("./remove/me/../", "./remove"),
659 ("../remove/me/..", "../remove"),
660 ("../remove/me/../", "../remove"),
661 ("/remove/me/..", "/remove"),
662 ("/remove/me/../", "/remove"),
663 ("./..", "../"),
664 ("./../", "../"),
665 ("../..", "../../"),
666 ("../../", "../../"),
667 ("/../", "/"),
669 ] {
670 let path = TextRequirer::normalize_path(input.as_ref());
671 assert_eq!(
672 &path,
673 expected.as_ref() as &Path,
674 "wrong normalization for {input}"
675 );
676 }
677 }
678}