assemble_build/
function_finder.rs

1use std::fs::File as StdFile;
2use std::io::Read;
3use std::path::{Path, PathBuf};
4use syn::parse::{Parse, ParseStream};
5
6use syn::visit::Visit;
7use syn::{parse2, ItemFn, ItemMod, LitStr, Token, Visibility};
8
9/// Finds _all_ functions in a project
10pub struct FunctionFinder {
11    all_functions: Vec<(ModuleData, ItemFn)>,
12}
13
14#[derive(Debug, Clone)]
15pub struct ModuleData {
16    full_path: Vec<String>,
17    id: String,
18    file_path: PathBuf,
19}
20
21impl ModuleData {
22    pub fn new(full_path: Vec<String>, id: String, path: PathBuf) -> Self {
23        Self {
24            full_path,
25            id,
26            file_path: path,
27        }
28    }
29
30    fn child_module(&self, id: String, path: PathBuf) -> Self {
31        let mut full_path = self.full_path.clone();
32        full_path.push(id.clone());
33        Self::new(full_path, id, path)
34    }
35
36    fn inner_child_module(&self, ids: &[String]) -> Self {
37        let mut full_path = self.full_path.clone();
38        full_path.extend_from_slice(ids);
39        Self::new(full_path, ids[0].clone(), self.file_path.clone())
40    }
41}
42
43impl FunctionFinder {
44    /// The path is starting file to begin the search.
45    pub fn find_all(path: &Path, package_name: String) -> Self {
46        let mut module_stack = Vec::new();
47        let mut found = Vec::new();
48
49        module_stack.push(ModuleData::new(
50            vec![package_name.clone()],
51            package_name,
52            path.to_path_buf(),
53        ));
54
55        while let Some(module) = module_stack.pop() {
56            println!("Parsing file: {:?}", module.file_path);
57
58            let module_name = module.id.clone();
59            let mut lib_file = StdFile::open(&module.file_path).unwrap();
60
61            let mut content = String::new();
62            lib_file.read_to_string(&mut content).unwrap();
63            let parsed = syn::parse_file(&content).unwrap();
64
65            let mut visitor = ModuleVisitor::new(module_name.clone());
66            visitor.visit_file(&parsed);
67
68            let ModuleVisitor {
69                functions, modules, ..
70            } = visitor;
71            found.extend(functions.into_iter().map(|(modules, fun)| {
72                if modules.is_empty() {
73                    (module.clone(), fun.clone())
74                } else {
75                    (module.inner_child_module(modules.as_slice()), fun.clone())
76                }
77            }));
78
79            for child_module in modules {
80                let ident = &child_module.ident;
81                let path: Option<String> = child_module
82                    .attrs
83                    .iter()
84                    .find(|attr| attr.path.is_ident("path"))
85                    .and_then(|attr| {
86                        let meta = parse2::<PathAssign>(attr.tokens.clone()).ok()?;
87                        Some(meta.path.value())
88                    });
89
90                let parent_dir = module.file_path.parent().unwrap();
91                let next_file = if let Some(path) = path {
92                    // path specified using #[path = ".."] attribute
93                    // see [here](https://doc.rust-lang.org/reference/items/modules.html#the-path-attribute)
94                    // for how modules from paths are determined
95                    parent_dir.join(path)
96                } else {
97                    // use default path
98                    let canonical = if module.file_path.ends_with("mod.rs") {
99                        parent_dir.join(ident.to_string()).with_extension("rs")
100                    } else {
101                        parent_dir
102                            .join(&module_name)
103                            .join(ident.to_string())
104                            .with_extension("rs")
105                    };
106                    if !canonical.exists() {
107                        canonical
108                            .with_extension("")
109                            .join("mod")
110                            .with_extension("rs")
111                    } else {
112                        canonical
113                    }
114                };
115
116                module_stack.push(module.child_module(ident.to_string(), next_file));
117            }
118        }
119
120        Self {
121            all_functions: found,
122        }
123    }
124
125    pub fn found(&self) -> impl Iterator<Item = &(ModuleData, ItemFn)> {
126        self.all_functions.iter()
127    }
128
129    /// Finds public function ids
130    pub fn pub_function_ids(&self) -> impl Iterator<Item = String> + '_ {
131        self.found()
132            .filter(|(_, fun)| matches!(&fun.vis, Visibility::Public(_)))
133            .map(|(data, fun)| {
134                let module_id = data.full_path.join("::");
135                format!("{module_id}::{}", fun.sig.ident)
136            })
137    }
138}
139
140struct ModuleVisitor<'l> {
141    _module: String,
142    inner_modules: Vec<String>,
143    /// The found functions
144    functions: Vec<(Vec<String>, &'l ItemFn)>,
145    /// Non-parsed_modules
146    modules: Vec<&'l ItemMod>,
147}
148
149impl<'l> ModuleVisitor<'l> {
150    pub fn new(module: String) -> Self {
151        Self {
152            _module: module,
153            inner_modules: vec![],
154            functions: Default::default(),
155            modules: Default::default(),
156        }
157    }
158}
159
160impl<'ast> Visit<'ast> for ModuleVisitor<'ast> {
161    fn visit_item_fn(&mut self, i: &'ast ItemFn) {
162        self.functions.push((self.inner_modules.clone(), i));
163    }
164
165    fn visit_item_mod(&mut self, module: &'ast ItemMod) {
166        self.inner_modules.push(module.ident.to_string());
167        match &module.content {
168            None => {
169                self.modules.push(module);
170            }
171            Some((_, items)) => {
172                for item in items {
173                    self.visit_item(item);
174                }
175            }
176        }
177        self.inner_modules.pop();
178    }
179}
180
181#[derive(Debug)]
182struct PathAssign {
183    path: LitStr,
184}
185
186impl Parse for PathAssign {
187    fn parse(input: ParseStream) -> syn::Result<Self> {
188        input.parse::<Token![=]>()?;
189        let lit = input.parse()?;
190        Ok(Self { path: lit })
191    }
192}