darklua_core/rules/require/
path_iterator.rs

1use std::path::{Path, PathBuf};
2
3pub(crate) fn find_require_paths<'a, 'b, 'c>(
4    path: &'a Path,
5    module_folder_name: &'b str,
6) -> impl Iterator<Item = PathBuf> + 'c
7where
8    'a: 'c,
9    'b: 'c,
10{
11    PathIterator::new(path, module_folder_name)
12}
13
14struct PathIterator<'a, 'b> {
15    path: &'a Path,
16    has_extension: bool,
17    module_folder_name: &'b str,
18    index: u8,
19}
20
21impl<'a, 'b> PathIterator<'a, 'b> {
22    fn new(path: &'a Path, module_folder_name: &'b str) -> Self {
23        Self {
24            path,
25            has_extension: path.extension().is_some(),
26            module_folder_name,
27            index: 0,
28        }
29    }
30
31    fn return_next(&mut self, path: PathBuf) -> Option<PathBuf> {
32        self.index += 1;
33        Some(path)
34    }
35}
36
37impl<'a, 'b> Iterator for PathIterator<'a, 'b> {
38    type Item = PathBuf;
39
40    fn next(&mut self) -> Option<Self::Item> {
41        if self.has_extension {
42            match self.index {
43                0 => self.return_next(self.path.to_path_buf()),
44                _ => None,
45            }
46        } else {
47            match self.index {
48                0 => self.return_next(self.path.to_path_buf()),
49                1 => self.return_next(self.path.with_extension("luau")),
50                2 => self.return_next(self.path.with_extension("lua")),
51                3 => self.return_next(self.path.join(self.module_folder_name)),
52                4 | 5 => {
53                    let mut next_path = self.path.join(self.module_folder_name);
54                    if next_path.extension().is_some() {
55                        None
56                    } else {
57                        next_path.set_extension(if self.index == 4 { "luau" } else { "lua" });
58                        self.return_next(next_path)
59                    }
60                }
61                _ => None,
62            }
63        }
64    }
65}
66
67#[cfg(test)]
68mod test {
69    use super::*;
70
71    const ANY_FOLDER_NAME: &str = "test";
72    const ANY_FOLDER_NAME_WITH_EXTENSION: &str = "test.luau";
73
74    #[test]
75    fn returns_exact_path_when_path_has_an_extension() {
76        let source = Path::new("hello.lua");
77        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
78
79        pretty_assertions::assert_eq!(vec![source.to_path_buf()], iterator.collect::<Vec<_>>())
80    }
81
82    #[test]
83    fn returns_paths_when_path_has_no_extension() {
84        let source = Path::new("hello");
85        let iterator = PathIterator::new(source, ANY_FOLDER_NAME);
86
87        pretty_assertions::assert_eq!(
88            vec![
89                source.to_path_buf(),
90                source.with_extension("luau"),
91                source.with_extension("lua"),
92                source.join(ANY_FOLDER_NAME),
93                source.join(ANY_FOLDER_NAME).with_extension("luau"),
94                source.join(ANY_FOLDER_NAME).with_extension("lua"),
95            ],
96            iterator.collect::<Vec<_>>()
97        )
98    }
99
100    #[test]
101    fn returns_paths_when_path_has_no_extension_and_module_folder_name_has_an_extension() {
102        let source = Path::new("hello");
103        let iterator = PathIterator::new(source, ANY_FOLDER_NAME_WITH_EXTENSION);
104
105        pretty_assertions::assert_eq!(
106            vec![
107                source.to_path_buf(),
108                source.with_extension("luau"),
109                source.with_extension("lua"),
110                source.join(ANY_FOLDER_NAME_WITH_EXTENSION),
111            ],
112            iterator.collect::<Vec<_>>()
113        )
114    }
115}