jrsonnet_gcmodule_derive/
lib.rs

1//! Provide `derive(Trace)` support for structures to implement
2//! `gcmodule::Trace` interface.
3//!
4//! # Example
5//!
6//! ```
7//! use jrsonnet_gcmodule_derive::Trace;
8//!
9//! #[derive(Trace)]
10//! struct S<T: jrsonnet_gcmodule::Trace> {
11//!     a: String,
12//!     b: Option<T>,
13//!
14//!     #[trace(skip)] // ignore this field for Trace.
15//!     c: MyType,
16//! }
17//!
18//! struct MyType;
19//! ```
20extern crate proc_macro;
21
22use proc_macro::TokenStream;
23use proc_macro2::TokenStream as TokenStream2;
24use quote::{format_ident, quote};
25use syn::{
26    parenthesized,
27    parse::{Parse, ParseStream},
28    parse_macro_input,
29    spanned::Spanned,
30    Attribute, Data, DeriveInput, Error, Field, Fields, Ident, Path, Result,
31};
32
33mod kw {
34    syn::custom_keyword!(trace);
35    syn::custom_keyword!(skip);
36    syn::custom_keyword!(with);
37    syn::custom_keyword!(tracking);
38    syn::custom_keyword!(ignore);
39    syn::custom_keyword!(force);
40}
41
42enum TraceAttr {
43    Skip,
44    With(Path),
45    TrackingForce(bool),
46}
47impl TraceAttr {
48    fn force_is_type_tracked(&self) -> Option<TokenStream2> {
49        match self {
50            Self::TrackingForce(v) => Some(quote! {#v}),
51            Self::Skip => Some(quote! {false}),
52            Self::With(_) => Some(quote! {true}),
53        }
54    }
55}
56impl Parse for TraceAttr {
57    fn parse(input: ParseStream) -> Result<Self> {
58        let lookahead = input.lookahead1();
59        if lookahead.peek(kw::skip) {
60            input.parse::<kw::skip>()?;
61            Ok(Self::Skip)
62        } else if lookahead.peek(kw::tracking) {
63            input.parse::<kw::tracking>()?;
64            let content;
65            parenthesized!(content in input);
66            let lookahead = content.lookahead1();
67            if lookahead.peek(kw::ignore) {
68                content.parse::<kw::ignore>()?;
69                Ok(Self::TrackingForce(false))
70            } else if lookahead.peek(kw::force) {
71                content.parse::<kw::force>()?;
72                Ok(Self::TrackingForce(true))
73            } else {
74                Err(lookahead.error())
75            }
76        } else if lookahead.peek(kw::with) {
77            input.parse::<kw::with>()?;
78            let content;
79            parenthesized!(content in input);
80            Ok(Self::With(content.parse()?))
81        } else {
82            Err(lookahead.error())
83        }
84    }
85}
86
87fn parse_attr<A: Parse, I>(attrs: &[Attribute], ident: I) -> Result<Option<A>>
88where
89    Ident: PartialEq<I>,
90{
91    let attrs = attrs
92        .iter()
93        .filter(|a| a.path().is_ident(&ident))
94        .collect::<Vec<_>>();
95    if attrs.len() > 1 {
96        return Err(Error::new(
97            attrs[1].span(),
98            "this attribute may be specified only once",
99        ));
100    } else if attrs.is_empty() {
101        return Ok(None);
102    }
103    let attr = attrs[0];
104    let attr = attr.parse_args::<A>()?;
105
106    Ok(Some(attr))
107}
108
109/// Returns impl for (trace, is_type_tracked)
110fn derive_fields(
111    trace_attr: &Option<TraceAttr>,
112    fields: &Fields,
113) -> Result<(TokenStream2, TokenStream2)> {
114    fn inner(names: &[Ident], fields: Vec<&Field>) -> Result<(TokenStream2, TokenStream2)> {
115        let attrs = fields
116            .iter()
117            .map(|f| parse_attr::<TraceAttr, _>(&f.attrs, "trace"))
118            .collect::<Result<Vec<_>>>()?;
119
120        let trace = names.iter().zip(attrs.iter()).filter_map(|(name, attr)| {
121            match attr {
122                Some(TraceAttr::Skip) => return None,
123                Some(TraceAttr::With(w)) => return Some(quote! {#w(#name, tracer)}),
124                _ => {}
125            }
126            Some(quote! {
127                ::jrsonnet_gcmodule::Trace::trace(#name, tracer)
128            })
129        });
130        let is_type_tracked = fields.iter().zip(attrs.iter()).filter_map(|(field, attr)| {
131            match attr {
132                Some(TraceAttr::Skip | TraceAttr::TrackingForce(false)) => return None,
133                Some(TraceAttr::With(_) | TraceAttr::TrackingForce(true)) => {
134                    return Some(quote! {true})
135                }
136                _ => {}
137            }
138            let ty = &field.ty;
139            Some(quote! {
140                <#ty as ::jrsonnet_gcmodule::Trace>::is_type_tracked()
141            })
142        });
143
144        let trace = quote! {
145            #(#trace;)*
146        };
147
148        Ok((
149            trace,
150            quote! {
151                #(if #is_type_tracked {return true;})*
152            },
153        ))
154    }
155    match fields {
156        Fields::Named(named) => {
157            if matches!(trace_attr, Some(TraceAttr::Skip)) {
158                return Ok((
159                    quote! {
160                        {...} => {}
161                    },
162                    quote! {},
163                ));
164            }
165            let force_is_type_tracked = trace_attr.as_ref().and_then(|a| a.force_is_type_tracked());
166
167            let names = named
168                .named
169                .iter()
170                .map(|i| i.ident.clone().unwrap())
171                .collect::<Vec<_>>();
172            let (trace, is_type_tracked) = inner(&names, named.named.iter().collect())?;
173            let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
174
175            Ok((
176                quote! {
177                    {#(#names),*} => {#trace}
178                },
179                is_type_tracked,
180            ))
181        }
182        Fields::Unnamed(unnamed) => {
183            if matches!(trace_attr, Some(TraceAttr::Skip)) {
184                return Ok((quote! {(...) => {}}, quote! {}));
185            }
186            let force_is_type_tracked = trace_attr.as_ref().and_then(|a| a.force_is_type_tracked());
187
188            let names = (0..unnamed.unnamed.len())
189                .map(|i| format_ident!("field_{}", i))
190                .collect::<Vec<_>>();
191            let (trace, is_type_tracked) = inner(&names, unnamed.unnamed.iter().collect())?;
192            let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
193
194            Ok((
195                quote! {
196                    (#(#names,)*) => {#trace}
197                },
198                is_type_tracked,
199            ))
200        }
201        Fields::Unit => Ok((
202            quote! {
203                => {}
204            },
205            quote! {},
206        )),
207    }
208}
209
210fn derive_trace(input: DeriveInput) -> Result<TokenStream2> {
211    let trace_attr = parse_attr::<TraceAttr, _>(&input.attrs, "trace")?;
212    if matches!(trace_attr, Some(TraceAttr::With(_))) {
213        return Err(Error::new(input.span(), "implement Trace instead"));
214    }
215    let ident = &input.ident;
216    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
217    if matches!(trace_attr, Some(TraceAttr::Skip)) {
218        return Ok(quote! {
219            impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
220                fn trace(&self, _tracer: &mut ::jrsonnet_gcmodule::Tracer) {
221                }
222                fn is_type_tracked() -> bool {
223                    false
224                }
225            }
226        });
227    }
228    let force_is_type_tracked = trace_attr.and_then(|a| a.force_is_type_tracked());
229    let (trace, is_type_tracked) = match &input.data {
230        Data::Struct(s) => {
231            let (trace, is_type_tracked) = derive_fields(&None, &s.fields)?;
232
233            (
234                quote! {
235                    Self #trace
236                },
237                quote! {
238                    #is_type_tracked
239                    false
240                },
241            )
242        }
243        Data::Enum(e) if e.variants.is_empty() => (quote! {_=>unreachable!()}, quote! {false}),
244        Data::Enum(e) => {
245            let variants = e
246                .variants
247                .iter()
248                .map(|v| {
249                    let name = &v.ident;
250                    let attr = parse_attr::<TraceAttr, _>(&v.attrs, "trace")?;
251                    let impls = derive_fields(&attr, &v.fields)?;
252                    Ok((name, impls)) as Result<_>
253                })
254                .collect::<Result<Vec<_>>>()?;
255
256            let trace = variants.iter().map(|(name, (trace, _))| {
257                quote! {
258                    Self::#name #trace
259                }
260            });
261            let is_type_tracked = variants.iter().map(|(_, (_, v))| v);
262
263            (
264                quote! {
265                    #(#trace),*
266                },
267                quote! {
268                    #(#is_type_tracked)*
269                    false
270                },
271            )
272        }
273
274        Data::Union(_) => return Err(Error::new(input.span(), "union is not supported")),
275    };
276    let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
277    Ok(quote! {
278        impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
279            #[allow(unused_variables)]
280            fn trace(&self, tracer: &mut ::jrsonnet_gcmodule::Tracer) {
281                match self {
282                    #trace
283                }
284            }
285            fn is_type_tracked() -> bool {
286                #is_type_tracked
287            }
288        }
289    })
290}
291
292#[proc_macro_derive(Trace, attributes(trace))]
293pub fn derive_trace_real(input: TokenStream) -> TokenStream {
294    let input = parse_macro_input!(input as DeriveInput);
295    match derive_trace(input) {
296        Ok(v) => v.into(),
297        Err(e) => e.to_compile_error().into(),
298    }
299}