derive_visitor_macros/
lib.rs

1//! This is a utility crate for [derive-visitor](https://docs.rs/derive-visitor)
2//!
3
4#![warn(clippy::all)]
5#![warn(clippy::pedantic)]
6
7use convert_case::{Case, Casing};
8use itertools::Itertools;
9use proc_macro2::{Span, TokenStream};
10use quote::{quote, ToTokens};
11use std::{
12    collections::{hash_map::Entry, HashMap},
13    iter::IntoIterator,
14};
15use syn::token::Mut;
16use syn::{
17    parse_macro_input, parse_str, spanned::Spanned, Attribute, Data, DataEnum, DataStruct,
18    DeriveInput, Error, Field, Fields, Ident, Lit, LitStr, Member, Meta, MetaList, NestedMeta,
19    Path, Result, Variant,
20};
21
22#[proc_macro_derive(Visitor, attributes(visitor))]
23pub fn derive_visitor(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
24    expand_with(input, |stream| impl_visitor(stream, false))
25}
26
27#[proc_macro_derive(VisitorMut, attributes(visitor))]
28pub fn derive_visitor_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
29    expand_with(input, |stream| impl_visitor(stream, true))
30}
31
32#[proc_macro_derive(Drive, attributes(drive))]
33pub fn derive_drive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
34    expand_with(input, |stream| impl_drive(stream, false))
35}
36
37#[proc_macro_derive(DriveMut, attributes(drive))]
38pub fn derive_drive_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
39    expand_with(input, |stream| impl_drive(stream, true))
40}
41
42fn expand_with(
43    input: proc_macro::TokenStream,
44    handler: impl Fn(DeriveInput) -> Result<TokenStream>,
45) -> proc_macro::TokenStream {
46    let input = parse_macro_input!(input as DeriveInput);
47    handler(input)
48        .unwrap_or_else(|error| error.to_compile_error())
49        .into()
50}
51
52fn extract_meta(attrs: Vec<Attribute>, attr_name: &str) -> Result<Option<Meta>> {
53    let macro_attrs = attrs
54        .into_iter()
55        .filter(|attr| attr.path.is_ident(attr_name))
56        .collect::<Vec<Attribute>>();
57
58    if let Some(second) = macro_attrs.get(2) {
59        return Err(Error::new_spanned(second, "duplicate attribute"));
60    }
61
62    macro_attrs.first().map(Attribute::parse_meta).transpose()
63}
64
65#[derive(Default)]
66struct Params(HashMap<Path, Meta>);
67
68impl Params {
69    fn from_attrs(attrs: Vec<Attribute>, attr_name: &str) -> Result<Self> {
70        Ok(extract_meta(attrs, attr_name)?
71            .map(|meta| {
72                if let Meta::List(meta_list) = meta {
73                    Self::from_meta_list(meta_list)
74                } else {
75                    Err(Error::new_spanned(meta, "invalid attribute"))
76                }
77            })
78            .transpose()?
79            .unwrap_or_default())
80    }
81
82    fn from_meta_list(meta_list: MetaList) -> Result<Self> {
83        let mut params = HashMap::new();
84        for meta in meta_list.nested {
85            if let NestedMeta::Meta(meta) = meta {
86                let path = meta.path();
87                let entry = params.entry(path.clone());
88                if matches!(entry, Entry::Occupied(_)) {
89                    return Err(Error::new_spanned(path, "duplicate parameter"));
90                }
91                entry.or_insert(meta);
92            } else {
93                return Err(Error::new_spanned(meta, "invalid attribute"));
94            }
95        }
96        Ok(Self(params))
97    }
98
99    fn validate(&self, allowed_params: &[&str]) -> Result<()> {
100        for path in self.0.keys() {
101            if !allowed_params
102                .iter()
103                .any(|allowed_param| path.is_ident(allowed_param))
104            {
105                return Err(Error::new_spanned(
106                    path,
107                    format!(
108                        "unknown parameter, supported: {}",
109                        Itertools::intersperse(allowed_params.iter().copied(), ", ")
110                            .collect::<String>()
111                    ),
112                ));
113            }
114        }
115        Ok(())
116    }
117
118    fn param(&mut self, name: &str) -> Result<Option<Param>> {
119        self.0
120            .remove(&Ident::new(name, Span::call_site()).into())
121            .map(Param::from_meta)
122            .transpose()
123    }
124}
125
126impl Iterator for Params {
127    type Item = Result<Param>;
128    fn next(&mut self) -> Option<Self::Item> {
129        self.0
130            .keys()
131            .next()
132            .cloned()
133            .map(|path| Param::from_meta(self.0.remove(&path).unwrap()))
134    }
135}
136
137enum Param {
138    Unit(Path, Span),
139    StringLiteral(Path, Span, LitStr),
140    NestedParams(Path, Span, Params),
141}
142
143impl Param {
144    fn from_meta(meta: Meta) -> Result<Self> {
145        let path = meta.path().clone();
146        let span = meta.span();
147        match meta {
148            Meta::Path(_) => Ok(Param::Unit(path, span)),
149            Meta::List(meta_list) => Ok(Param::NestedParams(
150                path,
151                span,
152                Params::from_meta_list(meta_list)?,
153            )),
154            Meta::NameValue(name_value) => {
155                if let Lit::Str(lit_str) = name_value.lit {
156                    Ok(Param::StringLiteral(path, span, lit_str))
157                } else {
158                    Err(Error::new_spanned(name_value, "invalid parameter"))
159                }
160            }
161        }
162    }
163    fn path(&self) -> &Path {
164        match self {
165            Self::Unit(path, _)
166            | Self::StringLiteral(path, _, _)
167            | Self::NestedParams(path, _, _) => path,
168        }
169    }
170
171    fn span(&self) -> Span {
172        match self {
173            Self::Unit(_, span)
174            | Self::StringLiteral(_, span, _)
175            | Self::NestedParams(_, span, _) => *span,
176        }
177    }
178
179    fn unit(self) -> Result<()> {
180        if let Self::Unit(_, _) = self {
181            Ok(())
182        } else {
183            Err(Error::new(self.span(), "invalid parameter"))
184        }
185    }
186
187    fn string_literal(self) -> Result<LitStr> {
188        if let Self::StringLiteral(_, _, lit_str) = self {
189            Ok(lit_str)
190        } else {
191            Err(Error::new(self.span(), "invalid parameter"))
192        }
193    }
194}
195
196struct VisitorItemParams {
197    enter: Option<Ident>,
198    exit: Option<Ident>,
199}
200
201fn visitor_method_name_from_path(struct_path: &Path, event: &str) -> Ident {
202    let last_segment = struct_path.segments.last().unwrap();
203    Ident::new(
204        &format!(
205            "{}_{}",
206            event,
207            last_segment.ident.to_string().to_case(Case::Snake)
208        ),
209        Span::call_site(),
210    )
211}
212
213fn visitor_method_name_from_param(param: Param, path: &Path, event: &str) -> Result<Ident> {
214    match param {
215        Param::StringLiteral(_, _, lit_str) => lit_str.parse(),
216        Param::Unit(_, _) => Ok(visitor_method_name_from_path(path, event)),
217        Param::NestedParams(_, span, _) => Err(Error::new(span, "invalid parameter")),
218    }
219}
220
221fn impl_visitor(input: DeriveInput, mutable: bool) -> Result<TokenStream> {
222    let params = Params::from_attrs(input.attrs, "visitor")?
223        .map_ok(|param| {
224            let path = param.path().clone();
225
226            let item_params = match param {
227                Param::Unit(_, _) => VisitorItemParams {
228                    enter: Some(visitor_method_name_from_path(&path, "enter")),
229                    exit: Some(visitor_method_name_from_path(&path, "exit")),
230                },
231                Param::NestedParams(_, _, mut nested) => {
232                    nested.validate(&["enter", "exit"])?;
233                    VisitorItemParams {
234                        enter: nested
235                            .param("enter")?
236                            .map(|param| visitor_method_name_from_param(param, &path, "enter"))
237                            .transpose()?,
238                        exit: nested
239                            .param("exit")?
240                            .map(|param| visitor_method_name_from_param(param, &path, "exit"))
241                            .transpose()?,
242                    }
243                }
244                Param::StringLiteral(_, _, lit) => {
245                    return Err(Error::new_spanned(lit, "invalid attribute"));
246                }
247            };
248            Ok((path, item_params))
249        })
250        .flatten()
251        .collect::<Result<HashMap<Path, VisitorItemParams>>>()?;
252
253    match input.data {
254        Data::Enum(enum_) => {
255            for variant in enum_.variants {
256                if let Some(attr) = variant.attrs.first() {
257                    return Err(Error::new_spanned(
258                        attr,
259                        "#[visitor] attribute can only be applied to enum or struct",
260                    ));
261                }
262                for field in variant.fields {
263                    if let Some(attr) = field.attrs.first() {
264                        return Err(Error::new_spanned(
265                            attr,
266                            "#[visitor] attribute can only be applied to enum or struct",
267                        ));
268                    }
269                }
270            }
271        }
272        Data::Struct(struct_) => {
273            for field in struct_.fields {
274                if let Some(attr) = field.attrs.first() {
275                    return Err(Error::new_spanned(
276                        attr,
277                        "#[visitor] attribute can only be applied to enum or struct",
278                    ));
279                }
280            }
281        }
282        Data::Union(union_) => {
283            return Err(Error::new_spanned(
284                union_.union_token,
285                "unions are not supported",
286            ));
287        }
288    }
289
290    let name = input.ident;
291    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
292    let routes = params
293        .into_iter()
294        .map(|(path, item_params)| visitor_route(&path, item_params, mutable));
295    let impl_trait = Ident::new(
296        if mutable { "VisitorMut" } else { "Visitor" },
297        Span::call_site(),
298    );
299    let mut_modifier = if mutable {
300        Some(Mut(Span::call_site()))
301    } else {
302        None
303    };
304    Ok(quote! {
305        impl #impl_generics ::derive_visitor::#impl_trait for #name #ty_generics #where_clause {
306            fn visit(&mut self, item: & #mut_modifier dyn ::std::any::Any, event: ::derive_visitor::Event) {
307                #(
308                    #routes
309                )*
310            }
311        }
312    })
313}
314
315fn visitor_route(path: &Path, item_params: VisitorItemParams, mutable: bool) -> TokenStream {
316    let enter = item_params.enter.map(|method_name| {
317        quote! {
318            ::derive_visitor::Event::Enter => {
319                self.#method_name(item);
320            }
321        }
322    });
323    let exit = item_params.exit.map(|method_name| {
324        quote! {
325            ::derive_visitor::Event::Exit => {
326                self.#method_name(item);
327            }
328        }
329    });
330
331    let method = Ident::new(
332        if mutable {
333            "downcast_mut"
334        } else {
335            "downcast_ref"
336        },
337        Span::call_site(),
338    );
339
340    quote! {
341        if let Some(item) = <dyn ::std::any::Any>::#method::<#path>(item) {
342            match event {
343                #enter
344                #exit
345                _ => {}
346            }
347        }
348    }
349}
350
351fn impl_drive(input: DeriveInput, mutable: bool) -> Result<TokenStream> {
352    let mut params = Params::from_attrs(input.attrs, "drive")?;
353    params.validate(&["skip"])?;
354
355    let skip_visit_self = params
356        .param("skip")?
357        .map(Param::unit)
358        .transpose()?
359        .is_some();
360
361    let name = input.ident;
362    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
363
364    let visitor = Ident::new(
365        if mutable { "VisitorMut" } else { "Visitor" },
366        Span::call_site(),
367    );
368
369    let enter_self = if skip_visit_self {
370        None
371    } else {
372        Some(quote! {
373            ::derive_visitor::#visitor::visit(visitor, self, ::derive_visitor::Event::Enter);
374        })
375    };
376
377    let exit_self = if skip_visit_self {
378        None
379    } else {
380        Some(quote! {
381            ::derive_visitor::#visitor::visit(visitor, self, ::derive_visitor::Event::Exit);
382        })
383    };
384
385    let drive_fields = match input.data {
386        Data::Struct(struct_) => drive_struct(struct_, mutable),
387        Data::Enum(enum_) => drive_enum(enum_, mutable),
388        Data::Union(union_) => {
389            return Err(Error::new_spanned(
390                union_.union_token,
391                "unions are not supported",
392            ));
393        }
394    }?;
395
396    let impl_trait = Ident::new(
397        if mutable { "DriveMut" } else { "Drive" },
398        Span::call_site(),
399    );
400    let method = Ident::new(
401        if mutable { "drive_mut" } else { "drive" },
402        Span::call_site(),
403    );
404    let mut_modifier = if mutable {
405        Some(Mut(Span::call_site()))
406    } else {
407        None
408    };
409
410    Ok(quote! {
411        impl #impl_generics ::derive_visitor::#impl_trait for #name #ty_generics #where_clause {
412            fn #method<V: ::derive_visitor::#visitor>(& #mut_modifier self, visitor: &mut V) {
413                #enter_self
414                #drive_fields
415                #exit_self
416            }
417        }
418    })
419}
420
421fn drive_struct(struct_: DataStruct, mutable: bool) -> Result<TokenStream> {
422    struct_
423        .fields
424        .into_iter()
425        .enumerate()
426        .map(|(index, field)| {
427            let member = field.ident.as_ref().map_or_else(
428                || Member::Unnamed(index.into()),
429                |ident| Member::Named(ident.clone()),
430            );
431            let mut_modifier = if mutable {
432                Some(Mut(Span::call_site()))
433            } else {
434                None
435            };
436            drive_field(&quote! { & #mut_modifier self.#member }, field, mutable)
437        })
438        .collect()
439}
440
441fn drive_enum(enum_: DataEnum, mutable: bool) -> Result<TokenStream> {
442    let variants = enum_
443        .variants
444        .into_iter()
445        .map(|x| drive_variant(x, mutable))
446        .collect::<Result<TokenStream>>()?;
447    Ok(quote! {
448        match self {
449            #variants
450            _ => {}
451        }
452    })
453}
454
455fn drive_variant(variant: Variant, mutable: bool) -> Result<TokenStream> {
456    let mut params = Params::from_attrs(variant.attrs, "drive")?;
457    params.validate(&["skip"])?;
458    if params.param("skip")?.map(Param::unit).is_some() {
459        return Ok(TokenStream::new());
460    }
461    let name = variant.ident;
462    let destructuring = destructure_fields(variant.fields.clone())?;
463    let fields = variant
464        .fields
465        .into_iter()
466        .enumerate()
467        .map(|(index, field)| {
468            drive_field(
469                &field
470                    .ident
471                    .clone()
472                    .unwrap_or_else(|| Ident::new(&format!("i{}", index), Span::call_site()))
473                    .to_token_stream(),
474                field,
475                mutable,
476            )
477        })
478        .collect::<Result<TokenStream>>()?;
479    Ok(quote! {
480        Self::#name#destructuring => {
481            #fields
482        }
483    })
484}
485
486fn destructure_fields(fields: Fields) -> Result<TokenStream> {
487    Ok(match fields {
488        Fields::Named(fields) => {
489            let field_list = fields
490                .named
491                .into_iter()
492                .map(|field| {
493                    let mut params = Params::from_attrs(field.attrs, "drive")?;
494                    let field_name = field.ident.unwrap();
495                    Ok(if params.param("skip")?.map(Param::unit).is_some() {
496                        quote! { #field_name: _ }
497                    } else {
498                        field_name.into_token_stream()
499                    })
500                })
501                .collect::<Result<Vec<TokenStream>>>()?;
502            quote! {
503                { #( #field_list ),* }
504            }
505        }
506        Fields::Unnamed(fields) => {
507            let field_list = fields
508                .unnamed
509                .into_iter()
510                .enumerate()
511                .map(|(index, field)| {
512                    let mut params = Params::from_attrs(field.attrs, "drive")?;
513                    Ok(if params.param("skip")?.map(Param::unit).is_some() {
514                        quote! { _ }
515                    } else {
516                        Ident::new(&format!("i{}", index), Span::call_site()).into_token_stream()
517                    })
518                })
519                .collect::<Result<Vec<TokenStream>>>()?;
520            quote! {
521                ( #( #field_list ),* )
522            }
523        }
524        Fields::Unit => TokenStream::new(),
525    })
526}
527
528fn drive_field(value_expr: &TokenStream, field: Field, mutable: bool) -> Result<TokenStream> {
529    let mut params = Params::from_attrs(field.attrs, "drive")?;
530    params.validate(&["skip", "with"])?;
531
532    if params.param("skip")?.map(Param::unit).is_some() {
533        return Ok(TokenStream::new());
534    }
535
536    let drive_fn = params.param("with")?.map_or_else(
537        || {
538            parse_str(if mutable {
539                "::derive_visitor::DriveMut::drive_mut"
540            } else {
541                "::derive_visitor::Drive::drive"
542            })
543        },
544        |param| param.string_literal()?.parse::<Path>(),
545    )?;
546
547    Ok(quote! {
548        #drive_fn(#value_expr, visitor);
549    })
550}