Skip to main content

jrsonnet_gcmodule_derive/
lib.rs

1//! Provide `derive(Trace)` support for structures to implement
2//! `gcmodule::Trace` interface.
3//!
4//! # Example
5//!
6//! ```
7//! use jrsonnet_gcmodule_derive::Trace;
8//!
9//! #[derive(Trace)]
10//! struct S<T: jrsonnet_gcmodule::Trace> {
11//!     a: String,
12//!     b: Option<T>,
13//!
14//!     #[trace(skip)] // ignore this field for Trace.
15//!     c: MyType,
16//! }
17//!
18//! struct MyType;
19//! ```
20extern crate proc_macro;
21
22use proc_macro::TokenStream;
23use proc_macro2::TokenStream as TokenStream2;
24use quote::{format_ident, quote};
25use syn::{
26    Attribute, Data, DeriveInput, Error, Field, Fields, Ident, Path, Result, parenthesized,
27    parse::{Parse, ParseStream},
28    parse_macro_input,
29    spanned::Spanned,
30};
31
32mod kw {
33    syn::custom_keyword!(trace);
34    syn::custom_keyword!(skip);
35    syn::custom_keyword!(with);
36    syn::custom_keyword!(tracking);
37    syn::custom_keyword!(ignore);
38    syn::custom_keyword!(force);
39}
40
41enum TraceAttr {
42    Skip,
43    With(Path),
44    TrackingForce(bool),
45}
46impl TraceAttr {
47    fn force_is_type_tracked(&self) -> TokenStream2 {
48        match self {
49            Self::TrackingForce(v) => quote! {#v},
50            Self::Skip => quote! {false},
51            Self::With(_) => quote! {true},
52        }
53    }
54}
55impl Parse for TraceAttr {
56    fn parse(input: ParseStream) -> Result<Self> {
57        let lookahead = input.lookahead1();
58        if lookahead.peek(kw::skip) {
59            input.parse::<kw::skip>()?;
60            Ok(Self::Skip)
61        } else if lookahead.peek(kw::tracking) {
62            input.parse::<kw::tracking>()?;
63            let content;
64            parenthesized!(content in input);
65            let lookahead = content.lookahead1();
66            if lookahead.peek(kw::ignore) {
67                content.parse::<kw::ignore>()?;
68                Ok(Self::TrackingForce(false))
69            } else if lookahead.peek(kw::force) {
70                content.parse::<kw::force>()?;
71                Ok(Self::TrackingForce(true))
72            } else {
73                Err(lookahead.error())
74            }
75        } else if lookahead.peek(kw::with) {
76            input.parse::<kw::with>()?;
77            let content;
78            parenthesized!(content in input);
79            Ok(Self::With(content.parse()?))
80        } else {
81            Err(lookahead.error())
82        }
83    }
84}
85
86fn parse_attr<A: Parse, I: ?Sized>(attrs: &[Attribute], ident: &I) -> Result<Option<A>>
87where
88    Ident: PartialEq<I>,
89{
90    let attrs = attrs
91        .iter()
92        .filter(|a| a.path().is_ident(ident))
93        .collect::<Vec<_>>();
94    if attrs.len() > 1 {
95        return Err(Error::new(
96            attrs[1].span(),
97            "this attribute may be specified only once",
98        ));
99    } else if attrs.is_empty() {
100        return Ok(None);
101    }
102    let attr = attrs[0];
103    let attr = attr.parse_args::<A>()?;
104
105    Ok(Some(attr))
106}
107
108/// Returns impl for (trace, `is_type_tracked`)
109fn derive_fields(
110    trace_attr: Option<&TraceAttr>,
111    fields: &Fields,
112) -> Result<(TokenStream2, TokenStream2)> {
113    fn inner(names: &[Ident], fields: &[&Field]) -> Result<(TokenStream2, TokenStream2)> {
114        let attrs = fields
115            .iter()
116            .map(|f| parse_attr::<TraceAttr, _>(&f.attrs, "trace"))
117            .collect::<Result<Vec<_>>>()?;
118
119        let trace = names.iter().zip(attrs.iter()).filter_map(|(name, attr)| {
120            match attr {
121                Some(TraceAttr::Skip) => return None,
122                Some(TraceAttr::With(w)) => return Some(quote! {#w(#name, tracer)}),
123                _ => {}
124            }
125            Some(quote! {
126                ::jrsonnet_gcmodule::Trace::trace(#name, tracer)
127            })
128        });
129        let is_type_tracked = fields.iter().zip(attrs.iter()).filter_map(|(field, attr)| {
130            match attr {
131                Some(TraceAttr::Skip | TraceAttr::TrackingForce(false)) => return None,
132                Some(TraceAttr::With(_) | TraceAttr::TrackingForce(true)) => {
133                    return Some(quote! {true});
134                }
135                _ => {}
136            }
137            let ty = &field.ty;
138            Some(quote! {
139                <#ty as ::jrsonnet_gcmodule::Trace>::is_type_tracked()
140            })
141        });
142
143        let trace = quote! {
144            #(#trace;)*
145        };
146
147        Ok((
148            trace,
149            quote! {
150                #(if #is_type_tracked {return true;})*
151            },
152        ))
153    }
154    match fields {
155        Fields::Named(named) => {
156            if matches!(trace_attr, Some(TraceAttr::Skip)) {
157                return Ok((
158                    quote! {
159                        {...} => {}
160                    },
161                    quote! {},
162                ));
163            }
164            let force_is_type_tracked = trace_attr.map(TraceAttr::force_is_type_tracked);
165
166            let field_names = named
167                .named
168                .iter()
169                .map(|i| i.ident.clone().unwrap())
170                .collect::<Vec<_>>();
171            let (trace, is_type_tracked) =
172                inner(&field_names, &named.named.iter().collect::<Vec<_>>())?;
173            let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
174
175            Ok((
176                quote! {
177                    {#(#field_names),*} => {#trace}
178                },
179                is_type_tracked,
180            ))
181        }
182        Fields::Unnamed(unnamed) => {
183            if matches!(trace_attr, Some(TraceAttr::Skip)) {
184                return Ok((quote! {(...) => {}}, quote! {}));
185            }
186            let force_is_type_tracked = trace_attr.map(TraceAttr::force_is_type_tracked);
187
188            let names = (0..unnamed.unnamed.len())
189                .map(|i| format_ident!("field_{}", i))
190                .collect::<Vec<_>>();
191            let (trace, is_type_tracked) =
192                inner(&names, &unnamed.unnamed.iter().collect::<Vec<_>>())?;
193            let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
194
195            Ok((
196                quote! {
197                    (#(#names,)*) => {#trace}
198                },
199                is_type_tracked,
200            ))
201        }
202        Fields::Unit => Ok((
203            quote! {
204                => {}
205            },
206            quote! {},
207        )),
208    }
209}
210
211fn derive_trace(input: &DeriveInput) -> Result<TokenStream2> {
212    let trace_attr = parse_attr::<TraceAttr, _>(&input.attrs, "trace")?;
213    if matches!(trace_attr, Some(TraceAttr::With(_))) {
214        return Err(Error::new(input.span(), "implement Trace instead"));
215    }
216    let ident = &input.ident;
217    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
218    if matches!(trace_attr, Some(TraceAttr::Skip)) {
219        return Ok(quote! {
220            impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
221                fn trace(&self, _tracer: &mut ::jrsonnet_gcmodule::Tracer) {
222                }
223                fn is_type_tracked() -> bool {
224                    false
225                }
226            }
227        });
228    }
229    let force_is_type_tracked = trace_attr.map(|a| a.force_is_type_tracked());
230    let (trace, is_type_tracked) = match &input.data {
231        Data::Struct(s) => {
232            let (trace, is_type_tracked) = derive_fields(None, &s.fields)?;
233
234            (
235                quote! {
236                    Self #trace
237                },
238                quote! {
239                    #is_type_tracked
240                    false
241                },
242            )
243        }
244        Data::Enum(e) if e.variants.is_empty() => (quote! {_=>unreachable!()}, quote! {false}),
245        Data::Enum(e) => {
246            let variants = e
247                .variants
248                .iter()
249                .map(|v| {
250                    let name = &v.ident;
251                    let attr = parse_attr::<TraceAttr, _>(&v.attrs, "trace")?;
252                    let impls = derive_fields(attr.as_ref(), &v.fields)?;
253                    Ok((name, impls)) as Result<_>
254                })
255                .collect::<Result<Vec<_>>>()?;
256
257            let trace = variants.iter().map(|(name, (trace, _))| {
258                quote! {
259                    Self::#name #trace
260                }
261            });
262            let is_type_tracked = variants.iter().map(|(_, (_, v))| v);
263
264            (
265                quote! {
266                    #(#trace),*
267                },
268                quote! {
269                    #(#is_type_tracked)*
270                    false
271                },
272            )
273        }
274
275        Data::Union(_) => return Err(Error::new(input.span(), "union is not supported")),
276    };
277    let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
278    Ok(quote! {
279        impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
280            #[allow(unused_variables, unused_assignments, clippy::used_underscore_binding)]
281            fn trace(&self, tracer: &mut ::jrsonnet_gcmodule::Tracer) {
282                match self {
283                    #trace
284                }
285            }
286            fn is_type_tracked() -> bool {
287                #is_type_tracked
288            }
289        }
290    })
291}
292
293#[proc_macro_derive(Trace, attributes(trace))]
294pub fn derive_trace_real(input: TokenStream) -> TokenStream {
295    let input = parse_macro_input!(input as DeriveInput);
296    match derive_trace(&input) {
297        Ok(v) => v.into(),
298        Err(e) => e.to_compile_error().into(),
299    }
300}
301fn assert_fields_acyclic(fields: &Fields) -> TokenStream2 {
302    fn inner(fields: &[&Field]) -> TokenStream2 {
303        let assert_field_acyclic = fields.iter().map(|field| {
304            let ty = &field.ty;
305            quote! {
306                let _: ::std::rc::Rc<#ty>;
307            }
308        });
309
310        quote! {
311            #(#assert_field_acyclic)*
312        }
313    }
314    match fields {
315        Fields::Named(named) => inner(&named.named.iter().collect::<Vec<_>>()),
316        Fields::Unnamed(unnamed) => inner(&unnamed.unnamed.iter().collect::<Vec<_>>()),
317        Fields::Unit => quote! {},
318    }
319}
320
321fn derive_acyclic(input: &DeriveInput) -> Result<TokenStream2> {
322    let ident = &input.ident;
323    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
324    let asserts = match &input.data {
325        Data::Struct(s) => assert_fields_acyclic(&s.fields),
326        Data::Enum(e) if e.variants.is_empty() => quote! {},
327        Data::Enum(e) => {
328            let variants = e
329                .variants
330                .iter()
331                .map(|v| {
332                    let impls = assert_fields_acyclic(&v.fields);
333                    Ok(impls)
334                })
335                .collect::<Result<Vec<_>>>()?;
336
337            quote! {
338                #(#variants)*
339            }
340        }
341
342        Data::Union(_) => return Err(Error::new(input.span(), "union is not supported")),
343    };
344    Ok(quote! {
345        impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
346            #[allow(unused_variables)]
347            fn trace(&self, tracer: &mut ::jrsonnet_gcmodule::Tracer) {}
348            fn is_type_tracked() -> bool {
349                false
350            }
351        }
352        unsafe impl #impl_generics ::jrsonnet_gcmodule::Acyclic for #ident #type_generics #where_clause {
353            fn assert_fields_are_acyclic(&self) {
354                #asserts
355            }
356        }
357    })
358}
359
360#[proc_macro_derive(Acyclic)]
361pub fn derive_acyclic_real(input: TokenStream) -> TokenStream {
362    let input = parse_macro_input!(input as DeriveInput);
363    match derive_acyclic(&input) {
364        Ok(v) => v.into(),
365        Err(e) => e.to_compile_error().into(),
366    }
367}