Skip to main content

mlua/luau/require/
fs.rs

1use std::collections::VecDeque;
2use std::io::Result as IoResult;
3use std::path::{Component, Path, PathBuf};
4use std::result::Result as StdResult;
5use std::{env, fs};
6
7use crate::error::Result;
8use crate::function::Function;
9use crate::state::Lua;
10
11use super::{NavigateError, Require};
12
13/// The standard implementation of Luau `require-by-string` navigation.
14#[derive(Default, Debug)]
15pub struct TextRequirer {
16    /// An absolute path to the current Luau module (not mapped to a physical file)
17    abs_path: PathBuf,
18    /// A relative path to the current Luau module (not mapped to a physical file)
19    rel_path: PathBuf,
20    /// A physical path to the current Luau module, which is a file or a directory with an
21    /// `init.lua(u)` file
22    resolved_path: Option<PathBuf>,
23}
24
25impl TextRequirer {
26    /// The prefix used for chunk names in the require system.
27    /// Only chunk names starting with this prefix are allowed to be used in `require`.
28    const CHUNK_PREFIX: &str = "@";
29
30    /// The file extensions that are considered valid for Luau modules.
31    const FILE_EXTENSIONS: &[&str] = &["luau", "lua"];
32
33    /// The filename for the JSON configuration file.
34    const LUAURC_CONFIG_FILENAME: &str = ".luaurc";
35
36    /// The filename for the Luau configuration file.
37    const LUAU_CONFIG_FILENAME: &str = ".config.luau";
38
39    /// Creates a new `TextRequirer` instance.
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    fn normalize_chunk_name(chunk_name: &str) -> &str {
45        if let Some((path, line)) = chunk_name.rsplit_once(':') {
46            if line.parse::<u32>().is_ok() {
47                return path;
48            }
49        }
50        chunk_name
51    }
52
53    // Normalizes the path by removing unnecessary components
54    fn normalize_path(path: &Path) -> PathBuf {
55        let mut components = VecDeque::new();
56
57        for comp in path.components() {
58            match comp {
59                Component::Prefix(..) | Component::RootDir => {
60                    components.push_back(comp);
61                }
62                Component::CurDir => {}
63                Component::ParentDir => {
64                    if matches!(components.back(), None | Some(Component::ParentDir)) {
65                        components.push_back(Component::ParentDir);
66                    } else if matches!(components.back(), Some(Component::Normal(..))) {
67                        components.pop_back();
68                    }
69                }
70                Component::Normal(..) => components.push_back(comp),
71            }
72        }
73
74        if matches!(components.front(), None | Some(Component::Normal(..))) {
75            components.push_front(Component::CurDir);
76        }
77
78        // Join the components back together
79        components.into_iter().collect()
80    }
81
82    /// Resolve a Luau module path to a physical file or directory.
83    ///
84    /// Empty directories without init files are considered valid as "intermediate" directories.
85    fn resolve_module(path: &Path) -> StdResult<Option<PathBuf>, NavigateError> {
86        let mut found_path = None;
87
88        if path.components().next_back() != Some(Component::Normal("init".as_ref())) {
89            let current_ext = (path.extension().and_then(|s| s.to_str()))
90                .map(|s| format!("{s}."))
91                .unwrap_or_default();
92            for ext in Self::FILE_EXTENSIONS {
93                let candidate = path.with_extension(format!("{current_ext}{ext}"));
94                if candidate.is_file() && found_path.replace(candidate).is_some() {
95                    return Err(NavigateError::Ambiguous);
96                }
97            }
98        }
99        if path.is_dir() {
100            for component in Self::FILE_EXTENSIONS.iter().map(|ext| format!("init.{ext}")) {
101                let candidate = path.join(component);
102                if candidate.is_file() && found_path.replace(candidate).is_some() {
103                    return Err(NavigateError::Ambiguous);
104                }
105            }
106
107            if found_path.is_none() {
108                // Directories without init files are considered valid "intermediate" path
109                return Ok(None);
110            }
111        }
112
113        Ok(Some(found_path.ok_or(NavigateError::NotFound)?))
114    }
115}
116
117impl Require for TextRequirer {
118    fn is_require_allowed(&self, chunk_name: &str) -> bool {
119        chunk_name.starts_with(Self::CHUNK_PREFIX)
120    }
121
122    fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError> {
123        if !chunk_name.starts_with(Self::CHUNK_PREFIX) {
124            return Err(NavigateError::NotFound);
125        }
126        let chunk_name = Self::normalize_chunk_name(&chunk_name[1..]);
127        let chunk_path = Self::normalize_path(chunk_name.as_ref());
128
129        if chunk_path.extension() == Some("rs".as_ref()) {
130            // Special case for Rust source files, reset to the current directory
131            let chunk_filename = chunk_path.file_name().unwrap();
132            let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
133            self.abs_path = Self::normalize_path(&cwd.join(chunk_filename));
134            self.rel_path = ([Component::CurDir, Component::Normal(chunk_filename)].into_iter()).collect();
135            self.resolved_path = None;
136
137            return Ok(());
138        }
139
140        if chunk_path.is_absolute() {
141            let resolved_path = Self::resolve_module(&chunk_path)?;
142            self.abs_path = chunk_path.clone();
143            self.rel_path = chunk_path;
144            self.resolved_path = resolved_path;
145        } else {
146            // Relative path
147            let cwd = env::current_dir().map_err(|_| NavigateError::NotFound)?;
148            let abs_path = Self::normalize_path(&cwd.join(&chunk_path));
149            let resolved_path = Self::resolve_module(&abs_path)?;
150            self.abs_path = abs_path;
151            self.rel_path = chunk_path;
152            self.resolved_path = resolved_path;
153        }
154
155        Ok(())
156    }
157
158    fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError> {
159        let path = Self::normalize_path(path.as_ref());
160        let resolved_path = Self::resolve_module(&path)?;
161
162        self.abs_path = path.clone();
163        self.rel_path = path;
164        self.resolved_path = resolved_path;
165
166        Ok(())
167    }
168
169    fn to_parent(&mut self) -> StdResult<(), NavigateError> {
170        let mut abs_path = self.abs_path.clone();
171        if !abs_path.pop() {
172            // It's important to return `NotFound` if we reached the root, as it's a "recoverable" error if we
173            // cannot go beyond the root directory.
174            // Luau "require-by-string` has a special logic to search for config file to resolve aliases.
175            return Err(NavigateError::NotFound);
176        }
177        let mut rel_parent = self.rel_path.clone();
178        rel_parent.pop();
179        let resolved_path = Self::resolve_module(&abs_path)?;
180
181        self.abs_path = abs_path;
182        self.rel_path = Self::normalize_path(&rel_parent);
183        self.resolved_path = resolved_path;
184
185        Ok(())
186    }
187
188    fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError> {
189        let abs_path = self.abs_path.join(name);
190        let rel_path = self.rel_path.join(name);
191        let resolved_path = Self::resolve_module(&abs_path)?;
192
193        self.abs_path = abs_path;
194        self.rel_path = rel_path;
195        self.resolved_path = resolved_path;
196
197        Ok(())
198    }
199
200    fn has_module(&self) -> bool {
201        (self.resolved_path.as_deref())
202            .map(Path::is_file)
203            .unwrap_or(false)
204    }
205
206    fn cache_key(&self) -> String {
207        self.resolved_path.as_deref().unwrap().display().to_string()
208    }
209
210    fn has_config(&self) -> bool {
211        self.abs_path.is_dir() && self.abs_path.join(Self::LUAURC_CONFIG_FILENAME).is_file()
212            || self.abs_path.is_dir() && self.abs_path.join(Self::LUAU_CONFIG_FILENAME).is_file()
213    }
214
215    fn config(&self) -> IoResult<Vec<u8>> {
216        if self.abs_path.join(Self::LUAURC_CONFIG_FILENAME).is_file() {
217            return fs::read(self.abs_path.join(Self::LUAURC_CONFIG_FILENAME));
218        }
219        fs::read(self.abs_path.join(Self::LUAU_CONFIG_FILENAME))
220    }
221
222    fn loader(&self, lua: &Lua) -> Result<Function> {
223        let name = format!("@{}", self.rel_path.display());
224        lua.load(self.resolved_path.as_deref().unwrap())
225            .set_name(name)
226            .into_function()
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use std::path::Path;
233
234    use super::TextRequirer;
235
236    #[test]
237    fn test_path_normalize() {
238        for (input, expected) in [
239            // Basic formatting checks
240            ("", "./"),
241            (".", "./"),
242            ("a/relative/path", "./a/relative/path"),
243            // Paths containing extraneous '.' and '/' symbols
244            ("./remove/extraneous/symbols/", "./remove/extraneous/symbols"),
245            ("./remove/extraneous//symbols", "./remove/extraneous/symbols"),
246            ("./remove/extraneous/symbols/.", "./remove/extraneous/symbols"),
247            ("./remove/extraneous/./symbols", "./remove/extraneous/symbols"),
248            ("../remove/extraneous/symbols/", "../remove/extraneous/symbols"),
249            ("../remove/extraneous//symbols", "../remove/extraneous/symbols"),
250            ("../remove/extraneous/symbols/.", "../remove/extraneous/symbols"),
251            ("../remove/extraneous/./symbols", "../remove/extraneous/symbols"),
252            ("/remove/extraneous/symbols/", "/remove/extraneous/symbols"),
253            ("/remove/extraneous//symbols", "/remove/extraneous/symbols"),
254            ("/remove/extraneous/symbols/.", "/remove/extraneous/symbols"),
255            ("/remove/extraneous/./symbols", "/remove/extraneous/symbols"),
256            // Paths containing '..'
257            ("./remove/me/..", "./remove"),
258            ("./remove/me/../", "./remove"),
259            ("../remove/me/..", "../remove"),
260            ("../remove/me/../", "../remove"),
261            ("/remove/me/..", "/remove"),
262            ("/remove/me/../", "/remove"),
263            ("./..", "../"),
264            ("./../", "../"),
265            ("../..", "../../"),
266            ("../../", "../../"),
267            // '..' disappears if path is absolute and component is non-erasable
268            ("/../", "/"),
269        ] {
270            let path = TextRequirer::normalize_path(input.as_ref());
271            assert_eq!(
272                &path,
273                expected.as_ref() as &Path,
274                "wrong normalization for {input}"
275            );
276        }
277    }
278}