zc_derive/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use proc_macro2::{Span, TokenStream as TokenStream2};
5use quote::{quote, quote_spanned};
6use syn::spanned::Spanned;
7use syn::{
8    parse_macro_input, Attribute, Data, DeriveInput, Field, GenericParam, Ident, Lifetime,
9    LifetimeDef,
10};
11
12#[proc_macro_derive(Dependant, attributes(zc))]
13pub fn derive_dependant(input: TokenStream) -> TokenStream {
14    let input = parse_macro_input!(input as DeriveInput);
15    let name = &input.ident;
16    let lifetime_count = input.generics.lifetimes().count();
17    let derive_opts = match parse_derive_attrs(&input) {
18        Ok(opts) => opts,
19        Err(err) => return TokenStream::from(err),
20    };
21    let mut static_generics = input.generics.clone();
22    let mut dependant_generics = input.generics.clone();
23    let static_lifetime = Lifetime::new("'static", Span::call_site());
24    let dependant_lifetime = if lifetime_count == 0 {
25        let dependant_lifetime = Lifetime::new("'a", Span::call_site());
26        dependant_generics.params.insert(
27            0,
28            GenericParam::Lifetime(LifetimeDef::new(dependant_lifetime.clone())),
29        );
30        dependant_lifetime
31    } else if lifetime_count == 1 {
32        let first_lifetime_mut = static_generics.lifetimes_mut().next().unwrap();
33        let dependant_lifetime = first_lifetime_mut.lifetime.clone();
34        first_lifetime_mut.lifetime = static_lifetime;
35        dependant_lifetime
36    } else {
37        let message = format!(
38            "{} lifetimes on `{}` when only a single is valid on a `zc::Dependant`",
39            lifetime_count, name
40        );
41        let error = quote_spanned! { input.generics.span() => compile_error!(#message); };
42        return TokenStream::from(error);
43    };
44    let field_checks = impl_field_checks(&input, &derive_opts, &dependant_lifetime);
45    let impl_dependant_generics = dependant_generics.split_for_impl().0;
46    let ty_generic_static = static_generics.split_for_impl().1;
47    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
48    let dependant_impl = quote! {
49        impl #impl_generics #name #ty_generics #where_clause {
50            fn _zc_field_checks() {
51                #field_checks
52            }
53        }
54
55        unsafe impl #impl_dependant_generics ::zc::Dependant<#dependant_lifetime> for #name #ty_generics #where_clause {
56            type Static = #name #ty_generic_static;
57        }
58    };
59    TokenStream::from(dependant_impl)
60}
61
62fn impl_field_checks(input: &DeriveInput, opts: &DeriveOpts, lifetime: &Lifetime) -> TokenStream2 {
63    match &input.data {
64        Data::Struct(v) => field_checks(opts, v.fields.iter(), lifetime),
65        Data::Enum(v) => field_checks(
66            opts,
67            v.variants.iter().flat_map(|v| v.fields.iter()),
68            lifetime,
69        ),
70        Data::Union(_) => {
71            quote_spanned! { input.span() => compile_error!("deriving `zc::Dependant` is not supported for unions"); }
72        }
73    }
74}
75
76fn field_checks<'f>(
77    opts: &DeriveOpts,
78    fields: impl Iterator<Item = &'f Field>,
79    lifetime: &Lifetime,
80) -> TokenStream2 {
81    let mut checks = TokenStream2::new();
82    checks.extend(quote! {
83        pub fn copy_check<'a, T: Copy + 'a>() {};
84        pub fn dependant_check<'a, T: ::zc::Dependant<'a>>() {};
85    });
86    for field in fields {
87        let field_ty = &field.ty;
88        let field_opts = match parse_field_attrs(opts, field) {
89            Ok(opts) => opts,
90            Err(err) => return err,
91        };
92        checks.extend(match field_opts.guard {
93            CheckType::Copy => quote! {
94                copy_check::<#lifetime, #field_ty>();
95            },
96            CheckType::Default => quote! {
97                dependant_check::<#lifetime, #field_ty>();
98            },
99        });
100    }
101    checks
102}
103
104#[derive(Copy, Clone)]
105enum CheckType {
106    Copy,
107    Default,
108}
109
110///////////////////////////////////////////////////////////////////////////////
111// DeriveOpts
112
113struct DeriveOpts {
114    check: CheckType,
115}
116
117fn parse_derive_attrs(input: &DeriveInput) -> Result<DeriveOpts, TokenStream2> {
118    let zc_attr_ident = Ident::new("zc", Span::call_site());
119    let zc_attrs = input
120        .attrs
121        .iter()
122        .filter(|attr| attr.path.get_ident() == Some(&zc_attr_ident));
123
124    let mut attrs = DeriveOpts {
125        check: CheckType::Default,
126    };
127
128    for attr in zc_attrs {
129        let attr_value = attr.tokens.to_string();
130
131        attrs.check = parse_guard_type(&attr, attr_value.as_str())?;
132    }
133
134    Ok(attrs)
135}
136
137///////////////////////////////////////////////////////////////////////////////
138// FieldOpts
139
140struct FieldOpts {
141    guard: CheckType,
142}
143
144fn parse_field_attrs(opts: &DeriveOpts, input: &Field) -> Result<FieldOpts, TokenStream2> {
145    let zc_attr_ident = Ident::new("zc", Span::call_site());
146    let zc_attrs = input
147        .attrs
148        .iter()
149        .filter(|attr| attr.path.get_ident() == Some(&zc_attr_ident));
150
151    let mut attrs = FieldOpts { guard: opts.check };
152
153    for attr in zc_attrs {
154        attrs.guard = parse_guard_type(&attr, attr.tokens.to_string().as_str())?;
155    }
156
157    Ok(attrs)
158}
159
160fn parse_guard_type(attr: &Attribute, attr_value: &str) -> Result<CheckType, TokenStream2> {
161    match attr_value {
162        r#"(check = "Copy")"# => Ok(CheckType::Copy),
163        r#"(guard = "Default")"# => Ok(CheckType::Default),
164        _ => Err(quote_spanned! { attr.span() => compile_error!("Unknown `zc` options"); }),
165    }
166}