macro_visit/
lib.rs

1//!
2//! Small library helper that uses syn::visit::Visit trait to find all macro calls in project structure.
3//!
4//! By the way of traversing, library will resolve imports, so end user can
5//! rename macros and mix macros with same name from different crates.
6//! It also resolve mod calls and provide some context to macro handler.
7//!
8
9use std::{cell::RefCell, collections::BTreeMap, path::PathBuf, rc::Rc};
10
11use proc_macro2::TokenStream;
12
13/// Macro visitor.
14///
15/// Handle all macro calls, and call appropriate function.
16/// on the way, it will find all `use` items, and add new imports to the list.
17///
18/// Creates new visitor for each function, to avoid mixed `use` items.
19///
20/// It uses lifetime to allow variable to be captured into closure.
21
22pub type RcMacro<'a> = Rc<RefCell<dyn FnMut(MacroContext, TokenStream) + 'a>>;
23pub type MacroMap<'a> = BTreeMap<String, RcMacro<'a>>;
24
25/// Provided context to macro handler.
26#[derive(Clone, Default, Debug)]
27pub struct MacroContext {
28    /// Path to modules from entrypoint.
29    pub mod_path: Vec<String>,
30    /// Filename of entrypoint.
31    pub entrypoint: String,
32    pub fn_call_name: Option<String>,
33    // Somewhere at module path, one used `#[path = "foo.rs"]`
34    pub used_path_attr: bool,
35    src_filedir: PathBuf,
36    // TODO: linenum/colnum
37}
38#[derive(Clone)]
39pub struct Visitor<'a> {
40    searched_imports: MacroMap<'a>,
41
42    context: MacroContext,
43}
44impl std::fmt::Debug for Visitor<'_> {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.debug_struct("Visitor").finish()
47    }
48}
49
50impl<'a> Default for Visitor<'a> {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56impl<'a> Visitor<'a> {
57    /// Creates empty visitor.
58    pub fn new() -> Self {
59        Self {
60            searched_imports: BTreeMap::new(),
61            context: MacroContext::default(),
62        }
63    }
64    /// Add macro implementation to the macro
65    pub fn add_macro(
66        &mut self,
67        imports: Vec<String>,
68        macro_call: impl FnMut(MacroContext, TokenStream) + 'a,
69    ) {
70        let macro_call = Rc::new(RefCell::new(macro_call));
71        for import in imports {
72            self.searched_imports.insert(import, macro_call.clone());
73        }
74    }
75    pub fn add_rc_macro(&mut self, imports: Vec<String>, macro_call: RcMacro<'a>) {
76        for import in imports {
77            self.searched_imports.insert(import, macro_call.clone());
78        }
79    }
80    // Visit file content.
81    pub fn visit_file_content(&mut self, content: &str) {
82        println!("Process file with context: {:?}", self.context);
83        println!("and imports: {:?}", self.searched_imports.keys());
84        let file = syn::parse_file(content).unwrap();
85        syn::visit::visit_file(self, &file)
86    }
87    /// Handle all *.rs files in src of project directory.
88    ///
89    /// `project_path` - is path to Cargo.toml of the project
90    pub fn visit_project<P: AsRef<std::path::Path>>(&self, entrypoint: P) {
91        let entrypoint = entrypoint.as_ref();
92
93        let content = std::fs::read_to_string(entrypoint).unwrap();
94
95        let mut src_filedir = entrypoint.to_path_buf();
96        src_filedir.pop();
97        let entrypoint_name = entrypoint
98            .file_stem()
99            .map(|c| c.to_string_lossy())
100            .unwrap_or_default();
101
102        Self {
103            context: MacroContext {
104                entrypoint: entrypoint_name.to_string(),
105                src_filedir,
106                ..MacroContext::default()
107            },
108            ..self.clone()
109        }
110        .visit_file_content(&content)
111    }
112    fn new_subcall(&self, fn_name: String) -> Self {
113        let mut new = self.clone();
114        new.context.fn_call_name = Some(fn_name);
115        new
116    }
117
118    // Hide current imports to parrent imports.
119    fn new_mod(&self, mod_name: String) -> Self {
120        let mut context = self.context.clone();
121        context.mod_path.push(mod_name);
122        Self {
123            searched_imports: self.searched_imports.clone(),
124            context,
125            ..self.clone()
126        }
127    }
128    fn get_macro(&self, path: syn::Path) -> Option<RcMacro<'a>> {
129        let path_str = path
130            .segments
131            .iter()
132            .map(|s| s.ident.to_string())
133            .collect::<Vec<_>>()
134            .join("::");
135        self.searched_imports.get(&path_str).cloned()
136    }
137}
138
139impl syn::visit::Visit<'_> for Visitor<'_> {
140    fn visit_use_tree(&mut self, node: &syn::UseTree) {
141        let mut new_imports = vec![];
142        for (import, macro_call) in &self.searched_imports {
143            let use_tree_form = use_tree_from_str(import);
144            let new = compare_use_tree(use_tree_form, node.clone());
145            if !new.is_empty() {
146                println!("Extending use tree with: {:?}", new);
147                new_imports.extend(new.into_iter().map(|i| (i, macro_call.clone())))
148            }
149        }
150        self.searched_imports.extend(new_imports);
151    }
152    fn visit_item_fn(&mut self, node: &syn::ItemFn) {
153        let mut new_visitor = self.new_subcall(node.sig.ident.to_string());
154        syn::visit::visit_item_fn(&mut new_visitor, node);
155    }
156
157    fn visit_impl_item_fn(&mut self, i: &syn::ImplItemFn) {
158        let mut new_visitor = self.new_subcall(i.sig.ident.to_string());
159        syn::visit::visit_impl_item_fn(&mut new_visitor, i);
160    }
161    fn visit_item_mod(&mut self, i: &syn::ItemMod) {
162        // get attrs #[path = "foo"];
163        let path_attr = i
164            .attrs
165            .iter()
166            .filter_map(|a| a.meta.require_name_value().ok())
167            .filter(|meta| meta.path.is_ident("path"))
168            .last();
169        let path_attr = path_attr.map(|a| match &a.value {
170            syn::Expr::Lit(syn::ExprLit {
171                lit: syn::Lit::Str(lit_str),
172                ..
173            }) => PathBuf::from(lit_str.value()),
174            _ => panic!("Expected literal string in path attribute"),
175        });
176        // Create new visitor for module
177        let mut mod_visitor = self.new_mod(i.ident.to_string());
178
179        println!("Found mod: {:?}", i.ident.to_string());
180        // Process items if content is present.
181        if let Some(c) = &i.content {
182            println!("Processing module with inner content");
183            for i in &c.1 {
184                mod_visitor.visit_item(i)
185            }
186            return;
187        }
188        // Process module that uses #[path = "foo"];
189        if let Some(path_attr) = path_attr {
190            let full_path = mod_visitor.context.src_filedir.join(path_attr);
191            println!(
192                "Processing module that uses #[path = \"{}\"]",
193                full_path.display()
194            );
195            let mut src_filedir = full_path.clone();
196            src_filedir.pop();
197            let mod_name = full_path
198                .file_stem()
199                .map(|c| c.to_string_lossy())
200                .unwrap_or_default()
201                .to_string();
202
203            mod_visitor.context.used_path_attr = true;
204            mod_visitor.context.src_filedir = src_filedir;
205            mod_visitor.context.mod_path = vec![];
206            let mod_path = resolve_module_path(&mod_visitor.context, &mod_name);
207
208            let content = std::fs::read_to_string(mod_path).unwrap();
209            mod_visitor.visit_file_content(&content);
210            return;
211        }
212
213        let mod_path = resolve_module_path(&mod_visitor.context, &i.ident.to_string());
214        println!(
215            "Processing regular module with content in path: {}",
216            mod_path.display()
217        );
218        // Or process file in case of `mod foo;` item.
219        let content = std::fs::read_to_string(mod_path).unwrap();
220        mod_visitor.visit_file_content(&content)
221    }
222
223    fn visit_macro(&mut self, i: &syn::Macro) {
224        if let Some(macro_impl) = self.get_macro(i.path.clone()) {
225            macro_impl.borrow_mut()(self.context.clone(), i.tokens.clone());
226        }
227    }
228}
229
230// Compare two paths, and return new one, if path was renamed.
231// Expect left path to be flat, and right might be nested.
232pub(crate) fn compare_use_tree(left: syn::UseTree, right: syn::UseTree) -> Vec<String> {
233    match (left, right) {
234        (syn::UseTree::Glob(_), _)
235        | (syn::UseTree::Group(_), _)
236        | (syn::UseTree::Rename(_), _) => {
237            panic!("Import path is not valid")
238        }
239        // If right is glob, then we remove prefix, and keep the rest import path as synonim.
240        (left_tree, syn::UseTree::Glob(_)) => {
241            vec![create_import_path(left_tree)]
242        }
243        // If right is group - traverse each group item.
244        (left_tree, syn::UseTree::Group(right_g)) => {
245            right_g.items.into_iter().flat_map(move |item| {
246                compare_use_tree(left_tree.clone(), item)
247            }).collect::<Vec<_>>()
248        }
249        // Name is terminal node,
250        // if it equal - we can use macro by its name without full path.
251        (syn::UseTree::Name(left_i), syn::UseTree::Name(right_i))
252        if right_i.ident == left_i.ident  =>
253        {
254            vec![create_import_path(syn::UseTree::Name(left_i))]
255        }
256        // Same but ident is renambed
257        (syn::UseTree::Name(left_i), syn::UseTree::Rename(right_r))
258        if right_r.ident == left_i.ident => {
259            vec![create_import_path(syn::UseTree::Name(
260                syn::UseName {
261                    ident: right_r.rename,
262                }))]
263        }
264        (syn::UseTree::Path(left_p), syn::UseTree::Name(right_i))
265        if right_i.ident == left_p.ident => {
266            vec![create_import_path(syn::UseTree::Path(left_p))]
267        }
268        (syn::UseTree::Path(left_p), syn::UseTree::Rename(right_r))
269        if right_r.ident == left_p.ident => {
270            let mut new_tree = left_p.clone();
271            new_tree.ident = right_r.rename;
272            vec![create_import_path(syn::UseTree::Path(new_tree))]
273        }
274        (syn::UseTree::Path(left_p), syn::UseTree::Path(right_p))
275        if right_p.ident == left_p.ident => {
276            // traverse deeper, while path is same
277            compare_use_tree(*left_p.tree, *right_p.tree)
278        }
279        (syn::UseTree::Path(_), syn::UseTree::Name(_))
280        | (syn::UseTree::Path(_), syn::UseTree::Rename(_))
281        | (syn::UseTree::Name(_), syn::UseTree::Name(_))
282        | (syn::UseTree::Name(_), syn::UseTree::Rename(_))
283        | (syn::UseTree::Path(_), syn::UseTree::Path(_))
284        // not comparable
285        | (syn::UseTree::Name(_), syn::UseTree::Path(_))
286         => {
287            // if path is different, then we can't add new synonim for this import.
288            vec![]
289        }
290    }
291}
292pub(crate) fn use_tree_from_str(path: &str) -> syn::UseTree {
293    syn::parse_str(path).unwrap()
294}
295
296pub(crate) fn create_import_path(remining: syn::UseTree) -> String {
297    let mut path = String::new();
298    match remining {
299        syn::UseTree::Name(ident) => {
300            path.push_str(&ident.ident.to_string());
301        }
302        syn::UseTree::Path(path_tree) => {
303            path.push_str(&path_tree.ident.to_string());
304            path.push_str("::");
305            path.push_str(&create_import_path(*path_tree.tree));
306        }
307        syn::UseTree::Rename(_) | syn::UseTree::Group(_) | syn::UseTree::Glob(_) => {
308            panic!("Import path is not valid")
309        }
310    }
311    path
312}
313
314// Resolve path to a mod, based on current module path and module_name.
315fn resolve_module_path(context: &MacroContext, mod_name: &str) -> PathBuf {
316    let mut mod_folder: PathBuf = context.src_filedir.to_path_buf();
317
318    if context.mod_path.len() > 1 {
319        for parent_mod in &context.mod_path[..context.mod_path.len() - 1] {
320            mod_folder.push(parent_mod);
321        }
322    }
323
324    let mod_path = mod_folder.join(format!("{mod_name}.rs"));
325    println!("Probing path: {:?}", mod_path);
326    if mod_path.exists() {
327        return mod_path;
328    } else {
329        let mut mod_path = mod_folder.join(mod_name);
330        mod_path.push("mod.rs");
331        println!("Probing path: {:?}", mod_path);
332        if mod_path.exists() {
333            return mod_path;
334        }
335    }
336    panic!(
337        "Cannot find module '{}' relative to path {:?}, src_dir: {}",
338        mod_name,
339        mod_folder,
340        context.src_filedir.display()
341    );
342}
343
344#[cfg(test)]
345mod test {
346    use super::*;
347
348    // Check that Visitor can find macro call
349    #[test]
350    fn test_simple_macro_call() {
351        let mut found = false;
352        let mut visitor = super::Visitor::new();
353        let macro_call = |_, _| {
354            found = true;
355        };
356        visitor.add_macro(vec!["rcss::file::css_module::css".to_owned()], macro_call);
357        let input = syn::parse_str::<syn::Item>(
358            r#"rcss::file::css_module::css! { .my-class { color: red; } }"#,
359        )
360        .unwrap();
361        syn::visit::visit_item(&mut visitor, &input);
362        drop(visitor);
363        assert!(found)
364    }
365
366    #[test]
367    fn test_macro_inside_fn() {
368        let mut found = false;
369        let mut visitor = super::Visitor::new();
370        let macro_call = |_, _| {
371            found = true;
372        };
373        visitor.add_macro(vec!["rcss::file::css_module::css".to_owned()], macro_call);
374        let input = syn::parse_quote!(
375            fn test() {
376                rcss::file::css_module::css! { .my-class { color: red; } }
377            }
378        );
379        syn::visit::visit_item(&mut visitor, &input);
380        drop(visitor);
381        assert!(found)
382    }
383
384    #[test]
385    fn test_macro_inside_impl_fn() {
386        let mut found = false;
387        let mut visitor = super::Visitor::new();
388        let macro_call = |_, _| {
389            found = true;
390        };
391        visitor.add_macro(vec!["rcss::file::css_module::css".to_owned()], macro_call);
392        let input = syn::parse_quote!(
393            impl Test {
394                fn test() {
395                    rcss::file::css_module::css! { .my-class { color: red; } }
396                }
397            }
398        );
399        syn::visit::visit_file(&mut visitor, &input);
400        drop(visitor);
401        assert!(found)
402    }
403
404    #[test]
405    fn test_macro_inside_fn_with_outer_and_inner_reimport() {
406        let mut found = false;
407        let mut visitor = super::Visitor::new();
408        let macro_call = |_, _| {
409            found = true;
410        };
411        visitor.add_macro(vec!["rcss::file::css_module::css".to_owned()], macro_call);
412        let input = syn::parse_quote!(
413            use rcss::file;
414            fn test() {
415                use file::css_module;
416                file::css_module::css! { .my-class { color: red; } }
417            }
418        );
419        syn::visit::visit_file(&mut visitor, &input);
420        drop(visitor);
421        assert!(found)
422    }
423    //check that import handle name;
424    #[test]
425    fn test_compare_use_by_name() {
426        let path = "rcss::file::css_module::css_struct";
427        let path = super::use_tree_from_str(path);
428        let use_item: syn::ItemUse = syn::parse_quote! {
429            use rcss::file;
430        };
431
432        let new_imports = compare_use_tree(path, use_item.tree);
433        assert_eq!(new_imports, vec!["file::css_module::css_struct".to_owned()]);
434    }
435
436    #[test]
437    fn test_compare_use_in_group() {
438        let path = "rcss::file::css_module::css_struct";
439        let path = super::use_tree_from_str(path);
440        let use_item: syn::ItemUse = syn::parse_quote! {
441            use rcss::file::{css_module, scoped};
442        };
443
444        let new_imports = compare_use_tree(path, use_item.tree);
445        assert_eq!(new_imports, vec!["css_module::css_struct".to_owned()]);
446    }
447
448    #[test]
449    fn test_compare_use_by_glob() {
450        let path = "rcss::file::css_module::css_struct";
451        let path = super::use_tree_from_str(path);
452        let use_item: syn::ItemUse = syn::parse_quote! {
453            use rcss::file::*;
454        };
455
456        let new_imports = compare_use_tree(path, use_item.tree);
457        assert_eq!(new_imports, vec!["css_module::css_struct".to_owned()]);
458    }
459    #[test]
460    fn test_compare_use_by_glob_in_group() {
461        let path = "rcss::file::css_module::css_struct";
462        let path = super::use_tree_from_str(path);
463        let use_item: syn::ItemUse = syn::parse_quote! {
464            use rcss::file::{*, scoped};
465        };
466
467        let new_imports = compare_use_tree(path, use_item.tree);
468        assert_eq!(new_imports, vec!["css_module::css_struct".to_owned()]);
469    }
470
471    #[test]
472    fn test_compare_deep_group_with_glob() {
473        let path = "rcss::file::css_module::css_struct";
474        let path = super::use_tree_from_str(path);
475        let use_item: syn::ItemUse = syn::parse_quote! {
476            use rcss::file::{*, css_module::{css, *}};
477        };
478
479        let new_imports = compare_use_tree(path, use_item.tree);
480        assert_eq!(
481            new_imports,
482            vec!["css_module::css_struct".to_owned(), "css_struct".to_owned()]
483        );
484    }
485
486    #[test]
487    fn test_compare_with_rename() {
488        let path = "rcss::file::css_module::css";
489        let path = super::use_tree_from_str(path);
490        let use_item: syn::ItemUse = syn::parse_quote! {
491            use rcss::file::{*, css_module::{css as css2, *}};
492        };
493
494        let new_imports = compare_use_tree(path, use_item.tree);
495        assert_eq!(
496            new_imports,
497            vec![
498                "css_module::css".to_owned(),
499                "css2".to_owned(),
500                "css".to_owned()
501            ]
502        );
503    }
504}