darklua_core/rules/require/
path_require_mode.rs

1use serde::{Deserialize, Serialize};
2
3use crate::frontend::DarkluaResult;
4use crate::nodes::{Arguments, FunctionCall, StringExpression};
5use crate::rules::require::path_utils::get_relative_path;
6use crate::rules::require::{match_path_require_call, path_utils, PathLocator};
7use crate::rules::{Context, RequireMode};
8use crate::utils;
9use crate::DarkluaError;
10
11use std::collections::HashMap;
12use std::ffi::OsStr;
13use std::path::{Path, PathBuf};
14
15use super::RequirePathLocator;
16
17/// A require mode for handling content from file system paths.
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19#[serde(deny_unknown_fields, rename_all = "snake_case")]
20pub struct PathRequireMode {
21    #[serde(
22        skip_serializing_if = "is_default_module_folder_name",
23        default = "get_default_module_folder_name"
24    )]
25    module_folder_name: String,
26    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
27    sources: HashMap<String, PathBuf>,
28    #[serde(default = "default_use_luau_configuration")]
29    use_luau_configuration: bool,
30    #[serde(skip)]
31    luau_rc_aliases: Option<HashMap<String, PathBuf>>,
32}
33
34fn default_use_luau_configuration() -> bool {
35    true
36}
37
38impl Default for PathRequireMode {
39    fn default() -> Self {
40        Self {
41            module_folder_name: get_default_module_folder_name(),
42            sources: Default::default(),
43            use_luau_configuration: default_use_luau_configuration(),
44            luau_rc_aliases: Default::default(),
45        }
46    }
47}
48
49const DEFAULT_MODULE_FOLDER_NAME: &str = "init";
50
51#[inline]
52fn get_default_module_folder_name() -> String {
53    DEFAULT_MODULE_FOLDER_NAME.to_owned()
54}
55
56fn is_default_module_folder_name(value: &String) -> bool {
57    value == DEFAULT_MODULE_FOLDER_NAME
58}
59
60impl PathRequireMode {
61    /// Creates a new path require mode with the specified module folder name.
62    pub fn new(module_folder_name: impl Into<String>) -> Self {
63        Self {
64            module_folder_name: module_folder_name.into(),
65            sources: Default::default(),
66            use_luau_configuration: default_use_luau_configuration(),
67            luau_rc_aliases: Default::default(),
68        }
69    }
70
71    pub(crate) fn initialize(&mut self, context: &Context) -> Result<(), DarkluaError> {
72        if !self.use_luau_configuration {
73            self.luau_rc_aliases.take();
74            return Ok(());
75        }
76
77        if let Some(config) =
78            utils::find_luau_configuration(context.current_path(), context.resources())?
79        {
80            self.luau_rc_aliases.replace(config.aliases);
81        } else {
82            self.luau_rc_aliases.take();
83        }
84
85        Ok(())
86    }
87
88    pub(crate) fn module_folder_name(&self) -> &str {
89        &self.module_folder_name
90    }
91
92    pub(crate) fn get_source(&self, name: &str, rel: &Path) -> Option<PathBuf> {
93        self.sources
94            .get(name)
95            .map(|alias| rel.join(alias))
96            .or_else(|| {
97                self.luau_rc_aliases
98                    .as_ref()
99                    .and_then(|aliases| aliases.get(name))
100                    .map(ToOwned::to_owned)
101            })
102    }
103
104    pub(crate) fn find_require(
105        &self,
106        call: &FunctionCall,
107        context: &Context,
108    ) -> DarkluaResult<Option<PathBuf>> {
109        if let Some(literal_path) = match_path_require_call(call) {
110            let required_path =
111                RequirePathLocator::new(self, context.project_location(), context.resources())
112                    .find_require_path(literal_path, context.current_path())?;
113
114            Ok(Some(required_path))
115        } else {
116            Ok(None)
117        }
118    }
119
120    pub(crate) fn is_module_folder_name(&self, path: &Path) -> bool {
121        let expect_value = Some(self.module_folder_name.as_str());
122        path.file_name().and_then(OsStr::to_str) == expect_value
123            || path.file_stem().and_then(OsStr::to_str) == expect_value
124    }
125
126    pub(crate) fn generate_require(
127        &self,
128        require_path: &Path,
129        _current: &RequireMode,
130        context: &Context<'_, '_, '_>,
131    ) -> Result<Option<crate::nodes::Arguments>, crate::DarkluaError> {
132        let source_path = utils::normalize_path(context.current_path());
133        log::debug!(
134            "generate path require for `{}` from `{}`",
135            require_path.display(),
136            source_path.display(),
137        );
138
139        let mut generated_path = if path_utils::is_require_relative(require_path) {
140            require_path.to_path_buf()
141        } else {
142            let normalized_require_path = utils::normalize_path(require_path);
143            log::trace!(
144                " ⨽ adjust non-relative path `{}` (normalized to `{}`) from `{}`",
145                require_path.display(),
146                normalized_require_path.display(),
147                source_path.display()
148            );
149
150            let mut potential_aliases: Vec<_> = self
151                .sources
152                .iter()
153                .map(|(alias_name, alias_path)| {
154                    (
155                        alias_name,
156                        utils::normalize_path(context.project_location().join(alias_path)),
157                    )
158                })
159                .filter(|(_, alias_path)| normalized_require_path.starts_with(alias_path))
160                .inspect(|(alias_name, alias_path)| {
161                    log::trace!(
162                        "   > alias candidate `{}` (`{}`)",
163                        alias_name,
164                        alias_path.display()
165                    );
166                })
167                .collect();
168            potential_aliases.sort_by_cached_key(|(_, alias_path)| alias_path.components().count());
169
170            if let Some((alias_name, alias_path)) = potential_aliases.into_iter().next_back() {
171                let mut new_path = PathBuf::from(alias_name);
172
173                new_path.extend(
174                    normalized_require_path
175                        .components()
176                        .skip(alias_path.components().count()),
177                );
178
179                new_path
180            } else if let Some(relative_require_path) =
181                get_relative_path(&normalized_require_path, &source_path, true)?
182            {
183                log::trace!(
184                    "   found relative path from source: `{}`",
185                    relative_require_path.display()
186                );
187
188                if !(relative_require_path.starts_with(".")
189                    || relative_require_path.starts_with(".."))
190                {
191                    Path::new(".").join(relative_require_path)
192                } else {
193                    relative_require_path
194                }
195            } else {
196                normalized_require_path
197            }
198        };
199
200        if self.is_module_folder_name(&generated_path) {
201            generated_path.pop();
202        } else if matches!(generated_path.extension(), Some(extension) if extension == "lua" || extension == "luau")
203        {
204            generated_path.set_extension("");
205        }
206
207        path_utils::write_require_path(&generated_path).map(generate_require_arguments)
208    }
209}
210
211fn generate_require_arguments(value: String) -> Option<Arguments> {
212    Some(Arguments::default().with_argument(StringExpression::from_value(value)))
213}
214
215#[cfg(test)]
216mod test {
217    use super::*;
218
219    mod is_module_folder_name {
220        use super::*;
221
222        #[test]
223        fn default_mode_is_false_for_regular_name() {
224            let require_mode = PathRequireMode::default();
225
226            assert!(!require_mode.is_module_folder_name(Path::new("oops.lua")));
227        }
228
229        #[test]
230        fn default_mode_is_true_for_init_lua() {
231            let require_mode = PathRequireMode::default();
232
233            assert!(require_mode.is_module_folder_name(Path::new("init.lua")));
234        }
235
236        #[test]
237        fn default_mode_is_true_for_init_luau() {
238            let require_mode = PathRequireMode::default();
239
240            assert!(require_mode.is_module_folder_name(Path::new("init.luau")));
241        }
242
243        #[test]
244        fn default_mode_is_true_for_folder_init_lua() {
245            let require_mode = PathRequireMode::default();
246
247            assert!(require_mode.is_module_folder_name(Path::new("folder/init.lua")));
248        }
249
250        #[test]
251        fn default_mode_is_true_for_folder_init_luau() {
252            let require_mode = PathRequireMode::default();
253
254            assert!(require_mode.is_module_folder_name(Path::new("folder/init.luau")));
255        }
256    }
257}