solar_macros/
visitor.rs

1use proc_macro2::{Group, TokenStream, TokenTree};
2use quote::{TokenStreamExt, quote};
3use syn::{
4    Attribute, Block, FnArg, Generics, Ident, Pat, Stmt, Token, TraitItem, Visibility, braced,
5    parse::Parse,
6};
7
8pub struct Input {
9    attrs: Vec<Attribute>,
10    vis: Visibility,
11    trait_token: Token![trait],
12    name: Ident,
13    mut_name: Option<Ident>,
14    generics: Generics,
15    items: TokenStream,
16}
17
18impl Parse for Input {
19    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
20        let attrs = input.call(Attribute::parse_outer)?;
21        let vis = input.parse()?;
22        let trait_token = input.parse()?;
23        let name: Ident = input.parse()?;
24        let mut_name: Option<Ident> = input.parse()?;
25        let generics: Generics = input.parse()?;
26
27        let content;
28        braced!(content in input);
29        let items = content.parse()?;
30
31        Ok(Self { attrs, vis, trait_token, name, mut_name, generics, items })
32    }
33}
34
35impl Input {
36    pub fn expand(&self) -> TokenStream {
37        let Self { attrs, vis, trait_token, name, mut_name, generics, items } = self;
38
39        let expand = |nonmut_items: TokenStream, mut_items: Option<TokenStream>| {
40            let mut_trait = mut_items.map(|mut_items| {
41                quote! {
42                    #(#attrs)*
43                    #vis #trait_token #mut_name #generics {
44                        #mut_items
45                    }
46                }
47            });
48            quote! {
49                #(#attrs)*
50                #vis #trait_token #name #generics {
51                    #nonmut_items
52                }
53
54                #mut_trait
55            }
56        };
57
58        let (nonmut_items, mut_items) = expand_streams(items);
59        // Better IDE support.
60        let fallback = || expand(nonmut_items.clone(), None);
61        let Ok(mut nonmut_trait_items) = parse_trait_items(nonmut_items.clone()) else {
62            return fallback();
63        };
64        let Ok(mut mut_trait_items) = parse_trait_items(mut_items) else {
65            return fallback();
66        };
67
68        for item in &mut mut_trait_items {
69            if let TraitItem::Fn(f) = item {
70                f.sig.ident = Ident::new(&format!("{}_mut", f.sig.ident), f.sig.ident.span());
71            }
72        }
73
74        add_walk_fns(&mut mut_trait_items);
75        add_walk_fns(&mut nonmut_trait_items);
76
77        expand(
78            quote! { #(#nonmut_trait_items)* },
79            mut_name.is_some().then(|| quote! { #(#mut_trait_items)* }),
80        )
81    }
82}
83
84// (nonmut, mut)
85// nonmut skips `#mut` and mut includes `#mut` as `mut`
86fn expand_streams(tts: &TokenStream) -> (TokenStream, TokenStream) {
87    let mut nonmut_tts = TokenStream::new();
88    let mut mut_tts = TokenStream::new();
89    let mut tt_iter = tts.clone().into_iter();
90    while let Some(tt) = tt_iter.next() {
91        match tt {
92            TokenTree::Group(group) => {
93                let (nm, m) = expand_streams(&group.stream());
94                let group = |stream| {
95                    let mut g = Group::new(group.delimiter(), stream);
96                    g.set_span(group.span());
97                    g
98                };
99                nonmut_tts.append(group(nm));
100                mut_tts.append(group(m));
101            }
102            TokenTree::Punct(punct)
103                if punct.as_char() == '#' && tt_iter.clone().next().is_some_and(is_token_mut) =>
104            {
105                let mut_token = tt_iter.next().unwrap();
106                mut_tts.append(mut_token);
107            }
108            TokenTree::Punct(punct)
109                if punct.as_char() == '#'
110                    && tt_iter.clone().next().is_some_and(is_token_onlymut) =>
111            {
112                let _onlymut_token = tt_iter.next().unwrap();
113                let TokenTree::Group(group) = tt_iter.next().unwrap() else { continue };
114                mut_tts.extend(group.stream());
115            }
116            TokenTree::Ident(id)
117                if tt_iter.clone().next().is_some_and(is_token_hash)
118                    && tt_iter.clone().nth(1).is_some_and(is_token_underscore_mut) =>
119            {
120                let _ = tt_iter.next();
121                let _ = tt_iter.next();
122                mut_tts.append(Ident::new(&format!("{id}_mut"), id.span()));
123                nonmut_tts.append(id);
124            }
125            tt => {
126                nonmut_tts.append(tt.clone());
127                mut_tts.append(tt);
128            }
129        }
130    }
131    (nonmut_tts, mut_tts)
132}
133
134fn is_token_hash(tt: TokenTree) -> bool {
135    if let TokenTree::Punct(punct) = tt {
136        return punct.as_char() == '#';
137    }
138    false
139}
140
141fn is_token_mut(tt: TokenTree) -> bool {
142    if let TokenTree::Ident(ident) = tt {
143        return ident == "mut";
144    }
145    false
146}
147
148fn is_token_onlymut(tt: TokenTree) -> bool {
149    if let TokenTree::Ident(ident) = tt {
150        return ident == "onlymut";
151    }
152    false
153}
154
155fn is_token_underscore_mut(tt: TokenTree) -> bool {
156    if let TokenTree::Ident(ident) = tt {
157        return ident == "_mut";
158    }
159    false
160}
161
162fn parse_trait_items(tts: TokenStream) -> Result<Vec<TraitItem>, syn::Error> {
163    struct TraitItems(Vec<TraitItem>);
164    impl Parse for TraitItems {
165        fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
166            let mut items = vec![];
167            while !input.is_empty() {
168                items.push(input.parse()?);
169            }
170            Ok(Self(items))
171        }
172    }
173    Ok(syn::parse2::<TraitItems>(tts)?.0)
174}
175
176// fn visit_... { stmts @ ... } -> fn visit_... { self.walk_...(...) }
177// + fn walk_... { #stmts }
178fn add_walk_fns(items: &mut Vec<TraitItem>) {
179    for i in 0..items.len() {
180        let item = &mut items[i];
181        if let TraitItem::Fn(f) = item {
182            let name = f.sig.ident.to_string();
183            let Some(name) = name.strip_prefix("visit_") else { continue };
184            let walk_name = Ident::new(&format!("walk_{name}"), f.sig.ident.span());
185
186            let mut walk_fn = f.clone();
187            let Some(body) = &mut f.default else { continue };
188            f.attrs.push(syn::parse_quote!(#[inline]));
189
190            let args = f.sig.inputs.iter().filter_map(|arg| {
191                Some(match arg {
192                    FnArg::Receiver(_rec) => return None,
193                    FnArg::Typed(pat) => match &*pat.pat {
194                        Pat::Ident(ident) => {
195                            let id = &ident.ident;
196                            quote!(#id)
197                        }
198                        _ => return None,
199                    },
200                })
201            });
202            let call_walk = syn::parse_quote! {
203                self.#walk_name(#(#args),*)
204            };
205            let call_walk_stmt = Stmt::Expr(call_walk, None);
206            let walk_stmts = std::mem::replace(&mut body.stmts, vec![call_walk_stmt]);
207
208            walk_fn.sig.ident = walk_name;
209            walk_fn.default = Some(Block { brace_token: body.brace_token, stmts: walk_stmts });
210            items.push(TraitItem::Fn(walk_fn));
211        }
212    }
213}