yachtsql_sqlparser_derive/
lib.rs1use proc_macro2::TokenStream;
19use quote::{format_ident, quote, quote_spanned, ToTokens};
20use syn::spanned::Spanned;
21use syn::{
22 parse::{Parse, ParseStream},
23 parse_macro_input, parse_quote, Attribute, Data, DeriveInput, Fields, GenericParam, Generics,
24 Ident, Index, LitStr, Meta, Token, Type, TypePath,
25};
26use syn::{Path, PathArguments};
27
28#[proc_macro_derive(VisitMut, attributes(visit))]
30pub fn derive_visit_mut(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
31 derive_visit(
32 input,
33 &VisitType {
34 visit_trait: quote!(VisitMut),
35 visitor_trait: quote!(VisitorMut),
36 modifier: Some(quote!(mut)),
37 },
38 )
39}
40
41#[proc_macro_derive(Visit, attributes(visit))]
43pub fn derive_visit_immutable(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
44 derive_visit(
45 input,
46 &VisitType {
47 visit_trait: quote!(Visit),
48 visitor_trait: quote!(Visitor),
49 modifier: None,
50 },
51 )
52}
53
54struct VisitType {
55 visit_trait: TokenStream,
56 visitor_trait: TokenStream,
57 modifier: Option<TokenStream>,
58}
59
60fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_macro::TokenStream {
61 let input = parse_macro_input!(input as DeriveInput);
63 let name = input.ident;
64
65 let VisitType {
66 visit_trait,
67 visitor_trait,
68 modifier,
69 } = visit_type;
70
71 let attributes = Attributes::parse(&input.attrs);
72 let generics = add_trait_bounds(input.generics, visit_type);
74 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
75
76 let (pre_visit, post_visit) = attributes.visit(quote!(self));
77 let children = visit_children(&input.data, visit_type);
78
79 let expanded = quote! {
80 impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
84 #[cfg_attr(feature = "recursive-protection", recursive::recursive)]
85 fn visit<V: sqlparser::ast::#visitor_trait>(
86 &#modifier self,
87 visitor: &mut V
88 ) -> ::std::ops::ControlFlow<V::Break> {
89 #pre_visit
90 #children
91 #post_visit
92 ::std::ops::ControlFlow::Continue(())
93 }
94 }
95 };
96
97 proc_macro::TokenStream::from(expanded)
98}
99
100#[derive(Default)]
104struct Attributes {
105 with: Option<Ident>,
107}
108
109struct WithIdent {
110 with: Option<Ident>,
111}
112impl Parse for WithIdent {
113 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
114 let mut result = WithIdent { with: None };
115 let ident = input.parse::<Ident>()?;
116 if ident != "with" {
117 return Err(syn::Error::new(
118 ident.span(),
119 "Expected identifier to be `with`",
120 ));
121 }
122 input.parse::<Token!(=)>()?;
123 let s = input.parse::<LitStr>()?;
124 result.with = Some(format_ident!("{}", s.value(), span = s.span()));
125 Ok(result)
126 }
127}
128
129impl Attributes {
130 fn parse(attrs: &[Attribute]) -> Self {
131 let mut out = Self::default();
132 for attr in attrs {
133 if let Meta::List(ref metalist) = attr.meta {
134 if metalist.path.is_ident("visit") {
135 match syn::parse2::<WithIdent>(metalist.tokens.clone()) {
136 Ok(with_ident) => {
137 out.with = with_ident.with;
138 }
139 Err(e) => {
140 panic!("{}", e);
141 }
142 }
143 }
144 }
145 }
146 out
147 }
148
149 fn visit(&self, s: TokenStream) -> (Option<TokenStream>, Option<TokenStream>) {
151 let pre_visit = self.with.as_ref().map(|m| {
152 let m = format_ident!("pre_{}", m);
153 quote!(visitor.#m(#s)?;)
154 });
155 let post_visit = self.with.as_ref().map(|m| {
156 let m = format_ident!("post_{}", m);
157 quote!(visitor.#m(#s)?;)
158 });
159 (pre_visit, post_visit)
160 }
161}
162
163fn add_trait_bounds(mut generics: Generics, VisitType { visit_trait, .. }: &VisitType) -> Generics {
165 for param in &mut generics.params {
166 if let GenericParam::Type(ref mut type_param) = *param {
167 type_param
168 .bounds
169 .push(parse_quote!(sqlparser::ast::#visit_trait));
170 }
171 }
172 generics
173}
174
175fn visit_children(
177 data: &Data,
178 VisitType {
179 visit_trait,
180 modifier,
181 ..
182 }: &VisitType,
183) -> TokenStream {
184 match data {
185 Data::Struct(data) => match &data.fields {
186 Fields::Named(fields) => {
187 let recurse = fields.named.iter().map(|f| {
188 let name = &f.ident;
189 let is_option = is_option(&f.ty);
190 let attributes = Attributes::parse(&f.attrs);
191 if is_option && attributes.with.is_some() {
192 let (pre_visit, post_visit) = attributes.visit(quote!(value));
193 quote_spanned!(f.span() =>
194 if let Some(value) = &#modifier self.#name {
195 #pre_visit sqlparser::ast::#visit_trait::visit(value, visitor)?; #post_visit
196 }
197 )
198 } else {
199 let (pre_visit, post_visit) = attributes.visit(quote!(&#modifier self.#name));
200 quote_spanned!(f.span() =>
201 #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#name, visitor)?; #post_visit
202 )
203 }
204 });
205 quote! {
206 #(#recurse)*
207 }
208 }
209 Fields::Unnamed(fields) => {
210 let recurse = fields.unnamed.iter().enumerate().map(|(i, f)| {
211 let index = Index::from(i);
212 let attributes = Attributes::parse(&f.attrs);
213 let (pre_visit, post_visit) = attributes.visit(quote!(&self.#index));
214 quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(&#modifier self.#index, visitor)?; #post_visit)
215 });
216 quote! {
217 #(#recurse)*
218 }
219 }
220 Fields::Unit => {
221 quote!()
222 }
223 },
224 Data::Enum(data) => {
225 let statements = data.variants.iter().map(|v| {
226 let name = &v.ident;
227 match &v.fields {
228 Fields::Named(fields) => {
229 let names = fields.named.iter().map(|f| &f.ident);
230 let visit = fields.named.iter().map(|f| {
231 let name = &f.ident;
232 let attributes = Attributes::parse(&f.attrs);
233 let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
234 quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
235 });
236
237 quote!(
238 Self::#name { #(#names),* } => {
239 #(#visit)*
240 }
241 )
242 }
243 Fields::Unnamed(fields) => {
244 let names = fields.unnamed.iter().enumerate().map(|(i, f)| format_ident!("_{}", i, span = f.span()));
245 let visit = fields.unnamed.iter().enumerate().map(|(i, f)| {
246 let name = format_ident!("_{}", i);
247 let attributes = Attributes::parse(&f.attrs);
248 let (pre_visit, post_visit) = attributes.visit(name.to_token_stream());
249 quote_spanned!(f.span() => #pre_visit sqlparser::ast::#visit_trait::visit(#name, visitor)?; #post_visit)
250 });
251
252 quote! {
253 Self::#name ( #(#names),*) => {
254 #(#visit)*
255 }
256 }
257 }
258 Fields::Unit => {
259 quote! {
260 Self::#name => {}
261 }
262 }
263 }
264 });
265
266 quote! {
267 match self {
268 #(#statements),*
269 }
270 }
271 }
272 Data::Union(_) => unimplemented!(),
273 }
274}
275
276fn is_option(ty: &Type) -> bool {
277 if let Type::Path(TypePath {
278 path: Path { segments, .. },
279 ..
280 }) = ty
281 {
282 if let Some(segment) = segments.last() {
283 if segment.ident == "Option" {
284 if let PathArguments::AngleBracketed(args) = &segment.arguments {
285 return args.args.len() == 1;
286 }
287 }
288 }
289 }
290 false
291}