darklua_core/rules/require/
path_require_mode.rs

1use serde::{Deserialize, Serialize};
2
3use crate::frontend::DarkluaResult;
4use crate::nodes::{Arguments, FunctionCall, StringExpression};
5use crate::rules::require::path_utils::get_relative_path;
6use crate::rules::require::{match_path_require_call, path_utils, PathLocator};
7use crate::rules::{Context, RequireMode};
8use crate::utils;
9use crate::DarkluaError;
10
11use std::collections::HashMap;
12use std::ffi::OsStr;
13use std::path::{Path, PathBuf};
14
15use super::RequirePathLocator;
16
17/// A require mode for handling content from file system paths.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19#[serde(deny_unknown_fields, rename_all = "snake_case")]
20pub struct PathRequireMode {
21    #[serde(
22        skip_serializing_if = "is_default_module_folder_name",
23        default = "get_default_module_folder_name"
24    )]
25    module_folder_name: String,
26    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
27    sources: HashMap<String, PathBuf>,
28    #[serde(default = "default_use_luau_configuration")]
29    use_luau_configuration: bool,
30    #[serde(skip)]
31    luau_rc_aliases: Option<HashMap<String, PathBuf>>,
32}
33
34fn default_use_luau_configuration() -> bool {
35    true
36}
37
38impl Default for PathRequireMode {
39    fn default() -> Self {
40        Self {
41            module_folder_name: get_default_module_folder_name(),
42            sources: Default::default(),
43            use_luau_configuration: default_use_luau_configuration(),
44            luau_rc_aliases: Default::default(),
45        }
46    }
47}
48
49const DEFAULT_MODULE_FOLDER_NAME: &str = "init";
50
51#[inline]
52fn get_default_module_folder_name() -> String {
53    DEFAULT_MODULE_FOLDER_NAME.to_owned()
54}
55
56fn is_default_module_folder_name(value: &String) -> bool {
57    value == DEFAULT_MODULE_FOLDER_NAME
58}
59
60impl PathRequireMode {
61    /// Creates a new path require mode with the specified module folder name.
62    pub fn new(module_folder_name: impl Into<String>) -> Self {
63        Self {
64            module_folder_name: module_folder_name.into(),
65            sources: Default::default(),
66            use_luau_configuration: default_use_luau_configuration(),
67            luau_rc_aliases: Default::default(),
68        }
69    }
70
71    pub(crate) fn initialize(&mut self, context: &Context) -> Result<(), DarkluaError> {
72        if !self.use_luau_configuration {
73            self.luau_rc_aliases.take();
74            return Ok(());
75        }
76
77        if let Some(config) =
78            utils::find_luau_configuration(context.current_path(), context.resources())?
79        {
80            self.luau_rc_aliases.replace(config.aliases);
81        } else {
82            self.luau_rc_aliases.take();
83        }
84
85        Ok(())
86    }
87
88    pub(crate) fn module_folder_name(&self) -> &str {
89        &self.module_folder_name
90    }
91
92    pub(crate) fn get_source(&self, name: &str, rel: &Path) -> Option<PathBuf> {
93        log::trace!(
94            "lookup alias `{}` from `{}` (path mode)",
95            name,
96            rel.display()
97        );
98
99        self.sources
100            .get(name)
101            .map(|alias| rel.join(alias))
102            .or_else(|| {
103                self.luau_rc_aliases
104                    .as_ref()
105                    .and_then(|aliases| aliases.get(name))
106                    .map(ToOwned::to_owned)
107            })
108    }
109
110    pub(crate) fn find_require(
111        &self,
112        call: &FunctionCall,
113        context: &Context,
114    ) -> DarkluaResult<Option<PathBuf>> {
115        if let Some(literal_path) = match_path_require_call(call) {
116            let required_path =
117                RequirePathLocator::new(self, context.project_location(), context.resources())
118                    .find_require_path(literal_path, context.current_path())?;
119
120            Ok(Some(required_path))
121        } else {
122            Ok(None)
123        }
124    }
125
126    pub(crate) fn is_module_folder_name(&self, path: &Path) -> bool {
127        let expect_value = Some(self.module_folder_name.as_str());
128        path.file_name().and_then(OsStr::to_str) == expect_value
129            || path.file_stem().and_then(OsStr::to_str) == expect_value
130    }
131
132    pub(crate) fn generate_require(
133        &self,
134        require_path: &Path,
135        _current: &RequireMode,
136        context: &Context<'_, '_, '_>,
137    ) -> Result<Option<crate::nodes::Arguments>, crate::DarkluaError> {
138        let source_path = utils::normalize_path(context.current_path());
139        log::debug!(
140            "generate path require for `{}` from `{}`",
141            require_path.display(),
142            source_path.display(),
143        );
144
145        let mut generated_path = if path_utils::is_require_relative(require_path) {
146            require_path.to_path_buf()
147        } else {
148            let normalized_require_path = utils::normalize_path(require_path);
149            log::trace!(
150                " ⨽ adjust non-relative path `{}` (normalized to `{}`) from `{}`",
151                require_path.display(),
152                normalized_require_path.display(),
153                source_path.display()
154            );
155
156            let mut potential_aliases: Vec<_> = self
157                .sources
158                .iter()
159                .map(|(alias_name, alias_path)| {
160                    (
161                        alias_name,
162                        utils::normalize_path(context.project_location().join(alias_path)),
163                    )
164                })
165                .filter(|(_, alias_path)| normalized_require_path.starts_with(alias_path))
166                .inspect(|(alias_name, alias_path)| {
167                    log::trace!(
168                        "   > alias candidate `{}` (`{}`)",
169                        alias_name,
170                        alias_path.display()
171                    );
172                })
173                .collect();
174            potential_aliases.sort_by_cached_key(|(_, alias_path)| alias_path.components().count());
175
176            if let Some((alias_name, alias_path)) = potential_aliases.into_iter().next_back() {
177                let mut new_path = PathBuf::from(alias_name);
178
179                new_path.extend(
180                    normalized_require_path
181                        .components()
182                        .skip(alias_path.components().count()),
183                );
184
185                new_path
186            } else if let Some(relative_require_path) =
187                get_relative_path(&normalized_require_path, &source_path, true)?
188            {
189                log::trace!(
190                    "   found relative path from source: `{}`",
191                    relative_require_path.display()
192                );
193
194                if !(relative_require_path.starts_with(".")
195                    || relative_require_path.starts_with(".."))
196                {
197                    Path::new(".").join(relative_require_path)
198                } else {
199                    relative_require_path
200                }
201            } else {
202                normalized_require_path
203            }
204        };
205
206        if self.is_module_folder_name(&generated_path) {
207            generated_path.pop();
208        } else if matches!(generated_path.extension(), Some(extension) if extension == "lua" || extension == "luau")
209        {
210            generated_path.set_extension("");
211        }
212
213        path_utils::write_require_path(&generated_path).map(generate_require_arguments)
214    }
215}
216
217fn generate_require_arguments(value: String) -> Option<Arguments> {
218    Some(Arguments::default().with_argument(StringExpression::from_value(value)))
219}
220
221#[cfg(test)]
222mod test {
223    use super::*;
224
225    mod is_module_folder_name {
226        use super::*;
227
228        #[test]
229        fn default_mode_is_false_for_regular_name() {
230            let require_mode = PathRequireMode::default();
231
232            assert!(!require_mode.is_module_folder_name(Path::new("oops.lua")));
233        }
234
235        #[test]
236        fn default_mode_is_true_for_init_lua() {
237            let require_mode = PathRequireMode::default();
238
239            assert!(require_mode.is_module_folder_name(Path::new("init.lua")));
240        }
241
242        #[test]
243        fn default_mode_is_true_for_init_luau() {
244            let require_mode = PathRequireMode::default();
245
246            assert!(require_mode.is_module_folder_name(Path::new("init.luau")));
247        }
248
249        #[test]
250        fn default_mode_is_true_for_folder_init_lua() {
251            let require_mode = PathRequireMode::default();
252
253            assert!(require_mode.is_module_folder_name(Path::new("folder/init.lua")));
254        }
255
256        #[test]
257        fn default_mode_is_true_for_folder_init_luau() {
258            let require_mode = PathRequireMode::default();
259
260            assert!(require_mode.is_module_folder_name(Path::new("folder/init.luau")));
261        }
262    }
263}