1use std::cell::RefCell;
2use std::collections::VecDeque;
3use std::ffi::CStr;
4use std::io::Result as IoResult;
5use std::ops::{Deref, DerefMut};
6use std::os::raw::{c_char, c_int, c_void};
7use std::path::{Component, Path, PathBuf};
8use std::result::Result as StdResult;
9use std::{env, fmt, fs, mem, ptr};
10
11use crate::error::{Error, Result};
12use crate::function::Function;
13use crate::state::{callback_error_ext, Lua};
14use crate::table::Table;
15use crate::types::MaybeSend;
16
17#[cfg(any(feature = "luau", doc))]
19#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
20#[derive(Debug, Clone)]
21pub enum NavigateError {
22 Ambiguous,
23 NotFound,
24 Other(Error),
25}
26
27#[cfg(feature = "luau")]
28trait IntoNavigateResult {
29 fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult>;
30}
31
32#[cfg(feature = "luau")]
33impl IntoNavigateResult for StdResult<(), NavigateError> {
34 fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult> {
35 match self {
36 Ok(()) => Ok(ffi::luarequire_NavigateResult::Success),
37 Err(NavigateError::Ambiguous) => Ok(ffi::luarequire_NavigateResult::Ambiguous),
38 Err(NavigateError::NotFound) => Ok(ffi::luarequire_NavigateResult::NotFound),
39 Err(NavigateError::Other(err)) => Err(err),
40 }
41 }
42}
43
44impl From<Error> for NavigateError {
45 fn from(err: Error) -> Self {
46 NavigateError::Other(err)
47 }
48}
49
50#[cfg(feature = "luau")]
51type WriteResult = ffi::luarequire_WriteResult;
52
53#[cfg(any(feature = "luau", doc))]
55#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
56pub trait Require: MaybeSend {
57 fn is_require_allowed(&self, chunk_name: &str) -> bool;
59
60 fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError>;
62
63 fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError>;
69
70 fn to_parent(&mut self) -> StdResult<(), NavigateError>;
72
73 fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError>;
75
76 fn has_module(&self) -> bool;
78
79 fn cache_key(&self) -> String;
83
84 fn has_config(&self) -> bool;
86
87 fn config(&self) -> IoResult<Vec<u8>>;
91
92 fn loader(&self, lua: &Lua) -> Result<Function>;
98}
99
100impl fmt::Debug for dyn Require {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 write!(f, "<dyn Require>")
103 }
104}
105
106#[doc(hidden)]
108#[derive(Default, Debug)]
109pub struct TextRequirer {
110 abs_path: PathBuf,
111 rel_path: PathBuf,
112 module_path: PathBuf,
113}
114
115impl TextRequirer {
116 pub fn new() -> Self {
118 Self::default()
119 }
120
121 fn normalize_chunk_name(chunk_name: &str) -> &str {
122 if let Some((path, line)) = chunk_name.split_once(':') {
123 if line.parse::<u32>().is_ok() {
124 return path;
125 }
126 }
127 chunk_name
128 }
129
130 fn normalize_path(path: &Path) -> PathBuf {
132 let mut components = VecDeque::new();
133
134 for comp in path.components() {
135 match comp {
136 Component::Prefix(..) | Component::RootDir => {
137 components.push_back(comp);
138 }
139 Component::CurDir => {}
140 Component::ParentDir => {
141 if matches!(components.back(), None | Some(Component::ParentDir)) {
142 components.push_back(Component::ParentDir);
143 } else if matches!(components.back(), Some(Component::Normal(..))) {
144 components.pop_back();
145 }
146 }
147 Component::Normal(..) => components.push_back(comp),
148 }
149 }
150
151 if matches!(components.front(), None | Some(Component::Normal(..))) {
152 components.push_front(Component::CurDir);
153 }
154
155 components.into_iter().collect()
157 }
158
159 fn find_module(path: &Path) -> StdResult<PathBuf, NavigateError> {
160 let mut found_path = None;
161
162 if path.components().next_back() != Some(Component::Normal("init".as_ref())) {
163 let current_ext = (path.extension().and_then(|s| s.to_str()))
164 .map(|s| format!("{s}."))
165 .unwrap_or_default();
166 for ext in ["luau", "lua"] {
167 let candidate = path.with_extension(format!("{current_ext}{ext}"));
168 if candidate.is_file() && found_path.replace(candidate).is_some() {
169 return Err(NavigateError::Ambiguous);
170 }
171 }
172 }
173 if path.is_dir() {
174 for component in ["init.luau", "init.lua"] {
175 let candidate = path.join(component);
176 if candidate.is_file() && found_path.replace(candidate).is_some() {
177 return Err(NavigateError::Ambiguous);
178 }
179 }
180
181 if found_path.is_none() {
182 found_path = Some(PathBuf::new());
183 }
184 }
185
186 found_path.ok_or(NavigateError::NotFound)
187 }
188}
189
190impl Require for TextRequirer {
191 fn is_require_allowed(&self, chunk_name: &str) -> bool {
192 chunk_name.starts_with('@')
193 }
194
195 fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError> {
196 if !chunk_name.starts_with('@') {
197 return Err(NavigateError::NotFound);
198 }
199 let chunk_name = Self::normalize_chunk_name(&chunk_name[1..]);
200 let chunk_path = Self::normalize_path(chunk_name.as_ref());
201
202 if chunk_path.extension() == Some("rs".as_ref()) {
203 let chunk_filename = chunk_path.file_name().unwrap();
205 let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
206 self.abs_path = Self::normalize_path(&cwd.join(chunk_filename));
207 self.rel_path = ([Component::CurDir, Component::Normal(chunk_filename)].into_iter()).collect();
208 self.module_path = PathBuf::new();
209
210 return Ok(());
211 }
212
213 if chunk_path.is_absolute() {
214 let module_path = Self::find_module(&chunk_path)?;
215 self.abs_path = chunk_path.clone();
216 self.rel_path = chunk_path;
217 self.module_path = module_path;
218 } else {
219 let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
221 let abs_path = Self::normalize_path(&cwd.join(&chunk_path));
222 let module_path = Self::find_module(&abs_path)?;
223 self.abs_path = abs_path;
224 self.rel_path = chunk_path;
225 self.module_path = module_path;
226 }
227
228 Ok(())
229 }
230
231 fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError> {
232 let path = Self::normalize_path(path.as_ref());
233 let module_path = Self::find_module(&path)?;
234
235 self.abs_path = path.clone();
236 self.rel_path = path;
237 self.module_path = module_path;
238
239 Ok(())
240 }
241
242 fn to_parent(&mut self) -> StdResult<(), NavigateError> {
243 let mut abs_path = self.abs_path.clone();
244 if !abs_path.pop() {
245 return Err(NavigateError::NotFound);
246 }
247 let mut rel_parent = self.rel_path.clone();
248 rel_parent.pop();
249 let module_path = Self::find_module(&abs_path)?;
250
251 self.abs_path = abs_path;
252 self.rel_path = Self::normalize_path(&rel_parent);
253 self.module_path = module_path;
254
255 Ok(())
256 }
257
258 fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError> {
259 let abs_path = self.abs_path.join(name);
260 let rel_path = self.rel_path.join(name);
261 let module_path = Self::find_module(&abs_path)?;
262
263 self.abs_path = abs_path;
264 self.rel_path = rel_path;
265 self.module_path = module_path;
266
267 Ok(())
268 }
269
270 fn has_module(&self) -> bool {
271 self.module_path.is_file()
272 }
273
274 fn cache_key(&self) -> String {
275 self.module_path.display().to_string()
276 }
277
278 fn has_config(&self) -> bool {
279 self.abs_path.is_dir() && self.abs_path.join(".luaurc").is_file()
280 }
281
282 fn config(&self) -> IoResult<Vec<u8>> {
283 fs::read(self.abs_path.join(".luaurc"))
284 }
285
286 fn loader(&self, lua: &Lua) -> Result<Function> {
287 let name = format!("@{}", self.rel_path.display());
288 lua.load(&*self.module_path).set_name(name).into_function()
289 }
290}
291
292struct Context(Box<dyn Require>);
293
294impl Deref for Context {
295 type Target = dyn Require;
296
297 fn deref(&self) -> &Self::Target {
298 &*self.0
299 }
300}
301
302impl DerefMut for Context {
303 fn deref_mut(&mut self) -> &mut Self::Target {
304 &mut *self.0
305 }
306}
307
308macro_rules! try_borrow {
309 ($state:expr, $ctx:expr) => {
310 match (*($ctx as *const RefCell<Context>)).try_borrow() {
311 Ok(ctx) => ctx,
312 Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
313 }
314 };
315}
316
317macro_rules! try_borrow_mut {
318 ($state:expr, $ctx:expr) => {
319 match (*($ctx as *const RefCell<Context>)).try_borrow_mut() {
320 Ok(ctx) => ctx,
321 Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
322 }
323 };
324}
325
326#[cfg(feature = "luau")]
327pub(super) unsafe extern "C-unwind" fn init_config(config: *mut ffi::luarequire_Configuration) {
328 if config.is_null() {
329 return;
330 }
331
332 unsafe extern "C-unwind" fn is_require_allowed(
333 state: *mut ffi::lua_State,
334 ctx: *mut c_void,
335 requirer_chunkname: *const c_char,
336 ) -> bool {
337 if requirer_chunkname.is_null() {
338 return false;
339 }
340
341 let this = try_borrow!(state, ctx);
342 let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
343 this.is_require_allowed(&chunk_name)
344 }
345
346 unsafe extern "C-unwind" fn reset(
347 state: *mut ffi::lua_State,
348 ctx: *mut c_void,
349 requirer_chunkname: *const c_char,
350 ) -> ffi::luarequire_NavigateResult {
351 let mut this = try_borrow_mut!(state, ctx);
352 let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
353 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
354 this.reset(&chunk_name).into_nav_result()
355 })
356 }
357
358 unsafe extern "C-unwind" fn jump_to_alias(
359 state: *mut ffi::lua_State,
360 ctx: *mut c_void,
361 path: *const c_char,
362 ) -> ffi::luarequire_NavigateResult {
363 let mut this = try_borrow_mut!(state, ctx);
364 let path = CStr::from_ptr(path).to_string_lossy();
365 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
366 this.jump_to_alias(&path).into_nav_result()
367 })
368 }
369
370 unsafe extern "C-unwind" fn to_parent(
371 state: *mut ffi::lua_State,
372 ctx: *mut c_void,
373 ) -> ffi::luarequire_NavigateResult {
374 let mut this = try_borrow_mut!(state, ctx);
375 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
376 this.to_parent().into_nav_result()
377 })
378 }
379
380 unsafe extern "C-unwind" fn to_child(
381 state: *mut ffi::lua_State,
382 ctx: *mut c_void,
383 name: *const c_char,
384 ) -> ffi::luarequire_NavigateResult {
385 let mut this = try_borrow_mut!(state, ctx);
386 let name = CStr::from_ptr(name).to_string_lossy();
387 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
388 this.to_child(&name).into_nav_result()
389 })
390 }
391
392 unsafe extern "C-unwind" fn is_module_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
393 let this = try_borrow!(state, ctx);
394 this.has_module()
395 }
396
397 unsafe extern "C-unwind" fn get_chunkname(
398 _state: *mut ffi::lua_State,
399 _ctx: *mut c_void,
400 buffer: *mut c_char,
401 buffer_size: usize,
402 size_out: *mut usize,
403 ) -> WriteResult {
404 write_to_buffer(buffer, buffer_size, size_out, &[])
405 }
406
407 unsafe extern "C-unwind" fn get_loadname(
408 _state: *mut ffi::lua_State,
409 _ctx: *mut c_void,
410 buffer: *mut c_char,
411 buffer_size: usize,
412 size_out: *mut usize,
413 ) -> WriteResult {
414 write_to_buffer(buffer, buffer_size, size_out, &[])
415 }
416
417 unsafe extern "C-unwind" fn get_cache_key(
418 state: *mut ffi::lua_State,
419 ctx: *mut c_void,
420 buffer: *mut c_char,
421 buffer_size: usize,
422 size_out: *mut usize,
423 ) -> WriteResult {
424 let this = try_borrow!(state, ctx);
425 let cache_key = this.cache_key();
426 write_to_buffer(buffer, buffer_size, size_out, cache_key.as_bytes())
427 }
428
429 unsafe extern "C-unwind" fn is_config_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
430 let this = try_borrow!(state, ctx);
431 this.has_config()
432 }
433
434 unsafe extern "C-unwind" fn get_config(
435 state: *mut ffi::lua_State,
436 ctx: *mut c_void,
437 buffer: *mut c_char,
438 buffer_size: usize,
439 size_out: *mut usize,
440 ) -> WriteResult {
441 let this = try_borrow!(state, ctx);
442 let config = callback_error_ext(state, ptr::null_mut(), true, move |_, _| Ok(this.config()?));
443 write_to_buffer(buffer, buffer_size, size_out, &config)
444 }
445
446 unsafe extern "C-unwind" fn load(
447 state: *mut ffi::lua_State,
448 ctx: *mut c_void,
449 _path: *const c_char,
450 _chunkname: *const c_char,
451 _loadname: *const c_char,
452 ) -> c_int {
453 let this = try_borrow!(state, ctx);
454 callback_error_ext(state, ptr::null_mut(), true, move |extra, _| {
455 let rawlua = (*extra).raw_lua();
456 let loader = this.loader(rawlua.lua())?;
457 rawlua.push(loader)?;
458 Ok(1)
459 })
460 }
461
462 (*config).is_require_allowed = is_require_allowed;
463 (*config).reset = reset;
464 (*config).jump_to_alias = jump_to_alias;
465 (*config).to_parent = to_parent;
466 (*config).to_child = to_child;
467 (*config).is_module_present = is_module_present;
468 (*config).get_chunkname = get_chunkname;
469 (*config).get_loadname = get_loadname;
470 (*config).get_cache_key = get_cache_key;
471 (*config).is_config_present = is_config_present;
472 (*config).get_alias = None;
473 (*config).get_config = Some(get_config);
474 (*config).load = load;
475}
476
477#[cfg(feature = "luau")]
479unsafe fn write_to_buffer(
480 buffer: *mut c_char,
481 buffer_size: usize,
482 size_out: *mut usize,
483 data: &[u8],
484) -> WriteResult {
485 let is_null_terminated = data.last() == Some(&0);
487 *size_out = data.len() + if is_null_terminated { 0 } else { 1 };
488 if *size_out > buffer_size {
489 return WriteResult::BufferTooSmall;
490 }
491 ptr::copy_nonoverlapping(data.as_ptr(), buffer as *mut _, data.len());
492 if !is_null_terminated {
493 *buffer.add(data.len()) = 0;
494 }
495 WriteResult::Success
496}
497
498#[cfg(feature = "luau")]
499pub fn create_require_function<R: Require + 'static>(lua: &Lua, require: R) -> Result<Function> {
500 unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int {
501 let mut ar: ffi::lua_Debug = mem::zeroed();
502 for level in 2.. {
503 if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 {
504 ffi::luaL_error(state, cstr!("require is not supported in this context"));
505 }
506 if CStr::from_ptr(ar.what) != c"C" {
507 break;
508 }
509 }
510 ffi::lua_pushstring(state, ar.source);
511 1
512 }
513
514 unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int {
515 let ctx = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1));
516 let ctx = try_borrow!(state, ctx);
517 let cache_key = ctx.cache_key();
518 ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len());
519 1
520 }
521
522 let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe {
523 lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| {
524 let context = Context(Box::new(require));
525 let context_ptr = ffi::lua_newuserdata_t(state, RefCell::new(context));
526 ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1);
527 ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file"));
528 ffi::luarequire_pushproxyrequire(state, init_config, context_ptr as *mut _);
529 ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE);
530 ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__MLUA_LOADER_CACHE"));
531 })
532 }?;
533
534 unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int {
535 ffi::luaL_where(state, 1);
536 ffi::lua_pushvalue(state, 1);
537 ffi::lua_concat(state, 2);
538 ffi::lua_error(state);
539 }
540
541 unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int {
542 ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1)));
543 1
544 }
545
546 let (error, r#type) = unsafe {
547 lua.exec_raw::<(Function, Function)>((), move |state| {
548 ffi::lua_pushcfunctiond(state, error, cstr!("error"));
549 ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
550 })
551 }?;
552
553 let env = lua.create_table_with_capacity(0, 7)?;
555 env.raw_set("get_cache_key", get_cache_key)?;
556 env.raw_set("find_current_file", find_current_file)?;
557 env.raw_set("proxyrequire", proxyrequire)?;
558 env.raw_set("REGISTERED_MODULES", registered_modules)?;
559 env.raw_set("LOADER_CACHE", loader_cache)?;
560 env.raw_set("error", error)?;
561 env.raw_set("type", r#type)?;
562
563 lua.load(
564 r#"
565 local path = ...
566 if type(path) ~= "string" then
567 error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")")
568 end
569
570 -- Check if the module (path) is explicitly registered
571 local maybe_result = REGISTERED_MODULES[path]
572 if maybe_result ~= nil then
573 return maybe_result
574 end
575
576 local loader = proxyrequire(path, find_current_file())
577 local cache_key = get_cache_key()
578 -- Check if the loader result is already cached
579 local result = LOADER_CACHE[cache_key]
580 if result ~= nil then
581 return result
582 end
583
584 -- Call the loader function and cache the result
585 result = loader()
586 if result == nil then
587 result = true
588 end
589 LOADER_CACHE[cache_key] = result
590 return result
591 "#,
592 )
593 .try_cache()
594 .set_name("=__mlua_require")
595 .set_environment(env)
596 .into_function()
597}
598
599#[cfg(test)]
600mod tests {
601 use std::path::Path;
602
603 use super::TextRequirer;
604
605 #[test]
606 fn test_path_normalize() {
607 for (input, expected) in [
608 ("", "./"),
610 (".", "./"),
611 ("a/relative/path", "./a/relative/path"),
612 ("./remove/extraneous/symbols/", "./remove/extraneous/symbols"),
614 ("./remove/extraneous//symbols", "./remove/extraneous/symbols"),
615 ("./remove/extraneous/symbols/.", "./remove/extraneous/symbols"),
616 ("./remove/extraneous/./symbols", "./remove/extraneous/symbols"),
617 ("../remove/extraneous/symbols/", "../remove/extraneous/symbols"),
618 ("../remove/extraneous//symbols", "../remove/extraneous/symbols"),
619 ("../remove/extraneous/symbols/.", "../remove/extraneous/symbols"),
620 ("../remove/extraneous/./symbols", "../remove/extraneous/symbols"),
621 ("/remove/extraneous/symbols/", "/remove/extraneous/symbols"),
622 ("/remove/extraneous//symbols", "/remove/extraneous/symbols"),
623 ("/remove/extraneous/symbols/.", "/remove/extraneous/symbols"),
624 ("/remove/extraneous/./symbols", "/remove/extraneous/symbols"),
625 ("./remove/me/..", "./remove"),
627 ("./remove/me/../", "./remove"),
628 ("../remove/me/..", "../remove"),
629 ("../remove/me/../", "../remove"),
630 ("/remove/me/..", "/remove"),
631 ("/remove/me/../", "/remove"),
632 ("./..", "../"),
633 ("./../", "../"),
634 ("../..", "../../"),
635 ("../../", "../../"),
636 ("/../", "/"),
638 ] {
639 let path = TextRequirer::normalize_path(input.as_ref());
640 assert_eq!(
641 &path,
642 expected.as_ref() as &Path,
643 "wrong normalization for {input}"
644 );
645 }
646 }
647}