partiql_ast_macros/
lib.rs1#![deny(rust_2018_idioms)]
2#![deny(clippy::all)]
3
4use darling::{FromDeriveInput, FromField, FromVariant};
5use inflector::Inflector;
6use proc_macro2::TokenStream;
7use quote::{format_ident, quote};
8use syn::{Data, Fields};
9use syn::{DeriveInput, Ident};
10
11#[proc_macro_derive(Visit, attributes(visit))]
12pub fn visit_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
13 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
14 let gen = impl_visit(&ast);
15 gen.into()
16}
17
18#[derive(FromDeriveInput)]
19#[darling(attributes(visit))]
20struct VisitItemOptions {
21 #[darling(rename = "skip_recurse")]
22 should_skip_recurse: Option<bool>,
23}
24
25fn should_skip_recurse(input: &syn::DeriveInput) -> bool {
26 VisitItemOptions::from_derive_input(input)
27 .expect("parse meta")
28 .should_skip_recurse
29 .unwrap_or(false)
30}
31
32#[derive(FromField, FromVariant)]
33#[darling(attributes(visit))]
34struct VisitFieldOptions {
35 #[darling(rename = "skip")]
36 should_skip: Option<bool>,
37}
38
39fn should_skip_field(field: &syn::Field) -> bool {
40 VisitFieldOptions::from_field(field)
41 .expect("parse meta")
42 .should_skip
43 .unwrap_or(false)
44}
45
46fn should_skip_variant(variant: &syn::Variant) -> bool {
47 VisitFieldOptions::from_variant(variant)
48 .expect("parse meta")
49 .should_skip
50 .unwrap_or(false)
51}
52
53fn impl_visit(ast: &syn::DeriveInput) -> proc_macro2::TokenStream {
54 let visit_fn_name = &ast.ident.to_string().to_snake_case();
55 let enter_fn_name = Ident::new(
56 &format!("enter_{visit_fn_name}"),
57 proc_macro2::Span::call_site(),
58 );
59 let exit_fn_name = Ident::new(
60 &format!("exit_{visit_fn_name}"),
61 proc_macro2::Span::call_site(),
62 );
63
64 let visit_children = (!should_skip_recurse(ast)).then(|| impl_visit_children(&ast));
65
66 let ast_name = &ast.ident;
67 quote! {
68 impl crate::visit::Visit for #ast_name {
69 fn visit<'v, V>(&'v self, v: &mut V) -> crate::visit::Traverse
70 where
71 V: crate::visit::Visitor<'v>,
72 {
73 if v.#enter_fn_name(self) == crate::visit::Traverse::Stop {
74 return crate::visit::Traverse::Stop
75 }
76 #visit_children
77 v.#exit_fn_name(self)
78 }
79 }
80 }
81}
82
83fn impl_visit_children(ast: &&DeriveInput) -> TokenStream {
84 match &ast.data {
85 Data::Enum(e) => {
86 let enum_name = std::iter::repeat(&ast.ident);
87 let variants = {
88 e.variants
89 .iter()
90 .filter_map(|v| (!should_skip_variant(v)).then_some(&v.ident))
91 };
92
93 let variants = variants.collect::<Vec<_>>();
94 let non_exhaustive = variants.len() < e.variants.len();
95 let else_clause = non_exhaustive.then(|| {
96 quote! {
97 _ => crate::visit::Traverse::Continue
98 }
99 });
100
101 quote! {
102 if match &self {
103 #(#enum_name::#variants(child) => child.visit(v),)*
104 #else_clause
105 } == crate::visit::Traverse::Stop {
106 return crate::visit::Traverse::Stop
107 }
108 }
109 }
110 Data::Struct(s) => {
111 let fields: Vec<_> = match &s.fields {
112 Fields::Named(named) => named
113 .named
114 .iter()
115 .filter_map(|f| {
116 if should_skip_field(f) {
117 None
118 } else {
119 f.ident.clone()
120 }
121 })
122 .collect(),
123 Fields::Unnamed(unnamed) => unnamed
124 .unnamed
125 .iter()
126 .enumerate()
127 .filter_map(|(i, f)| {
128 if should_skip_field(f) {
129 None
130 } else {
131 Some(format_ident!("{}", i))
132 }
133 })
134 .collect(),
135 Fields::Unit => vec![],
136 };
137 quote! {
138 #(if self.#fields.visit(v) == crate::visit::Traverse::Stop {
139 return crate::visit::Traverse::Stop
140 })*
141 }
142 }
143 Data::Union(_) => panic!("Union not supported"),
144 }
145}