partiql_ast_macros/
lib.rs

1#![deny(rust_2018_idioms)]
2#![deny(clippy::all)]
3
4use darling::{FromDeriveInput, FromField, FromVariant};
5use inflector::Inflector;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote};
8use syn::{Data, Fields};
9use syn::{DeriveInput, Ident};
10
11#[proc_macro_derive(Visit, attributes(visit))]
12pub fn visit_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
14    let gen = impl_visit(&ast);
15    gen.into()
16}
17
18#[derive(FromDeriveInput)]
19#[darling(attributes(visit))]
20struct VisitItemOptions {
21    #[darling(rename = "skip_recurse")]
22    should_skip_recurse: Option<bool>,
23}
24
25fn should_skip_recurse(input: &syn::DeriveInput) -> bool {
26    VisitItemOptions::from_derive_input(input)
27        .expect("parse meta")
28        .should_skip_recurse
29        .unwrap_or(false)
30}
31
32#[derive(FromField, FromVariant)]
33#[darling(attributes(visit))]
34struct VisitFieldOptions {
35    #[darling(rename = "skip")]
36    should_skip: Option<bool>,
37}
38
39fn should_skip_field(field: &syn::Field) -> bool {
40    VisitFieldOptions::from_field(field)
41        .expect("parse meta")
42        .should_skip
43        .unwrap_or(false)
44}
45
46fn should_skip_variant(variant: &syn::Variant) -> bool {
47    VisitFieldOptions::from_variant(variant)
48        .expect("parse meta")
49        .should_skip
50        .unwrap_or(false)
51}
52
53fn impl_visit(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
54    let visit_fn_name = &ast.ident.to_string().to_snake_case();
55    let enter_fn_name = Ident::new(
56        &format!("enter_{visit_fn_name}"),
57        proc_macro2::Span::call_site(),
58    );
59    let exit_fn_name = Ident::new(
60        &format!("exit_{visit_fn_name}"),
61        proc_macro2::Span::call_site(),
62    );
63
64    let visit_children = (!should_skip_recurse(ast)).then(|| impl_visit_children(&ast));
65
66    let ast_name = &ast.ident;
67    quote! {
68        impl crate::visit::Visit for #ast_name {
69            fn visit<'v, V>(&'v self, v: &mut V) -> crate::visit::Traverse
70            where
71                V: crate::visit::Visitor<'v>,
72            {
73                if v.#enter_fn_name(self) == crate::visit::Traverse::Stop {
74                    return crate::visit::Traverse::Stop
75                }
76                #visit_children
77                v.#exit_fn_name(self)
78            }
79        }
80    }
81}
82
83fn impl_visit_children(ast: &&DeriveInput) -> TokenStream {
84    match &ast.data {
85        Data::Enum(e) => {
86            let enum_name = std::iter::repeat(&ast.ident);
87            let variants = {
88                e.variants
89                    .iter()
90                    .filter_map(|v| (!should_skip_variant(v)).then_some(&v.ident))
91            };
92
93            let variants = variants.collect::<Vec<_>>();
94            let non_exhaustive = variants.len() < e.variants.len();
95            let else_clause = non_exhaustive.then(|| {
96                quote! {
97                    _ => crate::visit::Traverse::Continue
98                }
99            });
100
101            quote! {
102                if match &self {
103                    #(#enum_name::#variants(child) => child.visit(v),)*
104                    #else_clause
105                } == crate::visit::Traverse::Stop {
106                    return crate::visit::Traverse::Stop
107                }
108            }
109        }
110        Data::Struct(s) => {
111            let fields: Vec<_> = match &s.fields {
112                Fields::Named(named) => named
113                    .named
114                    .iter()
115                    .filter_map(|f| {
116                        if should_skip_field(f) {
117                            None
118                        } else {
119                            f.ident.clone()
120                        }
121                    })
122                    .collect(),
123                Fields::Unnamed(unnamed) => unnamed
124                    .unnamed
125                    .iter()
126                    .enumerate()
127                    .filter_map(|(i, f)| {
128                        if should_skip_field(f) {
129                            None
130                        } else {
131                            Some(format_ident!("{}", i))
132                        }
133                    })
134                    .collect(),
135                Fields::Unit => vec![],
136            };
137            quote! {
138                #(if self.#fields.visit(v) == crate::visit::Traverse::Stop {
139                    return crate::visit::Traverse::Stop
140                })*
141            }
142        }
143        Data::Union(_) => panic!("Union not supported"),
144    }
145}