Skip to main content

jrsonnet_gcmodule_derive/
lib.rs

1//! Provide `derive(Trace)` support for structures to implement
2//! `gcmodule::Trace` interface.
3//!
4//! # Example
5//!
6//! ```
7//! use jrsonnet_gcmodule_derive::Trace;
8//!
9//! #[derive(Trace)]
10//! struct S<T: jrsonnet_gcmodule::Trace> {
11//!     a: String,
12//!     b: Option<T>,
13//!
14//!     #[trace(skip)] // ignore this field for Trace.
15//!     c: MyType,
16//! }
17//!
18//! struct MyType;
19//! ```
20extern crate proc_macro;
21
22use proc_macro::TokenStream;
23use proc_macro2::TokenStream as TokenStream2;
24use quote::{format_ident, quote};
25use syn::{
26    Attribute, Data, DeriveInput, Error, Field, Fields, Ident, Path, Result, parenthesized,
27    parse::{Parse, ParseStream},
28    parse_macro_input,
29    spanned::Spanned,
30};
31
32mod kw {
33    syn::custom_keyword!(trace);
34    syn::custom_keyword!(skip);
35    syn::custom_keyword!(with);
36    syn::custom_keyword!(tracking);
37    syn::custom_keyword!(ignore);
38    syn::custom_keyword!(force);
39}
40
41enum TraceAttr {
42    Skip,
43    With(Path),
44    TrackingForce(bool),
45}
46impl TraceAttr {
47    fn force_is_type_tracked(&self) -> Option<TokenStream2> {
48        match self {
49            Self::TrackingForce(v) => Some(quote! {#v}),
50            Self::Skip => Some(quote! {false}),
51            Self::With(_) => Some(quote! {true}),
52        }
53    }
54}
55impl Parse for TraceAttr {
56    fn parse(input: ParseStream) -> Result<Self> {
57        let lookahead = input.lookahead1();
58        if lookahead.peek(kw::skip) {
59            input.parse::<kw::skip>()?;
60            Ok(Self::Skip)
61        } else if lookahead.peek(kw::tracking) {
62            input.parse::<kw::tracking>()?;
63            let content;
64            parenthesized!(content in input);
65            let lookahead = content.lookahead1();
66            if lookahead.peek(kw::ignore) {
67                content.parse::<kw::ignore>()?;
68                Ok(Self::TrackingForce(false))
69            } else if lookahead.peek(kw::force) {
70                content.parse::<kw::force>()?;
71                Ok(Self::TrackingForce(true))
72            } else {
73                Err(lookahead.error())
74            }
75        } else if lookahead.peek(kw::with) {
76            input.parse::<kw::with>()?;
77            let content;
78            parenthesized!(content in input);
79            Ok(Self::With(content.parse()?))
80        } else {
81            Err(lookahead.error())
82        }
83    }
84}
85
86fn parse_attr<A: Parse, I>(attrs: &[Attribute], ident: I) -> Result<Option<A>>
87where
88    Ident: PartialEq<I>,
89{
90    let attrs = attrs
91        .iter()
92        .filter(|a| a.path().is_ident(&ident))
93        .collect::<Vec<_>>();
94    if attrs.len() > 1 {
95        return Err(Error::new(
96            attrs[1].span(),
97            "this attribute may be specified only once",
98        ));
99    } else if attrs.is_empty() {
100        return Ok(None);
101    }
102    let attr = attrs[0];
103    let attr = attr.parse_args::<A>()?;
104
105    Ok(Some(attr))
106}
107
108/// Returns impl for (trace, is_type_tracked)
109fn derive_fields(
110    trace_attr: &Option<TraceAttr>,
111    fields: &Fields,
112) -> Result<(TokenStream2, TokenStream2)> {
113    fn inner(names: &[Ident], fields: Vec<&Field>) -> Result<(TokenStream2, TokenStream2)> {
114        let attrs = fields
115            .iter()
116            .map(|f| parse_attr::<TraceAttr, _>(&f.attrs, "trace"))
117            .collect::<Result<Vec<_>>>()?;
118
119        let trace = names.iter().zip(attrs.iter()).filter_map(|(name, attr)| {
120            match attr {
121                Some(TraceAttr::Skip) => return None,
122                Some(TraceAttr::With(w)) => return Some(quote! {#w(#name, tracer)}),
123                _ => {}
124            }
125            Some(quote! {
126                ::jrsonnet_gcmodule::Trace::trace(#name, tracer)
127            })
128        });
129        let is_type_tracked = fields.iter().zip(attrs.iter()).filter_map(|(field, attr)| {
130            match attr {
131                Some(TraceAttr::Skip | TraceAttr::TrackingForce(false)) => return None,
132                Some(TraceAttr::With(_) | TraceAttr::TrackingForce(true)) => {
133                    return Some(quote! {true});
134                }
135                _ => {}
136            }
137            let ty = &field.ty;
138            Some(quote! {
139                <#ty as ::jrsonnet_gcmodule::Trace>::is_type_tracked()
140            })
141        });
142
143        let trace = quote! {
144            #(#trace;)*
145        };
146
147        Ok((
148            trace,
149            quote! {
150                #(if #is_type_tracked {return true;})*
151            },
152        ))
153    }
154    match fields {
155        Fields::Named(named) => {
156            if matches!(trace_attr, Some(TraceAttr::Skip)) {
157                return Ok((
158                    quote! {
159                        {...} => {}
160                    },
161                    quote! {},
162                ));
163            }
164            let force_is_type_tracked = trace_attr.as_ref().and_then(|a| a.force_is_type_tracked());
165
166            let names = named
167                .named
168                .iter()
169                .map(|i| i.ident.clone().unwrap())
170                .collect::<Vec<_>>();
171            let (trace, is_type_tracked) = inner(&names, named.named.iter().collect())?;
172            let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
173
174            Ok((
175                quote! {
176                    {#(#names),*} => {#trace}
177                },
178                is_type_tracked,
179            ))
180        }
181        Fields::Unnamed(unnamed) => {
182            if matches!(trace_attr, Some(TraceAttr::Skip)) {
183                return Ok((quote! {(...) => {}}, quote! {}));
184            }
185            let force_is_type_tracked = trace_attr.as_ref().and_then(|a| a.force_is_type_tracked());
186
187            let names = (0..unnamed.unnamed.len())
188                .map(|i| format_ident!("field_{}", i))
189                .collect::<Vec<_>>();
190            let (trace, is_type_tracked) = inner(&names, unnamed.unnamed.iter().collect())?;
191            let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
192
193            Ok((
194                quote! {
195                    (#(#names,)*) => {#trace}
196                },
197                is_type_tracked,
198            ))
199        }
200        Fields::Unit => Ok((
201            quote! {
202                => {}
203            },
204            quote! {},
205        )),
206    }
207}
208
209fn derive_trace(input: DeriveInput) -> Result<TokenStream2> {
210    let trace_attr = parse_attr::<TraceAttr, _>(&input.attrs, "trace")?;
211    if matches!(trace_attr, Some(TraceAttr::With(_))) {
212        return Err(Error::new(input.span(), "implement Trace instead"));
213    }
214    let ident = &input.ident;
215    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
216    if matches!(trace_attr, Some(TraceAttr::Skip)) {
217        return Ok(quote! {
218            impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
219                fn trace(&self, _tracer: &mut ::jrsonnet_gcmodule::Tracer) {
220                }
221                fn is_type_tracked() -> bool {
222                    false
223                }
224            }
225        });
226    }
227    let force_is_type_tracked = trace_attr.and_then(|a| a.force_is_type_tracked());
228    let (trace, is_type_tracked) = match &input.data {
229        Data::Struct(s) => {
230            let (trace, is_type_tracked) = derive_fields(&None, &s.fields)?;
231
232            (
233                quote! {
234                    Self #trace
235                },
236                quote! {
237                    #is_type_tracked
238                    false
239                },
240            )
241        }
242        Data::Enum(e) if e.variants.is_empty() => (quote! {_=>unreachable!()}, quote! {false}),
243        Data::Enum(e) => {
244            let variants = e
245                .variants
246                .iter()
247                .map(|v| {
248                    let name = &v.ident;
249                    let attr = parse_attr::<TraceAttr, _>(&v.attrs, "trace")?;
250                    let impls = derive_fields(&attr, &v.fields)?;
251                    Ok((name, impls)) as Result<_>
252                })
253                .collect::<Result<Vec<_>>>()?;
254
255            let trace = variants.iter().map(|(name, (trace, _))| {
256                quote! {
257                    Self::#name #trace
258                }
259            });
260            let is_type_tracked = variants.iter().map(|(_, (_, v))| v);
261
262            (
263                quote! {
264                    #(#trace),*
265                },
266                quote! {
267                    #(#is_type_tracked)*
268                    false
269                },
270            )
271        }
272
273        Data::Union(_) => return Err(Error::new(input.span(), "union is not supported")),
274    };
275    let is_type_tracked = force_is_type_tracked.unwrap_or(is_type_tracked);
276    Ok(quote! {
277        impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
278            #[allow(unused_variables, unused_assignments)]
279            fn trace(&self, tracer: &mut ::jrsonnet_gcmodule::Tracer) {
280                match self {
281                    #trace
282                }
283            }
284            fn is_type_tracked() -> bool {
285                #is_type_tracked
286            }
287        }
288    })
289}
290
291#[proc_macro_derive(Trace, attributes(trace))]
292pub fn derive_trace_real(input: TokenStream) -> TokenStream {
293    let input = parse_macro_input!(input as DeriveInput);
294    match derive_trace(input) {
295        Ok(v) => v.into(),
296        Err(e) => e.to_compile_error().into(),
297    }
298}
299fn assert_fields_acyclic(fields: &Fields) -> Result<TokenStream2> {
300    fn inner(fields: Vec<&Field>) -> TokenStream2 {
301        let assert_field_acyclic = fields.iter().map(|field| {
302            let ty = &field.ty;
303            quote! {
304                let _: ::std::rc::Rc<#ty>;
305            }
306        });
307
308        quote! {
309            #(#assert_field_acyclic)*
310        }
311    }
312    match fields {
313        Fields::Named(named) => Ok(inner(named.named.iter().collect())),
314        Fields::Unnamed(unnamed) => Ok(inner(unnamed.unnamed.iter().collect())),
315        Fields::Unit => Ok(quote! {}),
316    }
317}
318
319fn derive_acyclic(input: DeriveInput) -> Result<TokenStream2> {
320    let ident = &input.ident;
321    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
322    let asserts = match &input.data {
323        Data::Struct(s) => assert_fields_acyclic(&s.fields)?,
324        Data::Enum(e) if e.variants.is_empty() => quote! {},
325        Data::Enum(e) => {
326            let variants = e
327                .variants
328                .iter()
329                .map(|v| {
330                    let impls = assert_fields_acyclic(&v.fields)?;
331                    Ok(impls)
332                })
333                .collect::<Result<Vec<_>>>()?;
334
335            quote! {
336                #(#variants)*
337            }
338        }
339
340        Data::Union(_) => return Err(Error::new(input.span(), "union is not supported")),
341    };
342    Ok(quote! {
343        impl #impl_generics ::jrsonnet_gcmodule::Trace for #ident #type_generics #where_clause {
344            #[allow(unused_variables)]
345            fn trace(&self, tracer: &mut ::jrsonnet_gcmodule::Tracer) {}
346            fn is_type_tracked() -> bool {
347                false
348            }
349        }
350        unsafe impl #impl_generics ::jrsonnet_gcmodule::Acyclic for #ident #type_generics #where_clause {
351            fn assert_fields_are_acyclic(&self) {
352                #asserts
353            }
354        }
355    })
356}
357
358#[proc_macro_derive(Acyclic)]
359pub fn derive_acyclic_real(input: TokenStream) -> TokenStream {
360    let input = parse_macro_input!(input as DeriveInput);
361    match derive_acyclic(input) {
362        Ok(v) => v.into(),
363        Err(e) => e.to_compile_error().into(),
364    }
365}