darklua_core/rules/require/
path_iterator.rs

1use std::{
2    ffi::OsStr,
3    path::{Path, PathBuf},
4};
5
6pub(crate) fn find_require_paths<'a, 'b, 'c>(
7    path: &'a Path,
8    module_folder_name: &'b str,
9) -> impl Iterator<Item = PathBuf> + 'c
10where
11    'a: 'c,
12    'b: 'c,
13{
14    PathIterator::new(path, module_folder_name)
15}
16
17struct PathIterator<'a, 'b> {
18    path: &'a Path,
19    extension: Option<&'a OsStr>,
20    file_name: Option<&'a OsStr>,
21    module_folder_name: &'b str,
22    index: u8,
23}
24
25impl<'a, 'b> PathIterator<'a, 'b> {
26    fn new(path: &'a Path, module_folder_name: &'b str) -> Self {
27        Self {
28            path,
29            extension: path.extension(),
30            file_name: path.file_name(),
31            module_folder_name,
32            index: 0,
33        }
34    }
35
36    #[inline]
37    fn return_next(&mut self, path: PathBuf) -> Option<PathBuf> {
38        self.index += 1;
39        Some(path)
40    }
41}
42
43impl Iterator for PathIterator<'_, '_> {
44    type Item = PathBuf;
45
46    fn next(&mut self) -> Option<Self::Item> {
47        if self.index == 0 {
48            return self.return_next(self.path.to_path_buf());
49        }
50
51        match (self.extension, self.file_name) {
52            (Some(extension), _) if matches!(extension.to_str(), Some("luau" | "lua")) => None,
53            (_, Some(name)) => match self.index {
54                1 => {
55                    let mut next_name = name.to_os_string();
56                    next_name.push(".luau");
57                    self.return_next(self.path.with_file_name(next_name))
58                }
59                2 => {
60                    let mut next_name = name.to_os_string();
61                    next_name.push(".lua");
62                    self.return_next(self.path.with_file_name(next_name))
63                }
64                3 => self.return_next(self.path.join(self.module_folder_name)),
65                4 | 5 => {
66                    let mut next_path = self.path.join(self.module_folder_name);
67                    if next_path.extension().is_some() {
68                        None
69                    } else {
70                        next_path.set_extension(if self.index == 4 { "luau" } else { "lua" });
71                        self.return_next(next_path)
72                    }
73                }
74                _ => None,
75            },
76            (_, None) => match self.index {
77                1 => self.return_next(self.path.join(self.module_folder_name)),
78                2 | 3 => {
79                    let mut next_path = self.path.join(self.module_folder_name);
80                    if next_path.extension().is_some() {
81                        None
82                    } else {
83                        next_path.set_extension(if self.index == 2 { "luau" } else { "lua" });
84                        self.return_next(next_path)
85                    }
86                }
87                _ => None,
88            },
89        }
90    }
91}
92
93#[cfg(test)]
94mod test {
95    use super::*;
96
97    const ANY_FOLDER_NAME: &str = "test";
98    const ANY_FOLDER_NAME_WITH_EXTENSION: &str = "test.luau";
99
100    #[test]
101    fn returns_exact_path_when_path_has_a_lua_extension() {
102        let source = Path::new("hello.lua");
103        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
104
105        pretty_assertions::assert_eq!(vec![source.to_path_buf()], iterator.collect::<Vec<_>>())
106    }
107
108    #[test]
109    fn returns_exact_path_when_path_has_a_luau_extension() {
110        let source = Path::new("hello.luau");
111        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
112
113        pretty_assertions::assert_eq!(vec![source.to_path_buf()], iterator.collect::<Vec<_>>())
114    }
115
116    #[test]
117    fn returns_paths_when_a_random_extension() {
118        let source = Path::new("hello.global");
119        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
120
121        pretty_assertions::assert_eq!(
122            vec![
123                source.to_path_buf(),
124                PathBuf::from("hello.global.luau"),
125                PathBuf::from("hello.global.lua"),
126                source.join(ANY_FOLDER_NAME),
127                source.join(ANY_FOLDER_NAME).with_extension("luau"),
128                source.join(ANY_FOLDER_NAME).with_extension("lua"),
129            ],
130            iterator.collect::<Vec<_>>()
131        )
132    }
133
134    #[test]
135    fn returns_paths_when_path_has_no_extension() {
136        let source = Path::new("hello");
137        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
138
139        pretty_assertions::assert_eq!(
140            vec![
141                source.to_path_buf(),
142                source.with_extension("luau"),
143                source.with_extension("lua"),
144                source.join(ANY_FOLDER_NAME),
145                source.join(ANY_FOLDER_NAME).with_extension("luau"),
146                source.join(ANY_FOLDER_NAME).with_extension("lua"),
147            ],
148            iterator.collect::<Vec<_>>()
149        )
150    }
151
152    #[test]
153    fn returns_paths_when_path_is_dot_luau() {
154        let source = Path::new(".luau");
155        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
156
157        pretty_assertions::assert_eq!(
158            vec![
159                source.to_path_buf(),
160                source.with_extension("luau"),
161                source.with_extension("lua"),
162                source.join(ANY_FOLDER_NAME),
163                source.join(ANY_FOLDER_NAME).with_extension("luau"),
164                source.join(ANY_FOLDER_NAME).with_extension("lua"),
165            ],
166            iterator.collect::<Vec<_>>()
167        )
168    }
169
170    #[test]
171    fn returns_paths_when_path_is_parent() {
172        let source = Path::new("..");
173        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
174
175        pretty_assertions::assert_eq!(
176            vec![
177                source.to_path_buf(),
178                source.join(ANY_FOLDER_NAME),
179                source.join(ANY_FOLDER_NAME).with_extension("luau"),
180                source.join(ANY_FOLDER_NAME).with_extension("lua"),
181            ],
182            iterator.collect::<Vec<_>>()
183        )
184    }
185
186    #[test]
187    fn returns_paths_when_path_is_current_directory() {
188        let source = Path::new(".");
189        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
190
191        pretty_assertions::assert_eq!(
192            vec![
193                source.to_path_buf(),
194                source.join(ANY_FOLDER_NAME),
195                source.join(ANY_FOLDER_NAME).with_extension("luau"),
196                source.join(ANY_FOLDER_NAME).with_extension("lua"),
197            ],
198            iterator.collect::<Vec<_>>()
199        )
200    }
201
202    #[test]
203    fn returns_paths_when_path_has_no_extension_and_module_folder_name_has_an_extension() {
204        let source = Path::new("hello");
205        let iterator = PathIterator::new(source, ANY_FOLDER_NAME_WITH_EXTENSION);
206
207        pretty_assertions::assert_eq!(
208            vec![
209                source.to_path_buf(),
210                source.with_extension("luau"),
211                source.with_extension("lua"),
212                source.join(ANY_FOLDER_NAME_WITH_EXTENSION),
213            ],
214            iterator.collect::<Vec<_>>()
215        )
216    }
217}