mlua/luau/
require.rs

1use std::cell::RefCell;
2use std::collections::VecDeque;
3use std::ffi::CStr;
4use std::io::Result as IoResult;
5use std::ops::{Deref, DerefMut};
6use std::os::raw::{c_char, c_int, c_void};
7use std::path::{Component, Path, PathBuf};
8use std::result::Result as StdResult;
9use std::{env, fmt, fs, mem, ptr};
10
11use crate::error::{Error, Result};
12use crate::function::Function;
13use crate::state::{callback_error_ext, Lua};
14use crate::table::Table;
15use crate::types::MaybeSend;
16
17/// An error that can occur during navigation in the Luau `require` system.
18#[cfg(any(feature = "luau", doc))]
19#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
20#[derive(Debug, Clone)]
21pub enum NavigateError {
22    Ambiguous,
23    NotFound,
24    Other(Error),
25}
26
27#[cfg(feature = "luau")]
28trait IntoNavigateResult {
29    fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult>;
30}
31
32#[cfg(feature = "luau")]
33impl IntoNavigateResult for StdResult<(), NavigateError> {
34    fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult> {
35        match self {
36            Ok(()) => Ok(ffi::luarequire_NavigateResult::Success),
37            Err(NavigateError::Ambiguous) => Ok(ffi::luarequire_NavigateResult::Ambiguous),
38            Err(NavigateError::NotFound) => Ok(ffi::luarequire_NavigateResult::NotFound),
39            Err(NavigateError::Other(err)) => Err(err),
40        }
41    }
42}
43
44impl From<Error> for NavigateError {
45    fn from(err: Error) -> Self {
46        NavigateError::Other(err)
47    }
48}
49
50#[cfg(feature = "luau")]
51type WriteResult = ffi::luarequire_WriteResult;
52
53/// A trait for handling modules loading and navigation in the Luau `require` system.
54#[cfg(any(feature = "luau", doc))]
55#[cfg_attr(docsrs, doc(cfg(feature = "luau")))]
56pub trait Require: MaybeSend {
57    /// Returns `true` if "require" is permitted for the given chunk name.
58    fn is_require_allowed(&self, chunk_name: &str) -> bool;
59
60    /// Resets the internal state to point at the requirer module.
61    fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError>;
62
63    /// Resets the internal state to point at an aliased module.
64    ///
65    /// This function received an exact path from a configuration file.
66    /// It's only called when an alias's path cannot be resolved relative to its
67    /// configuration file.
68    fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError>;
69
70    // Navigate to parent directory
71    fn to_parent(&mut self) -> StdResult<(), NavigateError>;
72
73    /// Navigate to the given child directory.
74    fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError>;
75
76    /// Returns whether the context is currently pointing at a module
77    fn has_module(&self) -> bool;
78
79    /// Provides a cache key representing the current module.
80    ///
81    /// This function is only called if `has_module` returns true.
82    fn cache_key(&self) -> String;
83
84    /// Returns whether a configuration is present in the current context.
85    fn has_config(&self) -> bool;
86
87    /// Returns the contents of the configuration file in the current context.
88    ///
89    /// This function is only called if `has_config` returns true.
90    fn config(&self) -> IoResult<Vec<u8>>;
91
92    /// Returns a loader function for the current module, that when called, loads the module
93    /// and returns the result.
94    ///
95    /// Loader can be sync or async.
96    /// This function is only called if `has_module` returns true.
97    fn loader(&self, lua: &Lua) -> Result<Function>;
98}
99
100impl fmt::Debug for dyn Require {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        write!(f, "<dyn Require>")
103    }
104}
105
106/// The standard implementation of Luau `require` navigation.
107#[doc(hidden)]
108#[derive(Default, Debug)]
109pub struct TextRequirer {
110    abs_path: PathBuf,
111    rel_path: PathBuf,
112    module_path: PathBuf,
113}
114
115impl TextRequirer {
116    /// Creates a new `TextRequirer` instance.
117    pub fn new() -> Self {
118        Self::default()
119    }
120
121    fn normalize_chunk_name(chunk_name: &str) -> &str {
122        if let Some((path, line)) = chunk_name.split_once(':') {
123            if line.parse::<u32>().is_ok() {
124                return path;
125            }
126        }
127        chunk_name
128    }
129
130    // Normalizes the path by removing unnecessary components
131    fn normalize_path(path: &Path) -> PathBuf {
132        let mut components = VecDeque::new();
133
134        for comp in path.components() {
135            match comp {
136                Component::Prefix(..) | Component::RootDir => {
137                    components.push_back(comp);
138                }
139                Component::CurDir => {}
140                Component::ParentDir => {
141                    if matches!(components.back(), None | Some(Component::ParentDir)) {
142                        components.push_back(Component::ParentDir);
143                    } else if matches!(components.back(), Some(Component::Normal(..))) {
144                        components.pop_back();
145                    }
146                }
147                Component::Normal(..) => components.push_back(comp),
148            }
149        }
150
151        if matches!(components.front(), None | Some(Component::Normal(..))) {
152            components.push_front(Component::CurDir);
153        }
154
155        // Join the components back together
156        components.into_iter().collect()
157    }
158
159    fn find_module(path: &Path) -> StdResult<PathBuf, NavigateError> {
160        let mut found_path = None;
161
162        if path.components().next_back() != Some(Component::Normal("init".as_ref())) {
163            let current_ext = (path.extension().and_then(|s| s.to_str()))
164                .map(|s| format!("{s}."))
165                .unwrap_or_default();
166            for ext in ["luau", "lua"] {
167                let candidate = path.with_extension(format!("{current_ext}{ext}"));
168                if candidate.is_file() && found_path.replace(candidate).is_some() {
169                    return Err(NavigateError::Ambiguous);
170                }
171            }
172        }
173        if path.is_dir() {
174            for component in ["init.luau", "init.lua"] {
175                let candidate = path.join(component);
176                if candidate.is_file() && found_path.replace(candidate).is_some() {
177                    return Err(NavigateError::Ambiguous);
178                }
179            }
180
181            if found_path.is_none() {
182                found_path = Some(PathBuf::new());
183            }
184        }
185
186        found_path.ok_or(NavigateError::NotFound)
187    }
188}
189
190impl Require for TextRequirer {
191    fn is_require_allowed(&self, chunk_name: &str) -> bool {
192        chunk_name.starts_with('@')
193    }
194
195    fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError> {
196        if !chunk_name.starts_with('@') {
197            return Err(NavigateError::NotFound);
198        }
199        let chunk_name = Self::normalize_chunk_name(&chunk_name[1..]);
200        let chunk_path = Self::normalize_path(chunk_name.as_ref());
201
202        if chunk_path.extension() == Some("rs".as_ref()) {
203            // Special case for Rust source files, reset to the current directory
204            let chunk_filename = chunk_path.file_name().unwrap();
205            let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
206            self.abs_path = Self::normalize_path(&cwd.join(chunk_filename));
207            self.rel_path = ([Component::CurDir, Component::Normal(chunk_filename)].into_iter()).collect();
208            self.module_path = PathBuf::new();
209
210            return Ok(());
211        }
212
213        if chunk_path.is_absolute() {
214            let module_path = Self::find_module(&chunk_path)?;
215            self.abs_path = chunk_path.clone();
216            self.rel_path = chunk_path;
217            self.module_path = module_path;
218        } else {
219            // Relative path
220            let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
221            let abs_path = Self::normalize_path(&cwd.join(&chunk_path));
222            let module_path = Self::find_module(&abs_path)?;
223            self.abs_path = abs_path;
224            self.rel_path = chunk_path;
225            self.module_path = module_path;
226        }
227
228        Ok(())
229    }
230
231    fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError> {
232        let path = Self::normalize_path(path.as_ref());
233        let module_path = Self::find_module(&path)?;
234
235        self.abs_path = path.clone();
236        self.rel_path = path;
237        self.module_path = module_path;
238
239        Ok(())
240    }
241
242    fn to_parent(&mut self) -> StdResult<(), NavigateError> {
243        let mut abs_path = self.abs_path.clone();
244        if !abs_path.pop() {
245            return Err(NavigateError::NotFound);
246        }
247        let mut rel_parent = self.rel_path.clone();
248        rel_parent.pop();
249        let module_path = Self::find_module(&abs_path)?;
250
251        self.abs_path = abs_path;
252        self.rel_path = Self::normalize_path(&rel_parent);
253        self.module_path = module_path;
254
255        Ok(())
256    }
257
258    fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError> {
259        let abs_path = self.abs_path.join(name);
260        let rel_path = self.rel_path.join(name);
261        let module_path = Self::find_module(&abs_path)?;
262
263        self.abs_path = abs_path;
264        self.rel_path = rel_path;
265        self.module_path = module_path;
266
267        Ok(())
268    }
269
270    fn has_module(&self) -> bool {
271        self.module_path.is_file()
272    }
273
274    fn cache_key(&self) -> String {
275        self.module_path.display().to_string()
276    }
277
278    fn has_config(&self) -> bool {
279        self.abs_path.is_dir() && self.abs_path.join(".luaurc").is_file()
280    }
281
282    fn config(&self) -> IoResult<Vec<u8>> {
283        fs::read(self.abs_path.join(".luaurc"))
284    }
285
286    fn loader(&self, lua: &Lua) -> Result<Function> {
287        let name = format!("@{}", self.rel_path.display());
288        lua.load(&*self.module_path).set_name(name).into_function()
289    }
290}
291
292struct Context(Box<dyn Require>);
293
294impl Deref for Context {
295    type Target = dyn Require;
296
297    fn deref(&self) -> &Self::Target {
298        &*self.0
299    }
300}
301
302impl DerefMut for Context {
303    fn deref_mut(&mut self) -> &mut Self::Target {
304        &mut *self.0
305    }
306}
307
308macro_rules! try_borrow {
309    ($state:expr, $ctx:expr) => {
310        match (*($ctx as *const RefCell<Context>)).try_borrow() {
311            Ok(ctx) => ctx,
312            Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
313        }
314    };
315}
316
317macro_rules! try_borrow_mut {
318    ($state:expr, $ctx:expr) => {
319        match (*($ctx as *const RefCell<Context>)).try_borrow_mut() {
320            Ok(ctx) => ctx,
321            Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
322        }
323    };
324}
325
326#[cfg(feature = "luau")]
327pub(super) unsafe extern "C-unwind" fn init_config(config: *mut ffi::luarequire_Configuration) {
328    if config.is_null() {
329        return;
330    }
331
332    unsafe extern "C-unwind" fn is_require_allowed(
333        state: *mut ffi::lua_State,
334        ctx: *mut c_void,
335        requirer_chunkname: *const c_char,
336    ) -> bool {
337        if requirer_chunkname.is_null() {
338            return false;
339        }
340
341        let this = try_borrow!(state, ctx);
342        let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
343        this.is_require_allowed(&chunk_name)
344    }
345
346    unsafe extern "C-unwind" fn reset(
347        state: *mut ffi::lua_State,
348        ctx: *mut c_void,
349        requirer_chunkname: *const c_char,
350    ) -> ffi::luarequire_NavigateResult {
351        let mut this = try_borrow_mut!(state, ctx);
352        let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
353        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
354            this.reset(&chunk_name).into_nav_result()
355        })
356    }
357
358    unsafe extern "C-unwind" fn jump_to_alias(
359        state: *mut ffi::lua_State,
360        ctx: *mut c_void,
361        path: *const c_char,
362    ) -> ffi::luarequire_NavigateResult {
363        let mut this = try_borrow_mut!(state, ctx);
364        let path = CStr::from_ptr(path).to_string_lossy();
365        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
366            this.jump_to_alias(&path).into_nav_result()
367        })
368    }
369
370    unsafe extern "C-unwind" fn to_parent(
371        state: *mut ffi::lua_State,
372        ctx: *mut c_void,
373    ) -> ffi::luarequire_NavigateResult {
374        let mut this = try_borrow_mut!(state, ctx);
375        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
376            this.to_parent().into_nav_result()
377        })
378    }
379
380    unsafe extern "C-unwind" fn to_child(
381        state: *mut ffi::lua_State,
382        ctx: *mut c_void,
383        name: *const c_char,
384    ) -> ffi::luarequire_NavigateResult {
385        let mut this = try_borrow_mut!(state, ctx);
386        let name = CStr::from_ptr(name).to_string_lossy();
387        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
388            this.to_child(&name).into_nav_result()
389        })
390    }
391
392    unsafe extern "C-unwind" fn is_module_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
393        let this = try_borrow!(state, ctx);
394        this.has_module()
395    }
396
397    unsafe extern "C-unwind" fn get_chunkname(
398        _state: *mut ffi::lua_State,
399        _ctx: *mut c_void,
400        buffer: *mut c_char,
401        buffer_size: usize,
402        size_out: *mut usize,
403    ) -> WriteResult {
404        write_to_buffer(buffer, buffer_size, size_out, &[])
405    }
406
407    unsafe extern "C-unwind" fn get_loadname(
408        _state: *mut ffi::lua_State,
409        _ctx: *mut c_void,
410        buffer: *mut c_char,
411        buffer_size: usize,
412        size_out: *mut usize,
413    ) -> WriteResult {
414        write_to_buffer(buffer, buffer_size, size_out, &[])
415    }
416
417    unsafe extern "C-unwind" fn get_cache_key(
418        state: *mut ffi::lua_State,
419        ctx: *mut c_void,
420        buffer: *mut c_char,
421        buffer_size: usize,
422        size_out: *mut usize,
423    ) -> WriteResult {
424        let this = try_borrow!(state, ctx);
425        let cache_key = this.cache_key();
426        write_to_buffer(buffer, buffer_size, size_out, cache_key.as_bytes())
427    }
428
429    unsafe extern "C-unwind" fn is_config_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
430        let this = try_borrow!(state, ctx);
431        this.has_config()
432    }
433
434    unsafe extern "C-unwind" fn get_config(
435        state: *mut ffi::lua_State,
436        ctx: *mut c_void,
437        buffer: *mut c_char,
438        buffer_size: usize,
439        size_out: *mut usize,
440    ) -> WriteResult {
441        let this = try_borrow!(state, ctx);
442        let config = callback_error_ext(state, ptr::null_mut(), true, move |_, _| Ok(this.config()?));
443        write_to_buffer(buffer, buffer_size, size_out, &config)
444    }
445
446    unsafe extern "C-unwind" fn load(
447        state: *mut ffi::lua_State,
448        ctx: *mut c_void,
449        _path: *const c_char,
450        _chunkname: *const c_char,
451        _loadname: *const c_char,
452    ) -> c_int {
453        let this = try_borrow!(state, ctx);
454        callback_error_ext(state, ptr::null_mut(), true, move |extra, _| {
455            let rawlua = (*extra).raw_lua();
456            let loader = this.loader(rawlua.lua())?;
457            rawlua.push(loader)?;
458            Ok(1)
459        })
460    }
461
462    (*config).is_require_allowed = is_require_allowed;
463    (*config).reset = reset;
464    (*config).jump_to_alias = jump_to_alias;
465    (*config).to_parent = to_parent;
466    (*config).to_child = to_child;
467    (*config).is_module_present = is_module_present;
468    (*config).get_chunkname = get_chunkname;
469    (*config).get_loadname = get_loadname;
470    (*config).get_cache_key = get_cache_key;
471    (*config).is_config_present = is_config_present;
472    (*config).get_alias = None;
473    (*config).get_config = Some(get_config);
474    (*config).load = load;
475}
476
477/// Helper function to write data to a buffer
478#[cfg(feature = "luau")]
479unsafe fn write_to_buffer(
480    buffer: *mut c_char,
481    buffer_size: usize,
482    size_out: *mut usize,
483    data: &[u8],
484) -> WriteResult {
485    // the buffer must be null terminated as it's a c++ `std::string` data() buffer
486    let is_null_terminated = data.last() == Some(&0);
487    *size_out = data.len() + if is_null_terminated { 0 } else { 1 };
488    if *size_out > buffer_size {
489        return WriteResult::BufferTooSmall;
490    }
491    ptr::copy_nonoverlapping(data.as_ptr(), buffer as *mut _, data.len());
492    if !is_null_terminated {
493        *buffer.add(data.len()) = 0;
494    }
495    WriteResult::Success
496}
497
498#[cfg(feature = "luau")]
499pub fn create_require_function<R: Require + 'static>(lua: &Lua, require: R) -> Result<Function> {
500    unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int {
501        let mut ar: ffi::lua_Debug = mem::zeroed();
502        for level in 2.. {
503            if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 {
504                ffi::luaL_error(state, cstr!("require is not supported in this context"));
505            }
506            if CStr::from_ptr(ar.what) != c"C" {
507                break;
508            }
509        }
510        ffi::lua_pushstring(state, ar.source);
511        1
512    }
513
514    unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int {
515        let ctx = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1));
516        let ctx = try_borrow!(state, ctx);
517        let cache_key = ctx.cache_key();
518        ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len());
519        1
520    }
521
522    let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe {
523        lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| {
524            let context = Context(Box::new(require));
525            let context_ptr = ffi::lua_newuserdata_t(state, RefCell::new(context));
526            ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1);
527            ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file"));
528            ffi::luarequire_pushproxyrequire(state, init_config, context_ptr as *mut _);
529            ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE);
530            ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__MLUA_LOADER_CACHE"));
531        })
532    }?;
533
534    unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int {
535        ffi::luaL_where(state, 1);
536        ffi::lua_pushvalue(state, 1);
537        ffi::lua_concat(state, 2);
538        ffi::lua_error(state);
539    }
540
541    unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int {
542        ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1)));
543        1
544    }
545
546    let (error, r#type) = unsafe {
547        lua.exec_raw::<(Function, Function)>((), move |state| {
548            ffi::lua_pushcfunctiond(state, error, cstr!("error"));
549            ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
550        })
551    }?;
552
553    // Prepare environment for the "require" function
554    let env = lua.create_table_with_capacity(0, 7)?;
555    env.raw_set("get_cache_key", get_cache_key)?;
556    env.raw_set("find_current_file", find_current_file)?;
557    env.raw_set("proxyrequire", proxyrequire)?;
558    env.raw_set("REGISTERED_MODULES", registered_modules)?;
559    env.raw_set("LOADER_CACHE", loader_cache)?;
560    env.raw_set("error", error)?;
561    env.raw_set("type", r#type)?;
562
563    lua.load(
564        r#"
565        local path = ...
566        if type(path) ~= "string" then
567            error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")")
568        end
569
570        -- Check if the module (path) is explicitly registered
571        local maybe_result = REGISTERED_MODULES[path]
572        if maybe_result ~= nil then
573            return maybe_result
574        end
575
576        local loader = proxyrequire(path, find_current_file())
577        local cache_key = get_cache_key()
578        -- Check if the loader result is already cached
579        local result = LOADER_CACHE[cache_key]
580        if result ~= nil then
581            return result
582        end
583
584        -- Call the loader function and cache the result
585        result = loader()
586        if result == nil then
587            result = true
588        end
589        LOADER_CACHE[cache_key] = result
590        return result
591        "#,
592    )
593    .try_cache()
594    .set_name("=__mlua_require")
595    .set_environment(env)
596    .into_function()
597}
598
599#[cfg(test)]
600mod tests {
601    use std::path::Path;
602
603    use super::TextRequirer;
604
605    #[test]
606    fn test_path_normalize() {
607        for (input, expected) in [
608            // Basic formatting checks
609            ("", "./"),
610            (".", "./"),
611            ("a/relative/path", "./a/relative/path"),
612            // Paths containing extraneous '.' and '/' symbols
613            ("./remove/extraneous/symbols/", "./remove/extraneous/symbols"),
614            ("./remove/extraneous//symbols", "./remove/extraneous/symbols"),
615            ("./remove/extraneous/symbols/.", "./remove/extraneous/symbols"),
616            ("./remove/extraneous/./symbols", "./remove/extraneous/symbols"),
617            ("../remove/extraneous/symbols/", "../remove/extraneous/symbols"),
618            ("../remove/extraneous//symbols", "../remove/extraneous/symbols"),
619            ("../remove/extraneous/symbols/.", "../remove/extraneous/symbols"),
620            ("../remove/extraneous/./symbols", "../remove/extraneous/symbols"),
621            ("/remove/extraneous/symbols/", "/remove/extraneous/symbols"),
622            ("/remove/extraneous//symbols", "/remove/extraneous/symbols"),
623            ("/remove/extraneous/symbols/.", "/remove/extraneous/symbols"),
624            ("/remove/extraneous/./symbols", "/remove/extraneous/symbols"),
625            // Paths containing '..'
626            ("./remove/me/..", "./remove"),
627            ("./remove/me/../", "./remove"),
628            ("../remove/me/..", "../remove"),
629            ("../remove/me/../", "../remove"),
630            ("/remove/me/..", "/remove"),
631            ("/remove/me/../", "/remove"),
632            ("./..", "../"),
633            ("./../", "../"),
634            ("../..", "../../"),
635            ("../../", "../../"),
636            // '..' disappears if path is absolute and component is non-erasable
637            ("/../", "/"),
638        ] {
639            let path = TextRequirer::normalize_path(input.as_ref());
640            assert_eq!(
641                &path,
642                expected.as_ref() as &Path,
643                "wrong normalization for {input}"
644            );
645        }
646    }
647}